File size: 3,394 Bytes
8a8bc41
 
 
 
67dab4d
 
 
 
 
 
 
 
 
 
 
 
8a8bc41
 
 
 
 
 
 
 
 
 
 
 
 
67dab4d
 
 
8a8bc41
67dab4d
8a8bc41
 
 
67dab4d
 
 
 
 
 
 
8a8bc41
 
 
 
 
 
67dab4d
 
 
 
 
 
 
 
 
 
8a8bc41
 
 
 
 
 
67dab4d
 
 
 
 
 
 
 
 
 
8a8bc41
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
os.environ["STREAMLIT_HOME"] = "/tmp"   # 修复 PermissionError: '/.streamlit'
os.environ["XDG_STATE_HOME"] = "/tmp"

import streamlit as st
import torch
from PIL import Image
from utils import (
    load_models,
    load_data,
    search_by_text,
    search_by_image,
    generate_caption,
)

st.set_page_config(page_title="🍱 Food Search App", layout="wide")
st.title("🍽️ Food Image & Text Search App (Fast Boot)")

# --------- 简单的参数面板 ----------
with st.sidebar:
    st.header("⚙️ Settings")
    parquet_path = st.text_input(
        "Parquet path",
        value="food101_embeddings_20000.parquet",  # 你可以改成你自己的文件名
        help="放在根目录或 data/ 目录都可以,程序会自动 fallback"
    )
    max_rows = st.number_input("Max rows to load", min_value=100, max_value=20000, value=5000, step=1000)
    top_k = st.number_input("Top-K results", min_value=1, max_value=20, value=5)
    st.caption("建议先用 5k~10k 测试,确定没问题再加大。")

device = "cuda" if torch.cuda.is_available() else "cpu"

with st.spinner("🔄 Loading models (cached) ..."):
    clip_model, clip_processor, blip_model, blip_processor = load_models(device)

with st.spinner(f"📦 Loading parquet ({max_rows} rows, cached)..."):
    df, image_embeddings = load_data(parquet_path=parquet_path, max_rows=max_rows)

tab1, tab2, tab3 = st.tabs(["🔤 Text Search", "🖼️ Image Search", "📝 Describe Image"])

with tab1:
    st.subheader("Search by Text")
    query = st.text_input("Type a food description (e.g. 'spicy noodles'):")
    if st.button("Search", key="text_search") and query.strip():
        with st.spinner("Searching..."):
            results = search_by_text(
                query, clip_processor, clip_model, image_embeddings, df,
                top_k=top_k, device=device
            )
        cols = st.columns(top_k)
        for col, item in zip(cols, results):
            col.image(item["image"], caption=item["label"], use_column_width=True)

with tab2:
    st.subheader("Search by Image")
    uploaded_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="img_search")
    if uploaded_img:
        image = Image.open(uploaded_img)
        st.image(image, caption="Uploaded image", use_column_width=True)
        if st.button("Find Similar Foods", key="search_image_button"):
            with st.spinner("Searching..."):
                results = search_by_image(
                    image, clip_processor, clip_model, image_embeddings, df,
                    top_k=top_k, device=device
                )
            cols = st.columns(top_k)
            for col, item in zip(cols, results):
                col.image(item["image"], caption=item["label"], use_column_width=True)

with tab3:
    st.subheader("Describe an Image (Auto Caption)")
    uploaded_caption_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="caption_img")
    if uploaded_caption_img:
        image = Image.open(uploaded_caption_img)
        st.image(image, caption="Uploaded image", use_column_width=True)
        if st.button("Generate Description", key="caption_button"):
            with st.spinner("Generating..."):
                caption = generate_caption(image, blip_processor, blip_model, device=device)
            st.success(f"**Generated Caption:** {caption}")