Spaces:
Paused
Paused
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -78,6 +78,17 @@ def load_tokenizer_and_model(base_model, load_8bit=False):
|
|
| 78 |
model.eval()
|
| 79 |
return tokenizer,model, device
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# Greedy Search
|
| 82 |
def greedy_search(input_ids: torch.Tensor,
|
| 83 |
model: torch.nn.Module,
|
|
|
|
| 78 |
model.eval()
|
| 79 |
return tokenizer,model, device
|
| 80 |
|
| 81 |
+
|
| 82 |
+
def load_tokenizer(base_model):
|
| 83 |
+
if torch.cuda.is_available():
|
| 84 |
+
device = "cuda"
|
| 85 |
+
else:
|
| 86 |
+
device = "cpu"
|
| 87 |
+
|
| 88 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast = True)
|
| 89 |
+
return tokenizer
|
| 90 |
+
|
| 91 |
+
|
| 92 |
# Greedy Search
|
| 93 |
def greedy_search(input_ids: torch.Tensor,
|
| 94 |
model: torch.nn.Module,
|