Gamahea commited on
Commit
0d353d2
·
verified ·
1 Parent(s): 33cfbff

Support T4 GPU and CPU - make ZeroGPU decorator conditional

Browse files
Files changed (1) hide show
  1. app.py +24 -8
app.py CHANGED
@@ -16,14 +16,24 @@ import time
16
  # Import spaces for ZeroGPU support
17
  try:
18
  import spaces
19
- HAS_SPACES = True
 
 
 
 
 
 
20
  except ImportError:
21
- HAS_SPACES = False
22
- # Create a dummy decorator for local development
23
- class spaces:
24
- @staticmethod
25
- def GPU(func):
26
- return func
 
 
 
 
27
 
28
  # Run DiffRhythm2 source setup if needed
29
  setup_script = Path(__file__).parent / "setup_diffrhythm2_src.sh"
@@ -47,6 +57,12 @@ logging.basicConfig(
47
  )
48
  logger = logging.getLogger(__name__)
49
 
 
 
 
 
 
 
50
  # Import services
51
  try:
52
  from services.diffrhythm_service import DiffRhythmService
@@ -105,7 +121,7 @@ def get_lyricmind_service():
105
  logger.info("LyricMind model loaded")
106
  return lyricmind_service
107
 
108
- @spaces.GPU
109
  def generate_lyrics(prompt: str, progress=gr.Progress()):
110
  """Generate lyrics from prompt using analysis"""
111
  try:
 
16
  # Import spaces for ZeroGPU support
17
  try:
18
  import spaces
19
+ # Check if we're actually on ZeroGPU (has device-api)
20
+ import requests
21
+ try:
22
+ requests.head("http://device-api.zero/", timeout=0.5)
23
+ HAS_ZEROGPU = True
24
+ except:
25
+ HAS_ZEROGPU = False
26
  except ImportError:
27
+ HAS_ZEROGPU = False
28
+
29
+ # Create appropriate decorator
30
+ if HAS_ZEROGPU:
31
+ # Use ZeroGPU decorator
32
+ GPU_DECORATOR = spaces.GPU
33
+ else:
34
+ # No-op decorator for regular GPU/CPU
35
+ def GPU_DECORATOR(func):
36
+ return func
37
 
38
  # Run DiffRhythm2 source setup if needed
39
  setup_script = Path(__file__).parent / "setup_diffrhythm2_src.sh"
 
57
  )
58
  logger = logging.getLogger(__name__)
59
 
60
+ # Log GPU mode
61
+ if HAS_ZEROGPU:
62
+ logger.info("🚀 ZeroGPU detected - using dynamic GPU allocation")
63
+ else:
64
+ logger.info("💻 Running on regular GPU/CPU - using static device allocation")
65
+
66
  # Import services
67
  try:
68
  from services.diffrhythm_service import DiffRhythmService
 
121
  logger.info("LyricMind model loaded")
122
  return lyricmind_service
123
 
124
+ @GPU_DECORATOR
125
  def generate_lyrics(prompt: str, progress=gr.Progress()):
126
  """Generate lyrics from prompt using analysis"""
127
  try: