msse-ai-engineering / src /search /search_service.py
sethmcknight
Refactor test cases for improved readability and consistency
159faf0
raw
history blame
6.99 kB
"""SearchService - Semantic document search functionality with optional caching.
Provides semantic search capabilities using embeddings and a vector similarity
database. Includes a small, bounded in-memory result cache to avoid repeated
embedding + vector DB work for identical queries (post expansion) with the same
parameters.
"""
import logging
from copy import deepcopy
from typing import Any, Dict, List, Optional
from src.embedding.embedding_service import EmbeddingService
from src.search.query_expander import QueryExpander
from src.vector_store.vector_db import VectorDatabase
logger = logging.getLogger(__name__)
class SearchService:
"""Semantic search service for finding relevant documents using embeddings.
Combines text embedding generation with vector similarity search to return
semantically relevant chunks. A lightweight FIFO cache (default capacity 50)
reduces duplicate work for popular queries.
"""
def __init__(
self,
vector_db: Optional[VectorDatabase],
embedding_service: Optional[EmbeddingService],
enable_query_expansion: bool = True,
cache_capacity: int = 50,
) -> None:
if vector_db is None:
raise ValueError("vector_db cannot be None")
if embedding_service is None:
raise ValueError("embedding_service cannot be None")
self.vector_db = vector_db
self.embedding_service = embedding_service
self.enable_query_expansion = enable_query_expansion
# Query expansion
if self.enable_query_expansion:
self.query_expander = QueryExpander()
logger.info("SearchService initialized with query expansion enabled")
else:
self.query_expander = None
logger.info("SearchService initialized without query expansion")
# Cache internals
self._cache_capacity = max(1, cache_capacity)
self._result_cache: Dict[str, List[Dict[str, Any]]] = {}
self._result_cache_order: List[str] = []
self._cache_hits = 0
self._cache_misses = 0
# ---------------------- Public API ----------------------
def search(self, query: str, top_k: int = 5, threshold: float = 0.0) -> List[Dict[str, Any]]:
"""Perform semantic search.
Args:
query: Raw user query.
top_k: Number of results to return (>0).
threshold: Minimum similarity (0-1).
Returns:
List of formatted result dictionaries.
"""
if not query or not query.strip():
raise ValueError("Query cannot be empty")
if top_k <= 0:
raise ValueError("top_k must be positive")
if not (0.0 <= threshold <= 1.0):
raise ValueError("threshold must be between 0 and 1")
processed_query = query.strip()
if self.enable_query_expansion and self.query_expander:
expanded_query = self.query_expander.expand_query(processed_query)
logger.debug(
"Query expanded from '%s' to '%s'",
processed_query,
expanded_query[:120],
)
processed_query = expanded_query
cache_key = self._make_cache_key(processed_query, top_k, threshold)
if cache_key in self._result_cache:
self._cache_hits += 1
cached = self._result_cache[cache_key]
logger.debug(
"Search cache HIT key=%s hits=%d misses=%d size=%d",
cache_key,
self._cache_hits,
self._cache_misses,
len(self._result_cache_order),
)
return deepcopy(cached) # defensive copy
# Cache miss: perform embedding + vector search
try:
query_embedding = self.embedding_service.embed_text(processed_query)
raw_results = self.vector_db.search(query_embedding=query_embedding, top_k=top_k)
formatted = self._format_search_results(raw_results, threshold)
except Exception as e: # pragma: no cover - propagate after logging
logger.error("Search failed for query '%s': %s", query, e)
raise
# Store in cache (FIFO eviction)
self._cache_misses += 1
self._result_cache[cache_key] = deepcopy(formatted)
self._result_cache_order.append(cache_key)
if len(self._result_cache_order) > self._cache_capacity:
oldest = self._result_cache_order.pop(0)
self._result_cache.pop(oldest, None)
logger.debug(
"Search cache MISS key=%s hits=%d misses=%d size=%d",
cache_key,
self._cache_hits,
self._cache_misses,
len(self._result_cache_order),
)
logger.info("Search completed: %d results returned", len(formatted))
return formatted
def get_cache_stats(self) -> Dict[str, Any]:
"""Return cache statistics for monitoring and tests."""
return {
"hits": self._cache_hits,
"misses": self._cache_misses,
"size": len(self._result_cache_order),
"capacity": self._cache_capacity,
}
# ---------------------- Internal Helpers ----------------------
def _make_cache_key(self, processed_query: str, top_k: int, threshold: float) -> str:
return f"{processed_query.lower()}|{top_k}|{threshold:.3f}"
def _format_search_results(self, raw_results: List[Dict[str, Any]], threshold: float) -> List[Dict[str, Any]]:
"""Convert raw vector DB results into standardized output filtered by threshold."""
if not raw_results:
return []
distances = [r.get("distance", float("inf")) for r in raw_results]
min_distance = min(distances) if distances else 0.0
max_distance = max(distances) if distances else 1.0
formatted: List[Dict[str, Any]] = []
for r in raw_results:
distance = r.get("distance", float("inf"))
if max_distance > min_distance:
normalized = (distance - min_distance) / (max_distance - min_distance)
similarity = 1.0 - normalized
else:
similarity = 1.0 if distance == min_distance else 0.0
similarity = max(0.0, min(1.0, similarity))
if similarity >= threshold:
formatted.append(
{
"chunk_id": r.get("id", ""),
"content": r.get("document", ""),
"similarity_score": similarity,
"distance": distance,
"metadata": r.get("metadata", {}),
}
)
logger.debug(
"Formatted %d results above threshold %.2f " "(distance range %.2f - %.2f)",
len(formatted),
threshold,
min_distance,
max_distance,
)
return formatted