weege007 commited on
Commit
0a5453e
·
verified ·
1 Parent(s): 21b7933

Update modeling_deepseekocr.py

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr.py +17 -11
modeling_deepseekocr.py CHANGED
@@ -383,6 +383,7 @@ class DeepseekOCRModel(DeepseekV2Model):
383
  images_seq_mask: Optional[torch.FloatTensor] = None,
384
  images_spatial_crop: Optional[torch.FloatTensor] = None,
385
  return_dict: Optional[bool] = None,
 
386
  ) -> Union[Tuple, BaseModelOutputWithPast]:
387
 
388
 
@@ -432,10 +433,11 @@ class DeepseekOCRModel(DeepseekV2Model):
432
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
433
  global_features = self.projector(global_features)
434
 
435
- print('=====================')
436
- print('BASE: ', global_features.shape)
437
- print('PATCHES: ', local_features.shape)
438
- print('=====================')
 
439
 
440
  _, hw, n_dim = global_features.shape
441
  h = w = int(hw ** 0.5)
@@ -475,10 +477,12 @@ class DeepseekOCRModel(DeepseekV2Model):
475
  global_features_2 = vision_model(image_ori, global_features_1)
476
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
477
  global_features = self.projector(global_features)
478
- print('=====================')
479
- print('BASE: ', global_features.shape)
480
- print('NO PATCHES')
481
- print('=====================')
 
 
482
  _, hw, n_dim = global_features.shape
483
  h = w = int(hw ** 0.5)
484
 
@@ -700,7 +704,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
700
 
701
 
702
 
703
- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None):
704
  self.disable_torch_init()
705
 
706
  if len(output_path) > 0 :
@@ -926,7 +930,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
926
  streamer=streamer,
927
  max_new_tokens=8192,
928
  no_repeat_ngram_size = 20,
929
- use_cache = True
 
930
  )
931
 
932
  else:
@@ -943,7 +948,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
943
  eos_token_id=tokenizer.eos_token_id,
944
  max_new_tokens=8192,
945
  no_repeat_ngram_size = 35,
946
- use_cache = True
 
947
  )
948
 
949
 
 
383
  images_seq_mask: Optional[torch.FloatTensor] = None,
384
  images_spatial_crop: Optional[torch.FloatTensor] = None,
385
  return_dict: Optional[bool] = None,
386
+ verbose: Optional[bool] = None,
387
  ) -> Union[Tuple, BaseModelOutputWithPast]:
388
 
389
 
 
433
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
434
  global_features = self.projector(global_features)
435
 
436
+ if verbose:
437
+ print('=====================')
438
+ print('BASE: ', global_features.shape)
439
+ print('PATCHES: ', local_features.shape)
440
+ print('=====================')
441
 
442
  _, hw, n_dim = global_features.shape
443
  h = w = int(hw ** 0.5)
 
477
  global_features_2 = vision_model(image_ori, global_features_1)
478
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
479
  global_features = self.projector(global_features)
480
+
481
+ if verbose:
482
+ print('=====================')
483
+ print('BASE: ', global_features.shape)
484
+ print('NO PATCHES')
485
+ print('=====================')
486
  _, hw, n_dim = global_features.shape
487
  h = w = int(hw ** 0.5)
488
 
 
704
 
705
 
706
 
707
+ def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
708
  self.disable_torch_init()
709
 
710
  if len(output_path) > 0 :
 
930
  streamer=streamer,
931
  max_new_tokens=8192,
932
  no_repeat_ngram_size = 20,
933
+ use_cache = True,
934
+ verbose = verbose
935
  )
936
 
937
  else:
 
948
  eos_token_id=tokenizer.eos_token_id,
949
  max_new_tokens=8192,
950
  no_repeat_ngram_size = 35,
951
+ use_cache = True,
952
+ verbose = verbose
953
  )
954
 
955