Upload folder using huggingface_hub
Browse files- .gitattributes +28 -9
- .ipynb_checkpoints/README-checkpoint.md +122 -0
- LICENSE +21 -0
- README.md +132 -3
- assets/fig1.png +3 -0
- assets/show1.jpg +3 -0
- assets/show2.jpg +3 -0
- assets/show3.jpg +3 -0
- assets/show4.jpg +3 -0
- config.json +118 -0
- configuration.json +1 -0
- configuration_deepseek_v2.py +210 -0
- conversation.py +280 -0
- deepencoder.py +1058 -0
- model-00001-of-000001.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_deepseekocr.py +1044 -0
- processor_config.json +28 -0
- special_tokens_map.json +39 -0
- tokenizer.json +0 -0
- tokenizer_config.json +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,54 @@
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 20 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
|
| 34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.gguf* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.ggml filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.llamafile* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 46 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 47 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
|
| 49 |
+
model-00001-of-000001.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -textassets/fig1.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
assets/show1.jpg filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
assets/show2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
assets/show3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
assets/show4.jpg filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/README-checkpoint.md
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
pipeline_tag: image-text-to-text
|
| 3 |
+
language:
|
| 4 |
+
- multilingual
|
| 5 |
+
tags:
|
| 6 |
+
- deepseek
|
| 7 |
+
- vision-language
|
| 8 |
+
- ocr
|
| 9 |
+
- custom_code
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
<div align="center">
|
| 13 |
+
<img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
|
| 14 |
+
</div>
|
| 15 |
+
<hr>
|
| 16 |
+
<div align="center">
|
| 17 |
+
<a href="https://www.deepseek.com/" target="_blank">
|
| 18 |
+
<img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
|
| 19 |
+
</a>
|
| 20 |
+
<a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR" target="_blank">
|
| 21 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
|
| 22 |
+
</a>
|
| 23 |
+
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
<div align="center">
|
| 27 |
+
|
| 28 |
+
<a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
|
| 29 |
+
<img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
|
| 30 |
+
</a>
|
| 31 |
+
<a href="https://twitter.com/deepseek_ai" target="_blank">
|
| 32 |
+
<img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
|
| 33 |
+
</a>
|
| 34 |
+
|
| 35 |
+
</div>
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
<p align="center">
|
| 40 |
+
<a href="https://github.com/deepseek-ai/DeepSeek-OCR"><b>🌟 Github</b></a> |
|
| 41 |
+
<a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR"><b>📥 Model Download</b></a> |
|
| 42 |
+
<a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
|
| 43 |
+
<a href=""><b>📄 Arxiv Paper Link</b></a> |
|
| 44 |
+
</p>
|
| 45 |
+
<h2>
|
| 46 |
+
<p align="center">
|
| 47 |
+
<a href="">DeepSeek-OCR: Contexts Optical Compression</a>
|
| 48 |
+
</p>
|
| 49 |
+
</h2>
|
| 50 |
+
<p align="center">
|
| 51 |
+
<img src="assets/fig1.png" style="width: 1000px" align=center>
|
| 52 |
+
</p>
|
| 53 |
+
<p align="center">
|
| 54 |
+
<a href="">Explore the boundaries of visual-text compression.</a>
|
| 55 |
+
</p>
|
| 56 |
+
|
| 57 |
+
## Usage
|
| 58 |
+
Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
torch==2.6.0
|
| 62 |
+
transformers==4.46.3
|
| 63 |
+
tokenizers==0.20.3
|
| 64 |
+
einops
|
| 65 |
+
addict
|
| 66 |
+
easydict
|
| 67 |
+
pip install flash-attn==2.7.3 --no-build-isolation
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
from transformers import AutoModel, AutoTokenizer
|
| 72 |
+
import torch
|
| 73 |
+
import os
|
| 74 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 75 |
+
model_name = 'deepseek-ai/DeepSeek-OCR'
|
| 76 |
+
|
| 77 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 78 |
+
model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
|
| 79 |
+
model = model.eval().cuda().to(torch.bfloat16)
|
| 80 |
+
|
| 81 |
+
# prompt = "<image>\nFree OCR. "
|
| 82 |
+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
| 83 |
+
image_file = 'your_image.jpg'
|
| 84 |
+
output_path = 'your/output/dir'
|
| 85 |
+
|
| 86 |
+
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
|
| 87 |
+
|
| 88 |
+
# Tiny: base_size = 512, image_size = 512, crop_mode = False
|
| 89 |
+
# Small: base_size = 640, image_size = 640, crop_mode = False
|
| 90 |
+
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
| 91 |
+
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
| 92 |
+
|
| 93 |
+
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
|
| 94 |
+
|
| 95 |
+
res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## vLLM
|
| 99 |
+
Refer to [🌟GitHub](https://github.com/deepseek-ai/DeepSeek-OCR/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
|
| 100 |
+
|
| 101 |
+
## Visualizations
|
| 102 |
+
<table>
|
| 103 |
+
<tr>
|
| 104 |
+
<td><img src="assets/show1.jpg" style="width: 500px"></td>
|
| 105 |
+
<td><img src="assets/show2.jpg" style="width: 500px"></td>
|
| 106 |
+
</tr>
|
| 107 |
+
<tr>
|
| 108 |
+
<td><img src="assets/show3.jpg" style="width: 500px"></td>
|
| 109 |
+
<td><img src="assets/show4.jpg" style="width: 500px"></td>
|
| 110 |
+
</tr>
|
| 111 |
+
</table>
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
## Acknowledgement
|
| 115 |
+
|
| 116 |
+
We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
|
| 117 |
+
|
| 118 |
+
We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
## Citation
|
| 122 |
+
Coming soon!
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,132 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
pipeline_tag: image-text-to-text
|
| 3 |
+
language:
|
| 4 |
+
- multilingual
|
| 5 |
+
tags:
|
| 6 |
+
- mindspore
|
| 7 |
+
- mindnlp
|
| 8 |
+
- deepseek
|
| 9 |
+
- vision-language
|
| 10 |
+
- ocr
|
| 11 |
+
- custom_code
|
| 12 |
+
license: mit
|
| 13 |
+
---
|
| 14 |
+
<div align="center">
|
| 15 |
+
<img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
|
| 16 |
+
</div>
|
| 17 |
+
<hr>
|
| 18 |
+
<div align="center">
|
| 19 |
+
<a href="https://www.deepseek.com/" target="_blank">
|
| 20 |
+
<img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
|
| 21 |
+
</a>
|
| 22 |
+
<a href="https://huggingface.co/lvyufeng/DeepSeek-OCR" target="_blank">
|
| 23 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
|
| 24 |
+
</a>
|
| 25 |
+
|
| 26 |
+
</div>
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
<p align="center">
|
| 32 |
+
<a href="https://github.com/mindspore-lab/mindnlp/tree/master/examples/transformers/inference/deepseek-ocr"><b>🌟 Github</b></a> |
|
| 33 |
+
<a href="https://huggingface.co/lvyufeng/DeepSeek-OCR"><b>📥 Model Download</b></a> |
|
| 34 |
+
<a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
|
| 35 |
+
<a href=""><b>📄 Arxiv Paper Link</b></a> |
|
| 36 |
+
</p>
|
| 37 |
+
<h2>
|
| 38 |
+
<p align="center">
|
| 39 |
+
<a href="">DeepSeek-OCR: Contexts Optical Compression</a>
|
| 40 |
+
</p>
|
| 41 |
+
</h2>
|
| 42 |
+
<p align="center">
|
| 43 |
+
<a href="">Explore the boundaries of visual-text compression.</a>
|
| 44 |
+
</p>
|
| 45 |
+
|
| 46 |
+
The official version of DeepSeek-OCR has limited the transformers version to 4.46.3 and has not been adapted to the latest version. Therefore, this community edition has modified the modeling.py module to facilitate user convenience without requiring a transformers downgrade. Additionally, this version has been adapted for MindSpore+MindNLP compatibility, and users are welcome to utilize it on Ascend hardware.
|
| 47 |
+
|
| 48 |
+
Feel free to opt for various attention implementations such as Flash Attention or SDPA to leverage the latest optimizations in transformers for a performance boost.
|
| 49 |
+
|
| 50 |
+
## MindSpore Usage
|
| 51 |
+
Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
|
| 52 |
+
|
| 53 |
+
```
|
| 54 |
+
mindspore==2.7.0
|
| 55 |
+
mindnlp==0.5.0rc3
|
| 56 |
+
transformers==4.57.1
|
| 57 |
+
tokenizers
|
| 58 |
+
einops
|
| 59 |
+
addict
|
| 60 |
+
easydict
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
import os
|
| 65 |
+
import mindnlp
|
| 66 |
+
import mindspore
|
| 67 |
+
from transformers import AutoModel, AutoTokenizer
|
| 68 |
+
|
| 69 |
+
model_name = 'lvyufeng/DeepSeek-OCR-Community-Latest'
|
| 70 |
+
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 72 |
+
model = AutoModel.from_pretrained(model_name, dtype=mindspore.float16, _attn_implementation='sdpa', trust_remote_code=True, use_safetensors=True, device_map='auto')
|
| 73 |
+
model = model.eval()
|
| 74 |
+
|
| 75 |
+
# prompt = "<image>\nFree OCR. "
|
| 76 |
+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
| 77 |
+
image_file = 'your_image.jpg'
|
| 78 |
+
output_path = 'your/output/dir'
|
| 79 |
+
|
| 80 |
+
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
|
| 81 |
+
|
| 82 |
+
# Tiny: base_size = 512, image_size = 512, crop_mode = False
|
| 83 |
+
# Small: base_size = 640, image_size = 640, crop_mode = False
|
| 84 |
+
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
| 85 |
+
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
| 86 |
+
|
| 87 |
+
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
|
| 88 |
+
|
| 89 |
+
res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
## Pytorch Usage
|
| 94 |
+
Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
|
| 95 |
+
|
| 96 |
+
```
|
| 97 |
+
torch
|
| 98 |
+
transformers==4.57.1
|
| 99 |
+
tokenizers
|
| 100 |
+
einops
|
| 101 |
+
addict
|
| 102 |
+
easydict
|
| 103 |
+
pip install flash-attn
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
from transformers import AutoModel, AutoTokenizer
|
| 108 |
+
import torch
|
| 109 |
+
import os
|
| 110 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 111 |
+
model_name = 'deepseek-ai/DeepSeek-OCR'
|
| 112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, dtype=torch.bfloat16,trust_remote_code=True, device_map='auto')
|
| 113 |
+
model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
|
| 114 |
+
model = model.eval()
|
| 115 |
+
# prompt = "<image>\nFree OCR. "
|
| 116 |
+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
| 117 |
+
image_file = 'your_image.jpg'
|
| 118 |
+
output_path = 'your/output/dir'
|
| 119 |
+
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
|
| 120 |
+
# Tiny: base_size = 512, image_size = 512, crop_mode = False
|
| 121 |
+
# Small: base_size = 640, image_size = 640, crop_mode = False
|
| 122 |
+
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
| 123 |
+
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
| 124 |
+
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
|
| 125 |
+
res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Acknowledgement
|
| 129 |
+
|
| 130 |
+
We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
|
| 131 |
+
|
| 132 |
+
We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
|
assets/fig1.png
ADDED
|
Git LFS Details
|
assets/show1.jpg
ADDED
|
Git LFS Details
|
assets/show2.jpg
ADDED
|
Git LFS Details
|
assets/show3.jpg
ADDED
|
Git LFS Details
|
assets/show4.jpg
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "deepseek-ai/DeepSeek-OCR",
|
| 3 |
+
"candidate_resolutions": [
|
| 4 |
+
[
|
| 5 |
+
1024,
|
| 6 |
+
1024
|
| 7 |
+
]
|
| 8 |
+
],
|
| 9 |
+
"global_view_pos": "head",
|
| 10 |
+
"architectures": [
|
| 11 |
+
"DeepseekOCRForCausalLM"
|
| 12 |
+
],
|
| 13 |
+
"auto_map": {
|
| 14 |
+
"AutoConfig": "modeling_deepseekocr.DeepseekOCRConfig",
|
| 15 |
+
"AutoModel": "modeling_deepseekocr.DeepseekOCRForCausalLM"
|
| 16 |
+
},
|
| 17 |
+
"language_config": {
|
| 18 |
+
"architectures": [
|
| 19 |
+
"DeepseekV2ForCausalLM"
|
| 20 |
+
],
|
| 21 |
+
"auto_map": {
|
| 22 |
+
"AutoConfig": "configuration_deepseekv2.DeepseekV2Config",
|
| 23 |
+
"AutoModel": "modeling_deepseek.DeepseekV2Model",
|
| 24 |
+
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV2ForCausalLM"
|
| 25 |
+
},
|
| 26 |
+
"bos_token_id": 0,
|
| 27 |
+
"eos_token_id": 1,
|
| 28 |
+
"first_k_dense_replace": 1,
|
| 29 |
+
"hidden_size": 1280,
|
| 30 |
+
"intermediate_size": 6848,
|
| 31 |
+
"kv_lora_rank": null,
|
| 32 |
+
"lm_head": true,
|
| 33 |
+
"max_position_embeddings": 8192,
|
| 34 |
+
"moe_intermediate_size": 896,
|
| 35 |
+
"n_group": 1,
|
| 36 |
+
"n_routed_experts": 64,
|
| 37 |
+
"n_shared_experts": 2,
|
| 38 |
+
"num_attention_heads": 10,
|
| 39 |
+
"num_experts_per_tok": 6,
|
| 40 |
+
"num_hidden_layers": 12,
|
| 41 |
+
"num_key_value_heads": 10,
|
| 42 |
+
"q_lora_rank": null,
|
| 43 |
+
"qk_nope_head_dim": 0,
|
| 44 |
+
"qk_rope_head_dim": 0,
|
| 45 |
+
"rm_head": false,
|
| 46 |
+
"topk_group": 1,
|
| 47 |
+
"topk_method": "greedy",
|
| 48 |
+
"torch_dtype": "bfloat16",
|
| 49 |
+
"use_mla": false,
|
| 50 |
+
"v_head_dim": 0,
|
| 51 |
+
"vocab_size": 129280
|
| 52 |
+
},
|
| 53 |
+
"model_type": "deepseek_vl_v2",
|
| 54 |
+
"projector_config": {
|
| 55 |
+
"input_dim": 2048,
|
| 56 |
+
"model_type": "mlp_projector",
|
| 57 |
+
"n_embed": 1280,
|
| 58 |
+
"projector_type": "linear"
|
| 59 |
+
},
|
| 60 |
+
"tile_tag": "2D",
|
| 61 |
+
"torch_dtype": "bfloat16",
|
| 62 |
+
"transformers_version": "4.46.3",
|
| 63 |
+
"vision_config": {
|
| 64 |
+
"image_size": 1024,
|
| 65 |
+
"mlp_ratio": 3.7362,
|
| 66 |
+
"model_name": "deeplip_b_l",
|
| 67 |
+
"model_type": "vision",
|
| 68 |
+
"width": {
|
| 69 |
+
"clip-l-14-224": {
|
| 70 |
+
"heads": 16,
|
| 71 |
+
"image_size": 224,
|
| 72 |
+
"layers": 24,
|
| 73 |
+
"patch_size": 14,
|
| 74 |
+
"width": 1024
|
| 75 |
+
},
|
| 76 |
+
"sam_vit_b": {
|
| 77 |
+
"downsample_channels": [
|
| 78 |
+
512,
|
| 79 |
+
1024
|
| 80 |
+
],
|
| 81 |
+
"global_attn_indexes": [
|
| 82 |
+
2,
|
| 83 |
+
5,
|
| 84 |
+
8,
|
| 85 |
+
11
|
| 86 |
+
],
|
| 87 |
+
"heads": 12,
|
| 88 |
+
"layers": 12,
|
| 89 |
+
"width": 768
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
"bos_token_id": 0,
|
| 94 |
+
"eos_token_id": 1,
|
| 95 |
+
"first_k_dense_replace": 1,
|
| 96 |
+
"hidden_size": 1280,
|
| 97 |
+
"intermediate_size": 6848,
|
| 98 |
+
"kv_lora_rank": null,
|
| 99 |
+
"lm_head": true,
|
| 100 |
+
"max_position_embeddings": 8192,
|
| 101 |
+
"moe_intermediate_size": 896,
|
| 102 |
+
"n_group": 1,
|
| 103 |
+
"n_routed_experts": 64,
|
| 104 |
+
"n_shared_experts": 2,
|
| 105 |
+
"num_attention_heads": 10,
|
| 106 |
+
"num_experts_per_tok": 6,
|
| 107 |
+
"num_hidden_layers": 12,
|
| 108 |
+
"num_key_value_heads": 10,
|
| 109 |
+
"q_lora_rank": null,
|
| 110 |
+
"qk_nope_head_dim": 0,
|
| 111 |
+
"qk_rope_head_dim": 0,
|
| 112 |
+
"rm_head": false,
|
| 113 |
+
"topk_group": 1,
|
| 114 |
+
"topk_method": "greedy",
|
| 115 |
+
"use_mla": false,
|
| 116 |
+
"v_head_dim": 0,
|
| 117 |
+
"vocab_size": 129280
|
| 118 |
+
}
|
configuration.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"framework": "pytorch", "task": "image-text-to-text", "allow_remote": true}
|
configuration_deepseek_v2.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
from transformers.utils import logging
|
| 3 |
+
|
| 4 |
+
logger = logging.get_logger(__name__)
|
| 5 |
+
|
| 6 |
+
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
| 7 |
+
class DeepseekV2Config(PretrainedConfig):
|
| 8 |
+
r"""
|
| 9 |
+
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
|
| 10 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 11 |
+
defaults will yield a similar configuration to that of the DeepSeek-V2 with multi-latent attention.
|
| 12 |
+
|
| 13 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 14 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
vocab_size (`int`, *optional*, defaults to 102400):
|
| 19 |
+
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
| 20 |
+
`inputs_ids` passed when calling [`DeepseekV2Model`]
|
| 21 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 22 |
+
Dimension of the hidden representations.
|
| 23 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 24 |
+
Dimension of the MLP representations.
|
| 25 |
+
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
| 26 |
+
Dimension of the MoE representations.
|
| 27 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 28 |
+
Number of hidden layers in the Transformer decoder.
|
| 29 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 30 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 31 |
+
n_shared_experts (`int`, *optional*, defaults to None):
|
| 32 |
+
Number of shared experts, None means dense model.
|
| 33 |
+
n_routed_experts (`int`, *optional*, defaults to None):
|
| 34 |
+
Number of routed experts, None means dense model.
|
| 35 |
+
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
| 36 |
+
Scaling factor or routed experts.
|
| 37 |
+
topk_method (`str`, *optional*, defaults to `gready`):
|
| 38 |
+
Topk method used in routed gate.
|
| 39 |
+
n_group (`int`, *optional*, defaults to None):
|
| 40 |
+
Number of groups for routed experts.
|
| 41 |
+
topk_group (`int`, *optional*, defaults to None):
|
| 42 |
+
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
| 43 |
+
num_experts_per_tok (`int`, *optional*, defaults to None):
|
| 44 |
+
Number of selected experts, None means dense model.
|
| 45 |
+
moe_layer_freq (`int`, *optional*, defaults to 1):
|
| 46 |
+
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
| 47 |
+
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
| 48 |
+
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
| 49 |
+
\--k dense layers--/
|
| 50 |
+
norm_topk_prob (`bool`, *optional*, defaults to False):
|
| 51 |
+
Whether to normalize the weights of the routed experts.
|
| 52 |
+
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
| 53 |
+
Method of computing expert weights.
|
| 54 |
+
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
| 55 |
+
Auxiliary loss weight coefficient.
|
| 56 |
+
seq_aux = (`bool`, *optional*, defaults to True):
|
| 57 |
+
Whether to compute the auxiliary loss for each individual sample.
|
| 58 |
+
num_key_value_heads (`int`, *optional*):
|
| 59 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 60 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 61 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 62 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 63 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 64 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 65 |
+
`num_attention_heads`.
|
| 66 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 67 |
+
The non-linear activation function (function or string) in the decoder.
|
| 68 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 69 |
+
The maximum sequence length that this model might ever be used with.
|
| 70 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 71 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 72 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 73 |
+
The epsilon used by the rms normalization layers.
|
| 74 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 75 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 76 |
+
relevant if `config.is_decoder=True`.
|
| 77 |
+
pad_token_id (`int`, *optional*):
|
| 78 |
+
Padding token id.
|
| 79 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 80 |
+
Beginning of stream token id.
|
| 81 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 82 |
+
End of stream token id.
|
| 83 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 84 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 85 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
| 86 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
| 87 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 88 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Whether to tie weight embeddings
|
| 90 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 91 |
+
The base period of the RoPE embeddings.
|
| 92 |
+
rope_scaling (`Dict`, *optional*):
|
| 93 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 94 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 95 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 96 |
+
`max_position_embeddings` to the expected new maximum.
|
| 97 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 98 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 99 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 100 |
+
The dropout ratio for the attention probabilities.
|
| 101 |
+
use_mla (`bool`, *optional*, defaults to `True`): Use multi-latent attention or multi-head attention. If True,
|
| 102 |
+
the model will use multi-latent attention, otherwise, it will use multi-head attention.
|
| 103 |
+
|
| 104 |
+
```python
|
| 105 |
+
>>> from transformers import DeepseekV2Model, DeepseekV2Config
|
| 106 |
+
|
| 107 |
+
>>> # Initializing a Deepseek-V2 style configuration
|
| 108 |
+
>>> configuration = DeepseekV2Config()
|
| 109 |
+
|
| 110 |
+
>>> # Accessing the model configuration
|
| 111 |
+
>>> configuration = model.config
|
| 112 |
+
```"""
|
| 113 |
+
|
| 114 |
+
model_type = "deepseek_v2"
|
| 115 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
vocab_size=102400,
|
| 120 |
+
hidden_size=4096,
|
| 121 |
+
intermediate_size=11008,
|
| 122 |
+
moe_intermediate_size = 1407,
|
| 123 |
+
num_hidden_layers=30,
|
| 124 |
+
num_attention_heads=32,
|
| 125 |
+
num_key_value_heads=32,
|
| 126 |
+
n_shared_experts = None,
|
| 127 |
+
n_routed_experts = None,
|
| 128 |
+
ep_size = 1,
|
| 129 |
+
routed_scaling_factor = 1.0,
|
| 130 |
+
kv_lora_rank = 512,
|
| 131 |
+
q_lora_rank = 1536,
|
| 132 |
+
qk_rope_head_dim = 64,
|
| 133 |
+
v_head_dim = 128,
|
| 134 |
+
qk_nope_head_dim = 128,
|
| 135 |
+
topk_method = 'gready',
|
| 136 |
+
n_group = None,
|
| 137 |
+
topk_group = None,
|
| 138 |
+
num_experts_per_tok = None,
|
| 139 |
+
moe_layer_freq = 1,
|
| 140 |
+
first_k_dense_replace = 0,
|
| 141 |
+
norm_topk_prob = False,
|
| 142 |
+
scoring_func = 'softmax',
|
| 143 |
+
aux_loss_alpha = 0.001,
|
| 144 |
+
seq_aux = True,
|
| 145 |
+
hidden_act="silu",
|
| 146 |
+
max_position_embeddings=2048,
|
| 147 |
+
initializer_range=0.02,
|
| 148 |
+
rms_norm_eps=1e-6,
|
| 149 |
+
use_cache=True,
|
| 150 |
+
pad_token_id=None,
|
| 151 |
+
bos_token_id=100000,
|
| 152 |
+
eos_token_id=100001,
|
| 153 |
+
pretraining_tp=1,
|
| 154 |
+
tie_word_embeddings=False,
|
| 155 |
+
rope_theta=10000.0,
|
| 156 |
+
rope_scaling=None,
|
| 157 |
+
attention_bias=False,
|
| 158 |
+
attention_dropout=0.0,
|
| 159 |
+
use_mla=True,
|
| 160 |
+
**kwargs,
|
| 161 |
+
):
|
| 162 |
+
self.vocab_size = vocab_size
|
| 163 |
+
self.max_position_embeddings = max_position_embeddings
|
| 164 |
+
self.hidden_size = hidden_size
|
| 165 |
+
self.intermediate_size = intermediate_size
|
| 166 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 167 |
+
self.num_hidden_layers = num_hidden_layers
|
| 168 |
+
self.num_attention_heads = num_attention_heads
|
| 169 |
+
self.n_shared_experts = n_shared_experts
|
| 170 |
+
self.n_routed_experts = n_routed_experts
|
| 171 |
+
self.ep_size = ep_size
|
| 172 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 173 |
+
self.kv_lora_rank = kv_lora_rank
|
| 174 |
+
self.q_lora_rank = q_lora_rank
|
| 175 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 176 |
+
self.v_head_dim = v_head_dim
|
| 177 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 178 |
+
self.topk_method = topk_method
|
| 179 |
+
self.n_group = n_group
|
| 180 |
+
self.topk_group = topk_group
|
| 181 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 182 |
+
self.moe_layer_freq = moe_layer_freq
|
| 183 |
+
self.first_k_dense_replace = first_k_dense_replace
|
| 184 |
+
self.norm_topk_prob = norm_topk_prob
|
| 185 |
+
self.scoring_func = scoring_func
|
| 186 |
+
self.aux_loss_alpha = aux_loss_alpha
|
| 187 |
+
self.seq_aux = seq_aux
|
| 188 |
+
# for backward compatibility
|
| 189 |
+
if num_key_value_heads is None:
|
| 190 |
+
num_key_value_heads = num_attention_heads
|
| 191 |
+
|
| 192 |
+
self.num_key_value_heads = num_key_value_heads
|
| 193 |
+
self.hidden_act = hidden_act
|
| 194 |
+
self.initializer_range = initializer_range
|
| 195 |
+
self.rms_norm_eps = float(rms_norm_eps)
|
| 196 |
+
self.pretraining_tp = pretraining_tp
|
| 197 |
+
self.use_cache = use_cache
|
| 198 |
+
self.rope_theta = rope_theta
|
| 199 |
+
self.rope_scaling = rope_scaling
|
| 200 |
+
self.attention_bias = attention_bias
|
| 201 |
+
self.attention_dropout = attention_dropout
|
| 202 |
+
self.use_mla = use_mla
|
| 203 |
+
|
| 204 |
+
super().__init__(
|
| 205 |
+
pad_token_id=pad_token_id,
|
| 206 |
+
bos_token_id=bos_token_id,
|
| 207 |
+
eos_token_id=eos_token_id,
|
| 208 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 209 |
+
**kwargs,
|
| 210 |
+
)
|
conversation.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import dataclasses
|
| 6 |
+
from enum import IntEnum, auto
|
| 7 |
+
from typing import Any, Dict, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SeparatorStyle(IntEnum):
|
| 11 |
+
"""Separator styles."""
|
| 12 |
+
|
| 13 |
+
DeepSeek = auto()
|
| 14 |
+
DeepSeekV2 = auto()
|
| 15 |
+
PLAIN = auto()
|
| 16 |
+
ALIGNMENT = auto()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclasses.dataclass
|
| 20 |
+
class Conversation:
|
| 21 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
| 22 |
+
|
| 23 |
+
# The name of this template
|
| 24 |
+
name: str
|
| 25 |
+
# The template of the system prompt
|
| 26 |
+
system_template: str = "{system_message}"
|
| 27 |
+
# The system message
|
| 28 |
+
system_message: str = ""
|
| 29 |
+
# The names of two roles
|
| 30 |
+
roles: List[str] = (("USER", "ASSISTANT"),)
|
| 31 |
+
# All messages. Each item is (role, message).
|
| 32 |
+
messages: List[List[str]] = ()
|
| 33 |
+
# The number of few shot examples
|
| 34 |
+
offset: int = 0
|
| 35 |
+
# The separator style and configurations
|
| 36 |
+
sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
|
| 37 |
+
sep: str = "\n"
|
| 38 |
+
sep2: str = None
|
| 39 |
+
# Stop criteria (the default one is EOS token)
|
| 40 |
+
stop_str: str = None
|
| 41 |
+
# Stops generation if meeting any token in this list
|
| 42 |
+
stop_token_ids: List[int] = None
|
| 43 |
+
|
| 44 |
+
def get_prompt(self) -> str:
|
| 45 |
+
"""Get the prompt for generation."""
|
| 46 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
| 47 |
+
if self.sep_style == SeparatorStyle.DeepSeek:
|
| 48 |
+
seps = [self.sep, self.sep2]
|
| 49 |
+
if system_prompt == "" or system_prompt is None:
|
| 50 |
+
ret = ""
|
| 51 |
+
else:
|
| 52 |
+
ret = system_prompt + seps[0]
|
| 53 |
+
for i, (role, message) in enumerate(self.messages):
|
| 54 |
+
if message:
|
| 55 |
+
ret += role + ": " + message + seps[i % 2]
|
| 56 |
+
else:
|
| 57 |
+
ret += role + ":"
|
| 58 |
+
return ret
|
| 59 |
+
elif self.sep_style == SeparatorStyle.DeepSeekV2:
|
| 60 |
+
seps = [self.sep, self.sep2]
|
| 61 |
+
if system_prompt == "" or system_prompt is None:
|
| 62 |
+
ret = ""
|
| 63 |
+
else:
|
| 64 |
+
ret = system_prompt + seps[0]
|
| 65 |
+
for i, (role, message) in enumerate(self.messages):
|
| 66 |
+
if message:
|
| 67 |
+
if role == "User":
|
| 68 |
+
ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
|
| 69 |
+
else:
|
| 70 |
+
ret += message + self.sep2
|
| 71 |
+
else:
|
| 72 |
+
ret = ret
|
| 73 |
+
return ret
|
| 74 |
+
|
| 75 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 76 |
+
seps = [self.sep, self.sep2]
|
| 77 |
+
ret = ""
|
| 78 |
+
for i, (role, message) in enumerate(self.messages):
|
| 79 |
+
if message:
|
| 80 |
+
if type(message) is tuple:
|
| 81 |
+
message, _, _ = message
|
| 82 |
+
if i % 2 == 0:
|
| 83 |
+
ret += message + seps[i % 2]
|
| 84 |
+
else:
|
| 85 |
+
ret += message + seps[i % 2]
|
| 86 |
+
else:
|
| 87 |
+
ret += ""
|
| 88 |
+
return ret
|
| 89 |
+
elif self.sep_style == SeparatorStyle.ALIGNMENT:
|
| 90 |
+
seps = [self.sep, self.sep2]
|
| 91 |
+
ret = ""
|
| 92 |
+
for i, (role, message) in enumerate(self.messages):
|
| 93 |
+
if message:
|
| 94 |
+
if type(message) is tuple:
|
| 95 |
+
message, _, _ = message
|
| 96 |
+
if i % 2 == 0:
|
| 97 |
+
ret += '<image>\n' + seps[i % 2]
|
| 98 |
+
else:
|
| 99 |
+
ret += message + seps[i % 2]
|
| 100 |
+
else:
|
| 101 |
+
ret += ""
|
| 102 |
+
return ret
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 105 |
+
|
| 106 |
+
def set_system_message(self, system_message: str):
|
| 107 |
+
"""Set the system message."""
|
| 108 |
+
self.system_message = system_message
|
| 109 |
+
|
| 110 |
+
def append_message(self, role: str, message: str):
|
| 111 |
+
"""Append a new message."""
|
| 112 |
+
self.messages.append([role, message])
|
| 113 |
+
|
| 114 |
+
def update_last_message(self, message: str):
|
| 115 |
+
"""Update the last output.
|
| 116 |
+
|
| 117 |
+
The last message is typically set to be None when constructing the prompt,
|
| 118 |
+
so we need to update it in-place after getting the response from a model.
|
| 119 |
+
"""
|
| 120 |
+
self.messages[-1][1] = message
|
| 121 |
+
|
| 122 |
+
def reset_message(self):
|
| 123 |
+
"""Reset a new message."""
|
| 124 |
+
self.messages = []
|
| 125 |
+
|
| 126 |
+
def to_gradio_chatbot(self):
|
| 127 |
+
"""Convert the conversation to gradio chatbot format."""
|
| 128 |
+
ret = []
|
| 129 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 130 |
+
if i % 2 == 0:
|
| 131 |
+
ret.append([msg, None])
|
| 132 |
+
else:
|
| 133 |
+
ret[-1][-1] = msg
|
| 134 |
+
return ret
|
| 135 |
+
|
| 136 |
+
def to_openai_api_messages(self):
|
| 137 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
| 138 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
| 139 |
+
ret = [{"role": "system", "content": system_prompt}]
|
| 140 |
+
|
| 141 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
| 142 |
+
if i % 2 == 0:
|
| 143 |
+
ret.append({"role": "user", "content": msg})
|
| 144 |
+
else:
|
| 145 |
+
if msg is not None:
|
| 146 |
+
ret.append({"role": "assistant", "content": msg})
|
| 147 |
+
return ret
|
| 148 |
+
|
| 149 |
+
def copy(self):
|
| 150 |
+
return Conversation(
|
| 151 |
+
name=self.name,
|
| 152 |
+
system_template=self.system_template,
|
| 153 |
+
system_message=self.system_message,
|
| 154 |
+
roles=self.roles,
|
| 155 |
+
messages=[[x, y] for x, y in self.messages],
|
| 156 |
+
offset=self.offset,
|
| 157 |
+
sep_style=self.sep_style,
|
| 158 |
+
sep=self.sep,
|
| 159 |
+
sep2=self.sep2,
|
| 160 |
+
stop_str=self.stop_str,
|
| 161 |
+
stop_token_ids=self.stop_token_ids,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def dict(self):
|
| 165 |
+
return {
|
| 166 |
+
"template_name": self.name,
|
| 167 |
+
"system_message": self.system_message,
|
| 168 |
+
"roles": self.roles,
|
| 169 |
+
"messages": self.messages,
|
| 170 |
+
"offset": self.offset,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# A global registry for all conversation templates
|
| 175 |
+
conv_templates: Dict[str, Conversation] = {}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
| 179 |
+
"""Register a new conversation template."""
|
| 180 |
+
if not override:
|
| 181 |
+
assert template.name not in conv_templates, f"{template.name} has been registered."
|
| 182 |
+
|
| 183 |
+
conv_templates[template.name] = template
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_conv_template(name: str) -> Conversation:
|
| 187 |
+
"""Get a conversation template."""
|
| 188 |
+
return conv_templates[name].copy()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
register_conv_template(
|
| 192 |
+
Conversation(
|
| 193 |
+
name="deepseek",
|
| 194 |
+
system_template="{system_message}",
|
| 195 |
+
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
| 196 |
+
# "thinking step by step to be sure you get the right answer.",
|
| 197 |
+
system_message="",
|
| 198 |
+
roles=("<|User|>", "<|Assistant|>"),
|
| 199 |
+
messages=(),
|
| 200 |
+
offset=0,
|
| 201 |
+
sep_style=SeparatorStyle.DeepSeek,
|
| 202 |
+
sep="\n\n",
|
| 203 |
+
sep2="<|end▁of▁sentence|>",
|
| 204 |
+
stop_token_ids=[100001],
|
| 205 |
+
stop_str=["User:", "<|end▁of▁sentence|>"]
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
register_conv_template(
|
| 209 |
+
Conversation(
|
| 210 |
+
name="deepseekv2",
|
| 211 |
+
system_template="{system_message}",
|
| 212 |
+
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
| 213 |
+
# "thinking step by step to be sure you get the right answer.",
|
| 214 |
+
system_message="",
|
| 215 |
+
roles=("<|User|>", "<|Assistant|>"),
|
| 216 |
+
messages=(),
|
| 217 |
+
offset=0,
|
| 218 |
+
sep_style=SeparatorStyle.DeepSeek,
|
| 219 |
+
sep="",
|
| 220 |
+
sep2="<|end▁of▁sentence|>",
|
| 221 |
+
stop_token_ids=[100001],
|
| 222 |
+
stop_str=["User:", "<|end▁of▁sentence|>"]
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
register_conv_template(
|
| 228 |
+
Conversation(
|
| 229 |
+
name="plain",
|
| 230 |
+
system_template="",
|
| 231 |
+
system_message="",
|
| 232 |
+
roles=("", ""),
|
| 233 |
+
messages=(),
|
| 234 |
+
offset=0,
|
| 235 |
+
sep_style=SeparatorStyle.PLAIN,
|
| 236 |
+
sep="",
|
| 237 |
+
sep2="",
|
| 238 |
+
stop_token_ids=[100001],
|
| 239 |
+
stop_str=['</s>'],
|
| 240 |
+
)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
register_conv_template(
|
| 245 |
+
Conversation(
|
| 246 |
+
name="alignment",
|
| 247 |
+
system_template="",
|
| 248 |
+
system_message="",
|
| 249 |
+
roles=("", ""),
|
| 250 |
+
messages=(),
|
| 251 |
+
offset=0,
|
| 252 |
+
sep_style=SeparatorStyle.ALIGNMENT,
|
| 253 |
+
sep="",
|
| 254 |
+
sep2="",
|
| 255 |
+
stop_token_ids=[100001],
|
| 256 |
+
stop_str=['</s>'],
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
print("deepseek template:")
|
| 263 |
+
conv = get_conv_template("deepseek")
|
| 264 |
+
conv.append_message(conv.roles[0], "Hello!")
|
| 265 |
+
conv.append_message(conv.roles[1], "Hi! This is Tony.")
|
| 266 |
+
conv.append_message(conv.roles[0], "Who are you?")
|
| 267 |
+
conv.append_message(conv.roles[1], "I am a helpful assistant.")
|
| 268 |
+
conv.append_message(conv.roles[0], "How are you?")
|
| 269 |
+
conv.append_message(conv.roles[1], None)
|
| 270 |
+
print(conv.get_prompt())
|
| 271 |
+
|
| 272 |
+
print("deepseekv2 template:")
|
| 273 |
+
conv = get_conv_template("deepseekv2")
|
| 274 |
+
conv.append_message(conv.roles[0], "Hello!")
|
| 275 |
+
conv.append_message(conv.roles[1], "Hi! This is Tony.")
|
| 276 |
+
conv.append_message(conv.roles[0], "Who are you?")
|
| 277 |
+
conv.append_message(conv.roles[1], "I am a helpful assistant.")
|
| 278 |
+
conv.append_message(conv.roles[0], "How are you?")
|
| 279 |
+
conv.append_message(conv.roles[1], None)
|
| 280 |
+
print(conv.get_prompt())
|
deepencoder.py
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
from contextlib import nullcontext
|
| 7 |
+
import math
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
# from megatron.model import LayerNorm
|
| 10 |
+
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from easydict import EasyDict as adict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from typing import Optional, Tuple, Type
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MlpProjector(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(self, cfg):
|
| 23 |
+
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.cfg = cfg
|
| 27 |
+
|
| 28 |
+
if cfg.projector_type == "identity":
|
| 29 |
+
modules = nn.Identity()
|
| 30 |
+
|
| 31 |
+
elif cfg.projector_type == "linear":
|
| 32 |
+
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
| 33 |
+
|
| 34 |
+
elif cfg.projector_type == "mlp_gelu":
|
| 35 |
+
mlp_depth = cfg.get("depth", 1)
|
| 36 |
+
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
| 37 |
+
for _ in range(1, mlp_depth):
|
| 38 |
+
modules.append(nn.GELU())
|
| 39 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
| 40 |
+
modules = nn.Sequential(*modules)
|
| 41 |
+
|
| 42 |
+
elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
|
| 43 |
+
mlp_depth = cfg.get("depth", 1)
|
| 44 |
+
mlp_ratio = cfg.get("mlp_ratio", 1)
|
| 45 |
+
modules = [
|
| 46 |
+
nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
|
| 47 |
+
nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
|
| 48 |
+
]
|
| 49 |
+
for _ in range(1, mlp_depth - 1):
|
| 50 |
+
modules.append(nn.GELU())
|
| 51 |
+
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
| 52 |
+
modules.append(nn.GELU())
|
| 53 |
+
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
| 54 |
+
modules = nn.Sequential(*modules)
|
| 55 |
+
|
| 56 |
+
elif cfg.projector_type == "downsample_mlp_gelu":
|
| 57 |
+
mlp_depth = cfg.get("depth", 1)
|
| 58 |
+
mlp_ratio = cfg.get("mlp_ratio", 1)
|
| 59 |
+
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
|
| 60 |
+
for _ in range(1, mlp_depth - 1):
|
| 61 |
+
modules.append(nn.GELU())
|
| 62 |
+
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
| 63 |
+
modules.append(nn.GELU())
|
| 64 |
+
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
| 65 |
+
modules = nn.Sequential(*modules)
|
| 66 |
+
|
| 67 |
+
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
| 68 |
+
mlp_depth = cfg.get("depth", 1)
|
| 69 |
+
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
| 70 |
+
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
| 71 |
+
|
| 72 |
+
modules = []
|
| 73 |
+
for _ in range(1, mlp_depth):
|
| 74 |
+
modules.append(nn.GELU())
|
| 75 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
| 76 |
+
modules = nn.Sequential(*modules)
|
| 77 |
+
|
| 78 |
+
elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
|
| 79 |
+
mlp_depth = cfg.get("depth", 1)
|
| 80 |
+
channel_div = cfg.get("channel_div", 0.5)
|
| 81 |
+
self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
|
| 82 |
+
self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
|
| 83 |
+
|
| 84 |
+
modules = []
|
| 85 |
+
for _ in range(1, mlp_depth):
|
| 86 |
+
modules.append(nn.GELU())
|
| 87 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
| 88 |
+
modules = nn.Sequential(*modules)
|
| 89 |
+
|
| 90 |
+
elif cfg.projector_type == "low_high_split_mlp_gelu":
|
| 91 |
+
mlp_depth = cfg.get("depth", 1)
|
| 92 |
+
modules = []
|
| 93 |
+
for _ in range(1, mlp_depth):
|
| 94 |
+
modules.append(nn.GELU())
|
| 95 |
+
modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
|
| 96 |
+
modules = nn.Sequential(*modules)
|
| 97 |
+
self.high_layers = nn.Sequential(*modules)
|
| 98 |
+
self.low_layers = copy.deepcopy(modules)
|
| 99 |
+
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
| 102 |
+
|
| 103 |
+
if cfg.get("token_pooling", False):
|
| 104 |
+
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
|
| 105 |
+
|
| 106 |
+
if cfg.get("conv_fusion_high_low_features", False):
|
| 107 |
+
self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
|
| 108 |
+
self.layers = modules
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
if self.cfg.get("token_pooling", False):
|
| 112 |
+
batch_size, wxh, channels = x.shape
|
| 113 |
+
w = h = int(wxh**0.5)
|
| 114 |
+
x = x.view(batch_size, w, h, channels)
|
| 115 |
+
x = x.permute(0, 3, 1, 2)
|
| 116 |
+
# import ipdb; ipdb.set_trace()
|
| 117 |
+
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
| 118 |
+
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
| 119 |
+
# 在通道维度上拼接
|
| 120 |
+
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
|
| 121 |
+
|
| 122 |
+
# 通过线性层
|
| 123 |
+
patches = patches.permute(0, 2, 1, 3).contiguous()
|
| 124 |
+
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
| 125 |
+
|
| 126 |
+
x = self.token_pooling_layer(patches)
|
| 127 |
+
|
| 128 |
+
if self.cfg.get("conv_fusion_high_low_features", False):
|
| 129 |
+
x = self.fusion_layer(x[:, 0]) + x[:, 1]
|
| 130 |
+
|
| 131 |
+
if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
|
| 132 |
+
high_x, low_x = x[0], x[1]
|
| 133 |
+
high_x = self.high_up_proj(high_x)
|
| 134 |
+
low_x = self.low_up_proj(low_x)
|
| 135 |
+
x = torch.concat([high_x, low_x], dim=-1)
|
| 136 |
+
|
| 137 |
+
if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
|
| 138 |
+
high_x = x[...,:self.cfg.input_dim[0]]
|
| 139 |
+
low_x = x[...,self.cfg.input_dim[0]:]
|
| 140 |
+
high_x = self.high_up_proj(high_x)
|
| 141 |
+
low_x = self.low_up_proj(low_x)
|
| 142 |
+
x = torch.concat([high_x, low_x], dim=-1)
|
| 143 |
+
|
| 144 |
+
if self.cfg.projector_type == 'low_high_split_mlp_gelu':
|
| 145 |
+
high_x, low_x = x[0], x[1]
|
| 146 |
+
high_x = self.high_layers(high_x)
|
| 147 |
+
low_x = self.low_layers(low_x)
|
| 148 |
+
x = torch.concat([high_x, low_x], dim=-1)
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
|
| 152 |
+
bs, hw, input_dim = x.shape
|
| 153 |
+
h = w = int((hw) ** 0.5)
|
| 154 |
+
|
| 155 |
+
"""compute padding"""
|
| 156 |
+
if h % self.cfg.downsample_ratio:
|
| 157 |
+
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
| 158 |
+
else:
|
| 159 |
+
pad = 0
|
| 160 |
+
x = x.reshape(bs, h, w, input_dim)
|
| 161 |
+
if pad > 0:
|
| 162 |
+
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
| 163 |
+
|
| 164 |
+
"""4 to 1 concat"""
|
| 165 |
+
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
| 166 |
+
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
|
| 167 |
+
x = x.permute(0, 2, 1)
|
| 168 |
+
|
| 169 |
+
return self.layers(x)
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def get_flops_per_sample(cfg):
|
| 173 |
+
if cfg.projector_type == "linear":
|
| 174 |
+
fwd = 2 * cfg.input_dim * cfg.n_embed
|
| 175 |
+
|
| 176 |
+
elif "mlp_gelu" in cfg.projector_type :
|
| 177 |
+
mlp_depth = cfg.get("depth", 1)
|
| 178 |
+
downsample_ratio = cfg.get("downsample_ratio", 1)
|
| 179 |
+
input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
|
| 180 |
+
input_dim = input_dim * downsample_ratio * downsample_ratio
|
| 181 |
+
fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
|
| 182 |
+
else:
|
| 183 |
+
fwd = 0
|
| 184 |
+
|
| 185 |
+
return fwd * 3
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
#===================clip============================================================
|
| 189 |
+
|
| 190 |
+
class LayerNormfp32(torch.nn.LayerNorm):
|
| 191 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 192 |
+
|
| 193 |
+
def forward(self, x: torch.Tensor):
|
| 194 |
+
orig_type = x.dtype
|
| 195 |
+
ret = super().forward(x.type(torch.float32))
|
| 196 |
+
return ret.type(orig_type)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_abs_pos(abs_pos, tgt_size):
|
| 200 |
+
# abs_pos: L, C
|
| 201 |
+
# tgt_size: M
|
| 202 |
+
# return: M, C
|
| 203 |
+
|
| 204 |
+
# print(tgt_size)
|
| 205 |
+
# print(abs_pos.shape)
|
| 206 |
+
# exit()
|
| 207 |
+
dim = abs_pos.size(-1)
|
| 208 |
+
# print(dim)
|
| 209 |
+
abs_pos_new = abs_pos.squeeze(0)
|
| 210 |
+
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
|
| 215 |
+
tgt_size = int(math.sqrt(tgt_size))
|
| 216 |
+
dtype = abs_pos.dtype
|
| 217 |
+
|
| 218 |
+
if src_size != tgt_size:
|
| 219 |
+
old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
|
| 220 |
+
2).contiguous()
|
| 221 |
+
old_pos_embed = old_pos_embed.to(torch.float32)
|
| 222 |
+
new_pos_embed = F.interpolate(
|
| 223 |
+
old_pos_embed,
|
| 224 |
+
size=(tgt_size, tgt_size),
|
| 225 |
+
mode='bicubic',
|
| 226 |
+
antialias=True,
|
| 227 |
+
align_corners=False,
|
| 228 |
+
).to(dtype)
|
| 229 |
+
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
| 230 |
+
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
|
| 231 |
+
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
|
| 232 |
+
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
|
| 233 |
+
return vision_pos_embed
|
| 234 |
+
else:
|
| 235 |
+
return abs_pos
|
| 236 |
+
|
| 237 |
+
@torch.jit.script
|
| 238 |
+
def quick_gelu(x):
|
| 239 |
+
return x * torch.sigmoid(1.702 * x)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class CLIPVisionEmbeddings(nn.Module):
|
| 244 |
+
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.embed_dim = hidden_size
|
| 247 |
+
self.image_size = image_size
|
| 248 |
+
self.patch_size = patch_size
|
| 249 |
+
|
| 250 |
+
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
|
| 251 |
+
|
| 252 |
+
self.patch_embedding = torch.nn.Conv2d(
|
| 253 |
+
in_channels=num_channels,
|
| 254 |
+
out_channels=self.embed_dim,
|
| 255 |
+
kernel_size=self.patch_size,
|
| 256 |
+
stride=self.patch_size,
|
| 257 |
+
bias=False,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 261 |
+
self.num_positions = self.num_patches + 1
|
| 262 |
+
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
|
| 263 |
+
self.register_buffer(
|
| 264 |
+
"position_ids", torch.arange(self.num_positions).expand((1, -1))
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def forward(self, pixel_values, patch_embeds):
|
| 268 |
+
batch_size = pixel_values.shape[0]
|
| 269 |
+
# patch_embeds = self.patch_embedding(
|
| 270 |
+
# pixel_values
|
| 271 |
+
# ) # shape = [*, width, grid, grid]
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if patch_embeds is not None:
|
| 275 |
+
patch_embeds = patch_embeds
|
| 276 |
+
# print(patch_embeds.shape)
|
| 277 |
+
else:
|
| 278 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
| 279 |
+
# print(111111)
|
| 280 |
+
# shape = [*, width, grid, grid]
|
| 281 |
+
# patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 282 |
+
|
| 283 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
| 287 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 288 |
+
|
| 289 |
+
# x = torch.cat([cls_token, x], dim=1)
|
| 290 |
+
embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
|
| 291 |
+
# embeddings = embeddings + self.position_embedding(self.position_ids)
|
| 292 |
+
return embeddings
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class NoTPFeedForward(nn.Module):
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
cfg,
|
| 299 |
+
dim: int,
|
| 300 |
+
hidden_dim: int,
|
| 301 |
+
):
|
| 302 |
+
super().__init__()
|
| 303 |
+
|
| 304 |
+
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
|
| 305 |
+
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
|
| 306 |
+
|
| 307 |
+
def forward(self, x):
|
| 308 |
+
output = self.fc2(quick_gelu(self.fc1(x)))
|
| 309 |
+
return output
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class NoTPAttention(torch.nn.Module):
|
| 315 |
+
def __init__(self, cfg):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.num_heads = cfg.num_attention_heads
|
| 318 |
+
self.n_local_heads = cfg.num_attention_heads
|
| 319 |
+
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
| 320 |
+
self.max_seq_len = cfg.seq_length
|
| 321 |
+
self.use_flash_attention = cfg.use_flash_attn
|
| 322 |
+
|
| 323 |
+
self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
|
| 324 |
+
self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
|
| 325 |
+
|
| 326 |
+
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
|
| 327 |
+
|
| 328 |
+
self.attn_drop = cfg.attention_dropout
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
x: torch.Tensor,
|
| 333 |
+
):
|
| 334 |
+
bsz, seqlen, _ = x.shape
|
| 335 |
+
xqkv = self.qkv_proj(x)
|
| 336 |
+
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
|
| 337 |
+
|
| 338 |
+
if self.use_flash_attention:
|
| 339 |
+
|
| 340 |
+
xq, xk, xv = torch.split(xqkv, 1, dim=2)
|
| 341 |
+
xq = xq.squeeze(2)
|
| 342 |
+
xk = xk.squeeze(2)
|
| 343 |
+
xv = xv.squeeze(2)
|
| 344 |
+
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
|
| 345 |
+
|
| 346 |
+
# (B, num_head, S, head_size)
|
| 347 |
+
xq = xq.permute(0, 2, 1, 3)
|
| 348 |
+
xk = xk.permute(0, 2, 1, 3)
|
| 349 |
+
xv = xv.permute(0, 2, 1, 3)
|
| 350 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
| 351 |
+
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
|
| 352 |
+
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
|
| 353 |
+
# output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
|
| 354 |
+
else:
|
| 355 |
+
# print(22222)
|
| 356 |
+
xq, xk, xv = torch.split(xqkv, 1, dim=2)
|
| 357 |
+
xq = xq.squeeze(2)
|
| 358 |
+
xk = xk.squeeze(2)
|
| 359 |
+
xv = xv.squeeze(2)
|
| 360 |
+
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
|
| 361 |
+
|
| 362 |
+
# (B, num_head, S, head_size)
|
| 363 |
+
xq = xq.permute(0, 2, 1, 3)
|
| 364 |
+
xk = xk.permute(0, 2, 1, 3)
|
| 365 |
+
xv = xv.permute(0, 2, 1, 3)
|
| 366 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
| 367 |
+
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
|
| 368 |
+
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
|
| 369 |
+
# output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
|
| 370 |
+
output = self.out_proj(output)
|
| 371 |
+
return output
|
| 372 |
+
|
| 373 |
+
class NoTPTransformerBlock(nn.Module):
|
| 374 |
+
def __init__(self, cfg, layer_id: int, multiple_of=256):
|
| 375 |
+
super().__init__()
|
| 376 |
+
|
| 377 |
+
self.n_heads = cfg.num_attention_heads
|
| 378 |
+
self.dim = cfg.hidden_size
|
| 379 |
+
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
| 380 |
+
self.self_attn = NoTPAttention(cfg)
|
| 381 |
+
self.mlp = NoTPFeedForward(
|
| 382 |
+
cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
|
| 383 |
+
)
|
| 384 |
+
self.layer_id = layer_id
|
| 385 |
+
self.layer_norm1 = torch.nn.LayerNorm(
|
| 386 |
+
cfg.hidden_size, eps=cfg.layernorm_epsilon
|
| 387 |
+
)
|
| 388 |
+
self.layer_norm2 = torch.nn.LayerNorm(
|
| 389 |
+
cfg.hidden_size, eps=cfg.layernorm_epsilon
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def forward(self, x: torch.Tensor):
|
| 393 |
+
residual = self.self_attn.forward(self.layer_norm1(x))
|
| 394 |
+
h = x + residual
|
| 395 |
+
out = h + self.mlp.forward(self.layer_norm2(h))
|
| 396 |
+
return out
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class NoTPTransformer(nn.Module):
|
| 400 |
+
def __init__(self, cfg):
|
| 401 |
+
super().__init__()
|
| 402 |
+
|
| 403 |
+
self.cfg = cfg
|
| 404 |
+
# self.recompute_list = self.cfg.get("recompute_list", [])
|
| 405 |
+
self.num_layers = cfg.num_layers # _get_num_layers(cfg)
|
| 406 |
+
|
| 407 |
+
self.layers = torch.nn.ModuleList()
|
| 408 |
+
for layer_id in range(self.num_layers):
|
| 409 |
+
self.layers.append(
|
| 410 |
+
NoTPTransformerBlock(
|
| 411 |
+
cfg,
|
| 412 |
+
layer_id + 1,
|
| 413 |
+
)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
def forward(
|
| 417 |
+
self,
|
| 418 |
+
hidden_states,
|
| 419 |
+
):
|
| 420 |
+
|
| 421 |
+
for lid, layer in enumerate(self.layers):
|
| 422 |
+
# if lid in self.recompute_list:
|
| 423 |
+
# def custom(layer_id):
|
| 424 |
+
# def custom_forward(*args, **kwargs):
|
| 425 |
+
# x_ = self.layers[layer_id](*args, **kwargs)
|
| 426 |
+
# return x_
|
| 427 |
+
|
| 428 |
+
# return custom_forward
|
| 429 |
+
|
| 430 |
+
# assert hidden_states.requires_grad == True, logger.warning(
|
| 431 |
+
# "When using recalculation, the input must have grad fn"
|
| 432 |
+
# )
|
| 433 |
+
# hidden_states = tensor_parallel.checkpoint(
|
| 434 |
+
# custom(lid),
|
| 435 |
+
# False,
|
| 436 |
+
# hidden_states.contiguous()
|
| 437 |
+
# )
|
| 438 |
+
# else:
|
| 439 |
+
hidden_states = layer(hidden_states)
|
| 440 |
+
|
| 441 |
+
return hidden_states
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
# from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
|
| 445 |
+
|
| 446 |
+
class VitModel(nn.Module):
|
| 447 |
+
def __init__(
|
| 448 |
+
self,
|
| 449 |
+
cfg,
|
| 450 |
+
freeze_embed=False,
|
| 451 |
+
freeze_pre_norm=False
|
| 452 |
+
) -> None:
|
| 453 |
+
super().__init__()
|
| 454 |
+
|
| 455 |
+
self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
|
| 456 |
+
|
| 457 |
+
if freeze_embed:
|
| 458 |
+
for name, param in self.embeddings.named_parameters():
|
| 459 |
+
param.requires_grad = False
|
| 460 |
+
|
| 461 |
+
self.transformer = NoTPTransformer(cfg=cfg)
|
| 462 |
+
|
| 463 |
+
if cfg.get("fp32norm", False):
|
| 464 |
+
logger.info("Load fp32 layernorm for ViT.")
|
| 465 |
+
self.pre_layrnorm = LayerNormfp32(
|
| 466 |
+
cfg.hidden_size,
|
| 467 |
+
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
self.pre_layrnorm = torch.nn.LayerNorm(
|
| 471 |
+
cfg.hidden_size,
|
| 472 |
+
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# self.pre_layrnorm = RMSNorm(
|
| 476 |
+
# cfg.hidden_size,
|
| 477 |
+
# eps=cfg.get("pre_layernorm_epsilon", 1e-5),
|
| 478 |
+
# sequence_parallel=False,
|
| 479 |
+
# use_fp32=True,
|
| 480 |
+
# use_optimus=True,
|
| 481 |
+
# )
|
| 482 |
+
|
| 483 |
+
if freeze_pre_norm:
|
| 484 |
+
for name, param in self.pre_layrnorm.named_parameters():
|
| 485 |
+
param.requires_grad = False
|
| 486 |
+
|
| 487 |
+
for p in self.parameters():
|
| 488 |
+
p.micro_dp = True
|
| 489 |
+
|
| 490 |
+
def set_input_tensor(self, input_tensor):
|
| 491 |
+
if not isinstance(input_tensor, list):
|
| 492 |
+
input_tensor = [input_tensor]
|
| 493 |
+
self.transformer.set_input_tensor(input_tensor[0])
|
| 494 |
+
|
| 495 |
+
def __str__(self) -> str:
|
| 496 |
+
return "open_clip"
|
| 497 |
+
|
| 498 |
+
def forward(
|
| 499 |
+
self,
|
| 500 |
+
x,
|
| 501 |
+
patch_embeds
|
| 502 |
+
):
|
| 503 |
+
x = self.embeddings(x, patch_embeds)
|
| 504 |
+
hidden_states = self.pre_layrnorm(x)
|
| 505 |
+
|
| 506 |
+
# hidden_states, dis = local_dp_scatter(hidden_states)
|
| 507 |
+
output = self.transformer(hidden_states)
|
| 508 |
+
|
| 509 |
+
# output = local_dp_reduce(output, dis)
|
| 510 |
+
|
| 511 |
+
return output
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
vit_model_cfg = adict(
|
| 515 |
+
num_layers=24,
|
| 516 |
+
hidden_size=1024,
|
| 517 |
+
num_heads = 16,
|
| 518 |
+
num_attention_heads=16,
|
| 519 |
+
ffn_hidden_size=4096,
|
| 520 |
+
seq_length=256,
|
| 521 |
+
max_position_embeddings=256,
|
| 522 |
+
use_flash_attn=False,
|
| 523 |
+
understand_projector_stride=2,
|
| 524 |
+
hidden_dropout = 0.0,
|
| 525 |
+
attention_dropout = 0.0,
|
| 526 |
+
no_persist_layer_norm = False,
|
| 527 |
+
layernorm_epsilon = 1e-5,
|
| 528 |
+
pre_layernorm_epsilon = 1e-5,
|
| 529 |
+
image_size = 224,
|
| 530 |
+
patch_size = 14,
|
| 531 |
+
recompute_list = []
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
def build_clip_l():
|
| 535 |
+
return VitModel(
|
| 536 |
+
cfg=vit_model_cfg,
|
| 537 |
+
freeze_embed=False,
|
| 538 |
+
freeze_pre_norm=False,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
#=========================Sam-Vary=================================
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def get_abs_pos_sam(abs_pos, tgt_size):
|
| 549 |
+
|
| 550 |
+
dtype = abs_pos.dtype
|
| 551 |
+
|
| 552 |
+
src_size = abs_pos.size(1)
|
| 553 |
+
|
| 554 |
+
if src_size != tgt_size:
|
| 555 |
+
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
|
| 556 |
+
old_pos_embed = old_pos_embed.to(torch.float32)
|
| 557 |
+
new_pos_embed = F.interpolate(
|
| 558 |
+
old_pos_embed,
|
| 559 |
+
size=(tgt_size, tgt_size),
|
| 560 |
+
mode='bicubic',
|
| 561 |
+
antialias=True,
|
| 562 |
+
align_corners=False,
|
| 563 |
+
).to(dtype)
|
| 564 |
+
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
| 565 |
+
return new_pos_embed
|
| 566 |
+
else:
|
| 567 |
+
return abs_pos
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class MLPBlock(nn.Module):
|
| 573 |
+
def __init__(
|
| 574 |
+
self,
|
| 575 |
+
embedding_dim: int,
|
| 576 |
+
mlp_dim: int,
|
| 577 |
+
act: Type[nn.Module] = nn.GELU,
|
| 578 |
+
) -> None:
|
| 579 |
+
super().__init__()
|
| 580 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
| 581 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
| 582 |
+
self.act = act()
|
| 583 |
+
|
| 584 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 585 |
+
return self.lin2(self.act(self.lin1(x)))
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
| 589 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
| 590 |
+
class LayerNorm2d(nn.Module):
|
| 591 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 592 |
+
super().__init__()
|
| 593 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 594 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 595 |
+
self.eps = eps
|
| 596 |
+
|
| 597 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 598 |
+
u = x.mean(1, keepdim=True)
|
| 599 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 600 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 601 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 602 |
+
return x
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
| 606 |
+
class ImageEncoderViT(nn.Module):
|
| 607 |
+
def __init__(
|
| 608 |
+
self,
|
| 609 |
+
img_size: int = 1024,
|
| 610 |
+
patch_size: int = 16,
|
| 611 |
+
in_chans: int = 3,
|
| 612 |
+
embed_dim: int = 768,
|
| 613 |
+
depth: int = 12,
|
| 614 |
+
num_heads: int = 12,
|
| 615 |
+
mlp_ratio: float = 4.0,
|
| 616 |
+
out_chans: int = 256,
|
| 617 |
+
qkv_bias: bool = True,
|
| 618 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 619 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 620 |
+
use_abs_pos: bool = True,
|
| 621 |
+
use_rel_pos: bool = False,
|
| 622 |
+
rel_pos_zero_init: bool = True,
|
| 623 |
+
window_size: int = 0,
|
| 624 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
| 625 |
+
) -> None:
|
| 626 |
+
"""
|
| 627 |
+
Args:
|
| 628 |
+
img_size (int): Input image size.
|
| 629 |
+
patch_size (int): Patch size.
|
| 630 |
+
in_chans (int): Number of input image channels.
|
| 631 |
+
embed_dim (int): Patch embedding dimension.
|
| 632 |
+
depth (int): Depth of ViT.
|
| 633 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 634 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 635 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 636 |
+
norm_layer (nn.Module): Normalization layer.
|
| 637 |
+
act_layer (nn.Module): Activation layer.
|
| 638 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
| 639 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 640 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 641 |
+
window_size (int): Window size for window attention blocks.
|
| 642 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
| 643 |
+
"""
|
| 644 |
+
super().__init__()
|
| 645 |
+
self.img_size = img_size
|
| 646 |
+
|
| 647 |
+
self.patch_embed = PatchEmbed(
|
| 648 |
+
kernel_size=(patch_size, patch_size),
|
| 649 |
+
stride=(patch_size, patch_size),
|
| 650 |
+
in_chans=in_chans,
|
| 651 |
+
embed_dim=embed_dim,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
| 655 |
+
if use_abs_pos:
|
| 656 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 657 |
+
self.pos_embed = nn.Parameter(
|
| 658 |
+
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
self.blocks = nn.ModuleList()
|
| 662 |
+
for i in range(depth):
|
| 663 |
+
block = Block(
|
| 664 |
+
dim=embed_dim,
|
| 665 |
+
num_heads=num_heads,
|
| 666 |
+
mlp_ratio=mlp_ratio,
|
| 667 |
+
qkv_bias=qkv_bias,
|
| 668 |
+
norm_layer=norm_layer,
|
| 669 |
+
act_layer=act_layer,
|
| 670 |
+
use_rel_pos=use_rel_pos,
|
| 671 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 672 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
| 673 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
| 674 |
+
)
|
| 675 |
+
self.blocks.append(block)
|
| 676 |
+
|
| 677 |
+
self.neck = nn.Sequential(
|
| 678 |
+
nn.Conv2d(
|
| 679 |
+
embed_dim,
|
| 680 |
+
out_chans,
|
| 681 |
+
kernel_size=1,
|
| 682 |
+
bias=False,
|
| 683 |
+
),
|
| 684 |
+
LayerNorm2d(out_chans),
|
| 685 |
+
nn.Conv2d(
|
| 686 |
+
out_chans,
|
| 687 |
+
out_chans,
|
| 688 |
+
kernel_size=3,
|
| 689 |
+
padding=1,
|
| 690 |
+
bias=False,
|
| 691 |
+
),
|
| 692 |
+
LayerNorm2d(out_chans),
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
|
| 696 |
+
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
|
| 697 |
+
|
| 698 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 699 |
+
x = self.patch_embed(x)
|
| 700 |
+
if self.pos_embed is not None:
|
| 701 |
+
# x = x + self.pos_embed
|
| 702 |
+
x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
|
| 703 |
+
|
| 704 |
+
for blk in self.blocks:
|
| 705 |
+
x = blk(x)
|
| 706 |
+
|
| 707 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
| 708 |
+
x2 = self.net_2(x)
|
| 709 |
+
x3 = self.net_3(x2.clone())
|
| 710 |
+
|
| 711 |
+
return x3
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class Block(nn.Module):
|
| 715 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
| 716 |
+
|
| 717 |
+
def __init__(
|
| 718 |
+
self,
|
| 719 |
+
dim: int,
|
| 720 |
+
num_heads: int,
|
| 721 |
+
mlp_ratio: float = 4.0,
|
| 722 |
+
qkv_bias: bool = True,
|
| 723 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 724 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 725 |
+
use_rel_pos: bool = False,
|
| 726 |
+
rel_pos_zero_init: bool = True,
|
| 727 |
+
window_size: int = 0,
|
| 728 |
+
input_size: Optional[Tuple[int, int]] = None,
|
| 729 |
+
) -> None:
|
| 730 |
+
"""
|
| 731 |
+
Args:
|
| 732 |
+
dim (int): Number of input channels.
|
| 733 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 734 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 735 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 736 |
+
norm_layer (nn.Module): Normalization layer.
|
| 737 |
+
act_layer (nn.Module): Activation layer.
|
| 738 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 739 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 740 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
| 741 |
+
use global attention.
|
| 742 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
| 743 |
+
positional parameter size.
|
| 744 |
+
"""
|
| 745 |
+
super().__init__()
|
| 746 |
+
self.norm1 = norm_layer(dim)
|
| 747 |
+
self.attn = Attention(
|
| 748 |
+
dim,
|
| 749 |
+
num_heads=num_heads,
|
| 750 |
+
qkv_bias=qkv_bias,
|
| 751 |
+
use_rel_pos=use_rel_pos,
|
| 752 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 753 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
self.norm2 = norm_layer(dim)
|
| 757 |
+
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
| 758 |
+
|
| 759 |
+
self.window_size = window_size
|
| 760 |
+
|
| 761 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 762 |
+
shortcut = x
|
| 763 |
+
x = self.norm1(x)
|
| 764 |
+
# Window partition
|
| 765 |
+
if self.window_size > 0:
|
| 766 |
+
H, W = x.shape[1], x.shape[2]
|
| 767 |
+
x, pad_hw = window_partition(x, self.window_size)
|
| 768 |
+
|
| 769 |
+
x = self.attn(x)
|
| 770 |
+
# Reverse window partition
|
| 771 |
+
if self.window_size > 0:
|
| 772 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
| 773 |
+
|
| 774 |
+
x = shortcut + x
|
| 775 |
+
x = x + self.mlp(self.norm2(x))
|
| 776 |
+
|
| 777 |
+
return x
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
class Attention(nn.Module):
|
| 781 |
+
"""Multi-head Attention block with relative position embeddings."""
|
| 782 |
+
|
| 783 |
+
def __init__(
|
| 784 |
+
self,
|
| 785 |
+
dim: int,
|
| 786 |
+
num_heads: int = 8,
|
| 787 |
+
qkv_bias: bool = True,
|
| 788 |
+
use_rel_pos: bool = False,
|
| 789 |
+
rel_pos_zero_init: bool = True,
|
| 790 |
+
input_size: Optional[Tuple[int, int]] = None,
|
| 791 |
+
) -> None:
|
| 792 |
+
"""
|
| 793 |
+
Args:
|
| 794 |
+
dim (int): Number of input channels.
|
| 795 |
+
num_heads (int): Number of attention heads.
|
| 796 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 797 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 798 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 799 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
| 800 |
+
positional parameter size.
|
| 801 |
+
"""
|
| 802 |
+
super().__init__()
|
| 803 |
+
self.num_heads = num_heads
|
| 804 |
+
head_dim = dim // num_heads
|
| 805 |
+
self.scale = head_dim**-0.5
|
| 806 |
+
|
| 807 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 808 |
+
self.proj = nn.Linear(dim, dim)
|
| 809 |
+
|
| 810 |
+
self.use_rel_pos = use_rel_pos
|
| 811 |
+
if self.use_rel_pos:
|
| 812 |
+
assert (
|
| 813 |
+
input_size is not None
|
| 814 |
+
), "Input size must be provided if using relative positional encoding."
|
| 815 |
+
# initialize relative positional embeddings
|
| 816 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
| 817 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
| 818 |
+
|
| 819 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 820 |
+
B, H, W, _ = x.shape
|
| 821 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
| 822 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 823 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
| 824 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
| 825 |
+
|
| 826 |
+
rel_h, rel_w = None, None
|
| 827 |
+
if self.use_rel_pos:
|
| 828 |
+
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
| 829 |
+
|
| 830 |
+
q = q.view(B, self.num_heads, H * W, -1)
|
| 831 |
+
k = k.view(B, self.num_heads, H * W, -1)
|
| 832 |
+
v = v.view(B, self.num_heads, H * W, -1)
|
| 833 |
+
|
| 834 |
+
if self.use_rel_pos:
|
| 835 |
+
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
|
| 836 |
+
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
|
| 837 |
+
attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
|
| 838 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
| 839 |
+
# x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
|
| 840 |
+
else:
|
| 841 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 842 |
+
|
| 843 |
+
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
| 844 |
+
|
| 845 |
+
x = self.proj(x)
|
| 846 |
+
|
| 847 |
+
return x
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
| 851 |
+
"""
|
| 852 |
+
Partition into non-overlapping windows with padding if needed.
|
| 853 |
+
Args:
|
| 854 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 855 |
+
window_size (int): window size.
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 859 |
+
(Hp, Wp): padded height and width before partition
|
| 860 |
+
"""
|
| 861 |
+
B, H, W, C = x.shape
|
| 862 |
+
|
| 863 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 864 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 865 |
+
if pad_h > 0 or pad_w > 0:
|
| 866 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 867 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 868 |
+
|
| 869 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 870 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 871 |
+
return windows, (Hp, Wp)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def window_unpartition(
|
| 875 |
+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
| 876 |
+
) -> torch.Tensor:
|
| 877 |
+
"""
|
| 878 |
+
Window unpartition into original sequences and removing padding.
|
| 879 |
+
Args:
|
| 880 |
+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 881 |
+
window_size (int): window size.
|
| 882 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 883 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 884 |
+
|
| 885 |
+
Returns:
|
| 886 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 887 |
+
"""
|
| 888 |
+
Hp, Wp = pad_hw
|
| 889 |
+
H, W = hw
|
| 890 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 891 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
| 892 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
| 893 |
+
|
| 894 |
+
if Hp > H or Wp > W:
|
| 895 |
+
x = x[:, :H, :W, :].contiguous()
|
| 896 |
+
return x
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
| 900 |
+
"""
|
| 901 |
+
Get relative positional embeddings according to the relative positions of
|
| 902 |
+
query and key sizes.
|
| 903 |
+
Args:
|
| 904 |
+
q_size (int): size of query q.
|
| 905 |
+
k_size (int): size of key k.
|
| 906 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
| 907 |
+
|
| 908 |
+
Returns:
|
| 909 |
+
Extracted positional embeddings according to relative positions.
|
| 910 |
+
"""
|
| 911 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 912 |
+
# Interpolate rel pos if needed.
|
| 913 |
+
if rel_pos.shape[0] != max_rel_dist:
|
| 914 |
+
# Interpolate rel pos.
|
| 915 |
+
dtype = rel_pos.dtype
|
| 916 |
+
rel_pos = rel_pos.to(torch.float32)
|
| 917 |
+
rel_pos_resized = F.interpolate(
|
| 918 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 919 |
+
size=max_rel_dist,
|
| 920 |
+
mode="linear",
|
| 921 |
+
).to(dtype)
|
| 922 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 923 |
+
else:
|
| 924 |
+
rel_pos_resized = rel_pos
|
| 925 |
+
|
| 926 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 927 |
+
q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
|
| 928 |
+
k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
|
| 929 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 930 |
+
|
| 931 |
+
return rel_pos_resized[relative_coords.long()]
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def add_decomposed_rel_pos(
|
| 935 |
+
q: torch.Tensor,
|
| 936 |
+
rel_pos_h: torch.Tensor,
|
| 937 |
+
rel_pos_w: torch.Tensor,
|
| 938 |
+
q_size: Tuple[int, int],
|
| 939 |
+
k_size: Tuple[int, int],
|
| 940 |
+
) -> torch.Tensor:
|
| 941 |
+
"""
|
| 942 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| 943 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
| 944 |
+
Args:
|
| 945 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
| 946 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
| 947 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
| 948 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
| 949 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
| 950 |
+
|
| 951 |
+
Returns:
|
| 952 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
| 953 |
+
"""
|
| 954 |
+
q_h, q_w = q_size
|
| 955 |
+
k_h, k_w = k_size
|
| 956 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
| 957 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
| 958 |
+
|
| 959 |
+
B, _, dim = q.shape
|
| 960 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
| 961 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
| 962 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
| 963 |
+
rel_h = rel_h.unsqueeze(-1)
|
| 964 |
+
rel_w = rel_w.unsqueeze(-2)
|
| 965 |
+
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
|
| 966 |
+
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
|
| 967 |
+
|
| 968 |
+
return rel_h, rel_w
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class PatchEmbed(nn.Module):
|
| 972 |
+
"""
|
| 973 |
+
Image to Patch Embedding.
|
| 974 |
+
"""
|
| 975 |
+
|
| 976 |
+
def __init__(
|
| 977 |
+
self,
|
| 978 |
+
kernel_size: Tuple[int, int] = (16, 16),
|
| 979 |
+
stride: Tuple[int, int] = (16, 16),
|
| 980 |
+
padding: Tuple[int, int] = (0, 0),
|
| 981 |
+
in_chans: int = 3,
|
| 982 |
+
embed_dim: int = 768,
|
| 983 |
+
) -> None:
|
| 984 |
+
"""
|
| 985 |
+
Args:
|
| 986 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 987 |
+
stride (Tuple): stride of the projection layer.
|
| 988 |
+
padding (Tuple): padding size of the projection layer.
|
| 989 |
+
in_chans (int): Number of input image channels.
|
| 990 |
+
embed_dim (int): Patch embedding dimension.
|
| 991 |
+
"""
|
| 992 |
+
super().__init__()
|
| 993 |
+
|
| 994 |
+
self.proj = nn.Conv2d(
|
| 995 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 999 |
+
x = self.proj(x)
|
| 1000 |
+
# B C H W -> B H W C
|
| 1001 |
+
x = x.permute(0, 2, 3, 1)
|
| 1002 |
+
return x
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def build_sam_vit_b(checkpoint=None):
|
| 1006 |
+
return _build_sam(
|
| 1007 |
+
encoder_embed_dim=768,
|
| 1008 |
+
encoder_depth=12,
|
| 1009 |
+
encoder_num_heads=12,
|
| 1010 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
| 1011 |
+
checkpoint=checkpoint,
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
|
| 1015 |
+
image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
|
| 1016 |
+
# sam = _apply_eval_dtype_sam(sam, dtype)
|
| 1017 |
+
image_encoder = torch.compile(image_encoder, mode=compile_mode)
|
| 1018 |
+
return image_encoder
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def _build_sam(
|
| 1022 |
+
encoder_embed_dim,
|
| 1023 |
+
encoder_depth,
|
| 1024 |
+
encoder_num_heads,
|
| 1025 |
+
encoder_global_attn_indexes,
|
| 1026 |
+
checkpoint=None,
|
| 1027 |
+
):
|
| 1028 |
+
prompt_embed_dim = 256
|
| 1029 |
+
image_size = 1024
|
| 1030 |
+
vit_patch_size = 16
|
| 1031 |
+
image_embedding_size = image_size // vit_patch_size
|
| 1032 |
+
image_encoder=ImageEncoderViT(
|
| 1033 |
+
depth=encoder_depth,
|
| 1034 |
+
embed_dim=encoder_embed_dim,
|
| 1035 |
+
img_size=image_size,
|
| 1036 |
+
mlp_ratio=4,
|
| 1037 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
| 1038 |
+
num_heads=encoder_num_heads,
|
| 1039 |
+
patch_size=vit_patch_size,
|
| 1040 |
+
qkv_bias=True,
|
| 1041 |
+
use_rel_pos=True,
|
| 1042 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
| 1043 |
+
window_size=14,
|
| 1044 |
+
out_chans=prompt_embed_dim,
|
| 1045 |
+
)
|
| 1046 |
+
image_encoder.eval()
|
| 1047 |
+
if checkpoint is not None:
|
| 1048 |
+
# with open(checkpoint, "rb") as f:
|
| 1049 |
+
state_dict = torch.load(checkpoint)
|
| 1050 |
+
# print(state_dict.keys())
|
| 1051 |
+
# for key in state_dict:
|
| 1052 |
+
# image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
|
| 1053 |
+
# ocr-anyting
|
| 1054 |
+
# image_encoder.load_state_dict(state_dict, strict=True)
|
| 1055 |
+
# tob
|
| 1056 |
+
image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
|
| 1057 |
+
print(checkpoint)
|
| 1058 |
+
return image_encoder
|
model-00001-of-000001.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1169e7cdc28ff2fb6186556acb2175db148ad26a62097df4c45a17e523180d3f
|
| 3 |
+
size 6672547120
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_deepseekocr.py
ADDED
|
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import re
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from abc import ABC
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from addict import Dict
|
| 9 |
+
from PIL import Image, ImageOps, ImageDraw, ImageFont
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.nn import CrossEntropyLoss
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
|
| 17 |
+
from transformers.cache_utils import Cache
|
| 18 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 19 |
+
from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
|
| 20 |
+
from transformers import DeepseekV2Config
|
| 21 |
+
from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
|
| 22 |
+
DeepseekV2Attention, DeepseekV2MLP, DeepseekV2MoE, DeepseekV2RMSNorm, DeepseekV2DecoderLayer)
|
| 23 |
+
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
|
| 24 |
+
from transformers import TextStreamer
|
| 25 |
+
from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector
|
| 26 |
+
from .conversation import get_conv_template
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_image(image_path):
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
image = Image.open(image_path)
|
| 33 |
+
|
| 34 |
+
corrected_image = ImageOps.exif_transpose(image)
|
| 35 |
+
|
| 36 |
+
return corrected_image
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"error: {e}")
|
| 40 |
+
try:
|
| 41 |
+
return Image.open(image_path)
|
| 42 |
+
except:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def re_match(text):
|
| 47 |
+
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
|
| 48 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 49 |
+
|
| 50 |
+
# pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
|
| 51 |
+
# new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL)
|
| 52 |
+
|
| 53 |
+
mathes_image = []
|
| 54 |
+
mathes_other = []
|
| 55 |
+
for a_match in matches:
|
| 56 |
+
if '<|ref|>image<|/ref|>' in a_match[0]:
|
| 57 |
+
mathes_image.append(a_match[0])
|
| 58 |
+
else:
|
| 59 |
+
mathes_other.append(a_match[0])
|
| 60 |
+
return matches, mathes_image, mathes_other
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def extract_coordinates_and_label(ref_text, image_width, image_height):
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
label_type = ref_text[1]
|
| 67 |
+
cor_list = eval(ref_text[2])
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(e)
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
return (label_type, cor_list)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def draw_bounding_boxes(image, refs, ouput_path):
|
| 76 |
+
|
| 77 |
+
image_width, image_height = image.size
|
| 78 |
+
|
| 79 |
+
img_draw = image.copy()
|
| 80 |
+
draw = ImageDraw.Draw(img_draw)
|
| 81 |
+
|
| 82 |
+
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
|
| 83 |
+
draw2 = ImageDraw.Draw(overlay)
|
| 84 |
+
|
| 85 |
+
# try:
|
| 86 |
+
# except IOError:
|
| 87 |
+
# try:
|
| 88 |
+
# font = ImageFont.truetype("DejaVuSans.ttf", 20)
|
| 89 |
+
# except IOError:
|
| 90 |
+
font = ImageFont.load_default()
|
| 91 |
+
|
| 92 |
+
img_idx = 0
|
| 93 |
+
|
| 94 |
+
for i, ref in enumerate(refs):
|
| 95 |
+
try:
|
| 96 |
+
result = extract_coordinates_and_label(ref, image_width, image_height)
|
| 97 |
+
if result:
|
| 98 |
+
label_type, points_list = result
|
| 99 |
+
|
| 100 |
+
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
|
| 101 |
+
|
| 102 |
+
color_a = color + (20, )
|
| 103 |
+
for points in points_list:
|
| 104 |
+
x1, y1, x2, y2 = points
|
| 105 |
+
|
| 106 |
+
x1 = int(x1 / 999 * image_width)
|
| 107 |
+
y1 = int(y1 / 999 * image_height)
|
| 108 |
+
|
| 109 |
+
x2 = int(x2 / 999 * image_width)
|
| 110 |
+
y2 = int(y2 / 999 * image_height)
|
| 111 |
+
|
| 112 |
+
if label_type == 'image':
|
| 113 |
+
try:
|
| 114 |
+
cropped = image.crop((x1, y1, x2, y2))
|
| 115 |
+
cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(e)
|
| 118 |
+
pass
|
| 119 |
+
img_idx += 1
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
if label_type == 'title':
|
| 123 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
| 124 |
+
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
|
| 125 |
+
else:
|
| 126 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
| 127 |
+
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
|
| 128 |
+
text_x = x1
|
| 129 |
+
text_y = max(0, y1 - 15)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
text_bbox = draw.textbbox((0, 0), label_type, font=font)
|
| 133 |
+
text_width = text_bbox[2] - text_bbox[0]
|
| 134 |
+
text_height = text_bbox[3] - text_bbox[1]
|
| 135 |
+
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
|
| 136 |
+
fill=(255, 255, 255, 30))
|
| 137 |
+
|
| 138 |
+
draw.text((text_x, text_y), label_type, font=font, fill=color)
|
| 139 |
+
except:
|
| 140 |
+
pass
|
| 141 |
+
except:
|
| 142 |
+
continue
|
| 143 |
+
img_draw.paste(overlay, (0, 0), overlay)
|
| 144 |
+
return img_draw
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def process_image_with_refs(image, ref_texts, output_path):
|
| 148 |
+
|
| 149 |
+
result_image = draw_bounding_boxes(image, ref_texts, output_path)
|
| 150 |
+
|
| 151 |
+
return result_image
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 158 |
+
best_ratio_diff = float('inf')
|
| 159 |
+
best_ratio = (1, 1)
|
| 160 |
+
area = width * height
|
| 161 |
+
for ratio in target_ratios:
|
| 162 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 163 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 164 |
+
if ratio_diff < best_ratio_diff:
|
| 165 |
+
best_ratio_diff = ratio_diff
|
| 166 |
+
best_ratio = ratio
|
| 167 |
+
elif ratio_diff == best_ratio_diff:
|
| 168 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 169 |
+
best_ratio = ratio
|
| 170 |
+
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
|
| 171 |
+
return best_ratio
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False):
|
| 175 |
+
orig_width, orig_height = image.size
|
| 176 |
+
aspect_ratio = orig_width / orig_height
|
| 177 |
+
|
| 178 |
+
# calculate the existing image aspect ratio
|
| 179 |
+
target_ratios = set(
|
| 180 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
| 181 |
+
i * j <= max_num and i * j >= min_num)
|
| 182 |
+
# print(target_ratios)
|
| 183 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 184 |
+
|
| 185 |
+
# find the closest aspect ratio to the target
|
| 186 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 187 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 188 |
+
|
| 189 |
+
# print(target_aspect_ratio)
|
| 190 |
+
# calculate the target width and height
|
| 191 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 192 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 193 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 194 |
+
|
| 195 |
+
# resize the image
|
| 196 |
+
resized_img = image.resize((target_width, target_height))
|
| 197 |
+
processed_images = []
|
| 198 |
+
for i in range(blocks):
|
| 199 |
+
box = (
|
| 200 |
+
(i % (target_width // image_size)) * image_size,
|
| 201 |
+
(i // (target_width // image_size)) * image_size,
|
| 202 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 203 |
+
((i // (target_width // image_size)) + 1) * image_size
|
| 204 |
+
)
|
| 205 |
+
# split the image
|
| 206 |
+
split_img = resized_img.crop(box)
|
| 207 |
+
processed_images.append(split_img)
|
| 208 |
+
assert len(processed_images) == blocks
|
| 209 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 210 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 211 |
+
processed_images.append(thumbnail_img)
|
| 212 |
+
return processed_images, target_aspect_ratio
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def normalize_transform(mean, std):
|
| 217 |
+
if mean is None and std is None:
|
| 218 |
+
transform = None
|
| 219 |
+
elif mean is None and std is not None:
|
| 220 |
+
mean = [0.] * len(std)
|
| 221 |
+
transform = transforms.Normalize(mean=mean, std=std)
|
| 222 |
+
elif mean is not None and std is None:
|
| 223 |
+
std = [1.] * len(mean)
|
| 224 |
+
transform = transforms.Normalize(mean=mean, std=std)
|
| 225 |
+
else:
|
| 226 |
+
transform = transforms.Normalize(mean=mean, std=std)
|
| 227 |
+
|
| 228 |
+
return transform
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def format_messages(
|
| 233 |
+
conversations: List[Dict[str, str]],
|
| 234 |
+
sft_format: str = "deepseek",
|
| 235 |
+
system_prompt: str = "",
|
| 236 |
+
):
|
| 237 |
+
"""
|
| 238 |
+
Applies the SFT template to conversation.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
conversations (List[Dict]): A List of messages.
|
| 242 |
+
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
|
| 243 |
+
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
sft_prompt (str): The formatted text.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
conv = get_conv_template(sft_format)
|
| 250 |
+
conv.set_system_message(system_prompt)
|
| 251 |
+
for message in conversations:
|
| 252 |
+
conv.append_message(message["role"], message["content"].strip())
|
| 253 |
+
sft_prompt = conv.get_prompt().strip()
|
| 254 |
+
|
| 255 |
+
return sft_prompt
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
|
| 259 |
+
t = tokenizer.encode(text, add_special_tokens=False)
|
| 260 |
+
bos_id = 0
|
| 261 |
+
eos_id = 1
|
| 262 |
+
if bos:
|
| 263 |
+
t = [bos_id] + t
|
| 264 |
+
if eos:
|
| 265 |
+
t = t + [eos_id]
|
| 266 |
+
|
| 267 |
+
return t
|
| 268 |
+
|
| 269 |
+
def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
|
| 274 |
+
[
|
| 275 |
+
{
|
| 276 |
+
"role": "User",
|
| 277 |
+
"content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
|
| 278 |
+
"images": ["./examples/table_datasets.png"]
|
| 279 |
+
},
|
| 280 |
+
{"role": "Assistant", "content": ""},
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
pil_images (List[PIL.Image.Image]): the list of PIL images.
|
| 285 |
+
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
pil_images = []
|
| 289 |
+
|
| 290 |
+
for message in conversations:
|
| 291 |
+
if "images" not in message:
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
for image_path in message["images"]:
|
| 295 |
+
# print('----------------')
|
| 296 |
+
# print(image_path)
|
| 297 |
+
# print('----------------')
|
| 298 |
+
# exit()
|
| 299 |
+
|
| 300 |
+
# pil_img = Image.open(image_path)
|
| 301 |
+
pil_img = load_image(image_path)
|
| 302 |
+
pil_img = pil_img.convert("RGB")
|
| 303 |
+
pil_images.append(pil_img)
|
| 304 |
+
|
| 305 |
+
return pil_images
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class BaseTransform(ABC):
|
| 309 |
+
|
| 310 |
+
def set_rng(self, *args, **kwargs):
|
| 311 |
+
pass
|
| 312 |
+
|
| 313 |
+
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
| 314 |
+
pass
|
| 315 |
+
|
| 316 |
+
@property
|
| 317 |
+
def default_shape(self):
|
| 318 |
+
raise NotImplementedError
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class BasicImageTransform(BaseTransform):
|
| 322 |
+
def __init__(
|
| 323 |
+
self,
|
| 324 |
+
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
| 325 |
+
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
| 326 |
+
normalize: bool = True
|
| 327 |
+
):
|
| 328 |
+
self.mean = mean
|
| 329 |
+
self.std = std
|
| 330 |
+
|
| 331 |
+
transform_pipelines = [
|
| 332 |
+
transforms.ToTensor()
|
| 333 |
+
]
|
| 334 |
+
|
| 335 |
+
normalize = normalize_transform(mean, std) if normalize else nn.Identity()
|
| 336 |
+
if normalize is not None:
|
| 337 |
+
transform_pipelines.append(normalize)
|
| 338 |
+
|
| 339 |
+
self.transform = transforms.Compose(transform_pipelines)
|
| 340 |
+
|
| 341 |
+
def __call__(self, x):
|
| 342 |
+
x = self.transform(x)
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
class NoEOSTextStreamer(TextStreamer):
|
| 346 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 347 |
+
|
| 348 |
+
eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
|
| 349 |
+
text = text.replace(eos_text, "\n")
|
| 350 |
+
print(text, flush=True, end="")
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def decoder_layer_init(self, config: DeepseekV2Config, layer_idx: int):
|
| 354 |
+
nn.Module.__init__(self)
|
| 355 |
+
self.hidden_size = config.hidden_size
|
| 356 |
+
|
| 357 |
+
if config.use_mla:
|
| 358 |
+
self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
|
| 359 |
+
else:
|
| 360 |
+
config.head_dim = config.hidden_size // config.num_attention_heads
|
| 361 |
+
self.self_attn = LlamaAttention(config, layer_idx)
|
| 362 |
+
self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
|
| 363 |
+
|
| 364 |
+
self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 365 |
+
self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
DeepseekV2DecoderLayer.__init__ = decoder_layer_init
|
| 369 |
+
|
| 370 |
+
class DeepseekOCRConfig(DeepseekV2Config):
|
| 371 |
+
model_type = "DeepseekOCR"
|
| 372 |
+
|
| 373 |
+
class DeepseekOCRModel(DeepseekV2Model):
|
| 374 |
+
config_class = DeepseekOCRConfig
|
| 375 |
+
|
| 376 |
+
def __init__(self, config: DeepseekV2Config):
|
| 377 |
+
super(DeepseekOCRModel, self).__init__(config)
|
| 378 |
+
|
| 379 |
+
self.sam_model = build_sam_vit_b()
|
| 380 |
+
self.vision_model = build_clip_l()
|
| 381 |
+
# self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
|
| 382 |
+
n_embed = 1280
|
| 383 |
+
self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
|
| 384 |
+
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
|
| 385 |
+
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 386 |
+
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 387 |
+
|
| 388 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
| 389 |
+
|
| 390 |
+
def forward(
|
| 391 |
+
self,
|
| 392 |
+
input_ids: torch.LongTensor = None,
|
| 393 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 394 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 395 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 396 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 397 |
+
use_cache: Optional[bool] = None,
|
| 398 |
+
output_attentions: Optional[bool] = None,
|
| 399 |
+
output_hidden_states: Optional[bool] = None,
|
| 400 |
+
images: Optional[torch.FloatTensor] = None,
|
| 401 |
+
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 402 |
+
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 403 |
+
return_dict: Optional[bool] = None,
|
| 404 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
if inputs_embeds is None:
|
| 409 |
+
# inputs_embeds = self.embed_tokens(input_ids)
|
| 410 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
sam_model = getattr(self, 'sam_model', None)
|
| 415 |
+
# sam_model = self.sam_model
|
| 416 |
+
vision_model = getattr(self, 'vision_model', None)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
|
| 421 |
+
|
| 422 |
+
idx = 0
|
| 423 |
+
|
| 424 |
+
# sam_model = torch.jit.script(sam_model)
|
| 425 |
+
|
| 426 |
+
# start_time = time.time()
|
| 427 |
+
for image, crop_shape in zip(images, images_spatial_crop):
|
| 428 |
+
images_in_this_batch = []
|
| 429 |
+
|
| 430 |
+
patches = image[0]
|
| 431 |
+
image_ori = image[1]
|
| 432 |
+
|
| 433 |
+
with torch.no_grad():
|
| 434 |
+
# with torch.inference_mode():
|
| 435 |
+
|
| 436 |
+
if torch.sum(patches).item() != 0:
|
| 437 |
+
# P, C, H, W = patches.shape
|
| 438 |
+
crop_flag = 1
|
| 439 |
+
local_features_1 = sam_model(patches)
|
| 440 |
+
|
| 441 |
+
local_features_2 = vision_model(patches, local_features_1)
|
| 442 |
+
# vit_time = time.time()
|
| 443 |
+
local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 444 |
+
local_features = self.projector(local_features)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
global_features_1 = sam_model(image_ori)
|
| 448 |
+
global_features_2 = vision_model(image_ori, global_features_1)
|
| 449 |
+
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 450 |
+
global_features = self.projector(global_features)
|
| 451 |
+
|
| 452 |
+
print('=====================')
|
| 453 |
+
print('BASE: ', global_features.shape)
|
| 454 |
+
print('PATCHES: ', local_features.shape)
|
| 455 |
+
print('=====================')
|
| 456 |
+
|
| 457 |
+
_, hw, n_dim = global_features.shape
|
| 458 |
+
h = w = int(hw ** 0.5)
|
| 459 |
+
|
| 460 |
+
_2, hw2, n_dim2 = local_features.shape
|
| 461 |
+
h2 = w2 = int(hw2 ** 0.5)
|
| 462 |
+
|
| 463 |
+
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
|
| 464 |
+
|
| 465 |
+
global_features = global_features.view(h, w, n_dim)
|
| 466 |
+
|
| 467 |
+
global_features = torch.cat(
|
| 468 |
+
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
global_features = global_features.view(-1, n_dim)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
|
| 475 |
+
local_features = torch.cat(
|
| 476 |
+
[local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
|
| 477 |
+
)
|
| 478 |
+
local_features = local_features.view(-1, n_dim2)
|
| 479 |
+
|
| 480 |
+
global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
|
| 481 |
+
|
| 482 |
+
# end_time = time.time()
|
| 483 |
+
|
| 484 |
+
# print('sam: ', sam_time - start_time)
|
| 485 |
+
# print('vit: ', vit_time - sam_time)
|
| 486 |
+
# print('all: ', end_time - start_time)
|
| 487 |
+
|
| 488 |
+
# exit()
|
| 489 |
+
|
| 490 |
+
else:
|
| 491 |
+
global_features_1 = sam_model(image_ori)
|
| 492 |
+
global_features_2 = vision_model(image_ori, global_features_1)
|
| 493 |
+
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 494 |
+
global_features = self.projector(global_features)
|
| 495 |
+
print('=====================')
|
| 496 |
+
print('BASE: ', global_features.shape)
|
| 497 |
+
print('NO PATCHES')
|
| 498 |
+
print('=====================')
|
| 499 |
+
_, hw, n_dim = global_features.shape
|
| 500 |
+
h = w = int(hw ** 0.5)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
global_features = global_features.view(h, w, n_dim)
|
| 504 |
+
|
| 505 |
+
global_features = torch.cat(
|
| 506 |
+
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
global_features = global_features.view(-1, n_dim)
|
| 510 |
+
|
| 511 |
+
global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
|
| 512 |
+
|
| 513 |
+
images_in_this_batch.append(global_local_features)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if images_in_this_batch:
|
| 517 |
+
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 518 |
+
# exit()
|
| 519 |
+
|
| 520 |
+
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
|
| 521 |
+
|
| 522 |
+
idx += 1
|
| 523 |
+
|
| 524 |
+
return super(DeepseekOCRModel, self).forward(
|
| 525 |
+
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
|
| 526 |
+
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
|
| 527 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
| 528 |
+
return_dict=return_dict
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
| 533 |
+
|
| 534 |
+
config_class = DeepseekOCRConfig
|
| 535 |
+
# supports_gradient_checkpointing = True
|
| 536 |
+
|
| 537 |
+
def __init__(self, config):
|
| 538 |
+
super(DeepseekV2ForCausalLM, self).__init__(config)
|
| 539 |
+
self.model = DeepseekOCRModel(config)
|
| 540 |
+
|
| 541 |
+
self.vocab_size = config.vocab_size
|
| 542 |
+
|
| 543 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 544 |
+
|
| 545 |
+
# self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 546 |
+
|
| 547 |
+
# Initialize weights and apply final processing
|
| 548 |
+
self.post_init()
|
| 549 |
+
|
| 550 |
+
def get_model(self):
|
| 551 |
+
return self.model
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def forward(
|
| 555 |
+
self,
|
| 556 |
+
input_ids: torch.LongTensor = None,
|
| 557 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 558 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 559 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 560 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 561 |
+
labels: Optional[torch.LongTensor] = None,
|
| 562 |
+
use_cache: Optional[bool] = None,
|
| 563 |
+
output_attentions: Optional[bool] = None,
|
| 564 |
+
output_hidden_states: Optional[bool] = None,
|
| 565 |
+
images: Optional[torch.FloatTensor] = None,
|
| 566 |
+
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 567 |
+
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 568 |
+
return_dict: Optional[bool] = None,
|
| 569 |
+
|
| 570 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 571 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 572 |
+
output_hidden_states = (
|
| 573 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 574 |
+
)
|
| 575 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
outputs = self.model(
|
| 580 |
+
input_ids=input_ids,
|
| 581 |
+
past_key_values=past_key_values,
|
| 582 |
+
attention_mask=attention_mask,
|
| 583 |
+
position_ids=position_ids,
|
| 584 |
+
inputs_embeds=inputs_embeds,
|
| 585 |
+
use_cache=use_cache,
|
| 586 |
+
output_attentions=output_attentions,
|
| 587 |
+
output_hidden_states=output_hidden_states,
|
| 588 |
+
images=images,
|
| 589 |
+
images_seq_mask = images_seq_mask,
|
| 590 |
+
images_spatial_crop = images_spatial_crop,
|
| 591 |
+
return_dict=return_dict
|
| 592 |
+
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
hidden_states = outputs[0]
|
| 596 |
+
logits = self.lm_head(hidden_states)
|
| 597 |
+
logits = logits.float()
|
| 598 |
+
|
| 599 |
+
# logits
|
| 600 |
+
|
| 601 |
+
loss = None
|
| 602 |
+
if labels is not None:
|
| 603 |
+
# Shift so that tokens < n predict n
|
| 604 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 605 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 606 |
+
# Flatten the tokens
|
| 607 |
+
loss_fct = CrossEntropyLoss()
|
| 608 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 609 |
+
shift_labels = shift_labels.view(-1)
|
| 610 |
+
# Enable model parallelism
|
| 611 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 612 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 613 |
+
|
| 614 |
+
if not return_dict:
|
| 615 |
+
output = (logits,) + outputs[1:]
|
| 616 |
+
return (loss,) + output if loss is not None else output
|
| 617 |
+
|
| 618 |
+
return CausalLMOutputWithPast(
|
| 619 |
+
loss=loss,
|
| 620 |
+
logits=logits,
|
| 621 |
+
past_key_values=outputs.past_key_values,
|
| 622 |
+
hidden_states=outputs.hidden_states,
|
| 623 |
+
attentions=outputs.attentions,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def prepare_inputs_for_generation(
|
| 628 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 629 |
+
):
|
| 630 |
+
# Omit tokens covered by past_key_values
|
| 631 |
+
past_length = 0
|
| 632 |
+
if past_key_values is not None:
|
| 633 |
+
if isinstance(past_key_values, Cache):
|
| 634 |
+
cache_length = past_key_values.get_seq_length()
|
| 635 |
+
past_length = past_key_values.get_seq_length()
|
| 636 |
+
max_cache_length = None
|
| 637 |
+
else:
|
| 638 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 639 |
+
max_cache_length = None
|
| 640 |
+
|
| 641 |
+
# Keep only the unprocessed tokens:
|
| 642 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 643 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 644 |
+
# input)
|
| 645 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 646 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 647 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 648 |
+
# input_ids based on the past_length.
|
| 649 |
+
elif past_length < input_ids.shape[1]:
|
| 650 |
+
input_ids = input_ids[:, past_length:]
|
| 651 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 652 |
+
|
| 653 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 654 |
+
if (
|
| 655 |
+
max_cache_length is not None
|
| 656 |
+
and attention_mask is not None
|
| 657 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 658 |
+
):
|
| 659 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 660 |
+
|
| 661 |
+
position_ids = kwargs.get("position_ids", None)
|
| 662 |
+
if attention_mask is not None and position_ids is None:
|
| 663 |
+
# create position_ids on the fly for batch generation
|
| 664 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 665 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 666 |
+
if past_key_values:
|
| 667 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 668 |
+
|
| 669 |
+
# if self.generation_config.cache_implementation == "static":
|
| 670 |
+
# # generation with static cache
|
| 671 |
+
# cache_position = kwargs.get("cache_position", None)
|
| 672 |
+
# if cache_position is None:
|
| 673 |
+
# past_length = 0
|
| 674 |
+
# else:
|
| 675 |
+
# past_length = cache_position[-1] + 1
|
| 676 |
+
# input_ids = input_ids[:, past_length:]
|
| 677 |
+
# position_ids = position_ids[:, past_length:]
|
| 678 |
+
|
| 679 |
+
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
| 680 |
+
# same goes for position ids. Could also help with continued generation.
|
| 681 |
+
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
| 682 |
+
|
| 683 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 684 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 685 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 686 |
+
else:
|
| 687 |
+
model_inputs = {"input_ids": input_ids}
|
| 688 |
+
|
| 689 |
+
model_inputs.update(
|
| 690 |
+
{
|
| 691 |
+
"position_ids": position_ids,
|
| 692 |
+
"past_key_values": past_key_values,
|
| 693 |
+
"use_cache": kwargs.get("use_cache"),
|
| 694 |
+
"attention_mask": attention_mask,
|
| 695 |
+
"images": kwargs.get("images", None),
|
| 696 |
+
"images_seq_mask": kwargs.get("images_seq_mask", None),
|
| 697 |
+
"images_spatial_crop": kwargs.get("images_spatial_crop", None),
|
| 698 |
+
}
|
| 699 |
+
)
|
| 700 |
+
return model_inputs
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def disable_torch_init(self):
|
| 704 |
+
"""
|
| 705 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
| 706 |
+
"""
|
| 707 |
+
import torch
|
| 708 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 709 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
|
| 714 |
+
self.disable_torch_init()
|
| 715 |
+
|
| 716 |
+
os.makedirs(output_path, exist_ok=True)
|
| 717 |
+
os.makedirs(f'{output_path}/images', exist_ok=True)
|
| 718 |
+
|
| 719 |
+
if prompt and image_file:
|
| 720 |
+
conversation = [
|
| 721 |
+
{
|
| 722 |
+
"role": "<|User|>",
|
| 723 |
+
# "content": "<image>\n<|grounding|>Given the layout of the image. ",
|
| 724 |
+
"content": f'{prompt}',
|
| 725 |
+
# "content": "君不见黄河之水天上来的下一句是什么?",
|
| 726 |
+
# "content": "<image>\nFree OCR. ",
|
| 727 |
+
# "content": "<image>\nParse the figure. ",
|
| 728 |
+
# "content": "<image>\nExtract the text in the image. ",
|
| 729 |
+
"images": [f'{image_file}'],
|
| 730 |
+
},
|
| 731 |
+
{"role": "<|Assistant|>", "content": ""},
|
| 732 |
+
]
|
| 733 |
+
|
| 734 |
+
elif prompt:
|
| 735 |
+
conversation = [
|
| 736 |
+
{
|
| 737 |
+
"role": "<|User|>",
|
| 738 |
+
# "content": "<image>\n<|grounding|>Given the layout of the image. ",
|
| 739 |
+
"content": f'{prompt}',
|
| 740 |
+
# "content": "君不见黄河之水天上来的下一句是什么?",
|
| 741 |
+
# "content": "<image>\nFree OCR. ",
|
| 742 |
+
# "content": "<image>\nParse the figure. ",
|
| 743 |
+
# "content": "<image>\nExtract the text in the image. ",
|
| 744 |
+
# "images": [f'{image_file}'],
|
| 745 |
+
},
|
| 746 |
+
{"role": "<|Assistant|>", "content": ""},
|
| 747 |
+
]
|
| 748 |
+
else:
|
| 749 |
+
assert False, f'prompt is none!'
|
| 750 |
+
|
| 751 |
+
prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
|
| 752 |
+
|
| 753 |
+
patch_size = 16
|
| 754 |
+
downsample_ratio = 4
|
| 755 |
+
images = load_pil_images(conversation)
|
| 756 |
+
|
| 757 |
+
valid_img_tokens = 0
|
| 758 |
+
ratio = 1
|
| 759 |
+
|
| 760 |
+
image_draw = images[0].copy()
|
| 761 |
+
|
| 762 |
+
w,h = image_draw.size
|
| 763 |
+
# print(w, h)
|
| 764 |
+
ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
|
| 768 |
+
images_seq_mask = []
|
| 769 |
+
|
| 770 |
+
image_token = '<image>'
|
| 771 |
+
image_token_id = 128815
|
| 772 |
+
text_splits = prompt.split(image_token)
|
| 773 |
+
|
| 774 |
+
images_list, images_crop_list, images_seq_mask = [], [], []
|
| 775 |
+
tokenized_str = []
|
| 776 |
+
images_spatial_crop = []
|
| 777 |
+
for text_sep, image in zip(text_splits, images):
|
| 778 |
+
|
| 779 |
+
tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
|
| 780 |
+
tokenized_str += tokenized_sep
|
| 781 |
+
images_seq_mask += [False] * len(tokenized_sep)
|
| 782 |
+
|
| 783 |
+
if crop_mode:
|
| 784 |
+
|
| 785 |
+
if image.size[0] <= 640 and image.size[1] <= 640:
|
| 786 |
+
crop_ratio = [1, 1]
|
| 787 |
+
|
| 788 |
+
else:
|
| 789 |
+
if crop_mode:
|
| 790 |
+
# best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
|
| 791 |
+
images_crop_raw, crop_ratio = dynamic_preprocess(image)
|
| 792 |
+
else:
|
| 793 |
+
# best_width, best_height = self.image_size, self.image_size
|
| 794 |
+
crop_ratio = [1, 1]
|
| 795 |
+
|
| 796 |
+
"""process the global view"""
|
| 797 |
+
# image = image.resize((base_size, base_size))
|
| 798 |
+
global_view = ImageOps.pad(image, (base_size, base_size),
|
| 799 |
+
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 800 |
+
|
| 801 |
+
if base_size == 1024:
|
| 802 |
+
valid_img_tokens += int(256 * ratio)
|
| 803 |
+
elif base_size == 1280:
|
| 804 |
+
valid_img_tokens += int(400 * ratio)
|
| 805 |
+
# elif base_size == 640:
|
| 806 |
+
# valid_img_tokens += int(100 * ratio)
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
images_list.append(image_transform(global_view).to(torch.bfloat16))
|
| 813 |
+
|
| 814 |
+
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
|
| 815 |
+
|
| 816 |
+
width_crop_num, height_crop_num = crop_ratio
|
| 817 |
+
|
| 818 |
+
images_spatial_crop.append([width_crop_num, height_crop_num])
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
if width_crop_num > 1 or height_crop_num > 1:
|
| 822 |
+
"""process the local views"""
|
| 823 |
+
|
| 824 |
+
for i in range(len(images_crop_raw)):
|
| 825 |
+
images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
|
| 826 |
+
|
| 827 |
+
if image_size == 640:
|
| 828 |
+
valid_img_tokens += len(images_crop_list) * 100
|
| 829 |
+
|
| 830 |
+
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
|
| 831 |
+
num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
"""add image tokens"""
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base
|
| 840 |
+
tokenized_image += [image_token_id]
|
| 841 |
+
if width_crop_num > 1 or height_crop_num > 1:
|
| 842 |
+
tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * (
|
| 843 |
+
num_queries * height_crop_num)
|
| 844 |
+
tokenized_str += tokenized_image
|
| 845 |
+
images_seq_mask += [True] * len(tokenized_image)
|
| 846 |
+
# num_image_tokens.append(len(tokenized_image))
|
| 847 |
+
|
| 848 |
+
else:
|
| 849 |
+
# best_width, best_height = self.image_size, self.image_size
|
| 850 |
+
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
|
| 851 |
+
|
| 852 |
+
"""process the global view"""
|
| 853 |
+
if image_size <= 640:
|
| 854 |
+
print('directly resize')
|
| 855 |
+
image = image.resize((image_size, image_size))
|
| 856 |
+
# else:
|
| 857 |
+
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 858 |
+
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 859 |
+
images_list.append(image_transform(global_view).to(torch.bfloat16))
|
| 860 |
+
|
| 861 |
+
if base_size == 1024:
|
| 862 |
+
valid_img_tokens += int(256 * ratio)
|
| 863 |
+
elif base_size == 1280:
|
| 864 |
+
valid_img_tokens += int(400 * ratio)
|
| 865 |
+
elif base_size == 640:
|
| 866 |
+
valid_img_tokens += int(100 * 1)
|
| 867 |
+
elif base_size == 512:
|
| 868 |
+
valid_img_tokens += int(64 * 1)
|
| 869 |
+
|
| 870 |
+
width_crop_num, height_crop_num = 1, 1
|
| 871 |
+
|
| 872 |
+
images_spatial_crop.append([width_crop_num, height_crop_num])
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
"""add image tokens"""
|
| 876 |
+
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
|
| 877 |
+
|
| 878 |
+
tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries
|
| 879 |
+
tokenized_image += [image_token_id]
|
| 880 |
+
# tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
|
| 881 |
+
# num_queries * height_crop_num)
|
| 882 |
+
tokenized_str += tokenized_image
|
| 883 |
+
images_seq_mask += [True] * len(tokenized_image)
|
| 884 |
+
# num_image_tokens.append(len(tokenized_image))
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
"""process the last text split"""
|
| 888 |
+
tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
|
| 889 |
+
tokenized_str += tokenized_sep
|
| 890 |
+
images_seq_mask += [False] * len(tokenized_sep)
|
| 891 |
+
|
| 892 |
+
"""add the bos tokens"""
|
| 893 |
+
bos_id = 0
|
| 894 |
+
tokenized_str = [bos_id] + tokenized_str
|
| 895 |
+
images_seq_mask = [False] + images_seq_mask
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
input_ids = torch.LongTensor(tokenized_str)
|
| 900 |
+
|
| 901 |
+
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
if len(images_list) == 0:
|
| 905 |
+
images_ori = torch.zeros((1, 3, image_size, image_size))
|
| 906 |
+
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
|
| 907 |
+
images_crop = torch.zeros((1, 3, base_size, base_size))
|
| 908 |
+
|
| 909 |
+
else:
|
| 910 |
+
images_ori = torch.stack(images_list, dim=0)
|
| 911 |
+
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
| 912 |
+
if images_crop_list:
|
| 913 |
+
images_crop = torch.stack(images_crop_list, dim=0)
|
| 914 |
+
else:
|
| 915 |
+
images_crop = torch.zeros((1, 3, base_size, base_size))
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
if not eval_mode:
|
| 920 |
+
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 921 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 922 |
+
with torch.no_grad():
|
| 923 |
+
output_ids = self.generate(
|
| 924 |
+
input_ids.unsqueeze(0).cuda(),
|
| 925 |
+
images=[(images_crop.cuda(), images_ori.cuda())],
|
| 926 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
|
| 927 |
+
images_spatial_crop = images_spatial_crop,
|
| 928 |
+
# do_sample=False,
|
| 929 |
+
# num_beams = 1,
|
| 930 |
+
temperature=0.0,
|
| 931 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 932 |
+
streamer=streamer,
|
| 933 |
+
max_new_tokens=8192,
|
| 934 |
+
no_repeat_ngram_size = 20,
|
| 935 |
+
use_cache = True
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
else:
|
| 939 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 940 |
+
with torch.no_grad():
|
| 941 |
+
output_ids = self.generate(
|
| 942 |
+
input_ids.unsqueeze(0).cuda(),
|
| 943 |
+
images=[(images_crop.cuda(), images_ori.cuda())],
|
| 944 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
|
| 945 |
+
images_spatial_crop = images_spatial_crop,
|
| 946 |
+
# do_sample=False,
|
| 947 |
+
# num_beams = 1,
|
| 948 |
+
temperature=0.0,
|
| 949 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 950 |
+
max_new_tokens=8192,
|
| 951 |
+
no_repeat_ngram_size = 35,
|
| 952 |
+
use_cache = True
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
if '<image>' in conversation[0]['content'] and eval_mode:
|
| 957 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
|
| 958 |
+
stop_str = '<|end▁of▁sentence|>'
|
| 959 |
+
if outputs.endswith(stop_str):
|
| 960 |
+
outputs = outputs[:-len(stop_str)]
|
| 961 |
+
# re_match
|
| 962 |
+
outputs = outputs.strip()
|
| 963 |
+
|
| 964 |
+
return outputs
|
| 965 |
+
|
| 966 |
+
if '<image>' in conversation[0]['content'] and test_compress:
|
| 967 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
|
| 968 |
+
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
|
| 969 |
+
print('='*50)
|
| 970 |
+
print('image size: ', (w, h))
|
| 971 |
+
print('valid image tokens: ', int(valid_img_tokens))
|
| 972 |
+
print('output texts tokens (valid): ', pure_texts_outputs_token_length)
|
| 973 |
+
print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
|
| 974 |
+
print('='*50)
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
if '<image>' in conversation[0]['content'] and save_results:
|
| 978 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
|
| 979 |
+
stop_str = '<|end▁of▁sentence|>'
|
| 980 |
+
|
| 981 |
+
print('='*15 + 'save results:' + '='*15)
|
| 982 |
+
|
| 983 |
+
# # # # conv.messages[-1][-1] = outputs
|
| 984 |
+
if outputs.endswith(stop_str):
|
| 985 |
+
outputs = outputs[:-len(stop_str)]
|
| 986 |
+
outputs = outputs.strip()
|
| 987 |
+
|
| 988 |
+
matches_ref, matches_images, mathes_other = re_match(outputs)
|
| 989 |
+
# print(matches_ref)
|
| 990 |
+
result = process_image_with_refs(image_draw, matches_ref, output_path)
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
|
| 994 |
+
outputs = outputs.replace(a_match_image, ' + '.jpg)\n')
|
| 995 |
+
|
| 996 |
+
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
|
| 997 |
+
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
# if 'structural formula' in conversation[0]['content']:
|
| 1001 |
+
# outputs = '<smiles>' + outputs + '</smiles>'
|
| 1002 |
+
with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
|
| 1003 |
+
afile.write(outputs)
|
| 1004 |
+
|
| 1005 |
+
if 'line_type' in outputs:
|
| 1006 |
+
import matplotlib.pyplot as plt
|
| 1007 |
+
lines = eval(outputs)['Line']['line']
|
| 1008 |
+
|
| 1009 |
+
line_type = eval(outputs)['Line']['line_type']
|
| 1010 |
+
# print(lines)
|
| 1011 |
+
|
| 1012 |
+
endpoints = eval(outputs)['Line']['line_endpoint']
|
| 1013 |
+
|
| 1014 |
+
fig, ax = plt.subplots(figsize=(3,3), dpi=200)
|
| 1015 |
+
ax.set_xlim(-15, 15)
|
| 1016 |
+
ax.set_ylim(-15, 15)
|
| 1017 |
+
|
| 1018 |
+
for idx, line in enumerate(lines):
|
| 1019 |
+
try:
|
| 1020 |
+
p0 = eval(line.split(' -- ')[0])
|
| 1021 |
+
p1 = eval(line.split(' -- ')[-1])
|
| 1022 |
+
|
| 1023 |
+
if line_type[idx] == '--':
|
| 1024 |
+
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
|
| 1025 |
+
else:
|
| 1026 |
+
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
|
| 1027 |
+
|
| 1028 |
+
ax.scatter(p0[0], p0[1], s=5, color = 'k')
|
| 1029 |
+
ax.scatter(p1[0], p1[1], s=5, color = 'k')
|
| 1030 |
+
except:
|
| 1031 |
+
pass
|
| 1032 |
+
|
| 1033 |
+
for endpoint in endpoints:
|
| 1034 |
+
|
| 1035 |
+
label = endpoint.split(': ')[0]
|
| 1036 |
+
(x, y) = eval(endpoint.split(': ')[1])
|
| 1037 |
+
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
|
| 1038 |
+
fontsize=5, fontweight='light')
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
plt.savefig(f'{output_path}/geo.jpg')
|
| 1042 |
+
plt.close()
|
| 1043 |
+
|
| 1044 |
+
result.save(f"{output_path}/result_with_boxes.jpg")
|
processor_config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_special_token": false,
|
| 3 |
+
"candidate_resolutions": [
|
| 4 |
+
[
|
| 5 |
+
1024,
|
| 6 |
+
1024
|
| 7 |
+
]
|
| 8 |
+
],
|
| 9 |
+
"downsample_ratio": 4,
|
| 10 |
+
"ignore_id": -100,
|
| 11 |
+
"image_mean": [
|
| 12 |
+
0.5,
|
| 13 |
+
0.5,
|
| 14 |
+
0.5
|
| 15 |
+
],
|
| 16 |
+
"image_std": [
|
| 17 |
+
0.5,
|
| 18 |
+
0.5,
|
| 19 |
+
0.5
|
| 20 |
+
],
|
| 21 |
+
"image_token": "<image>",
|
| 22 |
+
"mask_prompt": false,
|
| 23 |
+
"normalize": true,
|
| 24 |
+
"pad_token": "<\uff5c\u2581pad\u2581\uff5c>",
|
| 25 |
+
"patch_size": 16,
|
| 26 |
+
"processor_class": "DeepseekVLV2Processor",
|
| 27 |
+
"sft_format": "deepseek"
|
| 28 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
{
|
| 4 |
+
"content": "<|User|>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"content": "<|Assistant|>",
|
| 12 |
+
"lstrip": false,
|
| 13 |
+
"normalized": false,
|
| 14 |
+
"rstrip": false,
|
| 15 |
+
"single_word": false
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"bos_token": {
|
| 19 |
+
"content": "<|begin▁of▁sentence|>",
|
| 20 |
+
"lstrip": false,
|
| 21 |
+
"normalized": false,
|
| 22 |
+
"rstrip": false,
|
| 23 |
+
"single_word": false
|
| 24 |
+
},
|
| 25 |
+
"eos_token": {
|
| 26 |
+
"content": "<|end▁of▁sentence|>",
|
| 27 |
+
"lstrip": false,
|
| 28 |
+
"normalized": false,
|
| 29 |
+
"rstrip": false,
|
| 30 |
+
"single_word": false
|
| 31 |
+
},
|
| 32 |
+
"pad_token": {
|
| 33 |
+
"content": "<|▁pad▁|>",
|
| 34 |
+
"lstrip": false,
|
| 35 |
+
"normalized": false,
|
| 36 |
+
"rstrip": false,
|
| 37 |
+
"single_word": false
|
| 38 |
+
}
|
| 39 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|