File size: 2,950 Bytes
67dab4d
 
 
 
 
 
 
 
 
 
 
 
 
9942f0f
 
67dab4d
 
 
 
 
 
 
 
 
 
 
9942f0f
 
 
 
 
 
 
 
 
67dab4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import numpy as np
import pyarrow.parquet as pq
from PIL import Image
from io import BytesIO
from sklearn.neighbors import NearestNeighbors
from torchvision import transforms
from transformers import (
    CLIPProcessor,
    CLIPModel,
    BlipProcessor,
    BlipForConditionalGeneration
)
import zipfile
import os

MEAN = [0.48145466, 0.4578275, 0.40821073]
STD = [0.26862954, 0.26130258, 0.27577711]

def load_models(device):
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
    blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    return clip_model, clip_processor, blip_model, blip_processor


def load_data(parquet_path="food101_embeddings_10000.parquet"):
    
    if not os.path.exists("food_images"):
        with zipfile.ZipFile("food_images_10000.zip", "r") as zip_ref:
            zip_ref.extractall("food_images")

    df = pd.read_parquet(parquet_path)
    df["image_path"] = df["image_path"].apply(lambda p: os.path.join("food_images", os.path.basename(p)))
    embeddings = np.vstack(df["embedding"].to_numpy())
    return df, embeddings

def bytes_to_pil(byte_data):
    return Image.open(BytesIO(byte_data)).convert("RGB")

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD),
    ])
    return transform(image).unsqueeze(0)

def search_by_text(text, processor, model, embeddings, df, top_k=5, device="cpu"):
    inputs = processor(text=[text], return_tensors="pt").to(device)
    with torch.no_grad():
        text_feat = model.get_text_features(**inputs).cpu().numpy()
    nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings)
    return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(text_feat, return_distance=False)[0]]

def search_by_image(uploaded_image, processor, model, embeddings, df, top_k=5, device="cpu"):
    image_tensor = preprocess_image(uploaded_image).to(device)
    with torch.no_grad():
        img_feat = model.get_image_features(image_tensor).cpu().numpy()
    nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings)
    return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(img_feat, return_distance=False)[0]]

def generate_caption(uploaded_image, processor, model, device="cpu"):
    image = uploaded_image.convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        output = model.generate(**inputs)
    return processor.decode(output[0], skip_special_tokens=True)