Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
6d31310
1
Parent(s):
b70bd0a
Fix checking cuda at import time
Browse files- wan/modules/t5.py +7 -1
wan/modules/t5.py
CHANGED
|
@@ -535,13 +535,19 @@ class T5EncoderModel:
|
|
| 535 |
self,
|
| 536 |
text_len,
|
| 537 |
dtype=torch.bfloat16,
|
| 538 |
-
device=
|
| 539 |
checkpoint_path=None,
|
| 540 |
tokenizer_path=None,
|
| 541 |
shard_fn=None,
|
| 542 |
):
|
| 543 |
self.text_len = text_len
|
| 544 |
self.dtype = dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
self.device = device
|
| 546 |
self.checkpoint_path = checkpoint_path
|
| 547 |
self.tokenizer_path = tokenizer_path
|
|
|
|
| 535 |
self,
|
| 536 |
text_len,
|
| 537 |
dtype=torch.bfloat16,
|
| 538 |
+
device=None,
|
| 539 |
checkpoint_path=None,
|
| 540 |
tokenizer_path=None,
|
| 541 |
shard_fn=None,
|
| 542 |
):
|
| 543 |
self.text_len = text_len
|
| 544 |
self.dtype = dtype
|
| 545 |
+
# Defer CUDA device access until GPU is available (for ZeroGPU compatibility)
|
| 546 |
+
if device is None:
|
| 547 |
+
if torch.cuda.is_available():
|
| 548 |
+
device = torch.cuda.current_device()
|
| 549 |
+
else:
|
| 550 |
+
device = torch.device("cpu")
|
| 551 |
self.device = device
|
| 552 |
self.checkpoint_path = checkpoint_path
|
| 553 |
self.tokenizer_path = tokenizer_path
|