Gamahea commited on
Commit
12fedcc
·
1 Parent(s): 92cdb9d

Add ZeroGPU decorator to DiffRhythm2 generation for HF Spaces compatibility

Browse files
backend/services/diffrhythm_service.py CHANGED
@@ -14,6 +14,18 @@ import torch
14
  import torchaudio
15
  import json
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Configure espeak-ng path for phonemizer (required by g2p module)
18
  # Note: Environment configuration handled by hf_config.py for HuggingFace Spaces
19
  # or by launch scripts for local development
@@ -103,7 +115,7 @@ class DiffRhythmService:
103
 
104
  model_config['use_flex_attn'] = False
105
 
106
- # Create model
107
  self.model = CFM(
108
  transformer=DiT(**model_config),
109
  num_channels=model_config['mel_dim'],
@@ -113,13 +125,14 @@ class DiffRhythmService:
113
  # Load weights
114
  ckpt = load_file(model_ckpt)
115
  self.model.load_state_dict(ckpt)
116
- self.model = self.model.to(self.device)
117
 
118
- # Load MuLan for style encoding
119
  self.mulan = MuQMuLan.from_pretrained(
120
  "OpenMuQ/MuQ-MuLan-large",
121
  cache_dir=os.path.join(self.model_path, "mulan")
122
- ).to(self.device)
 
123
 
124
  # Load tokenizer
125
  from g2p.g2p_generation import chn_eng_g2p
@@ -147,7 +160,7 @@ class DiffRhythmService:
147
  'g2p': chn_eng_g2p
148
  }
149
 
150
- # Load decoder (BigVGAN vocoder)
151
  decoder_ckpt = hf_hub_download(
152
  repo_id=repo_id,
153
  filename="decoder.bin",
@@ -161,8 +174,9 @@ class DiffRhythmService:
161
  local_files_only=False,
162
  )
163
 
 
164
  self.decoder = Generator(decoder_config, decoder_ckpt)
165
- self.decoder = self.decoder.to(self.device)
166
 
167
  logger.info("✅ DiffRhythm 2 model loaded successfully")
168
 
@@ -239,6 +253,7 @@ class DiffRhythmService:
239
  logger.error(f"Music generation failed: {str(e)}", exc_info=True)
240
  raise RuntimeError(f"Failed to generate music: {str(e)}")
241
 
 
242
  def _generate_with_diffrhythm2(
243
  self,
244
  prompt: str,
@@ -263,6 +278,13 @@ class DiffRhythmService:
263
  try:
264
  logger.info("Generating with DiffRhythm 2 model...")
265
 
 
 
 
 
 
 
 
266
  # Prepare lyrics tokens
267
  if lyrics:
268
  lyrics_token = self._tokenize_lyrics(lyrics)
 
14
  import torchaudio
15
  import json
16
 
17
+ # Import spaces for ZeroGPU support
18
+ try:
19
+ import spaces
20
+ HAS_SPACES = True
21
+ except ImportError:
22
+ HAS_SPACES = False
23
+ # Create a dummy decorator for local development
24
+ class spaces:
25
+ @staticmethod
26
+ def GPU(func):
27
+ return func
28
+
29
  # Configure espeak-ng path for phonemizer (required by g2p module)
30
  # Note: Environment configuration handled by hf_config.py for HuggingFace Spaces
31
  # or by launch scripts for local development
 
115
 
116
  model_config['use_flex_attn'] = False
117
 
118
+ # Create model (keep on CPU initially for ZeroGPU compatibility)
119
  self.model = CFM(
120
  transformer=DiT(**model_config),
121
  num_channels=model_config['mel_dim'],
 
125
  # Load weights
126
  ckpt = load_file(model_ckpt)
127
  self.model.load_state_dict(ckpt)
128
+ # Note: Model will be moved to device inside GPU-decorated function
129
 
130
+ # Load MuLan for style encoding (keep on CPU initially)
131
  self.mulan = MuQMuLan.from_pretrained(
132
  "OpenMuQ/MuQ-MuLan-large",
133
  cache_dir=os.path.join(self.model_path, "mulan")
134
+ )
135
+ # Note: MuLan will be moved to device inside GPU-decorated function
136
 
137
  # Load tokenizer
138
  from g2p.g2p_generation import chn_eng_g2p
 
160
  'g2p': chn_eng_g2p
161
  }
162
 
163
+ # Load decoder (BigVGAN vocoder) - keep on CPU initially
164
  decoder_ckpt = hf_hub_download(
165
  repo_id=repo_id,
166
  filename="decoder.bin",
 
174
  local_files_only=False,
175
  )
176
 
177
+ # Load decoder (keep on CPU initially)
178
  self.decoder = Generator(decoder_config, decoder_ckpt)
179
+ # Note: Decoder will be moved to device inside GPU-decorated function
180
 
181
  logger.info("✅ DiffRhythm 2 model loaded successfully")
182
 
 
253
  logger.error(f"Music generation failed: {str(e)}", exc_info=True)
254
  raise RuntimeError(f"Failed to generate music: {str(e)}")
255
 
256
+ @spaces.GPU(duration=60)
257
  def _generate_with_diffrhythm2(
258
  self,
259
  prompt: str,
 
278
  try:
279
  logger.info("Generating with DiffRhythm 2 model...")
280
 
281
+ # Move models to GPU (for ZeroGPU compatibility)
282
+ # This ensures models are on GPU only within the decorated function
283
+ if self.device.type != 'cpu':
284
+ self.model = self.model.to(self.device)
285
+ self.mulan = self.mulan.to(self.device)
286
+ self.decoder = self.decoder.to(self.device)
287
+
288
  # Prepare lyrics tokens
289
  if lyrics:
290
  lyrics_token = self._tokenize_lyrics(lyrics)