Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["STREAMLIT_HOME"] = "/tmp" # 修复 PermissionError: '/.streamlit' | |
| os.environ["XDG_STATE_HOME"] = "/tmp" | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from utils import ( | |
| load_models, | |
| load_data, | |
| search_by_text, | |
| search_by_image, | |
| generate_caption, | |
| ) | |
| st.set_page_config(page_title="🍱 Food Search App", layout="wide") | |
| st.title("🍽️ Food Image & Text Search App (Fast Boot)") | |
| # --------- 简单的参数面板 ---------- | |
| with st.sidebar: | |
| st.header("⚙️ Settings") | |
| parquet_path = st.text_input( | |
| "Parquet path", | |
| value="food101_embeddings_20000.parquet", # 你可以改成你自己的文件名 | |
| help="放在根目录或 data/ 目录都可以,程序会自动 fallback" | |
| ) | |
| max_rows = st.number_input("Max rows to load", min_value=100, max_value=20000, value=5000, step=1000) | |
| top_k = st.number_input("Top-K results", min_value=1, max_value=20, value=5) | |
| st.caption("建议先用 5k~10k 测试,确定没问题再加大。") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| with st.spinner("🔄 Loading models (cached) ..."): | |
| clip_model, clip_processor, blip_model, blip_processor = load_models(device) | |
| with st.spinner(f"📦 Loading parquet ({max_rows} rows, cached)..."): | |
| df, image_embeddings = load_data(parquet_path=parquet_path, max_rows=max_rows) | |
| tab1, tab2, tab3 = st.tabs(["🔤 Text Search", "🖼️ Image Search", "📝 Describe Image"]) | |
| with tab1: | |
| st.subheader("Search by Text") | |
| query = st.text_input("Type a food description (e.g. 'spicy noodles'):") | |
| if st.button("Search", key="text_search") and query.strip(): | |
| with st.spinner("Searching..."): | |
| results = search_by_text( | |
| query, clip_processor, clip_model, image_embeddings, df, | |
| top_k=top_k, device=device | |
| ) | |
| cols = st.columns(top_k) | |
| for col, item in zip(cols, results): | |
| col.image(item["image"], caption=item["label"], use_column_width=True) | |
| with tab2: | |
| st.subheader("Search by Image") | |
| uploaded_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="img_search") | |
| if uploaded_img: | |
| image = Image.open(uploaded_img) | |
| st.image(image, caption="Uploaded image", use_column_width=True) | |
| if st.button("Find Similar Foods", key="search_image_button"): | |
| with st.spinner("Searching..."): | |
| results = search_by_image( | |
| image, clip_processor, clip_model, image_embeddings, df, | |
| top_k=top_k, device=device | |
| ) | |
| cols = st.columns(top_k) | |
| for col, item in zip(cols, results): | |
| col.image(item["image"], caption=item["label"], use_column_width=True) | |
| with tab3: | |
| st.subheader("Describe an Image (Auto Caption)") | |
| uploaded_caption_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="caption_img") | |
| if uploaded_caption_img: | |
| image = Image.open(uploaded_caption_img) | |
| st.image(image, caption="Uploaded image", use_column_width=True) | |
| if st.button("Generate Description", key="caption_button"): | |
| with st.spinner("Generating..."): | |
| caption = generate_caption(image, blip_processor, blip_model, device=device) | |
| st.success(f"**Generated Caption:** {caption}") | |