foodsearch101 / utils.py
Ma
Update utils.py
9942f0f verified
raw
history blame
2.95 kB
import torch
import numpy as np
import pyarrow.parquet as pq
from PIL import Image
from io import BytesIO
from sklearn.neighbors import NearestNeighbors
from torchvision import transforms
from transformers import (
CLIPProcessor,
CLIPModel,
BlipProcessor,
BlipForConditionalGeneration
)
import zipfile
import os
MEAN = [0.48145466, 0.4578275, 0.40821073]
STD = [0.26862954, 0.26130258, 0.27577711]
def load_models(device):
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
return clip_model, clip_processor, blip_model, blip_processor
def load_data(parquet_path="food101_embeddings_10000.parquet"):
if not os.path.exists("food_images"):
with zipfile.ZipFile("food_images_10000.zip", "r") as zip_ref:
zip_ref.extractall("food_images")
df = pd.read_parquet(parquet_path)
df["image_path"] = df["image_path"].apply(lambda p: os.path.join("food_images", os.path.basename(p)))
embeddings = np.vstack(df["embedding"].to_numpy())
return df, embeddings
def bytes_to_pil(byte_data):
return Image.open(BytesIO(byte_data)).convert("RGB")
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD),
])
return transform(image).unsqueeze(0)
def search_by_text(text, processor, model, embeddings, df, top_k=5, device="cpu"):
inputs = processor(text=[text], return_tensors="pt").to(device)
with torch.no_grad():
text_feat = model.get_text_features(**inputs).cpu().numpy()
nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings)
return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(text_feat, return_distance=False)[0]]
def search_by_image(uploaded_image, processor, model, embeddings, df, top_k=5, device="cpu"):
image_tensor = preprocess_image(uploaded_image).to(device)
with torch.no_grad():
img_feat = model.get_image_features(image_tensor).cpu().numpy()
nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings)
return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(img_feat, return_distance=False)[0]]
def generate_caption(uploaded_image, processor, model, device="cpu"):
image = uploaded_image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(**inputs)
return processor.decode(output[0], skip_special_tokens=True)