KpLBaTMaN
commited on
Commit
·
90aefec
1
Parent(s):
8e8b282
code
Browse files- modeling_GOT.py +66 -59
modeling_GOT.py
CHANGED
|
@@ -15,6 +15,7 @@ import dataclasses
|
|
| 15 |
import numpy as np
|
| 16 |
import cv2
|
| 17 |
from io import BytesIO
|
|
|
|
| 18 |
###
|
| 19 |
|
| 20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
|
@@ -501,15 +502,24 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 501 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 502 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 503 |
|
| 504 |
-
def chat(
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
self.disable_torch_init()
|
| 507 |
|
| 508 |
-
|
| 509 |
-
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 510 |
-
|
| 511 |
use_im_start_end = True
|
| 512 |
-
|
| 513 |
image_token_len = 256
|
| 514 |
|
| 515 |
if gradio_input:
|
|
@@ -518,7 +528,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 518 |
image = self.load_image(image_file)
|
| 519 |
|
| 520 |
w, h = image.size
|
| 521 |
-
|
| 522 |
if ocr_type == 'format':
|
| 523 |
qs = 'OCR with format: '
|
| 524 |
else:
|
|
@@ -527,13 +537,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 527 |
if ocr_box:
|
| 528 |
bbox = eval(ocr_box)
|
| 529 |
if len(bbox) == 2:
|
| 530 |
-
bbox[0] = int(bbox[0]/w*1000)
|
| 531 |
-
bbox[1] = int(bbox[1]/h*1000)
|
| 532 |
if len(bbox) == 4:
|
| 533 |
-
bbox[0] = int(bbox[0]/w*1000)
|
| 534 |
-
bbox[1] = int(bbox[1]/h*1000)
|
| 535 |
-
bbox[2] = int(bbox[2]/w*1000)
|
| 536 |
-
bbox[3] = int(bbox[3]/h*1000)
|
| 537 |
if ocr_type == 'format':
|
| 538 |
qs = str(bbox) + ' ' + 'OCR with format: '
|
| 539 |
else:
|
|
@@ -546,15 +556,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 546 |
qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
|
| 547 |
|
| 548 |
if use_im_start_end:
|
| 549 |
-
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
|
| 550 |
else:
|
| 551 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
| 552 |
|
| 553 |
-
|
| 554 |
conv_mpt = Conversation(
|
| 555 |
system="""<|im_start|>system
|
| 556 |
-
|
| 557 |
-
# system = None,
|
| 558 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 559 |
version="mpt",
|
| 560 |
messages=(),
|
|
@@ -572,43 +580,47 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 572 |
print(prompt)
|
| 573 |
|
| 574 |
inputs = tokenizer([prompt])
|
| 575 |
-
|
| 576 |
-
image_tensor_1 = image_processor_high(image)
|
| 577 |
-
|
| 578 |
-
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
| 579 |
|
| 580 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 581 |
keywords = [stop_str]
|
| 582 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 583 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 584 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
if stream_flag:
|
| 586 |
-
with
|
| 587 |
output_ids = self.generate(
|
| 588 |
input_ids,
|
| 589 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
| 590 |
do_sample=False,
|
| 591 |
-
num_beams
|
| 592 |
-
no_repeat_ngram_size
|
| 593 |
streamer=streamer,
|
| 594 |
max_new_tokens=4096,
|
| 595 |
stopping_criteria=[stopping_criteria]
|
| 596 |
-
|
| 597 |
else:
|
| 598 |
-
with
|
| 599 |
output_ids = self.generate(
|
| 600 |
input_ids,
|
| 601 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
| 602 |
do_sample=False,
|
| 603 |
-
num_beams
|
| 604 |
-
no_repeat_ngram_size
|
| 605 |
-
# streamer=streamer,
|
| 606 |
max_new_tokens=4096,
|
| 607 |
stopping_criteria=[stopping_criteria]
|
| 608 |
-
|
| 609 |
-
|
| 610 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 611 |
-
|
| 612 |
if outputs.endswith(stop_str):
|
| 613 |
outputs = outputs[:-len(stop_str)]
|
| 614 |
outputs = outputs.strip()
|
|
@@ -622,46 +634,44 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 622 |
import verovio
|
| 623 |
tk = verovio.toolkit()
|
| 624 |
tk.loadData(outputs)
|
| 625 |
-
tk.setOptions({
|
| 626 |
-
|
| 627 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
tk.getPageCount()
|
| 629 |
svg = tk.renderToSVG()
|
| 630 |
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
|
| 631 |
-
|
| 632 |
svg_to_html(svg, save_render_file)
|
| 633 |
|
| 634 |
if ocr_type == 'format' and '**kern' not in outputs:
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
if '\\begin{tikzpicture}' not in outputs:
|
| 638 |
html_path_2 = save_render_file
|
| 639 |
right_num = outputs.count('\\right')
|
| 640 |
-
left_num = outputs.count('
|
| 641 |
-
|
| 642 |
if right_num != left_num:
|
| 643 |
-
outputs = outputs.replace('\left(', '(').replace('\\right)', ')')
|
| 644 |
-
|
| 645 |
-
|
|
|
|
|
|
|
| 646 |
outputs = outputs.replace('"', '``').replace('$', '')
|
| 647 |
-
|
| 648 |
outputs_list = outputs.split('\n')
|
| 649 |
-
gt= ''
|
| 650 |
for out in outputs_list:
|
| 651 |
-
gt +=
|
| 652 |
-
|
| 653 |
gt = gt[:-2]
|
| 654 |
-
|
| 655 |
-
|
| 656 |
lines = content_mmd_to_html
|
| 657 |
lines = lines.split("const text =")
|
| 658 |
-
new_web = lines[0] + 'const text ='
|
| 659 |
-
|
| 660 |
else:
|
| 661 |
html_path_2 = save_render_file
|
| 662 |
outputs = outputs.translate(translation_table)
|
| 663 |
outputs_list = outputs.split('\n')
|
| 664 |
-
gt= ''
|
| 665 |
for out in outputs_list:
|
| 666 |
if out:
|
| 667 |
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
|
|
@@ -669,7 +679,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 669 |
out = out[:-1]
|
| 670 |
if out is None:
|
| 671 |
break
|
| 672 |
-
|
| 673 |
if out:
|
| 674 |
if out[-1] != ';':
|
| 675 |
gt += out[:-1] + ';\n'
|
|
@@ -677,14 +686,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 677 |
gt += out + '\n'
|
| 678 |
else:
|
| 679 |
gt += out + '\n'
|
| 680 |
-
|
| 681 |
-
|
| 682 |
lines = tik_html
|
| 683 |
lines = lines.split("const text =")
|
| 684 |
new_web = lines[0] + gt + lines[1]
|
| 685 |
-
|
| 686 |
with open(html_path_2, 'w') as web_f_new:
|
| 687 |
web_f_new.write(new_web)
|
|
|
|
| 688 |
return response_str
|
| 689 |
|
| 690 |
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import cv2
|
| 17 |
from io import BytesIO
|
| 18 |
+
import contextlib
|
| 19 |
###
|
| 20 |
|
| 21 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
|
|
|
| 502 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 503 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 504 |
|
| 505 |
+
def chat(
|
| 506 |
+
self,
|
| 507 |
+
tokenizer,
|
| 508 |
+
image_file,
|
| 509 |
+
ocr_type,
|
| 510 |
+
ocr_box='',
|
| 511 |
+
ocr_color='',
|
| 512 |
+
render=False,
|
| 513 |
+
save_render_file=None,
|
| 514 |
+
print_prompt=False,
|
| 515 |
+
gradio_input=False,
|
| 516 |
+
stream_flag=False,
|
| 517 |
+
device="cuda" # new parameter to specify the device
|
| 518 |
+
):
|
| 519 |
self.disable_torch_init()
|
| 520 |
|
| 521 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
|
|
|
|
|
|
| 522 |
use_im_start_end = True
|
|
|
|
| 523 |
image_token_len = 256
|
| 524 |
|
| 525 |
if gradio_input:
|
|
|
|
| 528 |
image = self.load_image(image_file)
|
| 529 |
|
| 530 |
w, h = image.size
|
| 531 |
+
|
| 532 |
if ocr_type == 'format':
|
| 533 |
qs = 'OCR with format: '
|
| 534 |
else:
|
|
|
|
| 537 |
if ocr_box:
|
| 538 |
bbox = eval(ocr_box)
|
| 539 |
if len(bbox) == 2:
|
| 540 |
+
bbox[0] = int(bbox[0] / w * 1000)
|
| 541 |
+
bbox[1] = int(bbox[1] / h * 1000)
|
| 542 |
if len(bbox) == 4:
|
| 543 |
+
bbox[0] = int(bbox[0] / w * 1000)
|
| 544 |
+
bbox[1] = int(bbox[1] / h * 1000)
|
| 545 |
+
bbox[2] = int(bbox[2] / w * 1000)
|
| 546 |
+
bbox[3] = int(bbox[3] / h * 1000)
|
| 547 |
if ocr_type == 'format':
|
| 548 |
qs = str(bbox) + ' ' + 'OCR with format: '
|
| 549 |
else:
|
|
|
|
| 556 |
qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
|
| 557 |
|
| 558 |
if use_im_start_end:
|
| 559 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
|
| 560 |
else:
|
| 561 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
| 562 |
|
|
|
|
| 563 |
conv_mpt = Conversation(
|
| 564 |
system="""<|im_start|>system
|
| 565 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
|
|
|
| 566 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 567 |
version="mpt",
|
| 568 |
messages=(),
|
|
|
|
| 580 |
print(prompt)
|
| 581 |
|
| 582 |
inputs = tokenizer([prompt])
|
| 583 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(device)
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 586 |
keywords = [stop_str]
|
| 587 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 588 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 589 |
|
| 590 |
+
image_tensor_1 = image_processor_high(image)
|
| 591 |
+
|
| 592 |
+
# Use autocast only when on CUDA, otherwise use a null context for CPU
|
| 593 |
+
if device == "cuda":
|
| 594 |
+
autocast_context = torch.autocast("cuda", dtype=torch.bfloat16)
|
| 595 |
+
else:
|
| 596 |
+
autocast_context = contextlib.nullcontext()
|
| 597 |
+
|
| 598 |
if stream_flag:
|
| 599 |
+
with autocast_context:
|
| 600 |
output_ids = self.generate(
|
| 601 |
input_ids,
|
| 602 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
| 603 |
do_sample=False,
|
| 604 |
+
num_beams=1,
|
| 605 |
+
no_repeat_ngram_size=20,
|
| 606 |
streamer=streamer,
|
| 607 |
max_new_tokens=4096,
|
| 608 |
stopping_criteria=[stopping_criteria]
|
| 609 |
+
)
|
| 610 |
else:
|
| 611 |
+
with autocast_context:
|
| 612 |
output_ids = self.generate(
|
| 613 |
input_ids,
|
| 614 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
| 615 |
do_sample=False,
|
| 616 |
+
num_beams=1,
|
| 617 |
+
no_repeat_ngram_size=20,
|
|
|
|
| 618 |
max_new_tokens=4096,
|
| 619 |
stopping_criteria=[stopping_criteria]
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 623 |
+
|
| 624 |
if outputs.endswith(stop_str):
|
| 625 |
outputs = outputs[:-len(stop_str)]
|
| 626 |
outputs = outputs.strip()
|
|
|
|
| 634 |
import verovio
|
| 635 |
tk = verovio.toolkit()
|
| 636 |
tk.loadData(outputs)
|
| 637 |
+
tk.setOptions({
|
| 638 |
+
"pageWidth": 2100,
|
| 639 |
+
"footer": 'none',
|
| 640 |
+
'barLineWidth': 0.5,
|
| 641 |
+
'beamMaxSlope': 15,
|
| 642 |
+
'staffLineWidth': 0.2,
|
| 643 |
+
'spacingStaff': 6
|
| 644 |
+
})
|
| 645 |
tk.getPageCount()
|
| 646 |
svg = tk.renderToSVG()
|
| 647 |
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
|
|
|
|
| 648 |
svg_to_html(svg, save_render_file)
|
| 649 |
|
| 650 |
if ocr_type == 'format' and '**kern' not in outputs:
|
| 651 |
+
if '\\begin{tikzpicture}' not in outputs:
|
|
|
|
|
|
|
| 652 |
html_path_2 = save_render_file
|
| 653 |
right_num = outputs.count('\\right')
|
| 654 |
+
left_num = outputs.count('\\left')
|
|
|
|
| 655 |
if right_num != left_num:
|
| 656 |
+
outputs = outputs.replace('\left(', '(').replace('\\right)', ')')\
|
| 657 |
+
.replace('\left[', '[').replace('\\right]', ']')\
|
| 658 |
+
.replace('\left{', '{').replace('\\right}', '}')\
|
| 659 |
+
.replace('\left|', '|').replace('\\right|', '|')\
|
| 660 |
+
.replace('\left.', '.').replace('\\right.', '.')
|
| 661 |
outputs = outputs.replace('"', '``').replace('$', '')
|
|
|
|
| 662 |
outputs_list = outputs.split('\n')
|
| 663 |
+
gt = ''
|
| 664 |
for out in outputs_list:
|
| 665 |
+
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
|
|
|
| 666 |
gt = gt[:-2]
|
|
|
|
|
|
|
| 667 |
lines = content_mmd_to_html
|
| 668 |
lines = lines.split("const text =")
|
| 669 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
|
|
|
| 670 |
else:
|
| 671 |
html_path_2 = save_render_file
|
| 672 |
outputs = outputs.translate(translation_table)
|
| 673 |
outputs_list = outputs.split('\n')
|
| 674 |
+
gt = ''
|
| 675 |
for out in outputs_list:
|
| 676 |
if out:
|
| 677 |
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
|
|
|
|
| 679 |
out = out[:-1]
|
| 680 |
if out is None:
|
| 681 |
break
|
|
|
|
| 682 |
if out:
|
| 683 |
if out[-1] != ';':
|
| 684 |
gt += out[:-1] + ';\n'
|
|
|
|
| 686 |
gt += out + '\n'
|
| 687 |
else:
|
| 688 |
gt += out + '\n'
|
|
|
|
|
|
|
| 689 |
lines = tik_html
|
| 690 |
lines = lines.split("const text =")
|
| 691 |
new_web = lines[0] + gt + lines[1]
|
|
|
|
| 692 |
with open(html_path_2, 'w') as web_f_new:
|
| 693 |
web_f_new.write(new_web)
|
| 694 |
+
|
| 695 |
return response_str
|
| 696 |
|
| 697 |
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|