Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import pyarrow.parquet as pq | |
| from PIL import Image | |
| from io import BytesIO | |
| from sklearn.neighbors import NearestNeighbors | |
| from torchvision import transforms | |
| from transformers import ( | |
| CLIPProcessor, | |
| CLIPModel, | |
| BlipProcessor, | |
| BlipForConditionalGeneration | |
| ) | |
| import zipfile | |
| import os | |
| MEAN = [0.48145466, 0.4578275, 0.40821073] | |
| STD = [0.26862954, 0.26130258, 0.27577711] | |
| def load_models(device): | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| return clip_model, clip_processor, blip_model, blip_processor | |
| def load_data(parquet_path="food101_embeddings_10000.parquet"): | |
| if not os.path.exists("food_images"): | |
| with zipfile.ZipFile("food_images_10000.zip", "r") as zip_ref: | |
| zip_ref.extractall("food_images") | |
| df = pd.read_parquet(parquet_path) | |
| df["image_path"] = df["image_path"].apply(lambda p: os.path.join("food_images", os.path.basename(p))) | |
| embeddings = np.vstack(df["embedding"].to_numpy()) | |
| return df, embeddings | |
| def bytes_to_pil(byte_data): | |
| return Image.open(BytesIO(byte_data)).convert("RGB") | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=MEAN, std=STD), | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def search_by_text(text, processor, model, embeddings, df, top_k=5, device="cpu"): | |
| inputs = processor(text=[text], return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| text_feat = model.get_text_features(**inputs).cpu().numpy() | |
| nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings) | |
| return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(text_feat, return_distance=False)[0]] | |
| def search_by_image(uploaded_image, processor, model, embeddings, df, top_k=5, device="cpu"): | |
| image_tensor = preprocess_image(uploaded_image).to(device) | |
| with torch.no_grad(): | |
| img_feat = model.get_image_features(image_tensor).cpu().numpy() | |
| nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings) | |
| return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(img_feat, return_distance=False)[0]] | |
| def generate_caption(uploaded_image, processor, model, device="cpu"): | |
| image = uploaded_image.convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate(**inputs) | |
| return processor.decode(output[0], skip_special_tokens=True) |