Update modeling_deepseekocr.py
Browse files- modeling_deepseekocr.py +17 -11
modeling_deepseekocr.py
CHANGED
|
@@ -383,6 +383,7 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 383 |
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 384 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 385 |
return_dict: Optional[bool] = None,
|
|
|
|
| 386 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 387 |
|
| 388 |
|
|
@@ -432,10 +433,11 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 432 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 433 |
global_features = self.projector(global_features)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
|
|
|
| 439 |
|
| 440 |
_, hw, n_dim = global_features.shape
|
| 441 |
h = w = int(hw ** 0.5)
|
|
@@ -475,10 +477,12 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 475 |
global_features_2 = vision_model(image_ori, global_features_1)
|
| 476 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 477 |
global_features = self.projector(global_features)
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
| 482 |
_, hw, n_dim = global_features.shape
|
| 483 |
h = w = int(hw ** 0.5)
|
| 484 |
|
|
@@ -700,7 +704,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 700 |
|
| 701 |
|
| 702 |
|
| 703 |
-
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None):
|
| 704 |
self.disable_torch_init()
|
| 705 |
|
| 706 |
if len(output_path) > 0 :
|
|
@@ -926,7 +930,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 926 |
streamer=streamer,
|
| 927 |
max_new_tokens=8192,
|
| 928 |
no_repeat_ngram_size = 20,
|
| 929 |
-
use_cache = True
|
|
|
|
| 930 |
)
|
| 931 |
|
| 932 |
else:
|
|
@@ -943,7 +948,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 943 |
eos_token_id=tokenizer.eos_token_id,
|
| 944 |
max_new_tokens=8192,
|
| 945 |
no_repeat_ngram_size = 35,
|
| 946 |
-
use_cache = True
|
|
|
|
| 947 |
)
|
| 948 |
|
| 949 |
|
|
|
|
| 383 |
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 384 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 385 |
return_dict: Optional[bool] = None,
|
| 386 |
+
verbose: Optional[bool] = None,
|
| 387 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 388 |
|
| 389 |
|
|
|
|
| 433 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 434 |
global_features = self.projector(global_features)
|
| 435 |
|
| 436 |
+
if verbose:
|
| 437 |
+
print('=====================')
|
| 438 |
+
print('BASE: ', global_features.shape)
|
| 439 |
+
print('PATCHES: ', local_features.shape)
|
| 440 |
+
print('=====================')
|
| 441 |
|
| 442 |
_, hw, n_dim = global_features.shape
|
| 443 |
h = w = int(hw ** 0.5)
|
|
|
|
| 477 |
global_features_2 = vision_model(image_ori, global_features_1)
|
| 478 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 479 |
global_features = self.projector(global_features)
|
| 480 |
+
|
| 481 |
+
if verbose:
|
| 482 |
+
print('=====================')
|
| 483 |
+
print('BASE: ', global_features.shape)
|
| 484 |
+
print('NO PATCHES')
|
| 485 |
+
print('=====================')
|
| 486 |
_, hw, n_dim = global_features.shape
|
| 487 |
h = w = int(hw ** 0.5)
|
| 488 |
|
|
|
|
| 704 |
|
| 705 |
|
| 706 |
|
| 707 |
+
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
|
| 708 |
self.disable_torch_init()
|
| 709 |
|
| 710 |
if len(output_path) > 0 :
|
|
|
|
| 930 |
streamer=streamer,
|
| 931 |
max_new_tokens=8192,
|
| 932 |
no_repeat_ngram_size = 20,
|
| 933 |
+
use_cache = True,
|
| 934 |
+
verbose = verbose
|
| 935 |
)
|
| 936 |
|
| 937 |
else:
|
|
|
|
| 948 |
eos_token_id=tokenizer.eos_token_id,
|
| 949 |
max_new_tokens=8192,
|
| 950 |
no_repeat_ngram_size = 35,
|
| 951 |
+
use_cache = True,
|
| 952 |
+
verbose = verbose
|
| 953 |
)
|
| 954 |
|
| 955 |
|