import logging import os from pathlib import Path from typing import Any, Dict, List, Optional import chromadb from src.config import VECTOR_STORAGE_TYPE from src.utils.memory_utils import log_memory_checkpoint, memory_monitor from src.vector_db.postgres_adapter import PostgresVectorAdapter def create_vector_database(persist_path: Optional[str] = None, collection_name: Optional[str] = None): """ Factory function to create the appropriate vector database implementation. Args: persist_path: Path for persistence (used by ChromaDB) collection_name: Name of the collection Returns: Vector database implementation """ # Allow runtime override via environment variable to make tests and # deploy-time configuration consistent. Prefer explicit env var when set. storage_type = os.getenv("VECTOR_STORAGE_TYPE") or VECTOR_STORAGE_TYPE if storage_type == "postgres": return PostgresVectorAdapter(table_name=collection_name or "document_embeddings") else: # Default to ChromaDB from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH return VectorDatabase( persist_path=persist_path or VECTOR_DB_PERSIST_PATH, collection_name=collection_name or COLLECTION_NAME, ) class VectorDatabase: """ChromaDB integration for vector storage and similarity search""" def __init__( self, persist_path: str, collection_name: str, ): """ Initialize the vector database Args: persist_path: Path to persist the database collection_name: Name of the collection to use """ self.persist_path = persist_path self.collection_name = collection_name # Ensure persist directory exists Path(persist_path).mkdir(parents=True, exist_ok=True) # Get chroma settings from config for memory optimization from chromadb.config import Settings from src.config import CHROMA_SETTINGS # Convert CHROMA_SETTINGS dict to Settings object chroma_settings = Settings(**CHROMA_SETTINGS) # Initialize ChromaDB client with persistence and memory optimization log_memory_checkpoint("vector_db_before_client_init") self.client = chromadb.PersistentClient(path=persist_path, settings=chroma_settings) log_memory_checkpoint("vector_db_after_client_init") # Get or create collection try: self.collection = self.client.get_collection(name=collection_name) except ValueError: # Collection doesn't exist, create it self.collection = self.client.create_collection(name=collection_name) logging.info(f"Initialized VectorDatabase with collection " f"'{collection_name}' at '{persist_path}'") def get_collection(self): """Get the ChromaDB collection""" return self.collection @memory_monitor def add_embeddings_batch( self, batch_embeddings: List[List[List[float]]], batch_chunk_ids: List[List[str]], batch_documents: List[List[str]], batch_metadatas: List[List[Dict[str, Any]]], ) -> int: """ Add embeddings in batches to prevent memory issues with large datasets Args: batch_embeddings: List of embedding batches batch_chunk_ids: List of chunk ID batches batch_documents: List of document batches batch_metadatas: List of metadata batches Returns: Number of embeddings added """ total_added = 0 for i, (embeddings, chunk_ids, documents, metadatas) in enumerate( zip( batch_embeddings, batch_chunk_ids, batch_documents, batch_metadatas, ) ): log_memory_checkpoint(f"before_add_batch_{i}") # add_embeddings may return True on success (or raise on failure) added = self.add_embeddings( embeddings=embeddings, chunk_ids=chunk_ids, documents=documents, metadatas=metadatas, ) # If add_embeddings returns True, treat as all embeddings added if isinstance(added, bool) and added: added_count = len(embeddings) elif isinstance(added, int): added_count = int(added) else: added_count = 0 total_added += added_count logging.info(f"Added batch {i+1}/{len(batch_embeddings)}") # Force cleanup after each batch import gc gc.collect() log_memory_checkpoint(f"after_add_batch_{i}") return total_added @memory_monitor def add_embeddings( self, embeddings: List[List[float]], chunk_ids: List[str], documents: List[str], metadatas: List[Dict[str, Any]], ) -> int: """ Add embeddings to the collection Args: embeddings: List of embedding vectors chunk_ids: List of chunk IDs documents: List of document texts metadatas: List of metadata dictionaries Returns: Number of embeddings added """ # Validate input lengths n = len(embeddings) if not (len(chunk_ids) == n and len(documents) == n and len(metadatas) == n): raise ValueError(f"Number of embeddings {n} must match number of ids {len(chunk_ids)}") log_memory_checkpoint("before_add_embeddings") try: self.collection.add( embeddings=embeddings, documents=documents, metadatas=metadatas, ids=chunk_ids, ) log_memory_checkpoint("after_add_embeddings") logging.info(f"Added {n} embeddings to collection") # Return boolean True for API compatibility tests return True except Exception as e: logging.error(f"Failed to add embeddings: {e}") # Re-raise to allow callers/tests to handle failures explicitly raise @memory_monitor def search(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]: """ Search for similar embeddings Args: query_embedding: Query vector to search for top_k: Number of results to return Returns: List of search results with metadata """ try: # Handle empty collection if self.get_count() == 0: return [] # Perform similarity search log_memory_checkpoint("vector_db_before_query") results = self.collection.query( query_embeddings=[query_embedding], n_results=min(top_k, self.get_count()), ) log_memory_checkpoint("vector_db_after_query") # Format results formatted_results = [] if results["ids"] and len(results["ids"][0]) > 0: for i in range(len(results["ids"][0])): result = { "id": results["ids"][0][i], "document": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i], } formatted_results.append(result) logging.info(f"Search returned {len(formatted_results)} results") return formatted_results except Exception as e: logging.error(f"Search failed: {e}") return [] def get_count(self) -> int: """Get the number of embeddings in the collection""" try: return self.collection.count() except Exception as e: logging.error(f"Failed to get count: {e}") return 0 def delete_collection(self) -> bool: """Delete the collection""" try: self.client.delete_collection(name=self.collection_name) logging.info(f"Deleted collection '{self.collection_name}'") return True except Exception as e: logging.error(f"Failed to delete collection: {e}") return False def reset_collection(self) -> bool: """Reset the collection (delete and recreate)""" try: # Delete existing collection try: self.client.delete_collection(name=self.collection_name) except ValueError: # Collection doesn't exist, that's fine pass # Create new collection self.collection = self.client.create_collection(name=self.collection_name) logging.info(f"Reset collection '{self.collection_name}'") return True except Exception as e: logging.error(f"Failed to reset collection: {e}") return False def get_embedding_dimension(self) -> int: """ Get the embedding dimension from existing data in the collection. Returns 0 if collection is empty or has no embeddings. """ try: count = self.get_count() if count == 0: return 0 # Retrieve one record to check its embedding dimension record = self.collection.get( ids=None, # None returns all records, but we only need one include=["embeddings"], limit=1, ) if record and "embeddings" in record and record["embeddings"]: return len(record["embeddings"][0]) return 0 except Exception as e: logging.error(f"Failed to get embedding dimension: {e}") return 0 def has_valid_embeddings(self, expected_dimension: int) -> bool: """ Check if the collection has embeddings with the expected dimension. Args: expected_dimension: The expected embedding dimension Returns: True if collection has embeddings with correct dimension, False otherwise """ try: actual_dimension = self.get_embedding_dimension() return actual_dimension == expected_dimension and actual_dimension > 0 except Exception as e: logging.error(f"Failed to validate embeddings: {e}") return False