nicholasKluge commited on
Commit
ac01c8d
·
1 Parent(s): 5716a02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
5
 
6
- model_id = "nicholasKluge/Aira-Instruct-124M"
7
  rewardmodel_id = "nicholasKluge/RewardModel"
8
  toxicitymodel_id = "nicholasKluge/ToxicityModel"
9
 
@@ -11,7 +11,9 @@ token = "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"
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
 
 
15
  rewardModel = AutoModelForSequenceClassification.from_pretrained(rewardmodel_id, use_auth_token=token)
16
  toxicityModel = AutoModelForSequenceClassification.from_pretrained(toxicitymodel_id, use_auth_token=token)
17
 
@@ -19,7 +21,7 @@ model.eval()
19
  rewardModel.eval()
20
  toxicityModel.eval()
21
 
22
- model.to(device)
23
  rewardModel.to(device)
24
  toxicityModel.to(device)
25
 
 
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
5
 
6
+ model_id = "nicholasKluge/Aira-Instruct-1.5B" # "nicholasKluge/Aira-Instruct-124M"
7
  rewardmodel_id = "nicholasKluge/RewardModel"
8
  toxicitymodel_id = "nicholasKluge/ToxicityModel"
9
 
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, device_map="auto", load_in_8bit=True)
15
+
16
+ #model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
17
  rewardModel = AutoModelForSequenceClassification.from_pretrained(rewardmodel_id, use_auth_token=token)
18
  toxicityModel = AutoModelForSequenceClassification.from_pretrained(toxicitymodel_id, use_auth_token=token)
19
 
 
21
  rewardModel.eval()
22
  toxicityModel.eval()
23
 
24
+ #model.to(device)
25
  rewardModel.to(device)
26
  toxicityModel.to(device)
27