Spaces:
Running
Running
| import json | |
| from pathlib import Path | |
| from typing import List, Dict, Optional | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| # ---------- Global embedder (loaded once, CPU-safe) ---------- | |
| _EMBEDDER: Optional[SentenceTransformer] = None | |
| def _get_embedder() -> SentenceTransformer: | |
| global _EMBEDDER | |
| if _EMBEDDER is None: | |
| # Explicit device="cpu" avoids any device_map/meta init paths. | |
| # Use the canonical model id to avoid redirect surprises. | |
| _EMBEDDER = SentenceTransformer( | |
| "sentence-transformers/all-MiniLM-L6-v2", | |
| device="cpu" | |
| ) | |
| # Optional: shorten for speed on Spaces; keep accuracy reasonable | |
| _EMBEDDER.max_seq_length = 256 | |
| return _EMBEDDER | |
| def load_index(env: Dict): | |
| import faiss | |
| index_path = Path(env["INDEX_DIR"]) / "faiss.index" | |
| meta_path = Path(env["INDEX_DIR"]) / "meta.json" | |
| if not index_path.exists(): | |
| raise RuntimeError("Index not found. Run ingest first.") | |
| index = faiss.read_index(str(index_path)) | |
| with open(meta_path, "r") as f: | |
| metas = json.load(f) | |
| return index, metas | |
| def embed(texts: List[str]) -> np.ndarray: | |
| emb = _get_embedder() | |
| vecs = emb.encode( | |
| texts, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| batch_size=32, | |
| ) | |
| # FAISS expects float32 | |
| if vecs.dtype != np.float32: | |
| vecs = vecs.astype(np.float32, copy=False) | |
| return vecs | |
| def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]: | |
| import faiss | |
| index, metas = load_index(env) | |
| qv = embed([q]) # shape (1, d) float32 | |
| # Defensive: ensure index dim matches query dim | |
| if hasattr(index, "d") and index.d != qv.shape[1]: | |
| raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} " | |
| f"!= embedding dim {qv.shape[1]}") | |
| scores, idxs = index.search(qv, top_k) # scores shape (1, k), idxs shape (1, k) | |
| results = [] | |
| f_geo = (filters or {}).get("geo") | |
| f_cats = (filters or {}).get("categories") | |
| for score, idx in zip(scores[0], idxs[0]): | |
| if idx == -1: | |
| continue | |
| m = dict(metas[idx]) # copy so we don’t mutate the cached list | |
| if f_geo and m.get("geo") not in f_geo: | |
| continue | |
| if f_cats: | |
| if not set(f_cats).intersection(set(m.get("categories", []))): | |
| continue | |
| m["score"] = float(score) | |
| results.append(m) | |
| return results | |