Spaces:
Running
on
Zero
Running
on
Zero
Fix tensor serialization error - disable gradients for ZeroGPU
Browse files
backend/services/diffrhythm_service.py
CHANGED
|
@@ -125,6 +125,10 @@ class DiffRhythmService:
|
|
| 125 |
# Load weights
|
| 126 |
ckpt = load_file(model_ckpt)
|
| 127 |
self.model.load_state_dict(ckpt)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# Note: Model will be moved to device inside GPU-decorated function
|
| 129 |
|
| 130 |
# Load MuLan for style encoding (keep on CPU initially)
|
|
@@ -132,6 +136,10 @@ class DiffRhythmService:
|
|
| 132 |
"OpenMuQ/MuQ-MuLan-large",
|
| 133 |
cache_dir=os.path.join(self.model_path, "mulan")
|
| 134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
# Note: MuLan will be moved to device inside GPU-decorated function
|
| 136 |
|
| 137 |
# Load tokenizer
|
|
@@ -176,6 +184,10 @@ class DiffRhythmService:
|
|
| 176 |
|
| 177 |
# Load decoder (keep on CPU initially)
|
| 178 |
self.decoder = Generator(decoder_config, decoder_ckpt)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# Note: Decoder will be moved to device inside GPU-decorated function
|
| 180 |
|
| 181 |
logger.info("✅ DiffRhythm 2 model loaded successfully")
|
|
|
|
| 125 |
# Load weights
|
| 126 |
ckpt = load_file(model_ckpt)
|
| 127 |
self.model.load_state_dict(ckpt)
|
| 128 |
+
self.model.eval() # Set to evaluation mode
|
| 129 |
+
# Disable gradients for all parameters to allow ZeroGPU serialization
|
| 130 |
+
for param in self.model.parameters():
|
| 131 |
+
param.requires_grad = False
|
| 132 |
# Note: Model will be moved to device inside GPU-decorated function
|
| 133 |
|
| 134 |
# Load MuLan for style encoding (keep on CPU initially)
|
|
|
|
| 136 |
"OpenMuQ/MuQ-MuLan-large",
|
| 137 |
cache_dir=os.path.join(self.model_path, "mulan")
|
| 138 |
)
|
| 139 |
+
self.mulan.eval() # Set to evaluation mode
|
| 140 |
+
# Disable gradients
|
| 141 |
+
for param in self.mulan.parameters():
|
| 142 |
+
param.requires_grad = False
|
| 143 |
# Note: MuLan will be moved to device inside GPU-decorated function
|
| 144 |
|
| 145 |
# Load tokenizer
|
|
|
|
| 184 |
|
| 185 |
# Load decoder (keep on CPU initially)
|
| 186 |
self.decoder = Generator(decoder_config, decoder_ckpt)
|
| 187 |
+
self.decoder.eval() # Set to evaluation mode
|
| 188 |
+
# Disable gradients
|
| 189 |
+
for param in self.decoder.parameters():
|
| 190 |
+
param.requires_grad = False
|
| 191 |
# Note: Decoder will be moved to device inside GPU-decorated function
|
| 192 |
|
| 193 |
logger.info("✅ DiffRhythm 2 model loaded successfully")
|