PlengRKO commited on
Commit
81f16d0
·
verified ·
1 Parent(s): 6913d10

Set up model

Browse files
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import requests
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from transformers import AutoModel, AutoProcessor, CLIPModel
8
+ from thai2transformers.preprocess import process_transformers
9
+ from qdrant_client import QdrantClient
10
+ from qdrant_client.http import models
11
+ from qdrant_client.http.models import Filter, FieldCondition, MatchValue
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # Load models
16
+ image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
17
+ image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
18
+
19
+ text_processor = AutoProcessor.from_pretrained("openthaigpt/CLIPTextCamembertModelWithProjection", trust_remote_code=True)
20
+ text_model = AutoModel.from_pretrained("openthaigpt/CLIPTextCamembertModelWithProjection", trust_remote_code=True).to(device)
21
+
22
+ # Qdrant setup
23
+
24
+ url = os.environ.get("QDRANT_URL")
25
+ api_key = os.environ.get("QDRANT_API_KEY")
26
+ qdrant_client = QdrantClient(url=url, api_key=api_key)
27
+
28
+ from PIL import ImageDraw, ImageFont
29
+
30
+ def generate_error_image(message="No image"):
31
+ img = Image.new("RGB", (256, 256), color=(240, 240, 240))
32
+ draw = ImageDraw.Draw(img)
33
+ try:
34
+ font = ImageFont.truetype("arial.ttf", 18)
35
+ except:
36
+ font = ImageFont.load_default()
37
+ draw.text((10, 120), message[:30], fill="red", font=font)
38
+ return img
39
+
40
+
41
+ def get_image_embedding(image: Image.Image):
42
+ inputs = image_processor(images=image, return_tensors="pt")
43
+ inputs = {k: v.to(device) for k, v in inputs.items()}
44
+ with torch.no_grad():
45
+ image_embeddings = image_model.get_image_features(**inputs)
46
+ image_embeddings /= image_embeddings.norm(dim=1, keepdim=True)
47
+ return image_embeddings[0].cpu().numpy().tolist()
48
+
49
+ def get_text_embedding(text: str):
50
+ try:
51
+ processed = process_transformers(text)
52
+ except Exception as e:
53
+ return [(None, f"Preprocessing error: {str(e)}")]
54
+ inputs = text_processor(text=processed, return_tensors="pt", padding=True)
55
+ inputs = {k: v.to(device) for k, v in inputs.items()}
56
+ with torch.no_grad():
57
+ text_embeddings = text_model(**inputs).text_embeds
58
+ text_embeddings /= text_embeddings.norm(dim=1, keepdim=True)
59
+ return text_embeddings[0].cpu().numpy().tolist()
60
+
61
+ def retrieve_from_qdrant(query_vector, modality, limit=10):
62
+ return qdrant_client.query_points(
63
+ collection_name="thai2transformers_clip",
64
+ query=query_vector,
65
+ with_payload=True,
66
+ query_filter=Filter(
67
+ must=[
68
+ FieldCondition(
69
+ key="modality",
70
+ match=MatchValue(value=modality)
71
+ )
72
+ ]
73
+ ),
74
+ limit=limit
75
+ ).points
76
+
77
+ def load_image_from_url(url):
78
+ try:
79
+ # Fetch image from URL
80
+ request = requests.get(url)
81
+ request.raise_for_status() # Raise an error for bad responses
82
+ return Image.open(BytesIO(request.content)).convert("RGB")
83
+ except Exception as e:
84
+ print(f"Error fetching image from {url}: {e}")
85
+ return None
86
+
87
+ def multimodal_query(text_input, text_target, image_input, image_url, mode):
88
+ if mode == "image-to-image":
89
+ image = image_input or (load_image_from_url(image_url) if image_url else None)
90
+ if image is None:
91
+ return [(Image.new("RGB", (256, 256), color="white"), "❌ No image provided.")]
92
+
93
+ query_vector = get_image_embedding(image)
94
+ modality = "image"
95
+
96
+ elif mode == "image-to-text":
97
+ image = image_input or (load_image_from_url(image_url) if image_url else None)
98
+ if image is None:
99
+ return [(Image.new("RGB", (256, 256), color="white"), "❌ No image provided.")]
100
+ query_vector = get_image_embedding(image)
101
+ modality = "text"
102
+
103
+ elif mode == "text-to-image":
104
+ if not text_input:
105
+ return [(None, "No text provided.")]
106
+ query_vector = get_text_embedding(text_input)
107
+ modality = "image"
108
+
109
+ elif mode == "text-to-text":
110
+ if not text_input:
111
+ return [(None, "No text provided.")]
112
+ query_vector = get_text_embedding(text_input)
113
+ modality = "text"
114
+
115
+ else:
116
+ return [(None, "Invalid mode selected.")]
117
+
118
+ results = retrieve_from_qdrant(query_vector, modality)
119
+ outputs = []
120
+
121
+ for res in results:
122
+ try:
123
+ img_url = res.payload.get("image_url")
124
+ caption = res.payload.get("name", "")
125
+ if img_url:
126
+ img = Image.open(BytesIO(requests.get(img_url).content)).resize((256, 256))
127
+ else:
128
+ img = generate_error_image("No image URL")
129
+ outputs.append((img, caption[:40]))
130
+ except Exception as e:
131
+ fallback_img = generate_error_image("Error loading image")
132
+ outputs.append((fallback_img, f"Error: {str(e)[:30]}"))
133
+
134
+ return outputs
135
+
136
+
137
+ # Gradio UI
138
+ with gr.Blocks(title="🔄 Multimodal Query System") as demo:
139
+ gr.Markdown("## 🔎 ค้นหาด้วยรูปภาพและข้อความ (CLIP + Qdrant)\nรองรับทั้ง image-to-image, image-to-text, text-to-image และ text-to-text")
140
+
141
+ with gr.Row():
142
+ with gr.Column():
143
+ mode = gr.Dropdown(label="โหมดการค้นหา", choices=["image-to-image", "image-to-text", "text-to-image", "text-to-text"], value="image-to-image")
144
+ image_input = gr.Image(type="pil", label="อัปโหลดภาพ")
145
+ image_url = gr.Textbox(label="หรือใส่ URL ของภาพ")
146
+ text_input = gr.Textbox(label="ข้อความที่ใช้ค้นหา (text input)")
147
+ text_target = gr.Textbox(label="เป้าหมาย (ใช้ในบางกรณี)", visible=False)
148
+ search_btn = gr.Button("🔍 ค้นหา")
149
+
150
+ with gr.Column():
151
+ gallery = gr.Gallery(label="ผลลัพธ์", columns=5, height=600)
152
+
153
+ search_btn.click(
154
+ fn=multimodal_query,
155
+ inputs=[text_input, text_target, image_input, image_url, mode],
156
+ outputs=gallery
157
+ )
configuration_clip_camembert.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CamembertConfig
2
+
3
+
4
+ class CLIPTextCamembertConfig(CamembertConfig):
5
+ # ref : https://huggingface.co/airesearch/wangchanberta-base-att-spm-uncased/blob/main/config.json
6
+ model_type = "clip_text_camembert"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=25005,
11
+ hidden_size=768,
12
+ intermediate_size=3072,
13
+ projection_dim=512,
14
+ num_hidden_layers=12,
15
+ num_attention_heads=12,
16
+ max_position_embeddings=512,
17
+ hidden_act="gelu",
18
+ layer_norm_eps=1e-12,
19
+ attention_dropout=0.1,
20
+ initializer_range=0.02,
21
+ initializer_factor=1.0,
22
+ pad_token_id=1,
23
+ bos_token_id=0,
24
+ eos_token_id=2,
25
+ type_vocab_size=1,
26
+ **kwargs,
27
+ ):
28
+ super().__init__(
29
+ pad_token_id=pad_token_id,
30
+ bos_token_id=bos_token_id,
31
+ eos_token_id=eos_token_id,
32
+ **kwargs,
33
+ )
34
+
35
+ self.vocab_size = vocab_size
36
+ self.hidden_size = hidden_size
37
+ self.intermediate_size = intermediate_size
38
+ self.projection_dim = projection_dim
39
+ self.num_hidden_layers = num_hidden_layers
40
+ self.num_attention_heads = num_attention_heads
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.layer_norm_eps = layer_norm_eps
43
+ self.hidden_act = hidden_act
44
+ self.initializer_range = initializer_range
45
+ self.initializer_factor = initializer_factor
46
+ self.attention_dropout = attention_dropout
47
+ self.type_vocab_size = type_vocab_size
48
+ self.auto_map = {
49
+ "AutoConfig": "configuration_clip_camembert.CLIPTextCamembertConfig",
50
+ "AutoModel": "modeling_clip_camembert.CLIPTextCamembertModelWithProjection",
51
+ }
modeling_clip_camembert.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_clip_camembert import CLIPTextCamembertConfig
2
+ from transformers import (
3
+ CamembertModel,
4
+ CLIPTextModelWithProjection,
5
+ )
6
+ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
7
+ import torch
8
+ from torch import nn
9
+ from typing import Any, Optional, Tuple, Union
10
+
11
+
12
+ class CLIPTextCamembertModelWithProjection(CLIPTextModelWithProjection):
13
+ config_class = CLIPTextCamembertConfig
14
+
15
+ def __init__(self, config: CLIPTextCamembertConfig):
16
+ super().__init__(config)
17
+
18
+ self.text_model = CamembertModel(config)
19
+
20
+ self.text_projection = nn.Linear(
21
+ config.hidden_size, config.projection_dim, bias=False
22
+ )
23
+ # Initialize weights and apply final processing
24
+ self.post_init()
25
+
26
+ def forward(
27
+ self,
28
+ input_ids: Optional[torch.Tensor] = None,
29
+ attention_mask: Optional[torch.Tensor] = None,
30
+ position_ids: Optional[torch.Tensor] = None,
31
+ output_attentions: Optional[bool] = None,
32
+ output_hidden_states: Optional[bool] = None,
33
+ return_dict: Optional[bool] = None,
34
+ ) -> Union[Tuple, CLIPTextModelOutput]:
35
+ return_dict = (
36
+ return_dict if return_dict is not None else self.config.use_return_dict
37
+ )
38
+
39
+ text_outputs = self.text_model(
40
+ input_ids=input_ids,
41
+ attention_mask=attention_mask,
42
+ position_ids=position_ids,
43
+ output_attentions=output_attentions,
44
+ output_hidden_states=output_hidden_states,
45
+ return_dict=return_dict,
46
+ )
47
+
48
+ pooled_output = text_outputs[1]
49
+
50
+ text_embeds = self.text_projection(pooled_output)
51
+
52
+ if not return_dict:
53
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
54
+ return tuple(output for output in outputs if output is not None)
55
+
56
+ return CLIPTextModelOutput(
57
+ text_embeds=text_embeds,
58
+ last_hidden_state=text_outputs.last_hidden_state,
59
+ hidden_states=text_outputs.hidden_states,
60
+ attentions=text_outputs.attentions,
61
+ )
62
+
63
+ def converter_weight(
64
+ self, path_model="airesearch/wangchanberta-base-att-spm-uncased"
65
+ ):
66
+ r"""
67
+ converter weight from airesearch/wangchanberta-base-att-spm-uncased
68
+ """
69
+ pretrained_state_dict = CamembertModel.from_pretrained(path_model).state_dict()
70
+ # Load the new state dictionary into the custom model
71
+ self.text_model.load_state_dict(pretrained_state_dict)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ thai2transformers==0.1.2
5
+ pythainlp
6
+ transformers
7
+ pillow
8
+ qdrant-client
9
+ requests
10
+ gradio
11
+ numpy
12
+ matplotlib
13
+ ftfy