Gamahea commited on
Commit
33cfbff
·
verified ·
1 Parent(s): 3faef55

Fix tensor serialization error - disable gradients for ZeroGPU

Browse files
backend/services/diffrhythm_service.py CHANGED
@@ -125,6 +125,10 @@ class DiffRhythmService:
125
  # Load weights
126
  ckpt = load_file(model_ckpt)
127
  self.model.load_state_dict(ckpt)
 
 
 
 
128
  # Note: Model will be moved to device inside GPU-decorated function
129
 
130
  # Load MuLan for style encoding (keep on CPU initially)
@@ -132,6 +136,10 @@ class DiffRhythmService:
132
  "OpenMuQ/MuQ-MuLan-large",
133
  cache_dir=os.path.join(self.model_path, "mulan")
134
  )
 
 
 
 
135
  # Note: MuLan will be moved to device inside GPU-decorated function
136
 
137
  # Load tokenizer
@@ -176,6 +184,10 @@ class DiffRhythmService:
176
 
177
  # Load decoder (keep on CPU initially)
178
  self.decoder = Generator(decoder_config, decoder_ckpt)
 
 
 
 
179
  # Note: Decoder will be moved to device inside GPU-decorated function
180
 
181
  logger.info("✅ DiffRhythm 2 model loaded successfully")
 
125
  # Load weights
126
  ckpt = load_file(model_ckpt)
127
  self.model.load_state_dict(ckpt)
128
+ self.model.eval() # Set to evaluation mode
129
+ # Disable gradients for all parameters to allow ZeroGPU serialization
130
+ for param in self.model.parameters():
131
+ param.requires_grad = False
132
  # Note: Model will be moved to device inside GPU-decorated function
133
 
134
  # Load MuLan for style encoding (keep on CPU initially)
 
136
  "OpenMuQ/MuQ-MuLan-large",
137
  cache_dir=os.path.join(self.model_path, "mulan")
138
  )
139
+ self.mulan.eval() # Set to evaluation mode
140
+ # Disable gradients
141
+ for param in self.mulan.parameters():
142
+ param.requires_grad = False
143
  # Note: MuLan will be moved to device inside GPU-decorated function
144
 
145
  # Load tokenizer
 
184
 
185
  # Load decoder (keep on CPU initially)
186
  self.decoder = Generator(decoder_config, decoder_ckpt)
187
+ self.decoder.eval() # Set to evaluation mode
188
+ # Disable gradients
189
+ for param in self.decoder.parameters():
190
+ param.requires_grad = False
191
  # Note: Decoder will be moved to device inside GPU-decorated function
192
 
193
  logger.info("✅ DiffRhythm 2 model loaded successfully")