Gamahea commited on
Commit
2984541
·
verified ·
1 Parent(s): 0d353d2

Support T4 GPU and CPU - make ZeroGPU decorator conditional

Browse files
backend/services/diffrhythm_service.py CHANGED
@@ -17,14 +17,28 @@ import json
17
  # Import spaces for ZeroGPU support
18
  try:
19
  import spaces
20
- HAS_SPACES = True
 
 
 
 
 
 
21
  except ImportError:
22
- HAS_SPACES = False
23
- # Create a dummy decorator for local development
24
- class spaces:
25
- @staticmethod
26
- def GPU(func):
 
 
 
 
 
27
  return func
 
 
 
28
 
29
  # Configure espeak-ng path for phonemizer (required by g2p module)
30
  # Note: Environment configuration handled by hf_config.py for HuggingFace Spaces
@@ -276,7 +290,7 @@ class DiffRhythmService:
276
  logger.error(f"Music generation failed: {str(e)}", exc_info=True)
277
  raise RuntimeError(f"Failed to generate music: {str(e)}")
278
 
279
- @spaces.GPU(duration=120)
280
  def _generate_with_diffrhythm2(
281
  self,
282
  prompt: str,
 
17
  # Import spaces for ZeroGPU support
18
  try:
19
  import spaces
20
+ # Check if we're actually on ZeroGPU (has device-api)
21
+ import requests
22
+ try:
23
+ requests.head("http://device-api.zero/", timeout=0.5)
24
+ HAS_ZEROGPU = True
25
+ except:
26
+ HAS_ZEROGPU = False
27
  except ImportError:
28
+ HAS_ZEROGPU = False
29
+
30
+ # Create appropriate decorator
31
+ if HAS_ZEROGPU:
32
+ # Use ZeroGPU decorator
33
+ GPU_DECORATOR = spaces.GPU
34
+ else:
35
+ # No-op decorator for regular GPU/CPU
36
+ def GPU_DECORATOR(duration=None):
37
+ def decorator(func):
38
  return func
39
+ if callable(duration): # Called as @GPU_DECORATOR without parentheses
40
+ return duration
41
+ return decorator
42
 
43
  # Configure espeak-ng path for phonemizer (required by g2p module)
44
  # Note: Environment configuration handled by hf_config.py for HuggingFace Spaces
 
290
  logger.error(f"Music generation failed: {str(e)}", exc_info=True)
291
  raise RuntimeError(f"Failed to generate music: {str(e)}")
292
 
293
+ @GPU_DECORATOR(duration=120)
294
  def _generate_with_diffrhythm2(
295
  self,
296
  prompt: str,