import logging from pathlib import Path from typing import Any, Dict, List import chromadb class VectorDatabase: """ChromaDB integration for vector storage and similarity search""" def __init__(self, persist_path: str, collection_name: str): """ Initialize the vector database Args: persist_path: Path to persist the database collection_name: Name of the collection to use """ self.persist_path = persist_path self.collection_name = collection_name # Ensure persist directory exists Path(persist_path).mkdir(parents=True, exist_ok=True) # Initialize ChromaDB client with persistence self.client = chromadb.PersistentClient(path=persist_path) # Get or create collection try: self.collection = self.client.get_collection(name=collection_name) except ValueError: # Collection doesn't exist, create it self.collection = self.client.create_collection(name=collection_name) logging.info( f"Initialized VectorDatabase with collection " f"'{collection_name}' at '{persist_path}'" ) def get_collection(self): """Get the ChromaDB collection""" return self.collection def add_embeddings( self, embeddings: List[List[float]], chunk_ids: List[str], documents: List[str], metadatas: List[Dict[str, Any]], ) -> bool: """ Add embeddings to the vector database Args: embeddings: List of embedding vectors chunk_ids: List of unique chunk IDs documents: List of document contents metadatas: List of metadata dictionaries Returns: True if successful, False otherwise """ try: # Validate input lengths match if not ( len(embeddings) == len(chunk_ids) == len(documents) == len(metadatas) ): raise ValueError("All input lists must have the same length") # Check for existing documents to prevent duplicates try: existing = self.collection.get(ids=chunk_ids, include=[]) existing_ids = set(existing.get("ids", [])) except Exception: existing_ids = set() # Only add documents that don't already exist new_embeddings = [] new_chunk_ids = [] new_documents = [] new_metadatas = [] for i, chunk_id in enumerate(chunk_ids): if chunk_id not in existing_ids: new_embeddings.append(embeddings[i]) new_chunk_ids.append(chunk_id) new_documents.append(documents[i]) new_metadatas.append(metadatas[i]) if not new_embeddings: logging.info( f"All {len(chunk_ids)} documents already exist in collection" ) return True # Add to ChromaDB collection self.collection.add( embeddings=new_embeddings, documents=new_documents, metadatas=new_metadatas, ids=new_chunk_ids, ) logging.info( f"Added {len(new_embeddings)} new embeddings to collection " f"'{self.collection_name}' " f"(skipped {len(chunk_ids) - len(new_embeddings)} duplicates)" ) return True except Exception as e: logging.error(f"Failed to add embeddings: {e}") raise e def search( self, query_embedding: List[float], top_k: int = 5 ) -> List[Dict[str, Any]]: """ Search for similar embeddings Args: query_embedding: Query vector to search for top_k: Number of results to return Returns: List of search results with metadata """ try: # Handle empty collection if self.get_count() == 0: return [] # Perform similarity search results = self.collection.query( query_embeddings=[query_embedding], n_results=min(top_k, self.get_count()), ) # Format results formatted_results = [] if results["ids"] and len(results["ids"][0]) > 0: for i in range(len(results["ids"][0])): result = { "id": results["ids"][0][i], "document": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i], } formatted_results.append(result) logging.info(f"Search returned {len(formatted_results)} results") return formatted_results except Exception as e: logging.error(f"Search failed: {e}") return [] def get_count(self) -> int: """Get the number of embeddings in the collection""" try: return self.collection.count() except Exception as e: logging.error(f"Failed to get count: {e}") return 0 def delete_collection(self) -> bool: """Delete the collection""" try: self.client.delete_collection(name=self.collection_name) logging.info(f"Deleted collection '{self.collection_name}'") return True except Exception as e: logging.error(f"Failed to delete collection: {e}") return False def reset_collection(self) -> bool: """Reset the collection (delete and recreate)""" try: # Delete existing collection try: self.client.delete_collection(name=self.collection_name) except ValueError: # Collection doesn't exist, that's fine pass # Create new collection self.collection = self.client.create_collection(name=self.collection_name) logging.info(f"Reset collection '{self.collection_name}'") return True except Exception as e: logging.error(f"Failed to reset collection: {e}") return False def get_embedding_dimension(self) -> int: """ Get the embedding dimension from existing data in the collection. Returns 0 if collection is empty or has no embeddings. """ try: count = self.get_count() if count == 0: return 0 # Retrieve one record to check its embedding dimension record = self.collection.get( ids=None, # None returns all records, but we only need one include=["embeddings"], limit=1, ) if record and "embeddings" in record and record["embeddings"]: return len(record["embeddings"][0]) return 0 except Exception as e: logging.error(f"Failed to get embedding dimension: {e}") return 0 def has_valid_embeddings(self, expected_dimension: int) -> bool: """ Check if the collection has embeddings with the expected dimension. Args: expected_dimension: The expected embedding dimension Returns: True if collection has embeddings with correct dimension, False otherwise """ try: actual_dimension = self.get_embedding_dimension() return actual_dimension == expected_dimension and actual_dimension > 0 except Exception as e: logging.error(f"Failed to validate embeddings: {e}") return False