vaibhavpandeyvpz commited on
Commit
6d31310
·
1 Parent(s): b70bd0a

Fix checking cuda at import time

Browse files
Files changed (1) hide show
  1. wan/modules/t5.py +7 -1
wan/modules/t5.py CHANGED
@@ -535,13 +535,19 @@ class T5EncoderModel:
535
  self,
536
  text_len,
537
  dtype=torch.bfloat16,
538
- device=torch.cuda.current_device(),
539
  checkpoint_path=None,
540
  tokenizer_path=None,
541
  shard_fn=None,
542
  ):
543
  self.text_len = text_len
544
  self.dtype = dtype
 
 
 
 
 
 
545
  self.device = device
546
  self.checkpoint_path = checkpoint_path
547
  self.tokenizer_path = tokenizer_path
 
535
  self,
536
  text_len,
537
  dtype=torch.bfloat16,
538
+ device=None,
539
  checkpoint_path=None,
540
  tokenizer_path=None,
541
  shard_fn=None,
542
  ):
543
  self.text_len = text_len
544
  self.dtype = dtype
545
+ # Defer CUDA device access until GPU is available (for ZeroGPU compatibility)
546
+ if device is None:
547
+ if torch.cuda.is_available():
548
+ device = torch.cuda.current_device()
549
+ else:
550
+ device = torch.device("cpu")
551
  self.device = device
552
  self.checkpoint_path = checkpoint_path
553
  self.tokenizer_path = tokenizer_path