Spaces:
Runtime error
Runtime error
File size: 3,394 Bytes
8a8bc41 67dab4d 8a8bc41 67dab4d 8a8bc41 67dab4d 8a8bc41 67dab4d 8a8bc41 67dab4d 8a8bc41 67dab4d 8a8bc41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
os.environ["STREAMLIT_HOME"] = "/tmp" # 修复 PermissionError: '/.streamlit'
os.environ["XDG_STATE_HOME"] = "/tmp"
import streamlit as st
import torch
from PIL import Image
from utils import (
load_models,
load_data,
search_by_text,
search_by_image,
generate_caption,
)
st.set_page_config(page_title="🍱 Food Search App", layout="wide")
st.title("🍽️ Food Image & Text Search App (Fast Boot)")
# --------- 简单的参数面板 ----------
with st.sidebar:
st.header("⚙️ Settings")
parquet_path = st.text_input(
"Parquet path",
value="food101_embeddings_20000.parquet", # 你可以改成你自己的文件名
help="放在根目录或 data/ 目录都可以,程序会自动 fallback"
)
max_rows = st.number_input("Max rows to load", min_value=100, max_value=20000, value=5000, step=1000)
top_k = st.number_input("Top-K results", min_value=1, max_value=20, value=5)
st.caption("建议先用 5k~10k 测试,确定没问题再加大。")
device = "cuda" if torch.cuda.is_available() else "cpu"
with st.spinner("🔄 Loading models (cached) ..."):
clip_model, clip_processor, blip_model, blip_processor = load_models(device)
with st.spinner(f"📦 Loading parquet ({max_rows} rows, cached)..."):
df, image_embeddings = load_data(parquet_path=parquet_path, max_rows=max_rows)
tab1, tab2, tab3 = st.tabs(["🔤 Text Search", "🖼️ Image Search", "📝 Describe Image"])
with tab1:
st.subheader("Search by Text")
query = st.text_input("Type a food description (e.g. 'spicy noodles'):")
if st.button("Search", key="text_search") and query.strip():
with st.spinner("Searching..."):
results = search_by_text(
query, clip_processor, clip_model, image_embeddings, df,
top_k=top_k, device=device
)
cols = st.columns(top_k)
for col, item in zip(cols, results):
col.image(item["image"], caption=item["label"], use_column_width=True)
with tab2:
st.subheader("Search by Image")
uploaded_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="img_search")
if uploaded_img:
image = Image.open(uploaded_img)
st.image(image, caption="Uploaded image", use_column_width=True)
if st.button("Find Similar Foods", key="search_image_button"):
with st.spinner("Searching..."):
results = search_by_image(
image, clip_processor, clip_model, image_embeddings, df,
top_k=top_k, device=device
)
cols = st.columns(top_k)
for col, item in zip(cols, results):
col.image(item["image"], caption=item["label"], use_column_width=True)
with tab3:
st.subheader("Describe an Image (Auto Caption)")
uploaded_caption_img = st.file_uploader("Upload a food image", type=["jpg", "jpeg", "png"], key="caption_img")
if uploaded_caption_img:
image = Image.open(uploaded_caption_img)
st.image(image, caption="Uploaded image", use_column_width=True)
if st.button("Generate Description", key="caption_button"):
with st.spinner("Generating..."):
caption = generate_caption(image, blip_processor, blip_model, device=device)
st.success(f"**Generated Caption:** {caption}")
|