Ma commited on
Commit
8a8bc41
·
verified ·
1 Parent(s): b99f59f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
@@ -10,13 +14,27 @@ from utils import (
10
  )
11
 
12
  st.set_page_config(page_title="🍱 Food Search App", layout="wide")
13
- st.title("🍽️ Food Image & Text Search App")
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- with st.spinner("🔄 Loading models and data..."):
18
  clip_model, clip_processor, blip_model, blip_processor = load_models(device)
19
- df, image_embeddings = load_data()
 
 
20
 
21
  tab1, tab2, tab3 = st.tabs(["🔤 Text Search", "🖼️ Image Search", "📝 Describe Image"])
22
 
@@ -24,8 +42,12 @@ with tab1:
24
  st.subheader("Search by Text")
25
  query = st.text_input("Type a food description (e.g. 'spicy noodles'):")
26
  if st.button("Search", key="text_search") and query.strip():
27
- results = search_by_text(query, clip_processor, clip_model, image_embeddings, df, device=device)
28
- cols = st.columns(5)
 
 
 
 
29
  for col, item in zip(cols, results):
30
  col.image(item["image"], caption=item["label"], use_column_width=True)
31
 
@@ -36,8 +58,12 @@ with tab2:
36
  image = Image.open(uploaded_img)
37
  st.image(image, caption="Uploaded image", use_column_width=True)
38
  if st.button("Find Similar Foods", key="search_image_button"):
39
- results = search_by_image(image, clip_processor, clip_model, image_embeddings, df, device=device)
40
- cols = st.columns(5)
 
 
 
 
41
  for col, item in zip(cols, results):
42
  col.image(item["image"], caption=item["label"], use_column_width=True)
43
 
@@ -48,5 +74,6 @@ with tab3:
48
  image = Image.open(uploaded_caption_img)
49
  st.image(image, caption="Uploaded image", use_column_width=True)
50
  if st.button("Generate Description", key="caption_button"):
51
- caption = generate_caption(image, blip_processor, blip_model, device=device)
52
- st.success(f"**Generated Caption:** {caption}")
 
 
1
+ import os
2
+ os.environ["STREAMLIT_HOME"] = "/tmp" # 修复 PermissionError: '/.streamlit'
3
+ os.environ["XDG_STATE_HOME"] = "/tmp"
4
+
5
  import streamlit as st
6
  import torch
7
  from PIL import Image
 
14
  )
15
 
16
  st.set_page_config(page_title="🍱 Food Search App", layout="wide")
17
+ st.title("🍽️ Food Image & Text Search App (Fast Boot)")
18
+
19
+ # --------- 简单的参数面板 ----------
20
+ with st.sidebar:
21
+ st.header("⚙️ Settings")
22
+ parquet_path = st.text_input(
23
+ "Parquet path",
24
+ value="food101_embeddings_20000.parquet", # 你可以改成你自己的文件名
25
+ help="放在根目录或 data/ 目录都可以,程序会自动 fallback"
26
+ )
27
+ max_rows = st.number_input("Max rows to load", min_value=100, max_value=20000, value=5000, step=1000)
28
+ top_k = st.number_input("Top-K results", min_value=1, max_value=20, value=5)
29
+ st.caption("建议先用 5k~10k 测试,确定没问题再加大。")
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
+ with st.spinner("🔄 Loading models (cached) ..."):
34
  clip_model, clip_processor, blip_model, blip_processor = load_models(device)
35
+
36
+ with st.spinner(f"📦 Loading parquet ({max_rows} rows, cached)..."):
37
+ df, image_embeddings = load_data(parquet_path=parquet_path, max_rows=max_rows)
38
 
39
  tab1, tab2, tab3 = st.tabs(["🔤 Text Search", "🖼️ Image Search", "📝 Describe Image"])
40
 
 
42
  st.subheader("Search by Text")
43
  query = st.text_input("Type a food description (e.g. 'spicy noodles'):")
44
  if st.button("Search", key="text_search") and query.strip():
45
+ with st.spinner("Searching..."):
46
+ results = search_by_text(
47
+ query, clip_processor, clip_model, image_embeddings, df,
48
+ top_k=top_k, device=device
49
+ )
50
+ cols = st.columns(top_k)
51
  for col, item in zip(cols, results):
52
  col.image(item["image"], caption=item["label"], use_column_width=True)
53
 
 
58
  image = Image.open(uploaded_img)
59
  st.image(image, caption="Uploaded image", use_column_width=True)
60
  if st.button("Find Similar Foods", key="search_image_button"):
61
+ with st.spinner("Searching..."):
62
+ results = search_by_image(
63
+ image, clip_processor, clip_model, image_embeddings, df,
64
+ top_k=top_k, device=device
65
+ )
66
+ cols = st.columns(top_k)
67
  for col, item in zip(cols, results):
68
  col.image(item["image"], caption=item["label"], use_column_width=True)
69
 
 
74
  image = Image.open(uploaded_caption_img)
75
  st.image(image, caption="Uploaded image", use_column_width=True)
76
  if st.button("Generate Description", key="caption_button"):
77
+ with st.spinner("Generating..."):
78
+ caption = generate_caption(image, blip_processor, blip_model, device=device)
79
+ st.success(f"**Generated Caption:** {caption}")