Gamahea commited on
Commit
d5ccfff
·
1 Parent(s): b166366

Fix ZeroGPU compatibility - Dynamic device allocation

Browse files

- Changed device initialization to always use CPU initially
- Device detection now happens inside @spaces.GPU decorated functions
- Models moved to GPU dynamically when ZeroGPU allocates resources
- Fixes 'CUDA driver initialization failed' error

Changes:
- DiffRhythmService: Dynamic device detection in _generate_with_diffrhythm2()
- LyricMindService: Dynamic device detection in _generate_with_model()
- Updated _tokenize_lyrics() to accept device parameter
- Added hf_oauth: true to README for HF authentication

README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: "4.44.1"
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  # LEMM - Let Everyone Make Music
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ hf_oauth: true
12
  ---
13
 
14
  # LEMM - Let Everyone Make Music
backend/services/diffrhythm_service.py CHANGED
@@ -63,18 +63,10 @@ class DiffRhythmService:
63
  logger.info(f"Using device: {self.device}")
64
 
65
  def _get_device(self):
66
- """Get compute device (CUDA or CPU)"""
67
- # Try CUDA first (NVIDIA)
68
- if torch.cuda.is_available():
69
- logger.info("Using CUDA (NVIDIA GPU)")
70
- return torch.device("cuda")
71
-
72
- # Note: DirectML support disabled due to version conflicts with DiffRhythm2
73
- # DiffRhythm2 requires torch>=2.4, but torch-directml requires torch==2.4.1
74
- # For AMD GPU acceleration, consider using ROCm with compatible PyTorch build
75
-
76
- # Fallback to CPU
77
- logger.info("Using CPU (no GPU acceleration)")
78
  return torch.device("cpu")
79
 
80
  def _initialize_model(self):
@@ -278,19 +270,21 @@ class DiffRhythmService:
278
  try:
279
  logger.info("Generating with DiffRhythm 2 model...")
280
 
281
- # Move models to GPU (for ZeroGPU compatibility)
282
- # This ensures models are on GPU only within the decorated function
283
- if self.device.type != 'cpu':
284
- self.model = self.model.to(self.device)
285
- self.mulan = self.mulan.to(self.device)
286
- self.decoder = self.decoder.to(self.device)
 
 
287
 
288
  # Prepare lyrics tokens
289
  if lyrics:
290
- lyrics_token = self._tokenize_lyrics(lyrics)
291
  else:
292
  # For instrumental, use empty structure
293
- lyrics_token = torch.tensor([500, 511], dtype=torch.long, device=self.device) # [start][stop]
294
 
295
  # Encode style prompt with optional reference audio blending
296
  with torch.no_grad():
@@ -303,7 +297,7 @@ class DiffRhythmService:
303
  ref_waveform = torchaudio.functional.resample(ref_waveform, ref_sr, 24000)
304
 
305
  # Encode reference audio with MuLan
306
- ref_waveform = ref_waveform.to(self.device)
307
  audio_style_embed = self.mulan(audios=ref_waveform.unsqueeze(0))
308
  text_style_embed = self.mulan(texts=[prompt])
309
 
@@ -316,10 +310,10 @@ class DiffRhythmService:
316
  else:
317
  style_prompt_embed = self.mulan(texts=[prompt])
318
 
319
- style_prompt_embed = style_prompt_embed.to(self.device).squeeze(0)
320
 
321
  # Use FP16 if on GPU
322
- if self.device.type != 'cpu':
323
  self.model = self.model.half()
324
  self.decoder = self.decoder.half()
325
  style_prompt_embed = style_prompt_embed.half()
@@ -361,16 +355,20 @@ class DiffRhythmService:
361
  logger.error(f"DiffRhythm 2 generation failed: {str(e)}")
362
  return self._generate_placeholder(duration, sample_rate)
363
 
364
- def _tokenize_lyrics(self, lyrics: str) -> torch.Tensor:
365
  """
366
  Tokenize lyrics for DiffRhythm 2
367
 
368
  Args:
369
  lyrics: Lyrics text
 
370
 
371
  Returns:
372
  Tokenized lyrics tensor
373
  """
 
 
 
374
  try:
375
  # Structure tags
376
  STRUCT_INFO = {
@@ -396,12 +394,12 @@ class DiffRhythmService:
396
  # Add structure: [start] + lyrics + [stop]
397
  lyrics_tokens = [STRUCT_INFO['[start]']] + tokens + [STRUCT_INFO['[stop]']]
398
 
399
- return torch.tensor(lyrics_tokens, dtype=torch.long, device=self.device)
400
 
401
  except Exception as e:
402
  logger.error(f"Lyrics tokenization failed: {str(e)}")
403
  # Return minimal structure
404
- return torch.tensor([500, 511], dtype=torch.long, device=self.device)
405
 
406
  def _generate_placeholder(self, duration: int, sample_rate: int) -> np.ndarray:
407
  """
 
63
  logger.info(f"Using device: {self.device}")
64
 
65
  def _get_device(self):
66
+ """Get compute device - for ZeroGPU, always start with CPU"""
67
+ # For ZeroGPU Spaces, device allocation happens dynamically inside @spaces.GPU functions
68
+ # Always return CPU here - GPU allocation is handled by the decorator
69
+ logger.info("Using CPU for initialization (GPU allocated by @spaces.GPU decorator)")
 
 
 
 
 
 
 
 
70
  return torch.device("cpu")
71
 
72
  def _initialize_model(self):
 
270
  try:
271
  logger.info("Generating with DiffRhythm 2 model...")
272
 
273
+ # For ZeroGPU, dynamically detect device inside GPU-decorated function
274
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275
+ logger.info(f"GPU-decorated function using device: {device}")
276
+
277
+ # Move models to detected device (GPU if available via ZeroGPU)
278
+ self.model = self.model.to(device)
279
+ self.mulan = self.mulan.to(device)
280
+ self.decoder = self.decoder.to(device)
281
 
282
  # Prepare lyrics tokens
283
  if lyrics:
284
+ lyrics_token = self._tokenize_lyrics(lyrics, device)
285
  else:
286
  # For instrumental, use empty structure
287
+ lyrics_token = torch.tensor([500, 511], dtype=torch.long, device=device) # [start][stop]
288
 
289
  # Encode style prompt with optional reference audio blending
290
  with torch.no_grad():
 
297
  ref_waveform = torchaudio.functional.resample(ref_waveform, ref_sr, 24000)
298
 
299
  # Encode reference audio with MuLan
300
+ ref_waveform = ref_waveform.to(device)
301
  audio_style_embed = self.mulan(audios=ref_waveform.unsqueeze(0))
302
  text_style_embed = self.mulan(texts=[prompt])
303
 
 
310
  else:
311
  style_prompt_embed = self.mulan(texts=[prompt])
312
 
313
+ style_prompt_embed = style_prompt_embed.to(device).squeeze(0)
314
 
315
  # Use FP16 if on GPU
316
+ if device.type != 'cpu':
317
  self.model = self.model.half()
318
  self.decoder = self.decoder.half()
319
  style_prompt_embed = style_prompt_embed.half()
 
355
  logger.error(f"DiffRhythm 2 generation failed: {str(e)}")
356
  return self._generate_placeholder(duration, sample_rate)
357
 
358
+ def _tokenize_lyrics(self, lyrics: str, device: torch.device = None) -> torch.Tensor:
359
  """
360
  Tokenize lyrics for DiffRhythm 2
361
 
362
  Args:
363
  lyrics: Lyrics text
364
+ device: Target device for tensor
365
 
366
  Returns:
367
  Tokenized lyrics tensor
368
  """
369
+ if device is None:
370
+ device = torch.device("cpu")
371
+
372
  try:
373
  # Structure tags
374
  STRUCT_INFO = {
 
394
  # Add structure: [start] + lyrics + [stop]
395
  lyrics_tokens = [STRUCT_INFO['[start]']] + tokens + [STRUCT_INFO['[stop]']]
396
 
397
+ return torch.tensor(lyrics_tokens, dtype=torch.long, device=device)
398
 
399
  except Exception as e:
400
  logger.error(f"Lyrics tokenization failed: {str(e)}")
401
  # Return minimal structure
402
+ return torch.tensor([500, 511], dtype=torch.long, device=device)
403
 
404
  def _generate_placeholder(self, duration: int, sample_rate: int) -> np.ndarray:
405
  """
backend/services/lyricmind_service.py CHANGED
@@ -27,12 +27,11 @@ class LyricMindService:
27
  logger.info(f"Using device: {self.device}")
28
 
29
  def _get_device(self):
30
- """Get compute device (AMD GPU via DirectML or CPU)"""
31
- try:
32
- from utils.amd_gpu import DEFAULT_DEVICE
33
- return DEFAULT_DEVICE
34
- except:
35
- return torch.device("cpu")
36
 
37
  def _initialize_model(self):
38
  """Lazy load the model when first needed"""
@@ -54,10 +53,10 @@ class LyricMindService:
54
  self.model = AutoModelForCausalLM.from_pretrained(
55
  fallback_path,
56
  trust_remote_code=True,
57
- torch_dtype=torch.float32 # Use FP32 for AMD GPU compatibility
58
  )
59
- self.model.to(self.device)
60
- logger.info("✅ Text generation model loaded successfully")
61
  else:
62
  logger.warning("Text generation model not found, using placeholder")
63
 
@@ -148,6 +147,14 @@ class LyricMindService:
148
  try:
149
  logger.info("Generating lyrics with AI model...")
150
 
 
 
 
 
 
 
 
 
151
  # Create structured prompt with analysis context
152
  mood = analysis.get('mood', 'neutral')
153
  bpm = analysis.get('bpm', 120)
@@ -157,7 +164,7 @@ class LyricMindService:
157
 
158
  # Tokenize
159
  inputs = self.tokenizer(full_prompt, return_tensors="pt")
160
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
161
 
162
  # Calculate max length based on duration
163
  max_length = min(200 + inputs["input_ids"].shape[1], 512)
 
27
  logger.info(f"Using device: {self.device}")
28
 
29
  def _get_device(self):
30
+ """Get compute device - for ZeroGPU, always start with CPU"""
31
+ # For ZeroGPU Spaces, device allocation happens dynamically inside @spaces.GPU functions
32
+ # Always return CPU here - GPU allocation is handled by the decorator
33
+ logger.info("Using CPU for initialization (GPU allocated by @spaces.GPU decorator)")
34
+ return torch.device("cpu")
 
35
 
36
  def _initialize_model(self):
37
  """Lazy load the model when first needed"""
 
53
  self.model = AutoModelForCausalLM.from_pretrained(
54
  fallback_path,
55
  trust_remote_code=True,
56
+ torch_dtype=torch.float32 # Use FP32 for compatibility
57
  )
58
+ # Model stays on CPU initially - moved to GPU inside @spaces.GPU function
59
+ logger.info("✅ Text generation model loaded successfully (on CPU)")
60
  else:
61
  logger.warning("Text generation model not found, using placeholder")
62
 
 
147
  try:
148
  logger.info("Generating lyrics with AI model...")
149
 
150
+ # Dynamically detect device (for ZeroGPU compatibility)
151
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
+ logger.info(f"Using device for lyrics generation: {device}")
153
+
154
+ # Move model to device if not already there
155
+ if self.model.device != device:
156
+ self.model = self.model.to(device)
157
+
158
  # Create structured prompt with analysis context
159
  mood = analysis.get('mood', 'neutral')
160
  bpm = analysis.get('bpm', 120)
 
164
 
165
  # Tokenize
166
  inputs = self.tokenizer(full_prompt, return_tensors="pt")
167
+ inputs = {k: v.to(device) for k, v in inputs.items()}
168
 
169
  # Calculate max length based on duration
170
  max_length = min(200 + inputs["input_ids"].shape[1], 512)