thai-clip-pin / app.py
PlengRKO's picture
Set up model
81f16d0 verified
import gradio as gr
import torch
import os
import requests
from io import BytesIO
from PIL import Image
from transformers import AutoModel, AutoProcessor, CLIPModel
from thai2transformers.preprocess import process_transformers
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models
image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
text_processor = AutoProcessor.from_pretrained("openthaigpt/CLIPTextCamembertModelWithProjection", trust_remote_code=True)
text_model = AutoModel.from_pretrained("openthaigpt/CLIPTextCamembertModelWithProjection", trust_remote_code=True).to(device)
# Qdrant setup
url = os.environ.get("QDRANT_URL")
api_key = os.environ.get("QDRANT_API_KEY")
qdrant_client = QdrantClient(url=url, api_key=api_key)
from PIL import ImageDraw, ImageFont
def generate_error_image(message="No image"):
img = Image.new("RGB", (256, 256), color=(240, 240, 240))
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 18)
except:
font = ImageFont.load_default()
draw.text((10, 120), message[:30], fill="red", font=font)
return img
def get_image_embedding(image: Image.Image):
inputs = image_processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
image_embeddings = image_model.get_image_features(**inputs)
image_embeddings /= image_embeddings.norm(dim=1, keepdim=True)
return image_embeddings[0].cpu().numpy().tolist()
def get_text_embedding(text: str):
try:
processed = process_transformers(text)
except Exception as e:
return [(None, f"Preprocessing error: {str(e)}")]
inputs = text_processor(text=processed, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
text_embeddings = text_model(**inputs).text_embeds
text_embeddings /= text_embeddings.norm(dim=1, keepdim=True)
return text_embeddings[0].cpu().numpy().tolist()
def retrieve_from_qdrant(query_vector, modality, limit=10):
return qdrant_client.query_points(
collection_name="thai2transformers_clip",
query=query_vector,
with_payload=True,
query_filter=Filter(
must=[
FieldCondition(
key="modality",
match=MatchValue(value=modality)
)
]
),
limit=limit
).points
def load_image_from_url(url):
try:
# Fetch image from URL
request = requests.get(url)
request.raise_for_status() # Raise an error for bad responses
return Image.open(BytesIO(request.content)).convert("RGB")
except Exception as e:
print(f"Error fetching image from {url}: {e}")
return None
def multimodal_query(text_input, text_target, image_input, image_url, mode):
if mode == "image-to-image":
image = image_input or (load_image_from_url(image_url) if image_url else None)
if image is None:
return [(Image.new("RGB", (256, 256), color="white"), "❌ No image provided.")]
query_vector = get_image_embedding(image)
modality = "image"
elif mode == "image-to-text":
image = image_input or (load_image_from_url(image_url) if image_url else None)
if image is None:
return [(Image.new("RGB", (256, 256), color="white"), "❌ No image provided.")]
query_vector = get_image_embedding(image)
modality = "text"
elif mode == "text-to-image":
if not text_input:
return [(None, "No text provided.")]
query_vector = get_text_embedding(text_input)
modality = "image"
elif mode == "text-to-text":
if not text_input:
return [(None, "No text provided.")]
query_vector = get_text_embedding(text_input)
modality = "text"
else:
return [(None, "Invalid mode selected.")]
results = retrieve_from_qdrant(query_vector, modality)
outputs = []
for res in results:
try:
img_url = res.payload.get("image_url")
caption = res.payload.get("name", "")
if img_url:
img = Image.open(BytesIO(requests.get(img_url).content)).resize((256, 256))
else:
img = generate_error_image("No image URL")
outputs.append((img, caption[:40]))
except Exception as e:
fallback_img = generate_error_image("Error loading image")
outputs.append((fallback_img, f"Error: {str(e)[:30]}"))
return outputs
# Gradio UI
with gr.Blocks(title="🔄 Multimodal Query System") as demo:
gr.Markdown("## 🔎 ค้นหาด้วยรูปภาพและข้อความ (CLIP + Qdrant)\nรองรับทั้ง image-to-image, image-to-text, text-to-image และ text-to-text")
with gr.Row():
with gr.Column():
mode = gr.Dropdown(label="โหมดการค้นหา", choices=["image-to-image", "image-to-text", "text-to-image", "text-to-text"], value="image-to-image")
image_input = gr.Image(type="pil", label="อัปโหลดภาพ")
image_url = gr.Textbox(label="หรือใส่ URL ของภาพ")
text_input = gr.Textbox(label="ข้อความที่ใช้ค้นหา (text input)")
text_target = gr.Textbox(label="เป้าหมาย (ใช้ในบางกรณี)", visible=False)
search_btn = gr.Button("🔍 ค้นหา")
with gr.Column():
gallery = gr.Gallery(label="ผลลัพธ์", columns=5, height=600)
search_btn.click(
fn=multimodal_query,
inputs=[text_input, text_target, image_input, image_url, mode],
outputs=gallery
)