import os os.environ["STREAMLIT_HOME"] = "/tmp" # 修复 PermissionError: '/.streamlit' os.environ["XDG_STATE_HOME"] = "/tmp" import streamlit as st import torch from PIL import Image from utils import ( load_models, load_data, search_by_text, search_by_image, generate_caption, ) st.set_page_config(page_title="🍱 Food Search App", layout="wide") st.title("🍽️ Food Image & Text Search App (Fast Boot)") # --------- 简单的参数面板 ---------- with st.sidebar: st.header("⚙️ Settings") parquet_path = st.text_input( "Parquet path", value="food101_embeddings_20000.parquet", # 你可以改成你自己的文件名 help="放在根目录或 data/ 目录都可以,程序会自动 fallback" ) max_rows = st.number_input("Max rows to load", min_value=100, max_value=20000, value=5000, step=1000) top_k = st.number_input("Top-K results", min_value=1, max_value=20, value=5) st.caption("建议先用 5k~10k 测试,确定没问题再加大。") device = "cuda" if torch.cuda.is_available() else "cpu" with st.spinner("🔄 Loading models (cached) ..."): clip_model, clip_processor, blip_model, blip_processor = load_models(device) with st.spinner(f"📦 Loading parquet ({max_rows} rows, cached)..."): df, image_embeddings = load_data(parquet_path=parquet_path, max_rows=max_rows) tab1, tab2, tab3 = st.tabs(["🔤 Text Search", "🖼️ Image Search", "📝 Describe Image"]) with tab1: st.subheader("Search by Text") query = st.text_input("Type a food description (e.g. 'spicy noodles'):") if st.button("Search", key="text_search") and query.strip(): with st.spinner("Searching..."): results = search_by_text( query, clip_processor, clip_model, image_embeddings, df, top_k=top_k, device=device ) cols = st.columns(top_k) for col, item in zip(cols, results): col.image(item["image"], caption=item["label"], use_column_width=True) with tab2: st.subheader("Search by Image") uploaded_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="img_search") if uploaded_img: image = Image.open(uploaded_img) st.image(image, caption="Uploaded image", use_column_width=True) if st.button("Find Similar Foods", key="search_image_button"): with st.spinner("Searching..."): results = search_by_image( image, clip_processor, clip_model, image_embeddings, df, top_k=top_k, device=device ) cols = st.columns(top_k) for col, item in zip(cols, results): col.image(item["image"], caption=item["label"], use_column_width=True) with tab3: st.subheader("Describe an Image (Auto Caption)") uploaded_caption_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="caption_img") if uploaded_caption_img: image = Image.open(uploaded_caption_img) st.image(image, caption="Uploaded image", use_column_width=True) if st.button("Generate Description", key="caption_button"): with st.spinner("Generating..."): caption = generate_caption(image, blip_processor, blip_model, device=device) st.success(f"**Generated Caption:** {caption}")