Update transformers_inference.py
Browse files
transformers_inference.py
CHANGED
|
@@ -10,7 +10,7 @@ tokenizer = LlamaTokenizer.from_pretrained('teknium/OpenHermes-2.5-Mistral-7B',
|
|
| 10 |
model = MistralForCausalLM.from_pretrained(
|
| 11 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
| 12 |
torch_dtype=torch.float16,
|
| 13 |
-
device_map=
|
| 14 |
load_in_8bit=False,
|
| 15 |
load_in_4bit=True,
|
| 16 |
use_flash_attention_2=True
|
|
|
|
| 10 |
model = MistralForCausalLM.from_pretrained(
|
| 11 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
| 12 |
torch_dtype=torch.float16,
|
| 13 |
+
device_map="auto",#{'': 'cuda:0'},
|
| 14 |
load_in_8bit=False,
|
| 15 |
load_in_4bit=True,
|
| 16 |
use_flash_attention_2=True
|