Upload folder using huggingface_hub
Browse files- modeling_internlm_xcomposer2.py +57 -22
modeling_internlm_xcomposer2.py
CHANGED
|
@@ -35,6 +35,7 @@ from transformers import (
|
|
| 35 |
StoppingCriteriaList,
|
| 36 |
set_seed,
|
| 37 |
)
|
|
|
|
| 38 |
from transformers.generation.streamers import BaseStreamer
|
| 39 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 40 |
from transformers.utils import (
|
|
@@ -52,6 +53,8 @@ from .modeling_internlm2 import (
|
|
| 52 |
)
|
| 53 |
|
| 54 |
_CONFIG_FOR_DOC = "InternLMXcomposer2Config"
|
|
|
|
|
|
|
| 55 |
|
| 56 |
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
| 57 |
video_extensions = {".mp4", ".avi", ".mkv", ".mov", ".wmv"}
|
|
@@ -103,7 +106,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 103 |
self.model = InternLM2Model(config)
|
| 104 |
self.vocab_size = config.vocab_size
|
| 105 |
self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 106 |
-
self.tokenizer = None
|
| 107 |
self.hd_num = 25
|
| 108 |
self.font = get_font()
|
| 109 |
|
|
@@ -245,12 +248,12 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 245 |
self.max_length = max_length
|
| 246 |
prompt = ""
|
| 247 |
if meta_instruction:
|
| 248 |
-
prompt +=
|
| 249 |
-
f"""[UNUSED_TOKEN_146]system\n{meta_instruction}[UNUSED_TOKEN_145]\n"""
|
| 250 |
-
)
|
| 251 |
for record in history:
|
| 252 |
-
prompt += f"""
|
| 253 |
-
prompt +=
|
|
|
|
|
|
|
| 254 |
|
| 255 |
image_nums = len(image)
|
| 256 |
if image_nums == 1 and prompt.find("<ImageHere>") == -1:
|
|
@@ -587,7 +590,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 587 |
shift_labels = labels[..., 1:].contiguous()
|
| 588 |
# Flatten the tokens
|
| 589 |
loss_fct = CrossEntropyLoss()
|
| 590 |
-
shift_logits = shift_logits.view(-1, self.
|
| 591 |
shift_labels = shift_labels.view(-1)
|
| 592 |
# Enable model parallelism
|
| 593 |
shift_labels = shift_labels.to(shift_logits.device)
|
|
@@ -676,12 +679,14 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 676 |
):
|
| 677 |
prompt = ""
|
| 678 |
if meta_instruction:
|
| 679 |
-
prompt += f"""<s>
|
| 680 |
else:
|
| 681 |
prompt += "<s>"
|
| 682 |
for record in history:
|
| 683 |
-
prompt += f"""
|
| 684 |
-
prompt +=
|
|
|
|
|
|
|
| 685 |
return tokenizer([prompt], return_tensors="pt")
|
| 686 |
|
| 687 |
@torch.no_grad()
|
|
@@ -724,7 +729,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 724 |
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
| 725 |
eos_token_id = [
|
| 726 |
tokenizer.eos_token_id,
|
| 727 |
-
tokenizer.convert_tokens_to_ids([
|
| 728 |
]
|
| 729 |
outputs = self.generate(
|
| 730 |
**inputs,
|
|
@@ -745,7 +750,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 745 |
else:
|
| 746 |
outputs = outputs[0].cpu().tolist()
|
| 747 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
| 748 |
-
response = response.split(
|
| 749 |
history = history + [(query, response)]
|
| 750 |
return response, history
|
| 751 |
|
|
@@ -807,8 +812,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 807 |
response = generate[0].tolist()
|
| 808 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 809 |
# remove eoa
|
| 810 |
-
response = response.replace(
|
| 811 |
-
response = response.replace(
|
| 812 |
|
| 813 |
return response
|
| 814 |
|
|
@@ -847,8 +852,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 847 |
response = generate[0].tolist()
|
| 848 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 849 |
# remove eoa
|
| 850 |
-
response = response.replace(
|
| 851 |
-
out = response.replace(
|
| 852 |
image_type = "random"
|
| 853 |
pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
|
| 854 |
if image_type == "placeholder":
|
|
@@ -900,8 +905,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 900 |
response = generate[0].tolist()
|
| 901 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 902 |
# remove eoa
|
| 903 |
-
response = response.replace(
|
| 904 |
-
html = response.replace(
|
| 905 |
|
| 906 |
if seed != -1:
|
| 907 |
set_random_seed(seed, set_cudnn=True)
|
|
@@ -923,8 +928,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 923 |
response = generate[0].tolist()
|
| 924 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 925 |
# remove eoa
|
| 926 |
-
response = response.replace(
|
| 927 |
-
js = response.replace(
|
| 928 |
|
| 929 |
if re.search(r"</script>", html):
|
| 930 |
js = re.findall(r"<script>([\s\S]*?)<\/script>", js)
|
|
@@ -983,8 +988,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 983 |
response = generate[0].tolist()
|
| 984 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 985 |
# remove eoa
|
| 986 |
-
response = response.replace(
|
| 987 |
-
out = response.replace(
|
| 988 |
image_type = "random"
|
| 989 |
pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
|
| 990 |
if image_type == "placeholder":
|
|
@@ -995,3 +1000,33 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 995 |
with open(task.replace(" ", "_") + ".html", "w") as f:
|
| 996 |
f.write(out)
|
| 997 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
StoppingCriteriaList,
|
| 36 |
set_seed,
|
| 37 |
)
|
| 38 |
+
from transformers import PreTrainedTokenizer
|
| 39 |
from transformers.generation.streamers import BaseStreamer
|
| 40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 41 |
from transformers.utils import (
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
_CONFIG_FOR_DOC = "InternLMXcomposer2Config"
|
| 56 |
+
FROM_TOKEN_1 = "[UNUSED_TOKEN_146]"
|
| 57 |
+
FROM_TOKEN_2 = "[UNUSED_TOKEN_145]"
|
| 58 |
|
| 59 |
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
| 60 |
video_extensions = {".mp4", ".avi", ".mkv", ".mov", ".wmv"}
|
|
|
|
| 106 |
self.model = InternLM2Model(config)
|
| 107 |
self.vocab_size = config.vocab_size
|
| 108 |
self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 109 |
+
self.tokenizer: PreTrainedTokenizer = None # type: ignore
|
| 110 |
self.hd_num = 25
|
| 111 |
self.font = get_font()
|
| 112 |
|
|
|
|
| 248 |
self.max_length = max_length
|
| 249 |
prompt = ""
|
| 250 |
if meta_instruction:
|
| 251 |
+
prompt += f"""{FROM_TOKEN_1}system\n{meta_instruction}{FROM_TOKEN_2}\n"""
|
|
|
|
|
|
|
| 252 |
for record in history:
|
| 253 |
+
prompt += f"""{FROM_TOKEN_1}user\n{record[0]}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n{record[1]}{FROM_TOKEN_2}\n"""
|
| 254 |
+
prompt += (
|
| 255 |
+
f"""{FROM_TOKEN_1}user\n{query}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n"""
|
| 256 |
+
)
|
| 257 |
|
| 258 |
image_nums = len(image)
|
| 259 |
if image_nums == 1 and prompt.find("<ImageHere>") == -1:
|
|
|
|
| 590 |
shift_labels = labels[..., 1:].contiguous()
|
| 591 |
# Flatten the tokens
|
| 592 |
loss_fct = CrossEntropyLoss()
|
| 593 |
+
shift_logits = shift_logits.view(-1, self.vocab_size)
|
| 594 |
shift_labels = shift_labels.view(-1)
|
| 595 |
# Enable model parallelism
|
| 596 |
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
| 679 |
):
|
| 680 |
prompt = ""
|
| 681 |
if meta_instruction:
|
| 682 |
+
prompt += f"""<s>{FROM_TOKEN_1}system\n{meta_instruction}{FROM_TOKEN_2}\n"""
|
| 683 |
else:
|
| 684 |
prompt += "<s>"
|
| 685 |
for record in history:
|
| 686 |
+
prompt += f"""{FROM_TOKEN_1}user\n{record[0]}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n{record[1]}{FROM_TOKEN_2}\n"""
|
| 687 |
+
prompt += (
|
| 688 |
+
f"""{FROM_TOKEN_1}user\n{query}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n"""
|
| 689 |
+
)
|
| 690 |
return tokenizer([prompt], return_tensors="pt")
|
| 691 |
|
| 692 |
@torch.no_grad()
|
|
|
|
| 729 |
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
| 730 |
eos_token_id = [
|
| 731 |
tokenizer.eos_token_id,
|
| 732 |
+
tokenizer.convert_tokens_to_ids([FROM_TOKEN_2])[0],
|
| 733 |
]
|
| 734 |
outputs = self.generate(
|
| 735 |
**inputs,
|
|
|
|
| 750 |
else:
|
| 751 |
outputs = outputs[0].cpu().tolist()
|
| 752 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
| 753 |
+
response = response.split(FROM_TOKEN_2)[0]
|
| 754 |
history = history + [(query, response)]
|
| 755 |
return response, history
|
| 756 |
|
|
|
|
| 812 |
response = generate[0].tolist()
|
| 813 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 814 |
# remove eoa
|
| 815 |
+
response = response.replace(FROM_TOKEN_2, "")
|
| 816 |
+
response = response.replace(FROM_TOKEN_1, "")
|
| 817 |
|
| 818 |
return response
|
| 819 |
|
|
|
|
| 852 |
response = generate[0].tolist()
|
| 853 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 854 |
# remove eoa
|
| 855 |
+
response = response.replace(FROM_TOKEN_2, "")
|
| 856 |
+
out = response.replace(FROM_TOKEN_1, "")
|
| 857 |
image_type = "random"
|
| 858 |
pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
|
| 859 |
if image_type == "placeholder":
|
|
|
|
| 905 |
response = generate[0].tolist()
|
| 906 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 907 |
# remove eoa
|
| 908 |
+
response = response.replace(FROM_TOKEN_2, "")
|
| 909 |
+
html = response.replace(FROM_TOKEN_1, "")
|
| 910 |
|
| 911 |
if seed != -1:
|
| 912 |
set_random_seed(seed, set_cudnn=True)
|
|
|
|
| 928 |
response = generate[0].tolist()
|
| 929 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 930 |
# remove eoa
|
| 931 |
+
response = response.replace(FROM_TOKEN_2, "")
|
| 932 |
+
js = response.replace(FROM_TOKEN_1, "")
|
| 933 |
|
| 934 |
if re.search(r"</script>", html):
|
| 935 |
js = re.findall(r"<script>([\s\S]*?)<\/script>", js)
|
|
|
|
| 988 |
response = generate[0].tolist()
|
| 989 |
response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
|
| 990 |
# remove eoa
|
| 991 |
+
response = response.replace(FROM_TOKEN_2, "")
|
| 992 |
+
out = response.replace(FROM_TOKEN_1, "")
|
| 993 |
image_type = "random"
|
| 994 |
pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
|
| 995 |
if image_type == "placeholder":
|
|
|
|
| 1000 |
with open(task.replace(" ", "_") + ".html", "w") as f:
|
| 1001 |
f.write(out)
|
| 1002 |
return out
|
| 1003 |
+
|
| 1004 |
+
def add_tokens(self, new_tokens: list[str]):
|
| 1005 |
+
self.tokenizer.add_tokens(new_tokens) # type: ignore
|
| 1006 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 1007 |
+
self.vocab_size = len(self.tokenizer)
|
| 1008 |
+
|
| 1009 |
+
# self.output needs to be resized accordingly but without loosing the weight
|
| 1010 |
+
new_output = nn.Linear(
|
| 1011 |
+
self.model.config.hidden_size,
|
| 1012 |
+
self.vocab_size,
|
| 1013 |
+
bias=False,
|
| 1014 |
+
dtype=self.output.weight.dtype,
|
| 1015 |
+
device=self.output.weight.device,
|
| 1016 |
+
).to(self.device)
|
| 1017 |
+
new_output.weight.data[: self.output.weight.shape[0]] = self.output.weight.data
|
| 1018 |
+
dummy_input_for_output = torch.zeros(
|
| 1019 |
+
1,
|
| 1020 |
+
1,
|
| 1021 |
+
self.model.config.hidden_size,
|
| 1022 |
+
device=new_output.weight.device,
|
| 1023 |
+
dtype=new_output.weight.dtype,
|
| 1024 |
+
).type_as(new_output.weight)
|
| 1025 |
+
# Check if output has same behavior
|
| 1026 |
+
dummy_old_output: torch.Tensor = self.output(dummy_input_for_output)
|
| 1027 |
+
dummy_new_output = new_output(dummy_input_for_output)
|
| 1028 |
+
assert dummy_old_output.allclose(
|
| 1029 |
+
dummy_new_output[:, :, : self.output.weight.shape[0]]
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
self.output = new_output
|