update chat template in model
Browse files- modeling_internlm2.py +6 -6
modeling_internlm2.py
CHANGED
|
@@ -1138,12 +1138,12 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 1138 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
|
| 1139 |
prompt = ""
|
| 1140 |
if meta_instruction:
|
| 1141 |
-
prompt += f"""<s
|
| 1142 |
else:
|
| 1143 |
prompt += "<s>"
|
| 1144 |
for record in history:
|
| 1145 |
-
prompt += f"""
|
| 1146 |
-
prompt += f"""
|
| 1147 |
return tokenizer([prompt], return_tensors="pt")
|
| 1148 |
|
| 1149 |
@torch.no_grad()
|
|
@@ -1165,7 +1165,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 1165 |
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
| 1166 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
| 1167 |
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
| 1168 |
-
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["
|
| 1169 |
outputs = self.generate(
|
| 1170 |
**inputs,
|
| 1171 |
streamer=streamer,
|
|
@@ -1178,7 +1178,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 1178 |
)
|
| 1179 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
|
| 1180 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
| 1181 |
-
response = response.split("
|
| 1182 |
history = history + [(query, response)]
|
| 1183 |
return response, history
|
| 1184 |
|
|
@@ -1231,7 +1231,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 1231 |
return
|
| 1232 |
|
| 1233 |
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
| 1234 |
-
if token.strip() != "
|
| 1235 |
self.response = self.response + token
|
| 1236 |
history = self.history + [(self.query, self.response)]
|
| 1237 |
self.queue.put((self.response, history))
|
|
|
|
| 1138 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
|
| 1139 |
prompt = ""
|
| 1140 |
if meta_instruction:
|
| 1141 |
+
prompt += f"""<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"""
|
| 1142 |
else:
|
| 1143 |
prompt += "<s>"
|
| 1144 |
for record in history:
|
| 1145 |
+
prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
|
| 1146 |
+
prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
|
| 1147 |
return tokenizer([prompt], return_tensors="pt")
|
| 1148 |
|
| 1149 |
@torch.no_grad()
|
|
|
|
| 1165 |
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
| 1166 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
| 1167 |
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
| 1168 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
|
| 1169 |
outputs = self.generate(
|
| 1170 |
**inputs,
|
| 1171 |
streamer=streamer,
|
|
|
|
| 1178 |
)
|
| 1179 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
|
| 1180 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
| 1181 |
+
response = response.split("<|im_end|>")[0]
|
| 1182 |
history = history + [(query, response)]
|
| 1183 |
return response, history
|
| 1184 |
|
|
|
|
| 1231 |
return
|
| 1232 |
|
| 1233 |
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
| 1234 |
+
if token.strip() != "<|im_end|>":
|
| 1235 |
self.response = self.response + token
|
| 1236 |
history = self.history + [(self.query, self.response)]
|
| 1237 |
self.queue.put((self.response, history))
|