"""SearchService - Semantic document search functionality with optional caching. Provides semantic search capabilities using embeddings and a vector similarity database. Includes a small, bounded in-memory result cache to avoid repeated embedding + vector DB work for identical queries (post expansion) with the same parameters. """ import logging from copy import deepcopy from typing import Any, Dict, List, Optional from src.embedding.embedding_service import EmbeddingService from src.search.query_expander import QueryExpander from src.vector_store.vector_db import VectorDatabase logger = logging.getLogger(__name__) class SearchService: """Semantic search service for finding relevant documents using embeddings. Combines text embedding generation with vector similarity search to return semantically relevant chunks. A lightweight FIFO cache (default capacity 50) reduces duplicate work for popular queries. """ def __init__( self, vector_db: Optional[VectorDatabase], embedding_service: Optional[EmbeddingService], enable_query_expansion: bool = True, cache_capacity: int = 50, ) -> None: if vector_db is None: raise ValueError("vector_db cannot be None") if embedding_service is None: raise ValueError("embedding_service cannot be None") self.vector_db = vector_db self.embedding_service = embedding_service self.enable_query_expansion = enable_query_expansion # Query expansion if self.enable_query_expansion: self.query_expander = QueryExpander() logger.info("SearchService initialized with query expansion enabled") else: self.query_expander = None logger.info("SearchService initialized without query expansion") # Cache internals self._cache_capacity = max(1, cache_capacity) self._result_cache: Dict[str, List[Dict[str, Any]]] = {} self._result_cache_order: List[str] = [] self._cache_hits = 0 self._cache_misses = 0 # ---------------------- Public API ---------------------- def search(self, query: str, top_k: int = 5, threshold: float = 0.0) -> List[Dict[str, Any]]: """Perform semantic search. Args: query: Raw user query. top_k: Number of results to return (>0). threshold: Minimum similarity (0-1). Returns: List of formatted result dictionaries. """ if not query or not query.strip(): raise ValueError("Query cannot be empty") if top_k <= 0: raise ValueError("top_k must be positive") if not (0.0 <= threshold <= 1.0): raise ValueError("threshold must be between 0 and 1") processed_query = query.strip() if self.enable_query_expansion and self.query_expander: expanded_query = self.query_expander.expand_query(processed_query) logger.debug( "Query expanded from '%s' to '%s'", processed_query, expanded_query[:120], ) processed_query = expanded_query cache_key = self._make_cache_key(processed_query, top_k, threshold) if cache_key in self._result_cache: self._cache_hits += 1 cached = self._result_cache[cache_key] logger.debug( "Search cache HIT key=%s hits=%d misses=%d size=%d", cache_key, self._cache_hits, self._cache_misses, len(self._result_cache_order), ) return deepcopy(cached) # defensive copy # Cache miss: perform embedding + vector search try: query_embedding = self.embedding_service.embed_text(processed_query) raw_results = self.vector_db.search(query_embedding=query_embedding, top_k=top_k) formatted = self._format_search_results(raw_results, threshold) except Exception as e: # pragma: no cover - propagate after logging logger.error("Search failed for query '%s': %s", query, e) raise # Store in cache (FIFO eviction) self._cache_misses += 1 self._result_cache[cache_key] = deepcopy(formatted) self._result_cache_order.append(cache_key) if len(self._result_cache_order) > self._cache_capacity: oldest = self._result_cache_order.pop(0) self._result_cache.pop(oldest, None) logger.debug( "Search cache MISS key=%s hits=%d misses=%d size=%d", cache_key, self._cache_hits, self._cache_misses, len(self._result_cache_order), ) logger.info("Search completed: %d results returned", len(formatted)) return formatted def get_cache_stats(self) -> Dict[str, Any]: """Return cache statistics for monitoring and tests.""" return { "hits": self._cache_hits, "misses": self._cache_misses, "size": len(self._result_cache_order), "capacity": self._cache_capacity, } # ---------------------- Internal Helpers ---------------------- def _make_cache_key(self, processed_query: str, top_k: int, threshold: float) -> str: return f"{processed_query.lower()}|{top_k}|{threshold:.3f}" def _format_search_results(self, raw_results: List[Dict[str, Any]], threshold: float) -> List[Dict[str, Any]]: """Convert raw vector DB results into standardized output filtered by threshold.""" if not raw_results: return [] distances = [r.get("distance", float("inf")) for r in raw_results] min_distance = min(distances) if distances else 0.0 max_distance = max(distances) if distances else 1.0 formatted: List[Dict[str, Any]] = [] for r in raw_results: distance = r.get("distance", float("inf")) if max_distance > min_distance: normalized = (distance - min_distance) / (max_distance - min_distance) similarity = 1.0 - normalized else: similarity = 1.0 if distance == min_distance else 0.0 similarity = max(0.0, min(1.0, similarity)) if similarity >= threshold: formatted.append( { "chunk_id": r.get("id", ""), "content": r.get("document", ""), "similarity_score": similarity, "distance": distance, "metadata": r.get("metadata", {}), } ) logger.debug( "Formatted %d results above threshold %.2f " "(distance range %.2f - %.2f)", len(formatted), threshold, min_distance, max_distance, ) return formatted