Spaces:
Sleeping
Sleeping
| """SearchService - Semantic document search functionality with optional caching. | |
| Provides semantic search capabilities using embeddings and a vector similarity | |
| database. Includes a small, bounded in-memory result cache to avoid repeated | |
| embedding + vector DB work for identical queries (post expansion) with the same | |
| parameters. | |
| """ | |
| import logging | |
| from copy import deepcopy | |
| from typing import Any, Dict, List, Optional | |
| from src.embedding.embedding_service import EmbeddingService | |
| from src.search.query_expander import QueryExpander | |
| from src.vector_store.vector_db import VectorDatabase | |
| logger = logging.getLogger(__name__) | |
| class SearchService: | |
| """Semantic search service for finding relevant documents using embeddings. | |
| Combines text embedding generation with vector similarity search to return | |
| semantically relevant chunks. A lightweight FIFO cache (default capacity 50) | |
| reduces duplicate work for popular queries. | |
| """ | |
| def __init__( | |
| self, | |
| vector_db: Optional[VectorDatabase], | |
| embedding_service: Optional[EmbeddingService], | |
| enable_query_expansion: bool = True, | |
| cache_capacity: int = 50, | |
| ) -> None: | |
| if vector_db is None: | |
| raise ValueError("vector_db cannot be None") | |
| if embedding_service is None: | |
| raise ValueError("embedding_service cannot be None") | |
| self.vector_db = vector_db | |
| self.embedding_service = embedding_service | |
| self.enable_query_expansion = enable_query_expansion | |
| # Query expansion | |
| if self.enable_query_expansion: | |
| self.query_expander = QueryExpander() | |
| logger.info("SearchService initialized with query expansion enabled") | |
| else: | |
| self.query_expander = None | |
| logger.info("SearchService initialized without query expansion") | |
| # Cache internals | |
| self._cache_capacity = max(1, cache_capacity) | |
| self._result_cache: Dict[str, List[Dict[str, Any]]] = {} | |
| self._result_cache_order: List[str] = [] | |
| self._cache_hits = 0 | |
| self._cache_misses = 0 | |
| # ---------------------- Public API ---------------------- | |
| def search(self, query: str, top_k: int = 5, threshold: float = 0.0) -> List[Dict[str, Any]]: | |
| """Perform semantic search. | |
| Args: | |
| query: Raw user query. | |
| top_k: Number of results to return (>0). | |
| threshold: Minimum similarity (0-1). | |
| Returns: | |
| List of formatted result dictionaries. | |
| """ | |
| if not query or not query.strip(): | |
| raise ValueError("Query cannot be empty") | |
| if top_k <= 0: | |
| raise ValueError("top_k must be positive") | |
| if not (0.0 <= threshold <= 1.0): | |
| raise ValueError("threshold must be between 0 and 1") | |
| processed_query = query.strip() | |
| if self.enable_query_expansion and self.query_expander: | |
| expanded_query = self.query_expander.expand_query(processed_query) | |
| logger.debug( | |
| "Query expanded from '%s' to '%s'", | |
| processed_query, | |
| expanded_query[:120], | |
| ) | |
| processed_query = expanded_query | |
| cache_key = self._make_cache_key(processed_query, top_k, threshold) | |
| if cache_key in self._result_cache: | |
| self._cache_hits += 1 | |
| cached = self._result_cache[cache_key] | |
| logger.debug( | |
| "Search cache HIT key=%s hits=%d misses=%d size=%d", | |
| cache_key, | |
| self._cache_hits, | |
| self._cache_misses, | |
| len(self._result_cache_order), | |
| ) | |
| return deepcopy(cached) # defensive copy | |
| # Cache miss: perform embedding + vector search | |
| try: | |
| query_embedding = self.embedding_service.embed_text(processed_query) | |
| raw_results = self.vector_db.search(query_embedding=query_embedding, top_k=top_k) | |
| formatted = self._format_search_results(raw_results, threshold) | |
| except Exception as e: # pragma: no cover - propagate after logging | |
| logger.error("Search failed for query '%s': %s", query, e) | |
| raise | |
| # Store in cache (FIFO eviction) | |
| self._cache_misses += 1 | |
| self._result_cache[cache_key] = deepcopy(formatted) | |
| self._result_cache_order.append(cache_key) | |
| if len(self._result_cache_order) > self._cache_capacity: | |
| oldest = self._result_cache_order.pop(0) | |
| self._result_cache.pop(oldest, None) | |
| logger.debug( | |
| "Search cache MISS key=%s hits=%d misses=%d size=%d", | |
| cache_key, | |
| self._cache_hits, | |
| self._cache_misses, | |
| len(self._result_cache_order), | |
| ) | |
| logger.info("Search completed: %d results returned", len(formatted)) | |
| return formatted | |
| def get_cache_stats(self) -> Dict[str, Any]: | |
| """Return cache statistics for monitoring and tests.""" | |
| return { | |
| "hits": self._cache_hits, | |
| "misses": self._cache_misses, | |
| "size": len(self._result_cache_order), | |
| "capacity": self._cache_capacity, | |
| } | |
| # ---------------------- Internal Helpers ---------------------- | |
| def _make_cache_key(self, processed_query: str, top_k: int, threshold: float) -> str: | |
| return f"{processed_query.lower()}|{top_k}|{threshold:.3f}" | |
| def _format_search_results(self, raw_results: List[Dict[str, Any]], threshold: float) -> List[Dict[str, Any]]: | |
| """Convert raw vector DB results into standardized output filtered by threshold.""" | |
| if not raw_results: | |
| return [] | |
| distances = [r.get("distance", float("inf")) for r in raw_results] | |
| min_distance = min(distances) if distances else 0.0 | |
| max_distance = max(distances) if distances else 1.0 | |
| formatted: List[Dict[str, Any]] = [] | |
| for r in raw_results: | |
| distance = r.get("distance", float("inf")) | |
| if max_distance > min_distance: | |
| normalized = (distance - min_distance) / (max_distance - min_distance) | |
| similarity = 1.0 - normalized | |
| else: | |
| similarity = 1.0 if distance == min_distance else 0.0 | |
| similarity = max(0.0, min(1.0, similarity)) | |
| if similarity >= threshold: | |
| formatted.append( | |
| { | |
| "chunk_id": r.get("id", ""), | |
| "content": r.get("document", ""), | |
| "similarity_score": similarity, | |
| "distance": distance, | |
| "metadata": r.get("metadata", {}), | |
| } | |
| ) | |
| logger.debug( | |
| "Formatted %d results above threshold %.2f " "(distance range %.2f - %.2f)", | |
| len(formatted), | |
| threshold, | |
| min_distance, | |
| max_distance, | |
| ) | |
| return formatted | |