JeffreyXiang commited on
Commit
523e08e
·
1 Parent(s): eff131c
trellis2/models/sparse_structure_flow.py CHANGED
@@ -4,8 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import numpy as np
7
- from ..trainers.utils import str_to_dtype
8
- from ..modules.utils import convert_module_to, manual_cast
9
  from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
10
  from ..modules.attention import RotaryPositionEmbedder
11
 
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import numpy as np
7
+ from ..modules.utils import convert_module_to, manual_cast, str_to_dtype
 
8
  from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
9
  from ..modules.attention import RotaryPositionEmbedder
10
 
trellis2/models/structured_latent_flow.py CHANGED
@@ -4,8 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import numpy as np
7
- from ..trainers.utils import str_to_dtype
8
- from ..modules.utils import convert_module_to, manual_cast
9
  from ..modules.transformer import AbsolutePositionEmbedder
10
  from ..modules import sparse as sp
11
  from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import numpy as np
7
+ from ..modules.utils import convert_module_to, manual_cast, str_to_dtype
 
8
  from ..modules.transformer import AbsolutePositionEmbedder
9
  from ..modules import sparse as sp
10
  from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
trellis2/modules/utils.py CHANGED
@@ -72,3 +72,16 @@ def manual_cast(tensor, dtype):
72
  if not torch.is_autocast_enabled():
73
  return tensor.type(dtype)
74
  return tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if not torch.is_autocast_enabled():
73
  return tensor.type(dtype)
74
  return tensor
75
+
76
+
77
+ def str_to_dtype(dtype_str: str):
78
+ return {
79
+ 'f16': torch.float16,
80
+ 'fp16': torch.float16,
81
+ 'float16': torch.float16,
82
+ 'bf16': torch.bfloat16,
83
+ 'bfloat16': torch.bfloat16,
84
+ 'f32': torch.float32,
85
+ 'fp32': torch.float32,
86
+ 'float32': torch.float32,
87
+ }[dtype_str]