Ma commited on
Commit
67dab4d
Β·
verified Β·
1 Parent(s): b723a16

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +52 -0
  2. requirements.txt +8 -3
  3. utils.py +61 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from utils import (
5
+ load_models,
6
+ load_data,
7
+ search_by_text,
8
+ search_by_image,
9
+ generate_caption,
10
+ )
11
+
12
+ st.set_page_config(page_title="🍱 Food Search App", layout="wide")
13
+ st.title("🍽️ Food Image & Text Search App")
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ with st.spinner("πŸ”„ Loading models and data..."):
18
+ clip_model, clip_processor, blip_model, blip_processor = load_models(device)
19
+ df, image_embeddings = load_data()
20
+
21
+ tab1, tab2, tab3 = st.tabs(["πŸ”€ Text Search", "πŸ–ΌοΈ Image Search", "πŸ“ Describe Image"])
22
+
23
+ with tab1:
24
+ st.subheader("Search by Text")
25
+ query = st.text_input("Type a food description (e.g. 'spicy noodles'):")
26
+ if st.button("Search", key="text_search") and query.strip():
27
+ results = search_by_text(query, clip_processor, clip_model, image_embeddings, df, device=device)
28
+ cols = st.columns(5)
29
+ for col, item in zip(cols, results):
30
+ col.image(item["image"], caption=item["label"], use_column_width=True)
31
+
32
+ with tab2:
33
+ st.subheader("Search by Image")
34
+ uploaded_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="img_search")
35
+ if uploaded_img:
36
+ image = Image.open(uploaded_img)
37
+ st.image(image, caption="Uploaded image", use_column_width=True)
38
+ if st.button("Find Similar Foods", key="search_image_button"):
39
+ results = search_by_image(image, clip_processor, clip_model, image_embeddings, df, device=device)
40
+ cols = st.columns(5)
41
+ for col, item in zip(cols, results):
42
+ col.image(item["image"], caption=item["label"], use_column_width=True)
43
+
44
+ with tab3:
45
+ st.subheader("Describe an Image (Auto Caption)")
46
+ uploaded_caption_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="caption_img")
47
+ if uploaded_caption_img:
48
+ image = Image.open(uploaded_caption_img)
49
+ st.image(image, caption="Uploaded image", use_column_width=True)
50
+ if st.button("Generate Description", key="caption_button"):
51
+ caption = generate_caption(image, blip_processor, blip_model, device=device)
52
+ st.success(f"**Generated Caption:** {caption}")
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch
4
+ datasets
5
+ scikit-learn
6
+ torchvision
7
+ pyarrow
8
+ Pillow
utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pyarrow.parquet as pq
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from sklearn.neighbors import NearestNeighbors
7
+ from torchvision import transforms
8
+ from transformers import (
9
+ CLIPProcessor,
10
+ CLIPModel,
11
+ BlipProcessor,
12
+ BlipForConditionalGeneration
13
+ )
14
+
15
+ MEAN = [0.48145466, 0.4578275, 0.40821073]
16
+ STD = [0.26862954, 0.26130258, 0.27577711]
17
+
18
+ def load_models(device):
19
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
20
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
21
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
22
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
+ return clip_model, clip_processor, blip_model, blip_processor
24
+
25
+ def load_data(parquet_path="food101_embeddings.parquet"):
26
+ table = pq.read_table(parquet_path)
27
+ df = table.to_pandas()
28
+ embeddings = np.vstack(df["embedding"].to_numpy())
29
+ return df, embeddings
30
+
31
+ def bytes_to_pil(byte_data):
32
+ return Image.open(BytesIO(byte_data)).convert("RGB")
33
+
34
+ def preprocess_image(image):
35
+ transform = transforms.Compose([
36
+ transforms.Resize((224, 224)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=MEAN, std=STD),
39
+ ])
40
+ return transform(image).unsqueeze(0)
41
+
42
+ def search_by_text(text, processor, model, embeddings, df, top_k=5, device="cpu"):
43
+ inputs = processor(text=[text], return_tensors="pt").to(device)
44
+ with torch.no_grad():
45
+ text_feat = model.get_text_features(**inputs).cpu().numpy()
46
+ nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings)
47
+ return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(text_feat, return_distance=False)[0]]
48
+
49
+ def search_by_image(uploaded_image, processor, model, embeddings, df, top_k=5, device="cpu"):
50
+ image_tensor = preprocess_image(uploaded_image).to(device)
51
+ with torch.no_grad():
52
+ img_feat = model.get_image_features(image_tensor).cpu().numpy()
53
+ nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings)
54
+ return [{"label": df.iloc[i]["label_name"], "image": bytes_to_pil(df.iloc[i]["image_bytes"])} for i in nn.kneighbors(img_feat, return_distance=False)[0]]
55
+
56
+ def generate_caption(uploaded_image, processor, model, device="cpu"):
57
+ image = uploaded_image.convert("RGB")
58
+ inputs = processor(images=image, return_tensors="pt").to(device)
59
+ with torch.no_grad():
60
+ output = model.generate(**inputs)
61
+ return processor.decode(output[0], skip_special_tokens=True)