| 
							 | 
						from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer | 
					
					
						
						| 
							 | 
						from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | 
					
					
						
						| 
							 | 
						from typing import List, Optional, Tuple, Union | 
					
					
						
						| 
							 | 
						from transformers.cache_utils import Cache | 
					
					
						
						| 
							 | 
						import requests | 
					
					
						
						| 
							 | 
						from PIL import Image | 
					
					
						
						| 
							 | 
						from io import BytesIO | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						from torch.nn import CrossEntropyLoss | 
					
					
						
						| 
							 | 
						from .got_vision_b import build_GOT_vit_b | 
					
					
						
						| 
							 | 
						from torchvision import transforms | 
					
					
						
						| 
							 | 
						from torchvision.transforms.functional import InterpolationMode | 
					
					
						
						| 
							 | 
						import dataclasses | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						DEFAULT_IMAGE_TOKEN = "<image>" | 
					
					
						
						| 
							 | 
						DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>' | 
					
					
						
						| 
							 | 
						DEFAULT_IM_START_TOKEN = '<img>' | 
					
					
						
						| 
							 | 
						DEFAULT_IM_END_TOKEN = '</img>' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from enum import auto, Enum | 
					
					
						
						| 
							 | 
						class SeparatorStyle(Enum): | 
					
					
						
						| 
							 | 
						    """Different separator style.""" | 
					
					
						
						| 
							 | 
						    SINGLE = auto() | 
					
					
						
						| 
							 | 
						    TWO = auto() | 
					
					
						
						| 
							 | 
						    MPT = auto() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@dataclasses.dataclass | 
					
					
						
						| 
							 | 
						class Conversation: | 
					
					
						
						| 
							 | 
						    """A class that keeps all conversation history.""" | 
					
					
						
						| 
							 | 
						    system: str | 
					
					
						
						| 
							 | 
						    roles: List[str] | 
					
					
						
						| 
							 | 
						    messages: List[List[str]] | 
					
					
						
						| 
							 | 
						    offset: int | 
					
					
						
						| 
							 | 
						    sep_style: SeparatorStyle = SeparatorStyle.SINGLE | 
					
					
						
						| 
							 | 
						    sep: str = "<|im_end|>" | 
					
					
						
						| 
							 | 
						    sep2: str = None | 
					
					
						
						| 
							 | 
						    version: str = "Unknown" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    skip_next: bool = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_prompt(self): | 
					
					
						
						| 
							 | 
						        if self.sep_style == SeparatorStyle.SINGLE: | 
					
					
						
						| 
							 | 
						            ret = self.system + self.sep + '\n' | 
					
					
						
						| 
							 | 
						            for role, message in self.messages: | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    if type(message) is tuple: | 
					
					
						
						| 
							 | 
						                        message, _, _ = message | 
					
					
						
						| 
							 | 
						                    ret += role + ": " + message + self.sep | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role + ":" | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        elif self.sep_style == SeparatorStyle.TWO: | 
					
					
						
						| 
							 | 
						            seps = [self.sep, self.sep2] | 
					
					
						
						| 
							 | 
						            ret = self.system + seps[0] | 
					
					
						
						| 
							 | 
						            for i, (role, message) in enumerate(self.messages): | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    if type(message) is tuple: | 
					
					
						
						| 
							 | 
						                        message, _, _ = message | 
					
					
						
						| 
							 | 
						                    ret += role + ": " + message + seps[i % 2] | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role + ":" | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        if self.sep_style == SeparatorStyle.MPT: | 
					
					
						
						| 
							 | 
						            if self.system: | 
					
					
						
						| 
							 | 
						                ret = self.system + self.sep  | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                ret = '' | 
					
					
						
						| 
							 | 
						            for role, message in self.messages: | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    if type(message) is tuple: | 
					
					
						
						| 
							 | 
						                        message, _, _ = message | 
					
					
						
						| 
							 | 
						                    ret += role + message + self.sep | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ValueError(f"Invalid style: {self.sep_style}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def append_message(self, role, message): | 
					
					
						
						| 
							 | 
						        self.messages.append([role, message]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def copy(self): | 
					
					
						
						| 
							 | 
						        return Conversation( | 
					
					
						
						| 
							 | 
						            system=self.system, | 
					
					
						
						| 
							 | 
						            roles=self.roles, | 
					
					
						
						| 
							 | 
						            messages=[[x, y] for x, y in self.messages], | 
					
					
						
						| 
							 | 
						            offset=self.offset, | 
					
					
						
						| 
							 | 
						            sep_style=self.sep_style, | 
					
					
						
						| 
							 | 
						            sep=self.sep, | 
					
					
						
						| 
							 | 
						            sep2=self.sep2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class KeywordsStoppingCriteria(StoppingCriteria): | 
					
					
						
						| 
							 | 
						    def __init__(self, keywords, tokenizer, input_ids): | 
					
					
						
						| 
							 | 
						        self.keywords = keywords | 
					
					
						
						| 
							 | 
						        self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] | 
					
					
						
						| 
							 | 
						        self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] | 
					
					
						
						| 
							 | 
						        self.tokenizer = tokenizer | 
					
					
						
						| 
							 | 
						        self.start_len = None | 
					
					
						
						| 
							 | 
						        self.input_ids = input_ids | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | 
					
					
						
						| 
							 | 
						        if self.start_len is None: | 
					
					
						
						| 
							 | 
						            self.start_len = self.input_ids.shape[1] | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            for keyword_id in self.keyword_ids: | 
					
					
						
						| 
							 | 
						                if output_ids[0, -1] == keyword_id: | 
					
					
						
						| 
							 | 
						                    return True | 
					
					
						
						| 
							 | 
						            outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] | 
					
					
						
						| 
							 | 
						            for keyword in self.keywords: | 
					
					
						
						| 
							 | 
						                if keyword in outputs: | 
					
					
						
						| 
							 | 
						                    return True | 
					
					
						
						| 
							 | 
						        return False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GOTImageEvalProcessor: | 
					
					
						
						| 
							 | 
						    def __init__(self, image_size=384, mean=None, std=None): | 
					
					
						
						| 
							 | 
						        if mean is None: | 
					
					
						
						| 
							 | 
						            mean = (0.48145466, 0.4578275, 0.40821073) | 
					
					
						
						| 
							 | 
						        if std is None: | 
					
					
						
						| 
							 | 
						            std = (0.26862954, 0.26130258, 0.27577711) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.normalize = transforms.Normalize(mean, std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.transform = transforms.Compose( | 
					
					
						
						| 
							 | 
						            [ | 
					
					
						
						| 
							 | 
						                transforms.Resize( | 
					
					
						
						| 
							 | 
						                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC | 
					
					
						
						| 
							 | 
						                ), | 
					
					
						
						| 
							 | 
						                transforms.ToTensor(), | 
					
					
						
						| 
							 | 
						                self.normalize, | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    def __call__(self, item): | 
					
					
						
						| 
							 | 
						        return self.transform(item) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GOTConfig(Qwen2Config): | 
					
					
						
						| 
							 | 
						    model_type = "GOT" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GOTQwenModel(Qwen2Model): | 
					
					
						
						| 
							 | 
						    config_class = GOTConfig | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: Qwen2Config): | 
					
					
						
						| 
							 | 
						        super(GOTQwenModel, self).__init__(config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.vision_tower_high = build_GOT_vit_b() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.mm_projector_vary =  nn.Linear(1024, 1024) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def initialize_vision_modules( | 
					
					
						
						| 
							 | 
						        self,  | 
					
					
						
						| 
							 | 
						        vision_tower, | 
					
					
						
						| 
							 | 
						        pretrained_stage1_model=None, | 
					
					
						
						| 
							 | 
						        freeze_vision_tower=False, | 
					
					
						
						| 
							 | 
						        use_im_start_end=False, | 
					
					
						
						| 
							 | 
						        vision_select_layer=-1, | 
					
					
						
						| 
							 | 
						        dtype=torch.float16, | 
					
					
						
						| 
							 | 
						        device="cuda" | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_processor_high = GOTImageEvalProcessor(image_size=1024) | 
					
					
						
						| 
							 | 
						       | 
					
					
						
						| 
							 | 
						        self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_token_len = 256 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.config.vision_tower = vision_tower | 
					
					
						
						| 
							 | 
						        self.config.image_token_len = image_token_len | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.config.use_im_start_end = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.config.vision_select_layer = vision_select_layer | 
					
					
						
						| 
							 | 
						        self.config.freeze_vision_tower = freeze_vision_tower | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return dict( | 
					
					
						
						| 
							 | 
						            image_processor_high=image_processor_high, | 
					
					
						
						| 
							 | 
						            image_token_len=image_token_len, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						          | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        input_ids: torch.LongTensor = None, | 
					
					
						
						| 
							 | 
						        attention_mask: Optional[torch.Tensor] = None, | 
					
					
						
						| 
							 | 
						        position_ids: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        past_key_values: Optional[List[torch.FloatTensor]] = None, | 
					
					
						
						| 
							 | 
						        inputs_embeds: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        use_cache: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_attentions: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_hidden_states: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        images: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        return_dict: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						    ) -> Union[Tuple, BaseModelOutputWithPast]: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        orig_embeds_params = getattr(self, 'orig_embeds_params', None) | 
					
					
						
						| 
							 | 
						        if orig_embeds_params is not None: | 
					
					
						
						| 
							 | 
						            with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if inputs_embeds is None: | 
					
					
						
						| 
							 | 
						            inputs_embeds = self.embed_tokens(input_ids) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        vision_tower_high = getattr(self, 'vision_tower_high', None) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: | 
					
					
						
						| 
							 | 
						            use_im_start_end = getattr(self.config, "use_im_start_end", -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            vision_select_layer = getattr(self.config, "vision_select_layer", -1) | 
					
					
						
						| 
							 | 
						            im_patch_token = getattr(self.config, "im_patch_token", -1) | 
					
					
						
						| 
							 | 
						            im_start_token = getattr(self.config, "im_start_token", -1) | 
					
					
						
						| 
							 | 
						            im_end_token = getattr(self.config, "im_end_token", -1) | 
					
					
						
						| 
							 | 
						            freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            im_patch_token = 151859 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            im_start_token = 151857 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            im_end_token = 151858 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            image_features = [] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            for image in images: | 
					
					
						
						| 
							 | 
						                P, C, H, W = image.shape | 
					
					
						
						| 
							 | 
						                if P == 1: | 
					
					
						
						| 
							 | 
						                    with torch.set_grad_enabled(False): | 
					
					
						
						| 
							 | 
						                        cnn_feature = vision_tower_high(image) | 
					
					
						
						| 
							 | 
						                        cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1)  | 
					
					
						
						| 
							 | 
						                    image_feature = self.mm_projector_vary(cnn_feature) | 
					
					
						
						| 
							 | 
						                    image_features.append(image_feature) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    image_patches = torch.unbind(image) | 
					
					
						
						| 
							 | 
						                    image_patches_features = [] | 
					
					
						
						| 
							 | 
						                    for image_patch in image_patches: | 
					
					
						
						| 
							 | 
						                        image_p = torch.stack([image_patch]) | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        with torch.set_grad_enabled(False): | 
					
					
						
						| 
							 | 
						                            cnn_feature_p = vision_tower_high(image_p) | 
					
					
						
						| 
							 | 
						                            cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1) | 
					
					
						
						| 
							 | 
						                        image_feature_p = self.mm_projector_vary(cnn_feature_p) | 
					
					
						
						| 
							 | 
						                        image_patches_features.append(image_feature_p) | 
					
					
						
						| 
							 | 
						                    image_feature = torch.cat(image_patches_features, dim=1) | 
					
					
						
						| 
							 | 
						                    image_features.append(image_feature) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) | 
					
					
						
						| 
							 | 
						            dummy_image_features = dummy_image_features_2 | 
					
					
						
						| 
							 | 
						            use_im_start_end = True | 
					
					
						
						| 
							 | 
						            new_input_embeds = [] | 
					
					
						
						| 
							 | 
						            for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): | 
					
					
						
						| 
							 | 
						                if (cur_input_ids == im_patch_token).sum() == 0: | 
					
					
						
						| 
							 | 
						                    cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() | 
					
					
						
						| 
							 | 
						                    new_input_embeds.append(cur_input_embeds) | 
					
					
						
						| 
							 | 
						                    continue | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if use_im_start_end: | 
					
					
						
						| 
							 | 
						                    if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): | 
					
					
						
						| 
							 | 
						                        raise ValueError("The number of image start tokens and image end tokens should be the same.") | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] | 
					
					
						
						| 
							 | 
						                    for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): | 
					
					
						
						| 
							 | 
						                        per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device) | 
					
					
						
						| 
							 | 
						                        num_patches = per_cur_image_features.shape[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: | 
					
					
						
						| 
							 | 
						                            raise ValueError("The image end token should follow the image start token.") | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        cur_input_embeds = torch.cat( | 
					
					
						
						| 
							 | 
						                            ( | 
					
					
						
						| 
							 | 
						                                cur_input_embeds[:image_start_token_pos+1],  | 
					
					
						
						| 
							 | 
						                                per_cur_image_features,  | 
					
					
						
						| 
							 | 
						                                cur_input_embeds[image_start_token_pos + num_patches + 1:] | 
					
					
						
						| 
							 | 
						                            ),  | 
					
					
						
						| 
							 | 
						                            dim=0 | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    new_input_embeds.append(cur_input_embeds) | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    raise NotImplementedError | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            inputs_embeds = torch.stack(new_input_embeds, dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return super(GOTQwenModel, self).forward( | 
					
					
						
						| 
							 | 
						            input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, | 
					
					
						
						| 
							 | 
						            inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, | 
					
					
						
						| 
							 | 
						            output_attentions=output_attentions, output_hidden_states=output_hidden_states, | 
					
					
						
						| 
							 | 
						            return_dict=return_dict | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GOTQwenForCausalLM(Qwen2ForCausalLM): | 
					
					
						
						| 
							 | 
						    config_class = GOTConfig | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config): | 
					
					
						
						| 
							 | 
						        super(Qwen2ForCausalLM, self).__init__(config) | 
					
					
						
						| 
							 | 
						        self.model = GOTQwenModel(config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.vocab_size = config.vocab_size | 
					
					
						
						| 
							 | 
						        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.post_init() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_model(self): | 
					
					
						
						| 
							 | 
						        return self.model | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        input_ids: torch.LongTensor = None, | 
					
					
						
						| 
							 | 
						        attention_mask: Optional[torch.Tensor] = None, | 
					
					
						
						| 
							 | 
						        position_ids: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        past_key_values: Optional[List[torch.FloatTensor]] = None, | 
					
					
						
						| 
							 | 
						        inputs_embeds: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        labels: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        use_cache: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_attentions: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_hidden_states: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        images: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        return_dict: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    ) -> Union[Tuple, CausalLMOutputWithPast]: | 
					
					
						
						| 
							 | 
						        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | 
					
					
						
						| 
							 | 
						        output_hidden_states = ( | 
					
					
						
						| 
							 | 
						            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return_dict = return_dict if return_dict is not None else self.config.use_return_dict | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs  = self.model( | 
					
					
						
						| 
							 | 
						            input_ids=input_ids, | 
					
					
						
						| 
							 | 
						            past_key_values=past_key_values, | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            position_ids=position_ids, | 
					
					
						
						| 
							 | 
						            inputs_embeds=inputs_embeds, | 
					
					
						
						| 
							 | 
						            use_cache=use_cache, | 
					
					
						
						| 
							 | 
						            output_attentions=output_attentions, | 
					
					
						
						| 
							 | 
						            output_hidden_states=output_hidden_states, | 
					
					
						
						| 
							 | 
						            images=images, | 
					
					
						
						| 
							 | 
						            return_dict=return_dict | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        hidden_states = outputs[0] | 
					
					
						
						| 
							 | 
						        logits = self.lm_head(hidden_states) | 
					
					
						
						| 
							 | 
						        logits = logits.float() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        loss = None | 
					
					
						
						| 
							 | 
						        if labels is not None: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            shift_logits = logits[..., :-1, :].contiguous() | 
					
					
						
						| 
							 | 
						            shift_labels = labels[..., 1:].contiguous() | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            loss_fct = CrossEntropyLoss() | 
					
					
						
						| 
							 | 
						            shift_logits = shift_logits.view(-1, self.config.vocab_size) | 
					
					
						
						| 
							 | 
						            shift_labels = shift_labels.view(-1) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            shift_labels = shift_labels.to(shift_logits.device) | 
					
					
						
						| 
							 | 
						            loss = loss_fct(shift_logits, shift_labels) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if not return_dict: | 
					
					
						
						| 
							 | 
						            output = (logits,) + outputs[1:] | 
					
					
						
						| 
							 | 
						            return (loss,) + output if loss is not None else output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return CausalLMOutputWithPast( | 
					
					
						
						| 
							 | 
						            loss=loss, | 
					
					
						
						| 
							 | 
						            logits=logits, | 
					
					
						
						| 
							 | 
						            past_key_values=outputs.past_key_values, | 
					
					
						
						| 
							 | 
						            hidden_states=outputs.hidden_states, | 
					
					
						
						| 
							 | 
						            attentions=outputs.attentions, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def prepare_inputs_for_generation( | 
					
					
						
						| 
							 | 
						        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if past_key_values is not None: | 
					
					
						
						| 
							 | 
						            if isinstance(past_key_values, Cache): | 
					
					
						
						| 
							 | 
						                cache_length = past_key_values.get_seq_length() | 
					
					
						
						| 
							 | 
						                past_length = past_key_values.seen_tokens | 
					
					
						
						| 
							 | 
						                max_cache_length = past_key_values.get_max_length() | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                cache_length = past_length = past_key_values[0][0].shape[2] | 
					
					
						
						| 
							 | 
						                max_cache_length = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | 
					
					
						
						| 
							 | 
						                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            elif past_length < input_ids.shape[1]: | 
					
					
						
						| 
							 | 
						                input_ids = input_ids[:, past_length:] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if ( | 
					
					
						
						| 
							 | 
						                max_cache_length is not None | 
					
					
						
						| 
							 | 
						                and attention_mask is not None | 
					
					
						
						| 
							 | 
						                and cache_length + input_ids.shape[1] > max_cache_length | 
					
					
						
						| 
							 | 
						            ): | 
					
					
						
						| 
							 | 
						                attention_mask = attention_mask[:, -max_cache_length:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        position_ids = kwargs.get("position_ids", None) | 
					
					
						
						| 
							 | 
						        if attention_mask is not None and position_ids is None: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            position_ids = attention_mask.long().cumsum(-1) - 1 | 
					
					
						
						| 
							 | 
						            position_ids.masked_fill_(attention_mask == 0, 1) | 
					
					
						
						| 
							 | 
						            if past_key_values: | 
					
					
						
						| 
							 | 
						                position_ids = position_ids[:, -input_ids.shape[1] :] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if inputs_embeds is not None and past_key_values is None: | 
					
					
						
						| 
							 | 
						            model_inputs = {"inputs_embeds": inputs_embeds} | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            model_inputs = {"input_ids": input_ids} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        model_inputs.update( | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "position_ids": position_ids, | 
					
					
						
						| 
							 | 
						                "past_key_values": past_key_values, | 
					
					
						
						| 
							 | 
						                "use_cache": kwargs.get("use_cache"), | 
					
					
						
						| 
							 | 
						                "attention_mask": attention_mask, | 
					
					
						
						| 
							 | 
						                "images": kwargs.get("images", None), | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return model_inputs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def initialize_vision_tokenizer( | 
					
					
						
						| 
							 | 
						        self,  | 
					
					
						
						| 
							 | 
						        tokenizer,  | 
					
					
						
						| 
							 | 
						        freeze_lm_model=False,  | 
					
					
						
						| 
							 | 
						        pretrained_stage1_model=None, | 
					
					
						
						| 
							 | 
						        device="cuda" | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        config = self.get_model().config | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.resize_token_embeddings(len(tokenizer)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        config.im_patch_token = 151859 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        config.use_im_start_end = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if config.use_im_start_end: | 
					
					
						
						| 
							 | 
						            self.resize_token_embeddings(len(tokenizer)) | 
					
					
						
						| 
							 | 
						            config.im_start_token, config.im_end_token = 151857, 151858 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def load_image(self, image_file): | 
					
					
						
						| 
							 | 
						        if image_file.startswith('http') or image_file.startswith('https'): | 
					
					
						
						| 
							 | 
						            response = requests.get(image_file) | 
					
					
						
						| 
							 | 
						            image = Image.open(BytesIO(response.content)).convert('RGB') | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            image = Image.open(image_file).convert('RGB') | 
					
					
						
						| 
							 | 
						        return image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def disable_torch_init(self): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Disable the redundant torch default initialization to accelerate model creation. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        import torch | 
					
					
						
						| 
							 | 
						        setattr(torch.nn.Linear, "reset_parameters", lambda self: None) | 
					
					
						
						| 
							 | 
						        setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.disable_torch_init() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_processor_high =  GOTImageEvalProcessor(image_size=1024) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        use_im_start_end = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_token_len = 256 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if gradio_input: | 
					
					
						
						| 
							 | 
						            image = image_file.copy() | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            image = self.load_image(image_file) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        w, h = image.size | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if ocr_type == 'format': | 
					
					
						
						| 
							 | 
						            qs = 'OCR with format: ' | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            qs = 'OCR: ' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if ocr_box: | 
					
					
						
						| 
							 | 
						            bbox = eval(ocr_box) | 
					
					
						
						| 
							 | 
						            if len(bbox) == 2: | 
					
					
						
						| 
							 | 
						                bbox[0] = int(bbox[0]/w*1000) | 
					
					
						
						| 
							 | 
						                bbox[1] = int(bbox[1]/h*1000) | 
					
					
						
						| 
							 | 
						            if len(bbox) == 4: | 
					
					
						
						| 
							 | 
						                bbox[0] = int(bbox[0]/w*1000) | 
					
					
						
						| 
							 | 
						                bbox[1] = int(bbox[1]/h*1000) | 
					
					
						
						| 
							 | 
						                bbox[2] = int(bbox[2]/w*1000) | 
					
					
						
						| 
							 | 
						                bbox[3] = int(bbox[3]/h*1000) | 
					
					
						
						| 
							 | 
						            if ocr_type == 'format': | 
					
					
						
						| 
							 | 
						                qs = str(bbox) + ' ' + 'OCR with format: ' | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                qs = str(bbox) + ' ' + 'OCR: ' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if ocr_color: | 
					
					
						
						| 
							 | 
						            if ocr_type == 'format': | 
					
					
						
						| 
							 | 
						                qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: ' | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                qs = '[' + ocr_color + ']' + ' ' + 'OCR: ' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if use_im_start_end: | 
					
					
						
						| 
							 | 
						            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs  | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        conv_mpt = Conversation( | 
					
					
						
						| 
							 | 
						            system="""<|im_start|>system | 
					
					
						
						| 
							 | 
						        You should follow the instructions carefully and explain your answers in detail.""", | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), | 
					
					
						
						| 
							 | 
						            version="mpt", | 
					
					
						
						| 
							 | 
						            messages=(), | 
					
					
						
						| 
							 | 
						            offset=0, | 
					
					
						
						| 
							 | 
						            sep_style=SeparatorStyle.MPT, | 
					
					
						
						| 
							 | 
						            sep="<|im_end|>", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        conv = conv_mpt.copy() | 
					
					
						
						| 
							 | 
						        conv.append_message(conv.roles[0], qs) | 
					
					
						
						| 
							 | 
						        conv.append_message(conv.roles[1], None) | 
					
					
						
						| 
							 | 
						        prompt = conv.get_prompt() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if print_prompt: | 
					
					
						
						| 
							 | 
						            print(prompt) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        inputs = tokenizer([prompt]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_tensor_1 = image_processor_high(image) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        input_ids = torch.as_tensor(inputs.input_ids).cuda() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | 
					
					
						
						| 
							 | 
						        keywords = [stop_str] | 
					
					
						
						| 
							 | 
						        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | 
					
					
						
						| 
							 | 
						        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if stream_flag: | 
					
					
						
						| 
							 | 
						            with torch.autocast("cuda", dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						                output_ids = self.generate( | 
					
					
						
						| 
							 | 
						                    input_ids, | 
					
					
						
						| 
							 | 
						                    images=[image_tensor_1.unsqueeze(0).half().cuda()], | 
					
					
						
						| 
							 | 
						                    do_sample=False, | 
					
					
						
						| 
							 | 
						                    num_beams = 1, | 
					
					
						
						| 
							 | 
						                    no_repeat_ngram_size = 20, | 
					
					
						
						| 
							 | 
						                    streamer=streamer, | 
					
					
						
						| 
							 | 
						                    max_new_tokens=4096, | 
					
					
						
						| 
							 | 
						                    stopping_criteria=[stopping_criteria] | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            with torch.autocast("cuda", dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						                output_ids = self.generate( | 
					
					
						
						| 
							 | 
						                    input_ids, | 
					
					
						
						| 
							 | 
						                    images=[image_tensor_1.unsqueeze(0).half().cuda()], | 
					
					
						
						| 
							 | 
						                    do_sample=False, | 
					
					
						
						| 
							 | 
						                    num_beams = 1, | 
					
					
						
						| 
							 | 
						                    no_repeat_ngram_size = 20, | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    max_new_tokens=4096, | 
					
					
						
						| 
							 | 
						                    stopping_criteria=[stopping_criteria] | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if outputs.endswith(stop_str): | 
					
					
						
						| 
							 | 
						            outputs = outputs[:-len(stop_str)] | 
					
					
						
						| 
							 | 
						        outputs = outputs.strip() | 
					
					
						
						| 
							 | 
						        response_str = outputs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if render: | 
					
					
						
						| 
							 | 
						            print('==============rendering===============') | 
					
					
						
						| 
							 | 
						            from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if '**kern' in outputs: | 
					
					
						
						| 
							 | 
						                import verovio | 
					
					
						
						| 
							 | 
						                tk = verovio.toolkit() | 
					
					
						
						| 
							 | 
						                tk.loadData(outputs) | 
					
					
						
						| 
							 | 
						                tk.setOptions({"pageWidth": 2100, "footer": 'none', | 
					
					
						
						| 
							 | 
						            'barLineWidth': 0.5, 'beamMaxSlope': 15, | 
					
					
						
						| 
							 | 
						            'staffLineWidth': 0.2, 'spacingStaff': 6}) | 
					
					
						
						| 
							 | 
						                tk.getPageCount() | 
					
					
						
						| 
							 | 
						                svg = tk.renderToSVG() | 
					
					
						
						| 
							 | 
						                svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                svg_to_html(svg, save_render_file) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if ocr_type == 'format' and '**kern' not in outputs: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if  '\\begin{tikzpicture}' not in outputs: | 
					
					
						
						| 
							 | 
						                    html_path_2 = save_render_file | 
					
					
						
						| 
							 | 
						                    right_num = outputs.count('\\right') | 
					
					
						
						| 
							 | 
						                    left_num = outputs.count('\left') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    if right_num != left_num: | 
					
					
						
						| 
							 | 
						                        outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    outputs = outputs.replace('"', '``').replace('$', '') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    outputs_list = outputs.split('\n') | 
					
					
						
						| 
							 | 
						                    gt= '' | 
					
					
						
						| 
							 | 
						                    for out in outputs_list: | 
					
					
						
						| 
							 | 
						                        gt +=  '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'  | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    gt = gt[:-2] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    lines = content_mmd_to_html | 
					
					
						
						| 
							 | 
						                    lines = lines.split("const text =") | 
					
					
						
						| 
							 | 
						                    new_web = lines[0] + 'const text ='  + gt  + lines[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    html_path_2 = save_render_file | 
					
					
						
						| 
							 | 
						                    outputs = outputs.translate(translation_table) | 
					
					
						
						| 
							 | 
						                    outputs_list = outputs.split('\n') | 
					
					
						
						| 
							 | 
						                    gt= '' | 
					
					
						
						| 
							 | 
						                    for out in outputs_list: | 
					
					
						
						| 
							 | 
						                        if out: | 
					
					
						
						| 
							 | 
						                            if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out: | 
					
					
						
						| 
							 | 
						                                while out[-1] == ' ': | 
					
					
						
						| 
							 | 
						                                    out = out[:-1] | 
					
					
						
						| 
							 | 
						                                    if out is None: | 
					
					
						
						| 
							 | 
						                                        break | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						                                if out: | 
					
					
						
						| 
							 | 
						                                    if out[-1] != ';': | 
					
					
						
						| 
							 | 
						                                        gt += out[:-1] + ';\n' | 
					
					
						
						| 
							 | 
						                                    else: | 
					
					
						
						| 
							 | 
						                                        gt += out + '\n' | 
					
					
						
						| 
							 | 
						                            else: | 
					
					
						
						| 
							 | 
						                                gt += out + '\n' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    lines = tik_html | 
					
					
						
						| 
							 | 
						                    lines = lines.split("const text =") | 
					
					
						
						| 
							 | 
						                    new_web = lines[0] + gt + lines[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                with open(html_path_2, 'w') as web_f_new: | 
					
					
						
						| 
							 | 
						                    web_f_new.write(new_web) | 
					
					
						
						| 
							 | 
						        return response_str | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | 
					
					
						
						| 
							 | 
						            best_ratio_diff = float('inf') | 
					
					
						
						| 
							 | 
						            best_ratio = (1, 1) | 
					
					
						
						| 
							 | 
						            area = width * height | 
					
					
						
						| 
							 | 
						            for ratio in target_ratios: | 
					
					
						
						| 
							 | 
						                target_aspect_ratio = ratio[0] / ratio[1] | 
					
					
						
						| 
							 | 
						                ratio_diff = abs(aspect_ratio - target_aspect_ratio) | 
					
					
						
						| 
							 | 
						                if ratio_diff < best_ratio_diff: | 
					
					
						
						| 
							 | 
						                    best_ratio_diff = ratio_diff | 
					
					
						
						| 
							 | 
						                    best_ratio = ratio | 
					
					
						
						| 
							 | 
						                elif ratio_diff == best_ratio_diff: | 
					
					
						
						| 
							 | 
						                    if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | 
					
					
						
						| 
							 | 
						                        best_ratio = ratio | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            return best_ratio | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        orig_width, orig_height = image.size | 
					
					
						
						| 
							 | 
						        aspect_ratio = orig_width / orig_height | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        target_ratios = set( | 
					
					
						
						| 
							 | 
						            (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | 
					
					
						
						| 
							 | 
						            i * j <= max_num and i * j >= min_num) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        target_aspect_ratio = find_closest_aspect_ratio( | 
					
					
						
						| 
							 | 
						            aspect_ratio, target_ratios, orig_width, orig_height, image_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        target_width = image_size * target_aspect_ratio[0] | 
					
					
						
						| 
							 | 
						        target_height = image_size * target_aspect_ratio[1] | 
					
					
						
						| 
							 | 
						        blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        resized_img = image.resize((target_width, target_height)) | 
					
					
						
						| 
							 | 
						        processed_images = [] | 
					
					
						
						| 
							 | 
						        for i in range(blocks): | 
					
					
						
						| 
							 | 
						            box = ( | 
					
					
						
						| 
							 | 
						                (i % (target_width // image_size)) * image_size, | 
					
					
						
						| 
							 | 
						                (i // (target_width // image_size)) * image_size, | 
					
					
						
						| 
							 | 
						                ((i % (target_width // image_size)) + 1) * image_size, | 
					
					
						
						| 
							 | 
						                ((i // (target_width // image_size)) + 1) * image_size | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            split_img = resized_img.crop(box) | 
					
					
						
						| 
							 | 
						            processed_images.append(split_img) | 
					
					
						
						| 
							 | 
						        assert len(processed_images) == blocks | 
					
					
						
						| 
							 | 
						        if use_thumbnail and len(processed_images) != 1: | 
					
					
						
						| 
							 | 
						            thumbnail_img = image.resize((image_size, image_size)) | 
					
					
						
						| 
							 | 
						            processed_images.append(thumbnail_img) | 
					
					
						
						| 
							 | 
						        return processed_images | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.disable_torch_init() | 
					
					
						
						| 
							 | 
						        multi_page=False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_processor_high =  GOTImageEvalProcessor(image_size=1024) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        use_im_start_end = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_token_len = 256 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_list = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if multi_page: | 
					
					
						
						| 
							 | 
						            qs = 'OCR with format across multi pages: ' | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            patches = image_file | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            sub_images = [] | 
					
					
						
						| 
							 | 
						            for sub_image in patches: | 
					
					
						
						| 
							 | 
						                sub_images.append(self.load_image(sub_image)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            ll = len(patches) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            if ocr_type == 'format': | 
					
					
						
						| 
							 | 
						                qs = 'OCR with format upon the patch reference: ' | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                qs = 'OCR upon the patch reference: ' | 
					
					
						
						| 
							 | 
						            if gradio_input: | 
					
					
						
						| 
							 | 
						                img = image_file.copy() | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                img = self.load_image(image_file) | 
					
					
						
						| 
							 | 
						            sub_images = self.dynamic_preprocess(img) | 
					
					
						
						| 
							 | 
						            ll = len(sub_images) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for image in sub_images: | 
					
					
						
						| 
							 | 
						            image_tensor_1 = image_processor_high(image) | 
					
					
						
						| 
							 | 
						            image_list.append(image_tensor_1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_list = torch.stack(image_list) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        print('====new images batch size======:  \n',image_list.shape) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if use_im_start_end: | 
					
					
						
						| 
							 | 
						            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs  | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        conv_mpt = Conversation( | 
					
					
						
						| 
							 | 
						            system="""<|im_start|>system | 
					
					
						
						| 
							 | 
						        You should follow the instructions carefully and explain your answers in detail.""", | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), | 
					
					
						
						| 
							 | 
						            version="mpt", | 
					
					
						
						| 
							 | 
						            messages=(), | 
					
					
						
						| 
							 | 
						            offset=0, | 
					
					
						
						| 
							 | 
						            sep_style=SeparatorStyle.MPT, | 
					
					
						
						| 
							 | 
						            sep="<|im_end|>", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        conv = conv_mpt.copy() | 
					
					
						
						| 
							 | 
						        conv.append_message(conv.roles[0], qs) | 
					
					
						
						| 
							 | 
						        conv.append_message(conv.roles[1], None) | 
					
					
						
						| 
							 | 
						        prompt = conv.get_prompt() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if print_prompt: | 
					
					
						
						| 
							 | 
						            print(prompt) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        inputs = tokenizer([prompt]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        input_ids = torch.as_tensor(inputs.input_ids).cuda() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | 
					
					
						
						| 
							 | 
						        keywords = [stop_str] | 
					
					
						
						| 
							 | 
						        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | 
					
					
						
						| 
							 | 
						        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if stream_flag: | 
					
					
						
						| 
							 | 
						            with torch.autocast("cuda", dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						                output_ids = self.generate( | 
					
					
						
						| 
							 | 
						                    input_ids, | 
					
					
						
						| 
							 | 
						                    images=[image_list.half().cuda()], | 
					
					
						
						| 
							 | 
						                    do_sample=False, | 
					
					
						
						| 
							 | 
						                    num_beams = 1, | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    streamer=streamer, | 
					
					
						
						| 
							 | 
						                    max_new_tokens=4096, | 
					
					
						
						| 
							 | 
						                    stopping_criteria=[stopping_criteria] | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            with torch.autocast("cuda", dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						                output_ids = self.generate( | 
					
					
						
						| 
							 | 
						                    input_ids, | 
					
					
						
						| 
							 | 
						                    images=[image_list.half().cuda()], | 
					
					
						
						| 
							 | 
						                    do_sample=False, | 
					
					
						
						| 
							 | 
						                    num_beams = 1, | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    max_new_tokens=4096, | 
					
					
						
						| 
							 | 
						                    stopping_criteria=[stopping_criteria] | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if outputs.endswith(stop_str): | 
					
					
						
						| 
							 | 
						            outputs = outputs[:-len(stop_str)] | 
					
					
						
						| 
							 | 
						        outputs = outputs.strip()    | 
					
					
						
						| 
							 | 
						        response_str = outputs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if render: | 
					
					
						
						| 
							 | 
						            print('==============rendering===============') | 
					
					
						
						| 
							 | 
						            from .render_tools import content_mmd_to_html | 
					
					
						
						| 
							 | 
						            html_path_2 = save_render_file | 
					
					
						
						| 
							 | 
						            right_num = outputs.count('\\right') | 
					
					
						
						| 
							 | 
						            left_num = outputs.count('\left') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if right_num != left_num: | 
					
					
						
						| 
							 | 
						                outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            outputs = outputs.replace('"', '``').replace('$', '') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            outputs_list = outputs.split('\n') | 
					
					
						
						| 
							 | 
						            gt= '' | 
					
					
						
						| 
							 | 
						            for out in outputs_list: | 
					
					
						
						| 
							 | 
						                gt +=  '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'  | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            gt = gt[:-2] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            lines = content_mmd_to_html | 
					
					
						
						| 
							 | 
						            lines = lines.split("const text =") | 
					
					
						
						| 
							 | 
						            new_web = lines[0] + 'const text ='  + gt  + lines[1] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						            with open(html_path_2, 'w') as web_f_new: | 
					
					
						
						| 
							 | 
						                web_f_new.write(new_web) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return response_str |