lvyufeng commited on
Commit
195cfa9
·
verified ·
1 Parent(s): 0a34c88

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,54 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.gguf* filter=lfs diff=lfs merge=lfs -text
36
+ *.ggml filter=lfs diff=lfs merge=lfs -text
37
+ *.llamafile* filter=lfs diff=lfs merge=lfs -text
38
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
39
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
40
+ *.npy filter=lfs diff=lfs merge=lfs -text
41
+ *.npz filter=lfs diff=lfs merge=lfs -text
42
+ *.pickle filter=lfs diff=lfs merge=lfs -text
43
+ *.pkl filter=lfs diff=lfs merge=lfs -text
44
+ *.tar filter=lfs diff=lfs merge=lfs -text
45
+ *.wasm filter=lfs diff=lfs merge=lfs -text
46
  *.zst filter=lfs diff=lfs merge=lfs -text
47
  *tfevents* filter=lfs diff=lfs merge=lfs -text
48
+
49
+ model-00001-of-000001.safetensors filter=lfs diff=lfs merge=lfs -text
50
+ tokenizer.json filter=lfs diff=lfs merge=lfs -textassets/fig1.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/show1.jpg filter=lfs diff=lfs merge=lfs -text
52
+ assets/show2.jpg filter=lfs diff=lfs merge=lfs -text
53
+ assets/show3.jpg filter=lfs diff=lfs merge=lfs -text
54
+ assets/show4.jpg filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: image-text-to-text
3
+ language:
4
+ - multilingual
5
+ tags:
6
+ - deepseek
7
+ - vision-language
8
+ - ocr
9
+ - custom_code
10
+ license: mit
11
+ ---
12
+ <div align="center">
13
+ <img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
14
+ </div>
15
+ <hr>
16
+ <div align="center">
17
+ <a href="https://www.deepseek.com/" target="_blank">
18
+ <img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
19
+ </a>
20
+ <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR" target="_blank">
21
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
22
+ </a>
23
+
24
+ </div>
25
+
26
+ <div align="center">
27
+
28
+ <a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
29
+ <img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
30
+ </a>
31
+ <a href="https://twitter.com/deepseek_ai" target="_blank">
32
+ <img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
33
+ </a>
34
+
35
+ </div>
36
+
37
+
38
+
39
+ <p align="center">
40
+ <a href="https://github.com/deepseek-ai/DeepSeek-OCR"><b>🌟 Github</b></a> |
41
+ <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR"><b>📥 Model Download</b></a> |
42
+ <a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
43
+ <a href=""><b>📄 Arxiv Paper Link</b></a> |
44
+ </p>
45
+ <h2>
46
+ <p align="center">
47
+ <a href="">DeepSeek-OCR: Contexts Optical Compression</a>
48
+ </p>
49
+ </h2>
50
+ <p align="center">
51
+ <img src="assets/fig1.png" style="width: 1000px" align=center>
52
+ </p>
53
+ <p align="center">
54
+ <a href="">Explore the boundaries of visual-text compression.</a>
55
+ </p>
56
+
57
+ ## Usage
58
+ Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
59
+
60
+ ```
61
+ torch==2.6.0
62
+ transformers==4.46.3
63
+ tokenizers==0.20.3
64
+ einops
65
+ addict
66
+ easydict
67
+ pip install flash-attn==2.7.3 --no-build-isolation
68
+ ```
69
+
70
+ ```python
71
+ from transformers import AutoModel, AutoTokenizer
72
+ import torch
73
+ import os
74
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
75
+ model_name = 'deepseek-ai/DeepSeek-OCR'
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
78
+ model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
79
+ model = model.eval().cuda().to(torch.bfloat16)
80
+
81
+ # prompt = "<image>\nFree OCR. "
82
+ prompt = "<image>\n<|grounding|>Convert the document to markdown. "
83
+ image_file = 'your_image.jpg'
84
+ output_path = 'your/output/dir'
85
+
86
+ # infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
87
+
88
+ # Tiny: base_size = 512, image_size = 512, crop_mode = False
89
+ # Small: base_size = 640, image_size = 640, crop_mode = False
90
+ # Base: base_size = 1024, image_size = 1024, crop_mode = False
91
+ # Large: base_size = 1280, image_size = 1280, crop_mode = False
92
+
93
+ # Gundam: base_size = 1024, image_size = 640, crop_mode = True
94
+
95
+ res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
96
+ ```
97
+
98
+ ## vLLM
99
+ Refer to [🌟GitHub](https://github.com/deepseek-ai/DeepSeek-OCR/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
100
+
101
+ ## Visualizations
102
+ <table>
103
+ <tr>
104
+ <td><img src="assets/show1.jpg" style="width: 500px"></td>
105
+ <td><img src="assets/show2.jpg" style="width: 500px"></td>
106
+ </tr>
107
+ <tr>
108
+ <td><img src="assets/show3.jpg" style="width: 500px"></td>
109
+ <td><img src="assets/show4.jpg" style="width: 500px"></td>
110
+ </tr>
111
+ </table>
112
+
113
+
114
+ ## Acknowledgement
115
+
116
+ We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
117
+
118
+ We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
119
+
120
+
121
+ ## Citation
122
+ Coming soon!
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,132 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: image-text-to-text
3
+ language:
4
+ - multilingual
5
+ tags:
6
+ - mindspore
7
+ - mindnlp
8
+ - deepseek
9
+ - vision-language
10
+ - ocr
11
+ - custom_code
12
+ license: mit
13
+ ---
14
+ <div align="center">
15
+ <img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
16
+ </div>
17
+ <hr>
18
+ <div align="center">
19
+ <a href="https://www.deepseek.com/" target="_blank">
20
+ <img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
21
+ </a>
22
+ <a href="https://huggingface.co/lvyufeng/DeepSeek-OCR" target="_blank">
23
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
24
+ </a>
25
+
26
+ </div>
27
+
28
+
29
+
30
+
31
+ <p align="center">
32
+ <a href="https://github.com/mindspore-lab/mindnlp/tree/master/examples/transformers/inference/deepseek-ocr"><b>🌟 Github</b></a> |
33
+ <a href="https://huggingface.co/lvyufeng/DeepSeek-OCR"><b>📥 Model Download</b></a> |
34
+ <a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
35
+ <a href=""><b>📄 Arxiv Paper Link</b></a> |
36
+ </p>
37
+ <h2>
38
+ <p align="center">
39
+ <a href="">DeepSeek-OCR: Contexts Optical Compression</a>
40
+ </p>
41
+ </h2>
42
+ <p align="center">
43
+ <a href="">Explore the boundaries of visual-text compression.</a>
44
+ </p>
45
+
46
+ The official version of DeepSeek-OCR has limited the transformers version to 4.46.3 and has not been adapted to the latest version. Therefore, this community edition has modified the modeling.py module to facilitate user convenience without requiring a transformers downgrade. Additionally, this version has been adapted for MindSpore+MindNLP compatibility, and users are welcome to utilize it on Ascend hardware.
47
+
48
+ Feel free to opt for various attention implementations such as Flash Attention or SDPA to leverage the latest optimizations in transformers for a performance boost.
49
+
50
+ ## MindSpore Usage
51
+ Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
52
+
53
+ ```
54
+ mindspore==2.7.0
55
+ mindnlp==0.5.0rc3
56
+ transformers==4.57.1
57
+ tokenizers
58
+ einops
59
+ addict
60
+ easydict
61
+ ```
62
+
63
+ ```python
64
+ import os
65
+ import mindnlp
66
+ import mindspore
67
+ from transformers import AutoModel, AutoTokenizer
68
+
69
+ model_name = 'lvyufeng/DeepSeek-OCR-Community-Latest'
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
72
+ model = AutoModel.from_pretrained(model_name, dtype=mindspore.float16, _attn_implementation='sdpa', trust_remote_code=True, use_safetensors=True, device_map='auto')
73
+ model = model.eval()
74
+
75
+ # prompt = "<image>\nFree OCR. "
76
+ prompt = "<image>\n<|grounding|>Convert the document to markdown. "
77
+ image_file = 'your_image.jpg'
78
+ output_path = 'your/output/dir'
79
+
80
+ # infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
81
+
82
+ # Tiny: base_size = 512, image_size = 512, crop_mode = False
83
+ # Small: base_size = 640, image_size = 640, crop_mode = False
84
+ # Base: base_size = 1024, image_size = 1024, crop_mode = False
85
+ # Large: base_size = 1280, image_size = 1280, crop_mode = False
86
+
87
+ # Gundam: base_size = 1024, image_size = 640, crop_mode = True
88
+
89
+ res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
90
+ ```
91
+
92
+
93
+ ## Pytorch Usage
94
+ Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
95
+
96
+ ```
97
+ torch
98
+ transformers==4.57.1
99
+ tokenizers
100
+ einops
101
+ addict
102
+ easydict
103
+ pip install flash-attn
104
+ ```
105
+
106
+ ```python
107
+ from transformers import AutoModel, AutoTokenizer
108
+ import torch
109
+ import os
110
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
111
+ model_name = 'deepseek-ai/DeepSeek-OCR'
112
+ tokenizer = AutoTokenizer.from_pretrained(model_name, dtype=torch.bfloat16,trust_remote_code=True, device_map='auto')
113
+ model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
114
+ model = model.eval()
115
+ # prompt = "<image>\nFree OCR. "
116
+ prompt = "<image>\n<|grounding|>Convert the document to markdown. "
117
+ image_file = 'your_image.jpg'
118
+ output_path = 'your/output/dir'
119
+ # infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
120
+ # Tiny: base_size = 512, image_size = 512, crop_mode = False
121
+ # Small: base_size = 640, image_size = 640, crop_mode = False
122
+ # Base: base_size = 1024, image_size = 1024, crop_mode = False
123
+ # Large: base_size = 1280, image_size = 1280, crop_mode = False
124
+ # Gundam: base_size = 1024, image_size = 640, crop_mode = True
125
+ res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
126
+ ```
127
+
128
+ ## Acknowledgement
129
+
130
+ We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
131
+
132
+ We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
assets/fig1.png ADDED

Git LFS Details

  • SHA256: 3a709eb3fdb51cf6d8a546ffb8efe3c80bef61ae0183c2c8476a4c4a41efa3f1
  • Pointer size: 131 Bytes
  • Size of remote file: 396 kB
assets/show1.jpg ADDED

Git LFS Details

  • SHA256: 887e88e60e5833bc10a2cd7edb89ea7e6992abaae5e1550b027c611b8b8456f2
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
assets/show2.jpg ADDED

Git LFS Details

  • SHA256: 81d08f7f33d9d39b95dd9b8162506659e6822d621b9829a208f3830c34c2b4d0
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB
assets/show3.jpg ADDED

Git LFS Details

  • SHA256: cd24b0cfc7b6c0b1b34bd1aa55bc385e746298fdd82410db6c0d4e0bf69085c0
  • Pointer size: 131 Bytes
  • Size of remote file: 247 kB
assets/show4.jpg ADDED

Git LFS Details

  • SHA256: 2fe88eacc470c34d00225151372d3770948864f3d9cfaae16afa15b2432d7793
  • Pointer size: 131 Bytes
  • Size of remote file: 269 kB
config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "deepseek-ai/DeepSeek-OCR",
3
+ "candidate_resolutions": [
4
+ [
5
+ 1024,
6
+ 1024
7
+ ]
8
+ ],
9
+ "global_view_pos": "head",
10
+ "architectures": [
11
+ "DeepseekOCRForCausalLM"
12
+ ],
13
+ "auto_map": {
14
+ "AutoConfig": "modeling_deepseekocr.DeepseekOCRConfig",
15
+ "AutoModel": "modeling_deepseekocr.DeepseekOCRForCausalLM"
16
+ },
17
+ "language_config": {
18
+ "architectures": [
19
+ "DeepseekV2ForCausalLM"
20
+ ],
21
+ "auto_map": {
22
+ "AutoConfig": "configuration_deepseekv2.DeepseekV2Config",
23
+ "AutoModel": "modeling_deepseek.DeepseekV2Model",
24
+ "AutoModelForCausalLM": "modeling_deepseek.DeepseekV2ForCausalLM"
25
+ },
26
+ "bos_token_id": 0,
27
+ "eos_token_id": 1,
28
+ "first_k_dense_replace": 1,
29
+ "hidden_size": 1280,
30
+ "intermediate_size": 6848,
31
+ "kv_lora_rank": null,
32
+ "lm_head": true,
33
+ "max_position_embeddings": 8192,
34
+ "moe_intermediate_size": 896,
35
+ "n_group": 1,
36
+ "n_routed_experts": 64,
37
+ "n_shared_experts": 2,
38
+ "num_attention_heads": 10,
39
+ "num_experts_per_tok": 6,
40
+ "num_hidden_layers": 12,
41
+ "num_key_value_heads": 10,
42
+ "q_lora_rank": null,
43
+ "qk_nope_head_dim": 0,
44
+ "qk_rope_head_dim": 0,
45
+ "rm_head": false,
46
+ "topk_group": 1,
47
+ "topk_method": "greedy",
48
+ "torch_dtype": "bfloat16",
49
+ "use_mla": false,
50
+ "v_head_dim": 0,
51
+ "vocab_size": 129280
52
+ },
53
+ "model_type": "deepseek_vl_v2",
54
+ "projector_config": {
55
+ "input_dim": 2048,
56
+ "model_type": "mlp_projector",
57
+ "n_embed": 1280,
58
+ "projector_type": "linear"
59
+ },
60
+ "tile_tag": "2D",
61
+ "torch_dtype": "bfloat16",
62
+ "transformers_version": "4.46.3",
63
+ "vision_config": {
64
+ "image_size": 1024,
65
+ "mlp_ratio": 3.7362,
66
+ "model_name": "deeplip_b_l",
67
+ "model_type": "vision",
68
+ "width": {
69
+ "clip-l-14-224": {
70
+ "heads": 16,
71
+ "image_size": 224,
72
+ "layers": 24,
73
+ "patch_size": 14,
74
+ "width": 1024
75
+ },
76
+ "sam_vit_b": {
77
+ "downsample_channels": [
78
+ 512,
79
+ 1024
80
+ ],
81
+ "global_attn_indexes": [
82
+ 2,
83
+ 5,
84
+ 8,
85
+ 11
86
+ ],
87
+ "heads": 12,
88
+ "layers": 12,
89
+ "width": 768
90
+ }
91
+ }
92
+ },
93
+ "bos_token_id": 0,
94
+ "eos_token_id": 1,
95
+ "first_k_dense_replace": 1,
96
+ "hidden_size": 1280,
97
+ "intermediate_size": 6848,
98
+ "kv_lora_rank": null,
99
+ "lm_head": true,
100
+ "max_position_embeddings": 8192,
101
+ "moe_intermediate_size": 896,
102
+ "n_group": 1,
103
+ "n_routed_experts": 64,
104
+ "n_shared_experts": 2,
105
+ "num_attention_heads": 10,
106
+ "num_experts_per_tok": 6,
107
+ "num_hidden_layers": 12,
108
+ "num_key_value_heads": 10,
109
+ "q_lora_rank": null,
110
+ "qk_nope_head_dim": 0,
111
+ "qk_rope_head_dim": 0,
112
+ "rm_head": false,
113
+ "topk_group": 1,
114
+ "topk_method": "greedy",
115
+ "use_mla": false,
116
+ "v_head_dim": 0,
117
+ "vocab_size": 129280
118
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework": "pytorch", "task": "image-text-to-text", "allow_remote": true}
configuration_deepseek_v2.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
7
+ class DeepseekV2Config(PretrainedConfig):
8
+ r"""
9
+ This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
10
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
11
+ defaults will yield a similar configuration to that of the DeepSeek-V2 with multi-latent attention.
12
+
13
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14
+ documentation from [`PretrainedConfig`] for more information.
15
+
16
+
17
+ Args:
18
+ vocab_size (`int`, *optional*, defaults to 102400):
19
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
20
+ `inputs_ids` passed when calling [`DeepseekV2Model`]
21
+ hidden_size (`int`, *optional*, defaults to 4096):
22
+ Dimension of the hidden representations.
23
+ intermediate_size (`int`, *optional*, defaults to 11008):
24
+ Dimension of the MLP representations.
25
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
26
+ Dimension of the MoE representations.
27
+ num_hidden_layers (`int`, *optional*, defaults to 32):
28
+ Number of hidden layers in the Transformer decoder.
29
+ num_attention_heads (`int`, *optional*, defaults to 32):
30
+ Number of attention heads for each attention layer in the Transformer decoder.
31
+ n_shared_experts (`int`, *optional*, defaults to None):
32
+ Number of shared experts, None means dense model.
33
+ n_routed_experts (`int`, *optional*, defaults to None):
34
+ Number of routed experts, None means dense model.
35
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
36
+ Scaling factor or routed experts.
37
+ topk_method (`str`, *optional*, defaults to `gready`):
38
+ Topk method used in routed gate.
39
+ n_group (`int`, *optional*, defaults to None):
40
+ Number of groups for routed experts.
41
+ topk_group (`int`, *optional*, defaults to None):
42
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
43
+ num_experts_per_tok (`int`, *optional*, defaults to None):
44
+ Number of selected experts, None means dense model.
45
+ moe_layer_freq (`int`, *optional*, defaults to 1):
46
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
47
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
48
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
49
+ \--k dense layers--/
50
+ norm_topk_prob (`bool`, *optional*, defaults to False):
51
+ Whether to normalize the weights of the routed experts.
52
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
53
+ Method of computing expert weights.
54
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
55
+ Auxiliary loss weight coefficient.
56
+ seq_aux = (`bool`, *optional*, defaults to True):
57
+ Whether to compute the auxiliary loss for each individual sample.
58
+ num_key_value_heads (`int`, *optional*):
59
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
60
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
61
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
62
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
63
+ by meanpooling all the original heads within that group. For more details checkout [this
64
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
65
+ `num_attention_heads`.
66
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67
+ The non-linear activation function (function or string) in the decoder.
68
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
69
+ The maximum sequence length that this model might ever be used with.
70
+ initializer_range (`float`, *optional*, defaults to 0.02):
71
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
73
+ The epsilon used by the rms normalization layers.
74
+ use_cache (`bool`, *optional*, defaults to `True`):
75
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
76
+ relevant if `config.is_decoder=True`.
77
+ pad_token_id (`int`, *optional*):
78
+ Padding token id.
79
+ bos_token_id (`int`, *optional*, defaults to 1):
80
+ Beginning of stream token id.
81
+ eos_token_id (`int`, *optional*, defaults to 2):
82
+ End of stream token id.
83
+ pretraining_tp (`int`, *optional*, defaults to 1):
84
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
85
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
86
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
87
+ issue](https://github.com/pytorch/pytorch/issues/76232).
88
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
89
+ Whether to tie weight embeddings
90
+ rope_theta (`float`, *optional*, defaults to 10000.0):
91
+ The base period of the RoPE embeddings.
92
+ rope_scaling (`Dict`, *optional*):
93
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
94
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
95
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
96
+ `max_position_embeddings` to the expected new maximum.
97
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
98
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
99
+ attention_dropout (`float`, *optional*, defaults to 0.0):
100
+ The dropout ratio for the attention probabilities.
101
+ use_mla (`bool`, *optional*, defaults to `True`): Use multi-latent attention or multi-head attention. If True,
102
+ the model will use multi-latent attention, otherwise, it will use multi-head attention.
103
+
104
+ ```python
105
+ >>> from transformers import DeepseekV2Model, DeepseekV2Config
106
+
107
+ >>> # Initializing a Deepseek-V2 style configuration
108
+ >>> configuration = DeepseekV2Config()
109
+
110
+ >>> # Accessing the model configuration
111
+ >>> configuration = model.config
112
+ ```"""
113
+
114
+ model_type = "deepseek_v2"
115
+ keys_to_ignore_at_inference = ["past_key_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size=102400,
120
+ hidden_size=4096,
121
+ intermediate_size=11008,
122
+ moe_intermediate_size = 1407,
123
+ num_hidden_layers=30,
124
+ num_attention_heads=32,
125
+ num_key_value_heads=32,
126
+ n_shared_experts = None,
127
+ n_routed_experts = None,
128
+ ep_size = 1,
129
+ routed_scaling_factor = 1.0,
130
+ kv_lora_rank = 512,
131
+ q_lora_rank = 1536,
132
+ qk_rope_head_dim = 64,
133
+ v_head_dim = 128,
134
+ qk_nope_head_dim = 128,
135
+ topk_method = 'gready',
136
+ n_group = None,
137
+ topk_group = None,
138
+ num_experts_per_tok = None,
139
+ moe_layer_freq = 1,
140
+ first_k_dense_replace = 0,
141
+ norm_topk_prob = False,
142
+ scoring_func = 'softmax',
143
+ aux_loss_alpha = 0.001,
144
+ seq_aux = True,
145
+ hidden_act="silu",
146
+ max_position_embeddings=2048,
147
+ initializer_range=0.02,
148
+ rms_norm_eps=1e-6,
149
+ use_cache=True,
150
+ pad_token_id=None,
151
+ bos_token_id=100000,
152
+ eos_token_id=100001,
153
+ pretraining_tp=1,
154
+ tie_word_embeddings=False,
155
+ rope_theta=10000.0,
156
+ rope_scaling=None,
157
+ attention_bias=False,
158
+ attention_dropout=0.0,
159
+ use_mla=True,
160
+ **kwargs,
161
+ ):
162
+ self.vocab_size = vocab_size
163
+ self.max_position_embeddings = max_position_embeddings
164
+ self.hidden_size = hidden_size
165
+ self.intermediate_size = intermediate_size
166
+ self.moe_intermediate_size = moe_intermediate_size
167
+ self.num_hidden_layers = num_hidden_layers
168
+ self.num_attention_heads = num_attention_heads
169
+ self.n_shared_experts = n_shared_experts
170
+ self.n_routed_experts = n_routed_experts
171
+ self.ep_size = ep_size
172
+ self.routed_scaling_factor = routed_scaling_factor
173
+ self.kv_lora_rank = kv_lora_rank
174
+ self.q_lora_rank = q_lora_rank
175
+ self.qk_rope_head_dim = qk_rope_head_dim
176
+ self.v_head_dim = v_head_dim
177
+ self.qk_nope_head_dim = qk_nope_head_dim
178
+ self.topk_method = topk_method
179
+ self.n_group = n_group
180
+ self.topk_group = topk_group
181
+ self.num_experts_per_tok = num_experts_per_tok
182
+ self.moe_layer_freq = moe_layer_freq
183
+ self.first_k_dense_replace = first_k_dense_replace
184
+ self.norm_topk_prob = norm_topk_prob
185
+ self.scoring_func = scoring_func
186
+ self.aux_loss_alpha = aux_loss_alpha
187
+ self.seq_aux = seq_aux
188
+ # for backward compatibility
189
+ if num_key_value_heads is None:
190
+ num_key_value_heads = num_attention_heads
191
+
192
+ self.num_key_value_heads = num_key_value_heads
193
+ self.hidden_act = hidden_act
194
+ self.initializer_range = initializer_range
195
+ self.rms_norm_eps = float(rms_norm_eps)
196
+ self.pretraining_tp = pretraining_tp
197
+ self.use_cache = use_cache
198
+ self.rope_theta = rope_theta
199
+ self.rope_scaling = rope_scaling
200
+ self.attention_bias = attention_bias
201
+ self.attention_dropout = attention_dropout
202
+ self.use_mla = use_mla
203
+
204
+ super().__init__(
205
+ pad_token_id=pad_token_id,
206
+ bos_token_id=bos_token_id,
207
+ eos_token_id=eos_token_id,
208
+ tie_word_embeddings=tie_word_embeddings,
209
+ **kwargs,
210
+ )
conversation.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
+ """
4
+
5
+ import dataclasses
6
+ from enum import IntEnum, auto
7
+ from typing import Any, Dict, List
8
+
9
+
10
+ class SeparatorStyle(IntEnum):
11
+ """Separator styles."""
12
+
13
+ DeepSeek = auto()
14
+ DeepSeekV2 = auto()
15
+ PLAIN = auto()
16
+ ALIGNMENT = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that manages prompt templates and keeps all conversation history."""
22
+
23
+ # The name of this template
24
+ name: str
25
+ # The template of the system prompt
26
+ system_template: str = "{system_message}"
27
+ # The system message
28
+ system_message: str = ""
29
+ # The names of two roles
30
+ roles: List[str] = (("USER", "ASSISTANT"),)
31
+ # All messages. Each item is (role, message).
32
+ messages: List[List[str]] = ()
33
+ # The number of few shot examples
34
+ offset: int = 0
35
+ # The separator style and configurations
36
+ sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37
+ sep: str = "\n"
38
+ sep2: str = None
39
+ # Stop criteria (the default one is EOS token)
40
+ stop_str: str = None
41
+ # Stops generation if meeting any token in this list
42
+ stop_token_ids: List[int] = None
43
+
44
+ def get_prompt(self) -> str:
45
+ """Get the prompt for generation."""
46
+ system_prompt = self.system_template.format(system_message=self.system_message)
47
+ if self.sep_style == SeparatorStyle.DeepSeek:
48
+ seps = [self.sep, self.sep2]
49
+ if system_prompt == "" or system_prompt is None:
50
+ ret = ""
51
+ else:
52
+ ret = system_prompt + seps[0]
53
+ for i, (role, message) in enumerate(self.messages):
54
+ if message:
55
+ ret += role + ": " + message + seps[i % 2]
56
+ else:
57
+ ret += role + ":"
58
+ return ret
59
+ elif self.sep_style == SeparatorStyle.DeepSeekV2:
60
+ seps = [self.sep, self.sep2]
61
+ if system_prompt == "" or system_prompt is None:
62
+ ret = ""
63
+ else:
64
+ ret = system_prompt + seps[0]
65
+ for i, (role, message) in enumerate(self.messages):
66
+ if message:
67
+ if role == "User":
68
+ ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
69
+ else:
70
+ ret += message + self.sep2
71
+ else:
72
+ ret = ret
73
+ return ret
74
+
75
+ elif self.sep_style == SeparatorStyle.PLAIN:
76
+ seps = [self.sep, self.sep2]
77
+ ret = ""
78
+ for i, (role, message) in enumerate(self.messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i % 2 == 0:
83
+ ret += message + seps[i % 2]
84
+ else:
85
+ ret += message + seps[i % 2]
86
+ else:
87
+ ret += ""
88
+ return ret
89
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
90
+ seps = [self.sep, self.sep2]
91
+ ret = ""
92
+ for i, (role, message) in enumerate(self.messages):
93
+ if message:
94
+ if type(message) is tuple:
95
+ message, _, _ = message
96
+ if i % 2 == 0:
97
+ ret += '<image>\n' + seps[i % 2]
98
+ else:
99
+ ret += message + seps[i % 2]
100
+ else:
101
+ ret += ""
102
+ return ret
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ def set_system_message(self, system_message: str):
107
+ """Set the system message."""
108
+ self.system_message = system_message
109
+
110
+ def append_message(self, role: str, message: str):
111
+ """Append a new message."""
112
+ self.messages.append([role, message])
113
+
114
+ def update_last_message(self, message: str):
115
+ """Update the last output.
116
+
117
+ The last message is typically set to be None when constructing the prompt,
118
+ so we need to update it in-place after getting the response from a model.
119
+ """
120
+ self.messages[-1][1] = message
121
+
122
+ def reset_message(self):
123
+ """Reset a new message."""
124
+ self.messages = []
125
+
126
+ def to_gradio_chatbot(self):
127
+ """Convert the conversation to gradio chatbot format."""
128
+ ret = []
129
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
130
+ if i % 2 == 0:
131
+ ret.append([msg, None])
132
+ else:
133
+ ret[-1][-1] = msg
134
+ return ret
135
+
136
+ def to_openai_api_messages(self):
137
+ """Convert the conversation to OpenAI chat completion format."""
138
+ system_prompt = self.system_template.format(system_message=self.system_message)
139
+ ret = [{"role": "system", "content": system_prompt}]
140
+
141
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
142
+ if i % 2 == 0:
143
+ ret.append({"role": "user", "content": msg})
144
+ else:
145
+ if msg is not None:
146
+ ret.append({"role": "assistant", "content": msg})
147
+ return ret
148
+
149
+ def copy(self):
150
+ return Conversation(
151
+ name=self.name,
152
+ system_template=self.system_template,
153
+ system_message=self.system_message,
154
+ roles=self.roles,
155
+ messages=[[x, y] for x, y in self.messages],
156
+ offset=self.offset,
157
+ sep_style=self.sep_style,
158
+ sep=self.sep,
159
+ sep2=self.sep2,
160
+ stop_str=self.stop_str,
161
+ stop_token_ids=self.stop_token_ids,
162
+ )
163
+
164
+ def dict(self):
165
+ return {
166
+ "template_name": self.name,
167
+ "system_message": self.system_message,
168
+ "roles": self.roles,
169
+ "messages": self.messages,
170
+ "offset": self.offset,
171
+ }
172
+
173
+
174
+ # A global registry for all conversation templates
175
+ conv_templates: Dict[str, Conversation] = {}
176
+
177
+
178
+ def register_conv_template(template: Conversation, override: bool = False):
179
+ """Register a new conversation template."""
180
+ if not override:
181
+ assert template.name not in conv_templates, f"{template.name} has been registered."
182
+
183
+ conv_templates[template.name] = template
184
+
185
+
186
+ def get_conv_template(name: str) -> Conversation:
187
+ """Get a conversation template."""
188
+ return conv_templates[name].copy()
189
+
190
+
191
+ register_conv_template(
192
+ Conversation(
193
+ name="deepseek",
194
+ system_template="{system_message}",
195
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
196
+ # "thinking step by step to be sure you get the right answer.",
197
+ system_message="",
198
+ roles=("<|User|>", "<|Assistant|>"),
199
+ messages=(),
200
+ offset=0,
201
+ sep_style=SeparatorStyle.DeepSeek,
202
+ sep="\n\n",
203
+ sep2="<|end▁of▁sentence|>",
204
+ stop_token_ids=[100001],
205
+ stop_str=["User:", "<|end▁of▁sentence|>"]
206
+ )
207
+ )
208
+ register_conv_template(
209
+ Conversation(
210
+ name="deepseekv2",
211
+ system_template="{system_message}",
212
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
213
+ # "thinking step by step to be sure you get the right answer.",
214
+ system_message="",
215
+ roles=("<|User|>", "<|Assistant|>"),
216
+ messages=(),
217
+ offset=0,
218
+ sep_style=SeparatorStyle.DeepSeek,
219
+ sep="",
220
+ sep2="<|end▁of▁sentence|>",
221
+ stop_token_ids=[100001],
222
+ stop_str=["User:", "<|end▁of▁sentence|>"]
223
+ )
224
+ )
225
+
226
+
227
+ register_conv_template(
228
+ Conversation(
229
+ name="plain",
230
+ system_template="",
231
+ system_message="",
232
+ roles=("", ""),
233
+ messages=(),
234
+ offset=0,
235
+ sep_style=SeparatorStyle.PLAIN,
236
+ sep="",
237
+ sep2="",
238
+ stop_token_ids=[100001],
239
+ stop_str=['</s>'],
240
+ )
241
+ )
242
+
243
+
244
+ register_conv_template(
245
+ Conversation(
246
+ name="alignment",
247
+ system_template="",
248
+ system_message="",
249
+ roles=("", ""),
250
+ messages=(),
251
+ offset=0,
252
+ sep_style=SeparatorStyle.ALIGNMENT,
253
+ sep="",
254
+ sep2="",
255
+ stop_token_ids=[100001],
256
+ stop_str=['</s>'],
257
+ )
258
+ )
259
+
260
+
261
+ if __name__ == "__main__":
262
+ print("deepseek template:")
263
+ conv = get_conv_template("deepseek")
264
+ conv.append_message(conv.roles[0], "Hello!")
265
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
266
+ conv.append_message(conv.roles[0], "Who are you?")
267
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
268
+ conv.append_message(conv.roles[0], "How are you?")
269
+ conv.append_message(conv.roles[1], None)
270
+ print(conv.get_prompt())
271
+
272
+ print("deepseekv2 template:")
273
+ conv = get_conv_template("deepseekv2")
274
+ conv.append_message(conv.roles[0], "Hello!")
275
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
276
+ conv.append_message(conv.roles[0], "Who are you?")
277
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
278
+ conv.append_message(conv.roles[0], "How are you?")
279
+ conv.append_message(conv.roles[1], None)
280
+ print(conv.get_prompt())
deepencoder.py ADDED
@@ -0,0 +1,1058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import copy
5
+
6
+ from contextlib import nullcontext
7
+ import math
8
+ from typing import Optional, Tuple
9
+ # from megatron.model import LayerNorm
10
+
11
+ from einops import rearrange
12
+ from easydict import EasyDict as adict
13
+
14
+
15
+ from typing import Optional, Tuple, Type
16
+ from functools import partial
17
+
18
+
19
+
20
+ class MlpProjector(nn.Module):
21
+
22
+ def __init__(self, cfg):
23
+
24
+ super().__init__()
25
+
26
+ self.cfg = cfg
27
+
28
+ if cfg.projector_type == "identity":
29
+ modules = nn.Identity()
30
+
31
+ elif cfg.projector_type == "linear":
32
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
33
+
34
+ elif cfg.projector_type == "mlp_gelu":
35
+ mlp_depth = cfg.get("depth", 1)
36
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
37
+ for _ in range(1, mlp_depth):
38
+ modules.append(nn.GELU())
39
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
40
+ modules = nn.Sequential(*modules)
41
+
42
+ elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
43
+ mlp_depth = cfg.get("depth", 1)
44
+ mlp_ratio = cfg.get("mlp_ratio", 1)
45
+ modules = [
46
+ nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
47
+ nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
48
+ ]
49
+ for _ in range(1, mlp_depth - 1):
50
+ modules.append(nn.GELU())
51
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
52
+ modules.append(nn.GELU())
53
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
54
+ modules = nn.Sequential(*modules)
55
+
56
+ elif cfg.projector_type == "downsample_mlp_gelu":
57
+ mlp_depth = cfg.get("depth", 1)
58
+ mlp_ratio = cfg.get("mlp_ratio", 1)
59
+ modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
60
+ for _ in range(1, mlp_depth - 1):
61
+ modules.append(nn.GELU())
62
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
63
+ modules.append(nn.GELU())
64
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
65
+ modules = nn.Sequential(*modules)
66
+
67
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
68
+ mlp_depth = cfg.get("depth", 1)
69
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
70
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
71
+
72
+ modules = []
73
+ for _ in range(1, mlp_depth):
74
+ modules.append(nn.GELU())
75
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
76
+ modules = nn.Sequential(*modules)
77
+
78
+ elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
79
+ mlp_depth = cfg.get("depth", 1)
80
+ channel_div = cfg.get("channel_div", 0.5)
81
+ self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
82
+ self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
83
+
84
+ modules = []
85
+ for _ in range(1, mlp_depth):
86
+ modules.append(nn.GELU())
87
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
88
+ modules = nn.Sequential(*modules)
89
+
90
+ elif cfg.projector_type == "low_high_split_mlp_gelu":
91
+ mlp_depth = cfg.get("depth", 1)
92
+ modules = []
93
+ for _ in range(1, mlp_depth):
94
+ modules.append(nn.GELU())
95
+ modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
96
+ modules = nn.Sequential(*modules)
97
+ self.high_layers = nn.Sequential(*modules)
98
+ self.low_layers = copy.deepcopy(modules)
99
+
100
+ else:
101
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
102
+
103
+ if cfg.get("token_pooling", False):
104
+ self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
105
+
106
+ if cfg.get("conv_fusion_high_low_features", False):
107
+ self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
108
+ self.layers = modules
109
+
110
+ def forward(self, x):
111
+ if self.cfg.get("token_pooling", False):
112
+ batch_size, wxh, channels = x.shape
113
+ w = h = int(wxh**0.5)
114
+ x = x.view(batch_size, w, h, channels)
115
+ x = x.permute(0, 3, 1, 2)
116
+ # import ipdb; ipdb.set_trace()
117
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
118
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
119
+ # 在通道维度上拼接
120
+ patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
121
+
122
+ # 通过线性层
123
+ patches = patches.permute(0, 2, 1, 3).contiguous()
124
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
125
+
126
+ x = self.token_pooling_layer(patches)
127
+
128
+ if self.cfg.get("conv_fusion_high_low_features", False):
129
+ x = self.fusion_layer(x[:, 0]) + x[:, 1]
130
+
131
+ if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
132
+ high_x, low_x = x[0], x[1]
133
+ high_x = self.high_up_proj(high_x)
134
+ low_x = self.low_up_proj(low_x)
135
+ x = torch.concat([high_x, low_x], dim=-1)
136
+
137
+ if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
138
+ high_x = x[...,:self.cfg.input_dim[0]]
139
+ low_x = x[...,self.cfg.input_dim[0]:]
140
+ high_x = self.high_up_proj(high_x)
141
+ low_x = self.low_up_proj(low_x)
142
+ x = torch.concat([high_x, low_x], dim=-1)
143
+
144
+ if self.cfg.projector_type == 'low_high_split_mlp_gelu':
145
+ high_x, low_x = x[0], x[1]
146
+ high_x = self.high_layers(high_x)
147
+ low_x = self.low_layers(low_x)
148
+ x = torch.concat([high_x, low_x], dim=-1)
149
+ return x
150
+
151
+ if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
152
+ bs, hw, input_dim = x.shape
153
+ h = w = int((hw) ** 0.5)
154
+
155
+ """compute padding"""
156
+ if h % self.cfg.downsample_ratio:
157
+ pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
158
+ else:
159
+ pad = 0
160
+ x = x.reshape(bs, h, w, input_dim)
161
+ if pad > 0:
162
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
163
+
164
+ """4 to 1 concat"""
165
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
166
+ x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
167
+ x = x.permute(0, 2, 1)
168
+
169
+ return self.layers(x)
170
+
171
+ @staticmethod
172
+ def get_flops_per_sample(cfg):
173
+ if cfg.projector_type == "linear":
174
+ fwd = 2 * cfg.input_dim * cfg.n_embed
175
+
176
+ elif "mlp_gelu" in cfg.projector_type :
177
+ mlp_depth = cfg.get("depth", 1)
178
+ downsample_ratio = cfg.get("downsample_ratio", 1)
179
+ input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
180
+ input_dim = input_dim * downsample_ratio * downsample_ratio
181
+ fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
182
+ else:
183
+ fwd = 0
184
+
185
+ return fwd * 3
186
+
187
+
188
+ #===================clip============================================================
189
+
190
+ class LayerNormfp32(torch.nn.LayerNorm):
191
+ """Subclass torch's LayerNorm to handle fp16."""
192
+
193
+ def forward(self, x: torch.Tensor):
194
+ orig_type = x.dtype
195
+ ret = super().forward(x.type(torch.float32))
196
+ return ret.type(orig_type)
197
+
198
+
199
+ def get_abs_pos(abs_pos, tgt_size):
200
+ # abs_pos: L, C
201
+ # tgt_size: M
202
+ # return: M, C
203
+
204
+ # print(tgt_size)
205
+ # print(abs_pos.shape)
206
+ # exit()
207
+ dim = abs_pos.size(-1)
208
+ # print(dim)
209
+ abs_pos_new = abs_pos.squeeze(0)
210
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
211
+
212
+
213
+
214
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
215
+ tgt_size = int(math.sqrt(tgt_size))
216
+ dtype = abs_pos.dtype
217
+
218
+ if src_size != tgt_size:
219
+ old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
220
+ 2).contiguous()
221
+ old_pos_embed = old_pos_embed.to(torch.float32)
222
+ new_pos_embed = F.interpolate(
223
+ old_pos_embed,
224
+ size=(tgt_size, tgt_size),
225
+ mode='bicubic',
226
+ antialias=True,
227
+ align_corners=False,
228
+ ).to(dtype)
229
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
230
+ new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
231
+ vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
232
+ vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
233
+ return vision_pos_embed
234
+ else:
235
+ return abs_pos
236
+
237
+ @torch.jit.script
238
+ def quick_gelu(x):
239
+ return x * torch.sigmoid(1.702 * x)
240
+
241
+
242
+
243
+ class CLIPVisionEmbeddings(nn.Module):
244
+ def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
245
+ super().__init__()
246
+ self.embed_dim = hidden_size
247
+ self.image_size = image_size
248
+ self.patch_size = patch_size
249
+
250
+ self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
251
+
252
+ self.patch_embedding = torch.nn.Conv2d(
253
+ in_channels=num_channels,
254
+ out_channels=self.embed_dim,
255
+ kernel_size=self.patch_size,
256
+ stride=self.patch_size,
257
+ bias=False,
258
+ )
259
+
260
+ self.num_patches = (self.image_size // self.patch_size) ** 2
261
+ self.num_positions = self.num_patches + 1
262
+ self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
263
+ self.register_buffer(
264
+ "position_ids", torch.arange(self.num_positions).expand((1, -1))
265
+ )
266
+
267
+ def forward(self, pixel_values, patch_embeds):
268
+ batch_size = pixel_values.shape[0]
269
+ # patch_embeds = self.patch_embedding(
270
+ # pixel_values
271
+ # ) # shape = [*, width, grid, grid]
272
+
273
+
274
+ if patch_embeds is not None:
275
+ patch_embeds = patch_embeds
276
+ # print(patch_embeds.shape)
277
+ else:
278
+ patch_embeds = self.patch_embedding(pixel_values)
279
+ # print(111111)
280
+ # shape = [*, width, grid, grid]
281
+ # patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
282
+
283
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
284
+
285
+
286
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
287
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
288
+
289
+ # x = torch.cat([cls_token, x], dim=1)
290
+ embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
291
+ # embeddings = embeddings + self.position_embedding(self.position_ids)
292
+ return embeddings
293
+
294
+
295
+ class NoTPFeedForward(nn.Module):
296
+ def __init__(
297
+ self,
298
+ cfg,
299
+ dim: int,
300
+ hidden_dim: int,
301
+ ):
302
+ super().__init__()
303
+
304
+ self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
305
+ self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
306
+
307
+ def forward(self, x):
308
+ output = self.fc2(quick_gelu(self.fc1(x)))
309
+ return output
310
+
311
+
312
+
313
+
314
+ class NoTPAttention(torch.nn.Module):
315
+ def __init__(self, cfg):
316
+ super().__init__()
317
+ self.num_heads = cfg.num_attention_heads
318
+ self.n_local_heads = cfg.num_attention_heads
319
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
320
+ self.max_seq_len = cfg.seq_length
321
+ self.use_flash_attention = cfg.use_flash_attn
322
+
323
+ self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
324
+ self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
325
+
326
+ # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
327
+
328
+ self.attn_drop = cfg.attention_dropout
329
+
330
+ def forward(
331
+ self,
332
+ x: torch.Tensor,
333
+ ):
334
+ bsz, seqlen, _ = x.shape
335
+ xqkv = self.qkv_proj(x)
336
+ xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
337
+
338
+ if self.use_flash_attention:
339
+
340
+ xq, xk, xv = torch.split(xqkv, 1, dim=2)
341
+ xq = xq.squeeze(2)
342
+ xk = xk.squeeze(2)
343
+ xv = xv.squeeze(2)
344
+ # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
345
+
346
+ # (B, num_head, S, head_size)
347
+ xq = xq.permute(0, 2, 1, 3)
348
+ xk = xk.permute(0, 2, 1, 3)
349
+ xv = xv.permute(0, 2, 1, 3)
350
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
351
+ output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
352
+ output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
353
+ # output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
354
+ else:
355
+ # print(22222)
356
+ xq, xk, xv = torch.split(xqkv, 1, dim=2)
357
+ xq = xq.squeeze(2)
358
+ xk = xk.squeeze(2)
359
+ xv = xv.squeeze(2)
360
+ # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
361
+
362
+ # (B, num_head, S, head_size)
363
+ xq = xq.permute(0, 2, 1, 3)
364
+ xk = xk.permute(0, 2, 1, 3)
365
+ xv = xv.permute(0, 2, 1, 3)
366
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
367
+ output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
368
+ output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
369
+ # output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
370
+ output = self.out_proj(output)
371
+ return output
372
+
373
+ class NoTPTransformerBlock(nn.Module):
374
+ def __init__(self, cfg, layer_id: int, multiple_of=256):
375
+ super().__init__()
376
+
377
+ self.n_heads = cfg.num_attention_heads
378
+ self.dim = cfg.hidden_size
379
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
380
+ self.self_attn = NoTPAttention(cfg)
381
+ self.mlp = NoTPFeedForward(
382
+ cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
383
+ )
384
+ self.layer_id = layer_id
385
+ self.layer_norm1 = torch.nn.LayerNorm(
386
+ cfg.hidden_size, eps=cfg.layernorm_epsilon
387
+ )
388
+ self.layer_norm2 = torch.nn.LayerNorm(
389
+ cfg.hidden_size, eps=cfg.layernorm_epsilon
390
+ )
391
+
392
+ def forward(self, x: torch.Tensor):
393
+ residual = self.self_attn.forward(self.layer_norm1(x))
394
+ h = x + residual
395
+ out = h + self.mlp.forward(self.layer_norm2(h))
396
+ return out
397
+
398
+
399
+ class NoTPTransformer(nn.Module):
400
+ def __init__(self, cfg):
401
+ super().__init__()
402
+
403
+ self.cfg = cfg
404
+ # self.recompute_list = self.cfg.get("recompute_list", [])
405
+ self.num_layers = cfg.num_layers # _get_num_layers(cfg)
406
+
407
+ self.layers = torch.nn.ModuleList()
408
+ for layer_id in range(self.num_layers):
409
+ self.layers.append(
410
+ NoTPTransformerBlock(
411
+ cfg,
412
+ layer_id + 1,
413
+ )
414
+ )
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states,
419
+ ):
420
+
421
+ for lid, layer in enumerate(self.layers):
422
+ # if lid in self.recompute_list:
423
+ # def custom(layer_id):
424
+ # def custom_forward(*args, **kwargs):
425
+ # x_ = self.layers[layer_id](*args, **kwargs)
426
+ # return x_
427
+
428
+ # return custom_forward
429
+
430
+ # assert hidden_states.requires_grad == True, logger.warning(
431
+ # "When using recalculation, the input must have grad fn"
432
+ # )
433
+ # hidden_states = tensor_parallel.checkpoint(
434
+ # custom(lid),
435
+ # False,
436
+ # hidden_states.contiguous()
437
+ # )
438
+ # else:
439
+ hidden_states = layer(hidden_states)
440
+
441
+ return hidden_states
442
+
443
+
444
+ # from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
445
+
446
+ class VitModel(nn.Module):
447
+ def __init__(
448
+ self,
449
+ cfg,
450
+ freeze_embed=False,
451
+ freeze_pre_norm=False
452
+ ) -> None:
453
+ super().__init__()
454
+
455
+ self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
456
+
457
+ if freeze_embed:
458
+ for name, param in self.embeddings.named_parameters():
459
+ param.requires_grad = False
460
+
461
+ self.transformer = NoTPTransformer(cfg=cfg)
462
+
463
+ if cfg.get("fp32norm", False):
464
+ logger.info("Load fp32 layernorm for ViT.")
465
+ self.pre_layrnorm = LayerNormfp32(
466
+ cfg.hidden_size,
467
+ eps=cfg.get("pre_layernorm_epsilon", 1e-5),
468
+ )
469
+ else:
470
+ self.pre_layrnorm = torch.nn.LayerNorm(
471
+ cfg.hidden_size,
472
+ eps=cfg.get("pre_layernorm_epsilon", 1e-5),
473
+ )
474
+
475
+ # self.pre_layrnorm = RMSNorm(
476
+ # cfg.hidden_size,
477
+ # eps=cfg.get("pre_layernorm_epsilon", 1e-5),
478
+ # sequence_parallel=False,
479
+ # use_fp32=True,
480
+ # use_optimus=True,
481
+ # )
482
+
483
+ if freeze_pre_norm:
484
+ for name, param in self.pre_layrnorm.named_parameters():
485
+ param.requires_grad = False
486
+
487
+ for p in self.parameters():
488
+ p.micro_dp = True
489
+
490
+ def set_input_tensor(self, input_tensor):
491
+ if not isinstance(input_tensor, list):
492
+ input_tensor = [input_tensor]
493
+ self.transformer.set_input_tensor(input_tensor[0])
494
+
495
+ def __str__(self) -> str:
496
+ return "open_clip"
497
+
498
+ def forward(
499
+ self,
500
+ x,
501
+ patch_embeds
502
+ ):
503
+ x = self.embeddings(x, patch_embeds)
504
+ hidden_states = self.pre_layrnorm(x)
505
+
506
+ # hidden_states, dis = local_dp_scatter(hidden_states)
507
+ output = self.transformer(hidden_states)
508
+
509
+ # output = local_dp_reduce(output, dis)
510
+
511
+ return output
512
+
513
+
514
+ vit_model_cfg = adict(
515
+ num_layers=24,
516
+ hidden_size=1024,
517
+ num_heads = 16,
518
+ num_attention_heads=16,
519
+ ffn_hidden_size=4096,
520
+ seq_length=256,
521
+ max_position_embeddings=256,
522
+ use_flash_attn=False,
523
+ understand_projector_stride=2,
524
+ hidden_dropout = 0.0,
525
+ attention_dropout = 0.0,
526
+ no_persist_layer_norm = False,
527
+ layernorm_epsilon = 1e-5,
528
+ pre_layernorm_epsilon = 1e-5,
529
+ image_size = 224,
530
+ patch_size = 14,
531
+ recompute_list = []
532
+ )
533
+
534
+ def build_clip_l():
535
+ return VitModel(
536
+ cfg=vit_model_cfg,
537
+ freeze_embed=False,
538
+ freeze_pre_norm=False,
539
+ )
540
+
541
+
542
+
543
+
544
+
545
+ #=========================Sam-Vary=================================
546
+
547
+
548
+ def get_abs_pos_sam(abs_pos, tgt_size):
549
+
550
+ dtype = abs_pos.dtype
551
+
552
+ src_size = abs_pos.size(1)
553
+
554
+ if src_size != tgt_size:
555
+ old_pos_embed = abs_pos.permute(0, 3, 1, 2)
556
+ old_pos_embed = old_pos_embed.to(torch.float32)
557
+ new_pos_embed = F.interpolate(
558
+ old_pos_embed,
559
+ size=(tgt_size, tgt_size),
560
+ mode='bicubic',
561
+ antialias=True,
562
+ align_corners=False,
563
+ ).to(dtype)
564
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
565
+ return new_pos_embed
566
+ else:
567
+ return abs_pos
568
+
569
+
570
+
571
+
572
+ class MLPBlock(nn.Module):
573
+ def __init__(
574
+ self,
575
+ embedding_dim: int,
576
+ mlp_dim: int,
577
+ act: Type[nn.Module] = nn.GELU,
578
+ ) -> None:
579
+ super().__init__()
580
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
581
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
582
+ self.act = act()
583
+
584
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
585
+ return self.lin2(self.act(self.lin1(x)))
586
+
587
+
588
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
589
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
590
+ class LayerNorm2d(nn.Module):
591
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
592
+ super().__init__()
593
+ self.weight = nn.Parameter(torch.ones(num_channels))
594
+ self.bias = nn.Parameter(torch.zeros(num_channels))
595
+ self.eps = eps
596
+
597
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
598
+ u = x.mean(1, keepdim=True)
599
+ s = (x - u).pow(2).mean(1, keepdim=True)
600
+ x = (x - u) / torch.sqrt(s + self.eps)
601
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
602
+ return x
603
+
604
+
605
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
606
+ class ImageEncoderViT(nn.Module):
607
+ def __init__(
608
+ self,
609
+ img_size: int = 1024,
610
+ patch_size: int = 16,
611
+ in_chans: int = 3,
612
+ embed_dim: int = 768,
613
+ depth: int = 12,
614
+ num_heads: int = 12,
615
+ mlp_ratio: float = 4.0,
616
+ out_chans: int = 256,
617
+ qkv_bias: bool = True,
618
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
619
+ act_layer: Type[nn.Module] = nn.GELU,
620
+ use_abs_pos: bool = True,
621
+ use_rel_pos: bool = False,
622
+ rel_pos_zero_init: bool = True,
623
+ window_size: int = 0,
624
+ global_attn_indexes: Tuple[int, ...] = (),
625
+ ) -> None:
626
+ """
627
+ Args:
628
+ img_size (int): Input image size.
629
+ patch_size (int): Patch size.
630
+ in_chans (int): Number of input image channels.
631
+ embed_dim (int): Patch embedding dimension.
632
+ depth (int): Depth of ViT.
633
+ num_heads (int): Number of attention heads in each ViT block.
634
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
635
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
636
+ norm_layer (nn.Module): Normalization layer.
637
+ act_layer (nn.Module): Activation layer.
638
+ use_abs_pos (bool): If True, use absolute positional embeddings.
639
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
640
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
641
+ window_size (int): Window size for window attention blocks.
642
+ global_attn_indexes (list): Indexes for blocks using global attention.
643
+ """
644
+ super().__init__()
645
+ self.img_size = img_size
646
+
647
+ self.patch_embed = PatchEmbed(
648
+ kernel_size=(patch_size, patch_size),
649
+ stride=(patch_size, patch_size),
650
+ in_chans=in_chans,
651
+ embed_dim=embed_dim,
652
+ )
653
+
654
+ self.pos_embed: Optional[nn.Parameter] = None
655
+ if use_abs_pos:
656
+ # Initialize absolute positional embedding with pretrain image size.
657
+ self.pos_embed = nn.Parameter(
658
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
659
+ )
660
+
661
+ self.blocks = nn.ModuleList()
662
+ for i in range(depth):
663
+ block = Block(
664
+ dim=embed_dim,
665
+ num_heads=num_heads,
666
+ mlp_ratio=mlp_ratio,
667
+ qkv_bias=qkv_bias,
668
+ norm_layer=norm_layer,
669
+ act_layer=act_layer,
670
+ use_rel_pos=use_rel_pos,
671
+ rel_pos_zero_init=rel_pos_zero_init,
672
+ window_size=window_size if i not in global_attn_indexes else 0,
673
+ input_size=(img_size // patch_size, img_size // patch_size),
674
+ )
675
+ self.blocks.append(block)
676
+
677
+ self.neck = nn.Sequential(
678
+ nn.Conv2d(
679
+ embed_dim,
680
+ out_chans,
681
+ kernel_size=1,
682
+ bias=False,
683
+ ),
684
+ LayerNorm2d(out_chans),
685
+ nn.Conv2d(
686
+ out_chans,
687
+ out_chans,
688
+ kernel_size=3,
689
+ padding=1,
690
+ bias=False,
691
+ ),
692
+ LayerNorm2d(out_chans),
693
+ )
694
+
695
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
696
+ self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
697
+
698
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
699
+ x = self.patch_embed(x)
700
+ if self.pos_embed is not None:
701
+ # x = x + self.pos_embed
702
+ x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
703
+
704
+ for blk in self.blocks:
705
+ x = blk(x)
706
+
707
+ x = self.neck(x.permute(0, 3, 1, 2))
708
+ x2 = self.net_2(x)
709
+ x3 = self.net_3(x2.clone())
710
+
711
+ return x3
712
+
713
+
714
+ class Block(nn.Module):
715
+ """Transformer blocks with support of window attention and residual propagation blocks"""
716
+
717
+ def __init__(
718
+ self,
719
+ dim: int,
720
+ num_heads: int,
721
+ mlp_ratio: float = 4.0,
722
+ qkv_bias: bool = True,
723
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
724
+ act_layer: Type[nn.Module] = nn.GELU,
725
+ use_rel_pos: bool = False,
726
+ rel_pos_zero_init: bool = True,
727
+ window_size: int = 0,
728
+ input_size: Optional[Tuple[int, int]] = None,
729
+ ) -> None:
730
+ """
731
+ Args:
732
+ dim (int): Number of input channels.
733
+ num_heads (int): Number of attention heads in each ViT block.
734
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
735
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
736
+ norm_layer (nn.Module): Normalization layer.
737
+ act_layer (nn.Module): Activation layer.
738
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
739
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
740
+ window_size (int): Window size for window attention blocks. If it equals 0, then
741
+ use global attention.
742
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
743
+ positional parameter size.
744
+ """
745
+ super().__init__()
746
+ self.norm1 = norm_layer(dim)
747
+ self.attn = Attention(
748
+ dim,
749
+ num_heads=num_heads,
750
+ qkv_bias=qkv_bias,
751
+ use_rel_pos=use_rel_pos,
752
+ rel_pos_zero_init=rel_pos_zero_init,
753
+ input_size=input_size if window_size == 0 else (window_size, window_size),
754
+ )
755
+
756
+ self.norm2 = norm_layer(dim)
757
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
758
+
759
+ self.window_size = window_size
760
+
761
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
762
+ shortcut = x
763
+ x = self.norm1(x)
764
+ # Window partition
765
+ if self.window_size > 0:
766
+ H, W = x.shape[1], x.shape[2]
767
+ x, pad_hw = window_partition(x, self.window_size)
768
+
769
+ x = self.attn(x)
770
+ # Reverse window partition
771
+ if self.window_size > 0:
772
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
773
+
774
+ x = shortcut + x
775
+ x = x + self.mlp(self.norm2(x))
776
+
777
+ return x
778
+
779
+
780
+ class Attention(nn.Module):
781
+ """Multi-head Attention block with relative position embeddings."""
782
+
783
+ def __init__(
784
+ self,
785
+ dim: int,
786
+ num_heads: int = 8,
787
+ qkv_bias: bool = True,
788
+ use_rel_pos: bool = False,
789
+ rel_pos_zero_init: bool = True,
790
+ input_size: Optional[Tuple[int, int]] = None,
791
+ ) -> None:
792
+ """
793
+ Args:
794
+ dim (int): Number of input channels.
795
+ num_heads (int): Number of attention heads.
796
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
797
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
798
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
799
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
800
+ positional parameter size.
801
+ """
802
+ super().__init__()
803
+ self.num_heads = num_heads
804
+ head_dim = dim // num_heads
805
+ self.scale = head_dim**-0.5
806
+
807
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
808
+ self.proj = nn.Linear(dim, dim)
809
+
810
+ self.use_rel_pos = use_rel_pos
811
+ if self.use_rel_pos:
812
+ assert (
813
+ input_size is not None
814
+ ), "Input size must be provided if using relative positional encoding."
815
+ # initialize relative positional embeddings
816
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
817
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
818
+
819
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
820
+ B, H, W, _ = x.shape
821
+ # qkv with shape (3, B, nHead, H * W, C)
822
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
823
+ # q, k, v with shape (B * nHead, H * W, C)
824
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
825
+
826
+ rel_h, rel_w = None, None
827
+ if self.use_rel_pos:
828
+ rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
829
+
830
+ q = q.view(B, self.num_heads, H * W, -1)
831
+ k = k.view(B, self.num_heads, H * W, -1)
832
+ v = v.view(B, self.num_heads, H * W, -1)
833
+
834
+ if self.use_rel_pos:
835
+ rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
836
+ rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
837
+ attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
838
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
839
+ # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
840
+ else:
841
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
842
+
843
+ x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
844
+
845
+ x = self.proj(x)
846
+
847
+ return x
848
+
849
+
850
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
851
+ """
852
+ Partition into non-overlapping windows with padding if needed.
853
+ Args:
854
+ x (tensor): input tokens with [B, H, W, C].
855
+ window_size (int): window size.
856
+
857
+ Returns:
858
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
859
+ (Hp, Wp): padded height and width before partition
860
+ """
861
+ B, H, W, C = x.shape
862
+
863
+ pad_h = (window_size - H % window_size) % window_size
864
+ pad_w = (window_size - W % window_size) % window_size
865
+ if pad_h > 0 or pad_w > 0:
866
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
867
+ Hp, Wp = H + pad_h, W + pad_w
868
+
869
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
870
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
871
+ return windows, (Hp, Wp)
872
+
873
+
874
+ def window_unpartition(
875
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
876
+ ) -> torch.Tensor:
877
+ """
878
+ Window unpartition into original sequences and removing padding.
879
+ Args:
880
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
881
+ window_size (int): window size.
882
+ pad_hw (Tuple): padded height and width (Hp, Wp).
883
+ hw (Tuple): original height and width (H, W) before padding.
884
+
885
+ Returns:
886
+ x: unpartitioned sequences with [B, H, W, C].
887
+ """
888
+ Hp, Wp = pad_hw
889
+ H, W = hw
890
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
891
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
892
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
893
+
894
+ if Hp > H or Wp > W:
895
+ x = x[:, :H, :W, :].contiguous()
896
+ return x
897
+
898
+
899
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
900
+ """
901
+ Get relative positional embeddings according to the relative positions of
902
+ query and key sizes.
903
+ Args:
904
+ q_size (int): size of query q.
905
+ k_size (int): size of key k.
906
+ rel_pos (Tensor): relative position embeddings (L, C).
907
+
908
+ Returns:
909
+ Extracted positional embeddings according to relative positions.
910
+ """
911
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
912
+ # Interpolate rel pos if needed.
913
+ if rel_pos.shape[0] != max_rel_dist:
914
+ # Interpolate rel pos.
915
+ dtype = rel_pos.dtype
916
+ rel_pos = rel_pos.to(torch.float32)
917
+ rel_pos_resized = F.interpolate(
918
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
919
+ size=max_rel_dist,
920
+ mode="linear",
921
+ ).to(dtype)
922
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
923
+ else:
924
+ rel_pos_resized = rel_pos
925
+
926
+ # Scale the coords with short length if shapes for q and k are different.
927
+ q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
928
+ k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
929
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
930
+
931
+ return rel_pos_resized[relative_coords.long()]
932
+
933
+
934
+ def add_decomposed_rel_pos(
935
+ q: torch.Tensor,
936
+ rel_pos_h: torch.Tensor,
937
+ rel_pos_w: torch.Tensor,
938
+ q_size: Tuple[int, int],
939
+ k_size: Tuple[int, int],
940
+ ) -> torch.Tensor:
941
+ """
942
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
943
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
944
+ Args:
945
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
946
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
947
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
948
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
949
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
950
+
951
+ Returns:
952
+ attn (Tensor): attention map with added relative positional embeddings.
953
+ """
954
+ q_h, q_w = q_size
955
+ k_h, k_w = k_size
956
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
957
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
958
+
959
+ B, _, dim = q.shape
960
+ r_q = q.reshape(B, q_h, q_w, dim)
961
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
962
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
963
+ rel_h = rel_h.unsqueeze(-1)
964
+ rel_w = rel_w.unsqueeze(-2)
965
+ rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
966
+ rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
967
+
968
+ return rel_h, rel_w
969
+
970
+
971
+ class PatchEmbed(nn.Module):
972
+ """
973
+ Image to Patch Embedding.
974
+ """
975
+
976
+ def __init__(
977
+ self,
978
+ kernel_size: Tuple[int, int] = (16, 16),
979
+ stride: Tuple[int, int] = (16, 16),
980
+ padding: Tuple[int, int] = (0, 0),
981
+ in_chans: int = 3,
982
+ embed_dim: int = 768,
983
+ ) -> None:
984
+ """
985
+ Args:
986
+ kernel_size (Tuple): kernel size of the projection layer.
987
+ stride (Tuple): stride of the projection layer.
988
+ padding (Tuple): padding size of the projection layer.
989
+ in_chans (int): Number of input image channels.
990
+ embed_dim (int): Patch embedding dimension.
991
+ """
992
+ super().__init__()
993
+
994
+ self.proj = nn.Conv2d(
995
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
996
+ )
997
+
998
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
999
+ x = self.proj(x)
1000
+ # B C H W -> B H W C
1001
+ x = x.permute(0, 2, 3, 1)
1002
+ return x
1003
+
1004
+
1005
+ def build_sam_vit_b(checkpoint=None):
1006
+ return _build_sam(
1007
+ encoder_embed_dim=768,
1008
+ encoder_depth=12,
1009
+ encoder_num_heads=12,
1010
+ encoder_global_attn_indexes=[2, 5, 8, 11],
1011
+ checkpoint=checkpoint,
1012
+ )
1013
+
1014
+ def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
1015
+ image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
1016
+ # sam = _apply_eval_dtype_sam(sam, dtype)
1017
+ image_encoder = torch.compile(image_encoder, mode=compile_mode)
1018
+ return image_encoder
1019
+
1020
+
1021
+ def _build_sam(
1022
+ encoder_embed_dim,
1023
+ encoder_depth,
1024
+ encoder_num_heads,
1025
+ encoder_global_attn_indexes,
1026
+ checkpoint=None,
1027
+ ):
1028
+ prompt_embed_dim = 256
1029
+ image_size = 1024
1030
+ vit_patch_size = 16
1031
+ image_embedding_size = image_size // vit_patch_size
1032
+ image_encoder=ImageEncoderViT(
1033
+ depth=encoder_depth,
1034
+ embed_dim=encoder_embed_dim,
1035
+ img_size=image_size,
1036
+ mlp_ratio=4,
1037
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
1038
+ num_heads=encoder_num_heads,
1039
+ patch_size=vit_patch_size,
1040
+ qkv_bias=True,
1041
+ use_rel_pos=True,
1042
+ global_attn_indexes=encoder_global_attn_indexes,
1043
+ window_size=14,
1044
+ out_chans=prompt_embed_dim,
1045
+ )
1046
+ image_encoder.eval()
1047
+ if checkpoint is not None:
1048
+ # with open(checkpoint, "rb") as f:
1049
+ state_dict = torch.load(checkpoint)
1050
+ # print(state_dict.keys())
1051
+ # for key in state_dict:
1052
+ # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
1053
+ # ocr-anyting
1054
+ # image_encoder.load_state_dict(state_dict, strict=True)
1055
+ # tob
1056
+ image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
1057
+ print(checkpoint)
1058
+ return image_encoder
model-00001-of-000001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1169e7cdc28ff2fb6186556acb2175db148ad26a62097df4c45a17e523180d3f
3
+ size 6672547120
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_deepseekocr.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import re
4
+ from tqdm import tqdm
5
+ from abc import ABC
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ from addict import Dict
9
+ from PIL import Image, ImageOps, ImageDraw, ImageFont
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import CrossEntropyLoss
15
+ from torchvision import transforms
16
+
17
+ from transformers.cache_utils import Cache
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
20
+ from transformers import DeepseekV2Config
21
+ from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
22
+ DeepseekV2Attention, DeepseekV2MLP, DeepseekV2MoE, DeepseekV2RMSNorm, DeepseekV2DecoderLayer)
23
+ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
24
+ from transformers import TextStreamer
25
+ from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector
26
+ from .conversation import get_conv_template
27
+
28
+
29
+ def load_image(image_path):
30
+
31
+ try:
32
+ image = Image.open(image_path)
33
+
34
+ corrected_image = ImageOps.exif_transpose(image)
35
+
36
+ return corrected_image
37
+
38
+ except Exception as e:
39
+ print(f"error: {e}")
40
+ try:
41
+ return Image.open(image_path)
42
+ except:
43
+ return None
44
+
45
+
46
+ def re_match(text):
47
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
48
+ matches = re.findall(pattern, text, re.DOTALL)
49
+
50
+ # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
51
+ # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL)
52
+
53
+ mathes_image = []
54
+ mathes_other = []
55
+ for a_match in matches:
56
+ if '<|ref|>image<|/ref|>' in a_match[0]:
57
+ mathes_image.append(a_match[0])
58
+ else:
59
+ mathes_other.append(a_match[0])
60
+ return matches, mathes_image, mathes_other
61
+
62
+
63
+ def extract_coordinates_and_label(ref_text, image_width, image_height):
64
+
65
+ try:
66
+ label_type = ref_text[1]
67
+ cor_list = eval(ref_text[2])
68
+ except Exception as e:
69
+ print(e)
70
+ return None
71
+
72
+ return (label_type, cor_list)
73
+
74
+
75
+ def draw_bounding_boxes(image, refs, ouput_path):
76
+
77
+ image_width, image_height = image.size
78
+
79
+ img_draw = image.copy()
80
+ draw = ImageDraw.Draw(img_draw)
81
+
82
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
83
+ draw2 = ImageDraw.Draw(overlay)
84
+
85
+ # try:
86
+ # except IOError:
87
+ # try:
88
+ # font = ImageFont.truetype("DejaVuSans.ttf", 20)
89
+ # except IOError:
90
+ font = ImageFont.load_default()
91
+
92
+ img_idx = 0
93
+
94
+ for i, ref in enumerate(refs):
95
+ try:
96
+ result = extract_coordinates_and_label(ref, image_width, image_height)
97
+ if result:
98
+ label_type, points_list = result
99
+
100
+ color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
101
+
102
+ color_a = color + (20, )
103
+ for points in points_list:
104
+ x1, y1, x2, y2 = points
105
+
106
+ x1 = int(x1 / 999 * image_width)
107
+ y1 = int(y1 / 999 * image_height)
108
+
109
+ x2 = int(x2 / 999 * image_width)
110
+ y2 = int(y2 / 999 * image_height)
111
+
112
+ if label_type == 'image':
113
+ try:
114
+ cropped = image.crop((x1, y1, x2, y2))
115
+ cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
116
+ except Exception as e:
117
+ print(e)
118
+ pass
119
+ img_idx += 1
120
+
121
+ try:
122
+ if label_type == 'title':
123
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
124
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
125
+ else:
126
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
127
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
128
+ text_x = x1
129
+ text_y = max(0, y1 - 15)
130
+
131
+
132
+ text_bbox = draw.textbbox((0, 0), label_type, font=font)
133
+ text_width = text_bbox[2] - text_bbox[0]
134
+ text_height = text_bbox[3] - text_bbox[1]
135
+ draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
136
+ fill=(255, 255, 255, 30))
137
+
138
+ draw.text((text_x, text_y), label_type, font=font, fill=color)
139
+ except:
140
+ pass
141
+ except:
142
+ continue
143
+ img_draw.paste(overlay, (0, 0), overlay)
144
+ return img_draw
145
+
146
+
147
+ def process_image_with_refs(image, ref_texts, output_path):
148
+
149
+ result_image = draw_bounding_boxes(image, ref_texts, output_path)
150
+
151
+ return result_image
152
+
153
+
154
+
155
+
156
+
157
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
158
+ best_ratio_diff = float('inf')
159
+ best_ratio = (1, 1)
160
+ area = width * height
161
+ for ratio in target_ratios:
162
+ target_aspect_ratio = ratio[0] / ratio[1]
163
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
164
+ if ratio_diff < best_ratio_diff:
165
+ best_ratio_diff = ratio_diff
166
+ best_ratio = ratio
167
+ elif ratio_diff == best_ratio_diff:
168
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
169
+ best_ratio = ratio
170
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
171
+ return best_ratio
172
+
173
+
174
+ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False):
175
+ orig_width, orig_height = image.size
176
+ aspect_ratio = orig_width / orig_height
177
+
178
+ # calculate the existing image aspect ratio
179
+ target_ratios = set(
180
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
181
+ i * j <= max_num and i * j >= min_num)
182
+ # print(target_ratios)
183
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
184
+
185
+ # find the closest aspect ratio to the target
186
+ target_aspect_ratio = find_closest_aspect_ratio(
187
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
188
+
189
+ # print(target_aspect_ratio)
190
+ # calculate the target width and height
191
+ target_width = image_size * target_aspect_ratio[0]
192
+ target_height = image_size * target_aspect_ratio[1]
193
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
194
+
195
+ # resize the image
196
+ resized_img = image.resize((target_width, target_height))
197
+ processed_images = []
198
+ for i in range(blocks):
199
+ box = (
200
+ (i % (target_width // image_size)) * image_size,
201
+ (i // (target_width // image_size)) * image_size,
202
+ ((i % (target_width // image_size)) + 1) * image_size,
203
+ ((i // (target_width // image_size)) + 1) * image_size
204
+ )
205
+ # split the image
206
+ split_img = resized_img.crop(box)
207
+ processed_images.append(split_img)
208
+ assert len(processed_images) == blocks
209
+ if use_thumbnail and len(processed_images) != 1:
210
+ thumbnail_img = image.resize((image_size, image_size))
211
+ processed_images.append(thumbnail_img)
212
+ return processed_images, target_aspect_ratio
213
+
214
+
215
+
216
+ def normalize_transform(mean, std):
217
+ if mean is None and std is None:
218
+ transform = None
219
+ elif mean is None and std is not None:
220
+ mean = [0.] * len(std)
221
+ transform = transforms.Normalize(mean=mean, std=std)
222
+ elif mean is not None and std is None:
223
+ std = [1.] * len(mean)
224
+ transform = transforms.Normalize(mean=mean, std=std)
225
+ else:
226
+ transform = transforms.Normalize(mean=mean, std=std)
227
+
228
+ return transform
229
+
230
+
231
+
232
+ def format_messages(
233
+ conversations: List[Dict[str, str]],
234
+ sft_format: str = "deepseek",
235
+ system_prompt: str = "",
236
+ ):
237
+ """
238
+ Applies the SFT template to conversation.
239
+
240
+ Args:
241
+ conversations (List[Dict]): A List of messages.
242
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
243
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
244
+
245
+ Returns:
246
+ sft_prompt (str): The formatted text.
247
+ """
248
+
249
+ conv = get_conv_template(sft_format)
250
+ conv.set_system_message(system_prompt)
251
+ for message in conversations:
252
+ conv.append_message(message["role"], message["content"].strip())
253
+ sft_prompt = conv.get_prompt().strip()
254
+
255
+ return sft_prompt
256
+
257
+
258
+ def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
259
+ t = tokenizer.encode(text, add_special_tokens=False)
260
+ bos_id = 0
261
+ eos_id = 1
262
+ if bos:
263
+ t = [bos_id] + t
264
+ if eos:
265
+ t = t + [eos_id]
266
+
267
+ return t
268
+
269
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
270
+ """
271
+
272
+ Args:
273
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
274
+ [
275
+ {
276
+ "role": "User",
277
+ "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
278
+ "images": ["./examples/table_datasets.png"]
279
+ },
280
+ {"role": "Assistant", "content": ""},
281
+ ]
282
+
283
+ Returns:
284
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
285
+
286
+ """
287
+
288
+ pil_images = []
289
+
290
+ for message in conversations:
291
+ if "images" not in message:
292
+ continue
293
+
294
+ for image_path in message["images"]:
295
+ # print('----------------')
296
+ # print(image_path)
297
+ # print('----------------')
298
+ # exit()
299
+
300
+ # pil_img = Image.open(image_path)
301
+ pil_img = load_image(image_path)
302
+ pil_img = pil_img.convert("RGB")
303
+ pil_images.append(pil_img)
304
+
305
+ return pil_images
306
+
307
+
308
+ class BaseTransform(ABC):
309
+
310
+ def set_rng(self, *args, **kwargs):
311
+ pass
312
+
313
+ def __call__(self, *args, **kwargs) -> torch.Tensor:
314
+ pass
315
+
316
+ @property
317
+ def default_shape(self):
318
+ raise NotImplementedError
319
+
320
+
321
+ class BasicImageTransform(BaseTransform):
322
+ def __init__(
323
+ self,
324
+ mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
325
+ std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
326
+ normalize: bool = True
327
+ ):
328
+ self.mean = mean
329
+ self.std = std
330
+
331
+ transform_pipelines = [
332
+ transforms.ToTensor()
333
+ ]
334
+
335
+ normalize = normalize_transform(mean, std) if normalize else nn.Identity()
336
+ if normalize is not None:
337
+ transform_pipelines.append(normalize)
338
+
339
+ self.transform = transforms.Compose(transform_pipelines)
340
+
341
+ def __call__(self, x):
342
+ x = self.transform(x)
343
+ return x
344
+
345
+ class NoEOSTextStreamer(TextStreamer):
346
+ def on_finalized_text(self, text: str, stream_end: bool = False):
347
+
348
+ eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
349
+ text = text.replace(eos_text, "\n")
350
+ print(text, flush=True, end="")
351
+
352
+
353
+ def decoder_layer_init(self, config: DeepseekV2Config, layer_idx: int):
354
+ nn.Module.__init__(self)
355
+ self.hidden_size = config.hidden_size
356
+
357
+ if config.use_mla:
358
+ self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
359
+ else:
360
+ config.head_dim = config.hidden_size // config.num_attention_heads
361
+ self.self_attn = LlamaAttention(config, layer_idx)
362
+ self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
363
+
364
+ self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
365
+ self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
366
+
367
+
368
+ DeepseekV2DecoderLayer.__init__ = decoder_layer_init
369
+
370
+ class DeepseekOCRConfig(DeepseekV2Config):
371
+ model_type = "DeepseekOCR"
372
+
373
+ class DeepseekOCRModel(DeepseekV2Model):
374
+ config_class = DeepseekOCRConfig
375
+
376
+ def __init__(self, config: DeepseekV2Config):
377
+ super(DeepseekOCRModel, self).__init__(config)
378
+
379
+ self.sam_model = build_sam_vit_b()
380
+ self.vision_model = build_clip_l()
381
+ # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
382
+ n_embed = 1280
383
+ self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
384
+ embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
385
+ self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
386
+ self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
387
+
388
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
389
+
390
+ def forward(
391
+ self,
392
+ input_ids: torch.LongTensor = None,
393
+ attention_mask: Optional[torch.Tensor] = None,
394
+ position_ids: Optional[torch.LongTensor] = None,
395
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
396
+ inputs_embeds: Optional[torch.FloatTensor] = None,
397
+ use_cache: Optional[bool] = None,
398
+ output_attentions: Optional[bool] = None,
399
+ output_hidden_states: Optional[bool] = None,
400
+ images: Optional[torch.FloatTensor] = None,
401
+ images_seq_mask: Optional[torch.FloatTensor] = None,
402
+ images_spatial_crop: Optional[torch.FloatTensor] = None,
403
+ return_dict: Optional[bool] = None,
404
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
405
+
406
+
407
+
408
+ if inputs_embeds is None:
409
+ # inputs_embeds = self.embed_tokens(input_ids)
410
+ inputs_embeds = self.get_input_embeddings()(input_ids)
411
+
412
+
413
+
414
+ sam_model = getattr(self, 'sam_model', None)
415
+ # sam_model = self.sam_model
416
+ vision_model = getattr(self, 'vision_model', None)
417
+
418
+
419
+
420
+ if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
421
+
422
+ idx = 0
423
+
424
+ # sam_model = torch.jit.script(sam_model)
425
+
426
+ # start_time = time.time()
427
+ for image, crop_shape in zip(images, images_spatial_crop):
428
+ images_in_this_batch = []
429
+
430
+ patches = image[0]
431
+ image_ori = image[1]
432
+
433
+ with torch.no_grad():
434
+ # with torch.inference_mode():
435
+
436
+ if torch.sum(patches).item() != 0:
437
+ # P, C, H, W = patches.shape
438
+ crop_flag = 1
439
+ local_features_1 = sam_model(patches)
440
+
441
+ local_features_2 = vision_model(patches, local_features_1)
442
+ # vit_time = time.time()
443
+ local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
444
+ local_features = self.projector(local_features)
445
+
446
+
447
+ global_features_1 = sam_model(image_ori)
448
+ global_features_2 = vision_model(image_ori, global_features_1)
449
+ global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
450
+ global_features = self.projector(global_features)
451
+
452
+ print('=====================')
453
+ print('BASE: ', global_features.shape)
454
+ print('PATCHES: ', local_features.shape)
455
+ print('=====================')
456
+
457
+ _, hw, n_dim = global_features.shape
458
+ h = w = int(hw ** 0.5)
459
+
460
+ _2, hw2, n_dim2 = local_features.shape
461
+ h2 = w2 = int(hw2 ** 0.5)
462
+
463
+ width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
464
+
465
+ global_features = global_features.view(h, w, n_dim)
466
+
467
+ global_features = torch.cat(
468
+ [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
469
+ )
470
+
471
+ global_features = global_features.view(-1, n_dim)
472
+
473
+
474
+ local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
475
+ local_features = torch.cat(
476
+ [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
477
+ )
478
+ local_features = local_features.view(-1, n_dim2)
479
+
480
+ global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
481
+
482
+ # end_time = time.time()
483
+
484
+ # print('sam: ', sam_time - start_time)
485
+ # print('vit: ', vit_time - sam_time)
486
+ # print('all: ', end_time - start_time)
487
+
488
+ # exit()
489
+
490
+ else:
491
+ global_features_1 = sam_model(image_ori)
492
+ global_features_2 = vision_model(image_ori, global_features_1)
493
+ global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
494
+ global_features = self.projector(global_features)
495
+ print('=====================')
496
+ print('BASE: ', global_features.shape)
497
+ print('NO PATCHES')
498
+ print('=====================')
499
+ _, hw, n_dim = global_features.shape
500
+ h = w = int(hw ** 0.5)
501
+
502
+
503
+ global_features = global_features.view(h, w, n_dim)
504
+
505
+ global_features = torch.cat(
506
+ [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
507
+ )
508
+
509
+ global_features = global_features.view(-1, n_dim)
510
+
511
+ global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
512
+
513
+ images_in_this_batch.append(global_local_features)
514
+
515
+
516
+ if images_in_this_batch:
517
+ images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
518
+ # exit()
519
+
520
+ inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
521
+
522
+ idx += 1
523
+
524
+ return super(DeepseekOCRModel, self).forward(
525
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
526
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
527
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
528
+ return_dict=return_dict
529
+ )
530
+
531
+
532
+ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
533
+
534
+ config_class = DeepseekOCRConfig
535
+ # supports_gradient_checkpointing = True
536
+
537
+ def __init__(self, config):
538
+ super(DeepseekV2ForCausalLM, self).__init__(config)
539
+ self.model = DeepseekOCRModel(config)
540
+
541
+ self.vocab_size = config.vocab_size
542
+
543
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
544
+
545
+ # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
546
+
547
+ # Initialize weights and apply final processing
548
+ self.post_init()
549
+
550
+ def get_model(self):
551
+ return self.model
552
+
553
+
554
+ def forward(
555
+ self,
556
+ input_ids: torch.LongTensor = None,
557
+ attention_mask: Optional[torch.Tensor] = None,
558
+ position_ids: Optional[torch.LongTensor] = None,
559
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
560
+ inputs_embeds: Optional[torch.FloatTensor] = None,
561
+ labels: Optional[torch.LongTensor] = None,
562
+ use_cache: Optional[bool] = None,
563
+ output_attentions: Optional[bool] = None,
564
+ output_hidden_states: Optional[bool] = None,
565
+ images: Optional[torch.FloatTensor] = None,
566
+ images_seq_mask: Optional[torch.FloatTensor] = None,
567
+ images_spatial_crop: Optional[torch.FloatTensor] = None,
568
+ return_dict: Optional[bool] = None,
569
+
570
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
571
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
572
+ output_hidden_states = (
573
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
574
+ )
575
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
576
+
577
+
578
+
579
+ outputs = self.model(
580
+ input_ids=input_ids,
581
+ past_key_values=past_key_values,
582
+ attention_mask=attention_mask,
583
+ position_ids=position_ids,
584
+ inputs_embeds=inputs_embeds,
585
+ use_cache=use_cache,
586
+ output_attentions=output_attentions,
587
+ output_hidden_states=output_hidden_states,
588
+ images=images,
589
+ images_seq_mask = images_seq_mask,
590
+ images_spatial_crop = images_spatial_crop,
591
+ return_dict=return_dict
592
+
593
+ )
594
+
595
+ hidden_states = outputs[0]
596
+ logits = self.lm_head(hidden_states)
597
+ logits = logits.float()
598
+
599
+ # logits
600
+
601
+ loss = None
602
+ if labels is not None:
603
+ # Shift so that tokens < n predict n
604
+ shift_logits = logits[..., :-1, :].contiguous()
605
+ shift_labels = labels[..., 1:].contiguous()
606
+ # Flatten the tokens
607
+ loss_fct = CrossEntropyLoss()
608
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
609
+ shift_labels = shift_labels.view(-1)
610
+ # Enable model parallelism
611
+ shift_labels = shift_labels.to(shift_logits.device)
612
+ loss = loss_fct(shift_logits, shift_labels)
613
+
614
+ if not return_dict:
615
+ output = (logits,) + outputs[1:]
616
+ return (loss,) + output if loss is not None else output
617
+
618
+ return CausalLMOutputWithPast(
619
+ loss=loss,
620
+ logits=logits,
621
+ past_key_values=outputs.past_key_values,
622
+ hidden_states=outputs.hidden_states,
623
+ attentions=outputs.attentions,
624
+ )
625
+
626
+
627
+ def prepare_inputs_for_generation(
628
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
629
+ ):
630
+ # Omit tokens covered by past_key_values
631
+ past_length = 0
632
+ if past_key_values is not None:
633
+ if isinstance(past_key_values, Cache):
634
+ cache_length = past_key_values.get_seq_length()
635
+ past_length = past_key_values.get_seq_length()
636
+ max_cache_length = None
637
+ else:
638
+ cache_length = past_length = past_key_values[0][0].shape[2]
639
+ max_cache_length = None
640
+
641
+ # Keep only the unprocessed tokens:
642
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
643
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
644
+ # input)
645
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
646
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
647
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
648
+ # input_ids based on the past_length.
649
+ elif past_length < input_ids.shape[1]:
650
+ input_ids = input_ids[:, past_length:]
651
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
652
+
653
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
654
+ if (
655
+ max_cache_length is not None
656
+ and attention_mask is not None
657
+ and cache_length + input_ids.shape[1] > max_cache_length
658
+ ):
659
+ attention_mask = attention_mask[:, -max_cache_length:]
660
+
661
+ position_ids = kwargs.get("position_ids", None)
662
+ if attention_mask is not None and position_ids is None:
663
+ # create position_ids on the fly for batch generation
664
+ position_ids = attention_mask.long().cumsum(-1) - 1
665
+ position_ids.masked_fill_(attention_mask == 0, 1)
666
+ if past_key_values:
667
+ position_ids = position_ids[:, -input_ids.shape[1] :]
668
+
669
+ # if self.generation_config.cache_implementation == "static":
670
+ # # generation with static cache
671
+ # cache_position = kwargs.get("cache_position", None)
672
+ # if cache_position is None:
673
+ # past_length = 0
674
+ # else:
675
+ # past_length = cache_position[-1] + 1
676
+ # input_ids = input_ids[:, past_length:]
677
+ # position_ids = position_ids[:, past_length:]
678
+
679
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
680
+ # same goes for position ids. Could also help with continued generation.
681
+ cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
682
+
683
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
684
+ if inputs_embeds is not None and past_key_values is None:
685
+ model_inputs = {"inputs_embeds": inputs_embeds}
686
+ else:
687
+ model_inputs = {"input_ids": input_ids}
688
+
689
+ model_inputs.update(
690
+ {
691
+ "position_ids": position_ids,
692
+ "past_key_values": past_key_values,
693
+ "use_cache": kwargs.get("use_cache"),
694
+ "attention_mask": attention_mask,
695
+ "images": kwargs.get("images", None),
696
+ "images_seq_mask": kwargs.get("images_seq_mask", None),
697
+ "images_spatial_crop": kwargs.get("images_spatial_crop", None),
698
+ }
699
+ )
700
+ return model_inputs
701
+
702
+
703
+ def disable_torch_init(self):
704
+ """
705
+ Disable the redundant torch default initialization to accelerate model creation.
706
+ """
707
+ import torch
708
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
709
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
710
+
711
+
712
+
713
+ def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
714
+ self.disable_torch_init()
715
+
716
+ os.makedirs(output_path, exist_ok=True)
717
+ os.makedirs(f'{output_path}/images', exist_ok=True)
718
+
719
+ if prompt and image_file:
720
+ conversation = [
721
+ {
722
+ "role": "<|User|>",
723
+ # "content": "<image>\n<|grounding|>Given the layout of the image. ",
724
+ "content": f'{prompt}',
725
+ # "content": "君不见黄河之水天上来的下一句是什么?",
726
+ # "content": "<image>\nFree OCR. ",
727
+ # "content": "<image>\nParse the figure. ",
728
+ # "content": "<image>\nExtract the text in the image. ",
729
+ "images": [f'{image_file}'],
730
+ },
731
+ {"role": "<|Assistant|>", "content": ""},
732
+ ]
733
+
734
+ elif prompt:
735
+ conversation = [
736
+ {
737
+ "role": "<|User|>",
738
+ # "content": "<image>\n<|grounding|>Given the layout of the image. ",
739
+ "content": f'{prompt}',
740
+ # "content": "君不见黄河之水天上来的下一句是什么?",
741
+ # "content": "<image>\nFree OCR. ",
742
+ # "content": "<image>\nParse the figure. ",
743
+ # "content": "<image>\nExtract the text in the image. ",
744
+ # "images": [f'{image_file}'],
745
+ },
746
+ {"role": "<|Assistant|>", "content": ""},
747
+ ]
748
+ else:
749
+ assert False, f'prompt is none!'
750
+
751
+ prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
752
+
753
+ patch_size = 16
754
+ downsample_ratio = 4
755
+ images = load_pil_images(conversation)
756
+
757
+ valid_img_tokens = 0
758
+ ratio = 1
759
+
760
+ image_draw = images[0].copy()
761
+
762
+ w,h = image_draw.size
763
+ # print(w, h)
764
+ ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
765
+
766
+
767
+ image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
768
+ images_seq_mask = []
769
+
770
+ image_token = '<image>'
771
+ image_token_id = 128815
772
+ text_splits = prompt.split(image_token)
773
+
774
+ images_list, images_crop_list, images_seq_mask = [], [], []
775
+ tokenized_str = []
776
+ images_spatial_crop = []
777
+ for text_sep, image in zip(text_splits, images):
778
+
779
+ tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
780
+ tokenized_str += tokenized_sep
781
+ images_seq_mask += [False] * len(tokenized_sep)
782
+
783
+ if crop_mode:
784
+
785
+ if image.size[0] <= 640 and image.size[1] <= 640:
786
+ crop_ratio = [1, 1]
787
+
788
+ else:
789
+ if crop_mode:
790
+ # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
791
+ images_crop_raw, crop_ratio = dynamic_preprocess(image)
792
+ else:
793
+ # best_width, best_height = self.image_size, self.image_size
794
+ crop_ratio = [1, 1]
795
+
796
+ """process the global view"""
797
+ # image = image.resize((base_size, base_size))
798
+ global_view = ImageOps.pad(image, (base_size, base_size),
799
+ color=tuple(int(x * 255) for x in image_transform.mean))
800
+
801
+ if base_size == 1024:
802
+ valid_img_tokens += int(256 * ratio)
803
+ elif base_size == 1280:
804
+ valid_img_tokens += int(400 * ratio)
805
+ # elif base_size == 640:
806
+ # valid_img_tokens += int(100 * ratio)
807
+
808
+
809
+
810
+
811
+
812
+ images_list.append(image_transform(global_view).to(torch.bfloat16))
813
+
814
+ # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
815
+
816
+ width_crop_num, height_crop_num = crop_ratio
817
+
818
+ images_spatial_crop.append([width_crop_num, height_crop_num])
819
+
820
+
821
+ if width_crop_num > 1 or height_crop_num > 1:
822
+ """process the local views"""
823
+
824
+ for i in range(len(images_crop_raw)):
825
+ images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
826
+
827
+ if image_size == 640:
828
+ valid_img_tokens += len(images_crop_list) * 100
829
+
830
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
831
+ num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
832
+
833
+
834
+
835
+ """add image tokens"""
836
+
837
+
838
+
839
+ tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base
840
+ tokenized_image += [image_token_id]
841
+ if width_crop_num > 1 or height_crop_num > 1:
842
+ tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * (
843
+ num_queries * height_crop_num)
844
+ tokenized_str += tokenized_image
845
+ images_seq_mask += [True] * len(tokenized_image)
846
+ # num_image_tokens.append(len(tokenized_image))
847
+
848
+ else:
849
+ # best_width, best_height = self.image_size, self.image_size
850
+ # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
851
+
852
+ """process the global view"""
853
+ if image_size <= 640:
854
+ print('directly resize')
855
+ image = image.resize((image_size, image_size))
856
+ # else:
857
+ global_view = ImageOps.pad(image, (image_size, image_size),
858
+ color=tuple(int(x * 255) for x in image_transform.mean))
859
+ images_list.append(image_transform(global_view).to(torch.bfloat16))
860
+
861
+ if base_size == 1024:
862
+ valid_img_tokens += int(256 * ratio)
863
+ elif base_size == 1280:
864
+ valid_img_tokens += int(400 * ratio)
865
+ elif base_size == 640:
866
+ valid_img_tokens += int(100 * 1)
867
+ elif base_size == 512:
868
+ valid_img_tokens += int(64 * 1)
869
+
870
+ width_crop_num, height_crop_num = 1, 1
871
+
872
+ images_spatial_crop.append([width_crop_num, height_crop_num])
873
+
874
+
875
+ """add image tokens"""
876
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
877
+
878
+ tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries
879
+ tokenized_image += [image_token_id]
880
+ # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
881
+ # num_queries * height_crop_num)
882
+ tokenized_str += tokenized_image
883
+ images_seq_mask += [True] * len(tokenized_image)
884
+ # num_image_tokens.append(len(tokenized_image))
885
+
886
+
887
+ """process the last text split"""
888
+ tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
889
+ tokenized_str += tokenized_sep
890
+ images_seq_mask += [False] * len(tokenized_sep)
891
+
892
+ """add the bos tokens"""
893
+ bos_id = 0
894
+ tokenized_str = [bos_id] + tokenized_str
895
+ images_seq_mask = [False] + images_seq_mask
896
+
897
+
898
+
899
+ input_ids = torch.LongTensor(tokenized_str)
900
+
901
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
902
+
903
+
904
+ if len(images_list) == 0:
905
+ images_ori = torch.zeros((1, 3, image_size, image_size))
906
+ images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
907
+ images_crop = torch.zeros((1, 3, base_size, base_size))
908
+
909
+ else:
910
+ images_ori = torch.stack(images_list, dim=0)
911
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
912
+ if images_crop_list:
913
+ images_crop = torch.stack(images_crop_list, dim=0)
914
+ else:
915
+ images_crop = torch.zeros((1, 3, base_size, base_size))
916
+
917
+
918
+
919
+ if not eval_mode:
920
+ streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
921
+ with torch.autocast("cuda", dtype=torch.bfloat16):
922
+ with torch.no_grad():
923
+ output_ids = self.generate(
924
+ input_ids.unsqueeze(0).cuda(),
925
+ images=[(images_crop.cuda(), images_ori.cuda())],
926
+ images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
927
+ images_spatial_crop = images_spatial_crop,
928
+ # do_sample=False,
929
+ # num_beams = 1,
930
+ temperature=0.0,
931
+ eos_token_id=tokenizer.eos_token_id,
932
+ streamer=streamer,
933
+ max_new_tokens=8192,
934
+ no_repeat_ngram_size = 20,
935
+ use_cache = True
936
+ )
937
+
938
+ else:
939
+ with torch.autocast("cuda", dtype=torch.bfloat16):
940
+ with torch.no_grad():
941
+ output_ids = self.generate(
942
+ input_ids.unsqueeze(0).cuda(),
943
+ images=[(images_crop.cuda(), images_ori.cuda())],
944
+ images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
945
+ images_spatial_crop = images_spatial_crop,
946
+ # do_sample=False,
947
+ # num_beams = 1,
948
+ temperature=0.0,
949
+ eos_token_id=tokenizer.eos_token_id,
950
+ max_new_tokens=8192,
951
+ no_repeat_ngram_size = 35,
952
+ use_cache = True
953
+ )
954
+
955
+
956
+ if '<image>' in conversation[0]['content'] and eval_mode:
957
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
958
+ stop_str = '<|end▁of▁sentence|>'
959
+ if outputs.endswith(stop_str):
960
+ outputs = outputs[:-len(stop_str)]
961
+ # re_match
962
+ outputs = outputs.strip()
963
+
964
+ return outputs
965
+
966
+ if '<image>' in conversation[0]['content'] and test_compress:
967
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
968
+ pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
969
+ print('='*50)
970
+ print('image size: ', (w, h))
971
+ print('valid image tokens: ', int(valid_img_tokens))
972
+ print('output texts tokens (valid): ', pure_texts_outputs_token_length)
973
+ print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
974
+ print('='*50)
975
+
976
+
977
+ if '<image>' in conversation[0]['content'] and save_results:
978
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
979
+ stop_str = '<|end▁of▁sentence|>'
980
+
981
+ print('='*15 + 'save results:' + '='*15)
982
+
983
+ # # # # conv.messages[-1][-1] = outputs
984
+ if outputs.endswith(stop_str):
985
+ outputs = outputs[:-len(stop_str)]
986
+ outputs = outputs.strip()
987
+
988
+ matches_ref, matches_images, mathes_other = re_match(outputs)
989
+ # print(matches_ref)
990
+ result = process_image_with_refs(image_draw, matches_ref, output_path)
991
+
992
+
993
+ for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
994
+ outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
995
+
996
+ for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
997
+ outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
998
+
999
+
1000
+ # if 'structural formula' in conversation[0]['content']:
1001
+ # outputs = '<smiles>' + outputs + '</smiles>'
1002
+ with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
1003
+ afile.write(outputs)
1004
+
1005
+ if 'line_type' in outputs:
1006
+ import matplotlib.pyplot as plt
1007
+ lines = eval(outputs)['Line']['line']
1008
+
1009
+ line_type = eval(outputs)['Line']['line_type']
1010
+ # print(lines)
1011
+
1012
+ endpoints = eval(outputs)['Line']['line_endpoint']
1013
+
1014
+ fig, ax = plt.subplots(figsize=(3,3), dpi=200)
1015
+ ax.set_xlim(-15, 15)
1016
+ ax.set_ylim(-15, 15)
1017
+
1018
+ for idx, line in enumerate(lines):
1019
+ try:
1020
+ p0 = eval(line.split(' -- ')[0])
1021
+ p1 = eval(line.split(' -- ')[-1])
1022
+
1023
+ if line_type[idx] == '--':
1024
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
1025
+ else:
1026
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
1027
+
1028
+ ax.scatter(p0[0], p0[1], s=5, color = 'k')
1029
+ ax.scatter(p1[0], p1[1], s=5, color = 'k')
1030
+ except:
1031
+ pass
1032
+
1033
+ for endpoint in endpoints:
1034
+
1035
+ label = endpoint.split(': ')[0]
1036
+ (x, y) = eval(endpoint.split(': ')[1])
1037
+ ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
1038
+ fontsize=5, fontweight='light')
1039
+
1040
+
1041
+ plt.savefig(f'{output_path}/geo.jpg')
1042
+ plt.close()
1043
+
1044
+ result.save(f"{output_path}/result_with_boxes.jpg")
processor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_special_token": false,
3
+ "candidate_resolutions": [
4
+ [
5
+ 1024,
6
+ 1024
7
+ ]
8
+ ],
9
+ "downsample_ratio": 4,
10
+ "ignore_id": -100,
11
+ "image_mean": [
12
+ 0.5,
13
+ 0.5,
14
+ 0.5
15
+ ],
16
+ "image_std": [
17
+ 0.5,
18
+ 0.5,
19
+ 0.5
20
+ ],
21
+ "image_token": "<image>",
22
+ "mask_prompt": false,
23
+ "normalize": true,
24
+ "pad_token": "<\uff5c\u2581pad\u2581\uff5c>",
25
+ "patch_size": 16,
26
+ "processor_class": "DeepseekVLV2Processor",
27
+ "sft_format": "deepseek"
28
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|User|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<|Assistant|>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ }
17
+ ],
18
+ "bos_token": {
19
+ "content": "<|begin▁of▁sentence|>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "eos_token": {
26
+ "content": "<|end▁of▁sentence|>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ },
32
+ "pad_token": {
33
+ "content": "<|▁pad▁|>",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false
38
+ }
39
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff