import os import glob import io import torch import numpy as np import pandas as pd from PIL import Image from sklearn.neighbors import NearestNeighbors from torchvision import transforms from transformers import ( CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration, ) import streamlit as st import zipfile import os import pandas as pd import numpy as np MEAN = [0.48145466, 0.4578275, 0.40821073] STD = [0.26862954, 0.26130258, 0.27577711] # ---------------- Models ---------------- @st.cache_resource(show_spinner=False) def load_models(device): clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") blip_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base" ).to(device) blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") return clip_model, clip_processor, blip_model, blip_processor # ---------------- Data ---------------- def _auto_find_parquet(path_hint: str): # 优先使用你传的路径 if os.path.exists(path_hint): return path_hint # 尝试 data/ 目录 candidate = os.path.join("data", os.path.basename(path_hint)) if os.path.exists(candidate): return candidate # 再尝试搜索当前目录下所有 parquet,选一个最新的 files = sorted(glob.glob("**/*.parquet", recursive=True), key=os.path.getmtime, reverse=True) if files: return files[0] raise FileNotFoundError(f"Cannot find parquet file. Tried '{path_hint}' and data/ & glob(**/*.parquet)") @st.cache_data(show_spinner=False) def load_data(parquet_path="food101_embeddings_10000.parquet", image_zip="food_images_10000.zip"): # ✅ 解压 zip(只执行一次) image_folder = "food_images_10000" if not os.path.exists(image_folder): with zipfile.ZipFile(image_zip, 'r') as zip_ref: zip_ref.extractall(image_folder) print("✅ 解压完成:", image_folder) # ✅ 加载 parquet df = pd.read_parquet(parquet_path) # ✅ 构造 image_path if "image_path" not in df.columns: df["image_path"] = df["idx"].apply(lambda i: os.path.join(image_folder, f"{i:05d}.jpg")) embeddings = np.vstack(df["embedding"].to_numpy()) return df, embeddings # ---------------- Helpers ---------------- def _bytes_to_pil(b: bytes): return Image.open(io.BytesIO(b)).convert("RGB") def _preprocess_image(image: Image.Image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD), ]) return transform(image.convert("RGB")).unsqueeze(0) def _row_to_image(row): if "image_bytes" in row and isinstance(row["image_bytes"], (bytes, bytearray)): return _bytes_to_pil(row["image_bytes"]) elif "image_path" in row and os.path.exists(row["image_path"]): return Image.open(row["image_path"]) else: # 找不到就用空白占位 return Image.new("RGB", (224, 224), color=(240, 240, 240)) # ---------------- Search ---------------- def search_by_text(text, processor, model, embeddings, df, top_k=5, device="cpu"): inputs = processor(text=[text], return_tensors="pt").to(device) with torch.no_grad(): text_feat = model.get_text_features(**inputs).cpu().numpy() nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings) _, indices = nn.kneighbors(text_feat) out = [] for i in indices[0]: row = df.iloc[i] img = _row_to_image(row) out.append({"label": row.get("label_name", str(row.get("label", ""))), "image": img}) return out def search_by_image(uploaded_image, processor, model, embeddings, df, top_k=5, device="cpu"): image_tensor = _preprocess_image(uploaded_image).to(device) with torch.no_grad(): img_feat = model.get_image_features(image_tensor).cpu().numpy() nn = NearestNeighbors(n_neighbors=top_k, metric="cosine").fit(embeddings) _, indices = nn.kneighbors(img_feat) out = [] for i in indices[0]: row = df.iloc[i] img = _row_to_image(row) out.append({"label": row.get("label_name", str(row.get("label", ""))), "image": img}) return out def generate_caption(uploaded_image, processor, model, device="cpu"): image = uploaded_image.convert("RGB") inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): output = model.generate(**inputs) return processor.decode(output[0], skip_special_tokens=True)