Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import os | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| from transformers import AutoModel, AutoProcessor, CLIPModel | |
| from thai2transformers.preprocess import process_transformers | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| from qdrant_client.http.models import Filter, FieldCondition, MatchValue | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load models | |
| image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| text_processor = AutoProcessor.from_pretrained("openthaigpt/CLIPTextCamembertModelWithProjection", trust_remote_code=True) | |
| text_model = AutoModel.from_pretrained("openthaigpt/CLIPTextCamembertModelWithProjection", trust_remote_code=True).to(device) | |
| # Qdrant setup | |
| url = os.environ.get("QDRANT_URL") | |
| api_key = os.environ.get("QDRANT_API_KEY") | |
| qdrant_client = QdrantClient(url=url, api_key=api_key) | |
| from PIL import ImageDraw, ImageFont | |
| def generate_error_image(message="No image"): | |
| img = Image.new("RGB", (256, 256), color=(240, 240, 240)) | |
| draw = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 18) | |
| except: | |
| font = ImageFont.load_default() | |
| draw.text((10, 120), message[:30], fill="red", font=font) | |
| return img | |
| def get_image_embedding(image: Image.Image): | |
| inputs = image_processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| image_embeddings = image_model.get_image_features(**inputs) | |
| image_embeddings /= image_embeddings.norm(dim=1, keepdim=True) | |
| return image_embeddings[0].cpu().numpy().tolist() | |
| def get_text_embedding(text: str): | |
| try: | |
| processed = process_transformers(text) | |
| except Exception as e: | |
| return [(None, f"Preprocessing error: {str(e)}")] | |
| inputs = text_processor(text=processed, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| text_embeddings = text_model(**inputs).text_embeds | |
| text_embeddings /= text_embeddings.norm(dim=1, keepdim=True) | |
| return text_embeddings[0].cpu().numpy().tolist() | |
| def retrieve_from_qdrant(query_vector, modality, limit=10): | |
| return qdrant_client.query_points( | |
| collection_name="thai2transformers_clip", | |
| query=query_vector, | |
| with_payload=True, | |
| query_filter=Filter( | |
| must=[ | |
| FieldCondition( | |
| key="modality", | |
| match=MatchValue(value=modality) | |
| ) | |
| ] | |
| ), | |
| limit=limit | |
| ).points | |
| def load_image_from_url(url): | |
| try: | |
| # Fetch image from URL | |
| request = requests.get(url) | |
| request.raise_for_status() # Raise an error for bad responses | |
| return Image.open(BytesIO(request.content)).convert("RGB") | |
| except Exception as e: | |
| print(f"Error fetching image from {url}: {e}") | |
| return None | |
| def multimodal_query(text_input, text_target, image_input, image_url, mode): | |
| if mode == "image-to-image": | |
| image = image_input or (load_image_from_url(image_url) if image_url else None) | |
| if image is None: | |
| return [(Image.new("RGB", (256, 256), color="white"), "❌ No image provided.")] | |
| query_vector = get_image_embedding(image) | |
| modality = "image" | |
| elif mode == "image-to-text": | |
| image = image_input or (load_image_from_url(image_url) if image_url else None) | |
| if image is None: | |
| return [(Image.new("RGB", (256, 256), color="white"), "❌ No image provided.")] | |
| query_vector = get_image_embedding(image) | |
| modality = "text" | |
| elif mode == "text-to-image": | |
| if not text_input: | |
| return [(None, "No text provided.")] | |
| query_vector = get_text_embedding(text_input) | |
| modality = "image" | |
| elif mode == "text-to-text": | |
| if not text_input: | |
| return [(None, "No text provided.")] | |
| query_vector = get_text_embedding(text_input) | |
| modality = "text" | |
| else: | |
| return [(None, "Invalid mode selected.")] | |
| results = retrieve_from_qdrant(query_vector, modality) | |
| outputs = [] | |
| for res in results: | |
| try: | |
| img_url = res.payload.get("image_url") | |
| caption = res.payload.get("name", "") | |
| if img_url: | |
| img = Image.open(BytesIO(requests.get(img_url).content)).resize((256, 256)) | |
| else: | |
| img = generate_error_image("No image URL") | |
| outputs.append((img, caption[:40])) | |
| except Exception as e: | |
| fallback_img = generate_error_image("Error loading image") | |
| outputs.append((fallback_img, f"Error: {str(e)[:30]}")) | |
| return outputs | |
| # Gradio UI | |
| with gr.Blocks(title="🔄 Multimodal Query System") as demo: | |
| gr.Markdown("## 🔎 ค้นหาด้วยรูปภาพและข้อความ (CLIP + Qdrant)\nรองรับทั้ง image-to-image, image-to-text, text-to-image และ text-to-text") | |
| with gr.Row(): | |
| with gr.Column(): | |
| mode = gr.Dropdown(label="โหมดการค้นหา", choices=["image-to-image", "image-to-text", "text-to-image", "text-to-text"], value="image-to-image") | |
| image_input = gr.Image(type="pil", label="อัปโหลดภาพ") | |
| image_url = gr.Textbox(label="หรือใส่ URL ของภาพ") | |
| text_input = gr.Textbox(label="ข้อความที่ใช้ค้นหา (text input)") | |
| text_target = gr.Textbox(label="เป้าหมาย (ใช้ในบางกรณี)", visible=False) | |
| search_btn = gr.Button("🔍 ค้นหา") | |
| with gr.Column(): | |
| gallery = gr.Gallery(label="ผลลัพธ์", columns=5, height=600) | |
| search_btn.click( | |
| fn=multimodal_query, | |
| inputs=[text_input, text_target, image_input, image_url, mode], | |
| outputs=gallery | |
| ) |