Spaces:
Sleeping
Sleeping
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import chromadb | |
| class VectorDatabase: | |
| """ChromaDB integration for vector storage and similarity search""" | |
| def __init__(self, persist_path: str, collection_name: str): | |
| """ | |
| Initialize the vector database | |
| Args: | |
| persist_path: Path to persist the database | |
| collection_name: Name of the collection to use | |
| """ | |
| self.persist_path = persist_path | |
| self.collection_name = collection_name | |
| # Ensure persist directory exists | |
| Path(persist_path).mkdir(parents=True, exist_ok=True) | |
| # Initialize ChromaDB client with persistence | |
| self.client = chromadb.PersistentClient(path=persist_path) | |
| # Get or create collection | |
| try: | |
| self.collection = self.client.get_collection(name=collection_name) | |
| except ValueError: | |
| # Collection doesn't exist, create it | |
| self.collection = self.client.create_collection(name=collection_name) | |
| logging.info( | |
| f"Initialized VectorDatabase with collection " | |
| f"'{collection_name}' at '{persist_path}'" | |
| ) | |
| def get_collection(self): | |
| """Get the ChromaDB collection""" | |
| return self.collection | |
| def add_embeddings( | |
| self, | |
| embeddings: List[List[float]], | |
| chunk_ids: List[str], | |
| documents: List[str], | |
| metadatas: List[Dict[str, Any]], | |
| ) -> bool: | |
| """ | |
| Add embeddings to the vector database | |
| Args: | |
| embeddings: List of embedding vectors | |
| chunk_ids: List of unique chunk IDs | |
| documents: List of document contents | |
| metadatas: List of metadata dictionaries | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| try: | |
| # Validate input lengths match | |
| if not ( | |
| len(embeddings) == len(chunk_ids) == len(documents) == len(metadatas) | |
| ): | |
| raise ValueError("All input lists must have the same length") | |
| # Check for existing documents to prevent duplicates | |
| try: | |
| existing = self.collection.get(ids=chunk_ids, include=[]) | |
| existing_ids = set(existing.get("ids", [])) | |
| except Exception: | |
| existing_ids = set() | |
| # Only add documents that don't already exist | |
| new_embeddings = [] | |
| new_chunk_ids = [] | |
| new_documents = [] | |
| new_metadatas = [] | |
| for i, chunk_id in enumerate(chunk_ids): | |
| if chunk_id not in existing_ids: | |
| new_embeddings.append(embeddings[i]) | |
| new_chunk_ids.append(chunk_id) | |
| new_documents.append(documents[i]) | |
| new_metadatas.append(metadatas[i]) | |
| if not new_embeddings: | |
| logging.info( | |
| f"All {len(chunk_ids)} documents already exist in collection" | |
| ) | |
| return True | |
| # Add to ChromaDB collection | |
| self.collection.add( | |
| embeddings=new_embeddings, | |
| documents=new_documents, | |
| metadatas=new_metadatas, | |
| ids=new_chunk_ids, | |
| ) | |
| logging.info( | |
| f"Added {len(new_embeddings)} new embeddings to collection " | |
| f"'{self.collection_name}' " | |
| f"(skipped {len(chunk_ids) - len(new_embeddings)} duplicates)" | |
| ) | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to add embeddings: {e}") | |
| raise e | |
| def search( | |
| self, query_embedding: List[float], top_k: int = 5 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Search for similar embeddings | |
| Args: | |
| query_embedding: Query vector to search for | |
| top_k: Number of results to return | |
| Returns: | |
| List of search results with metadata | |
| """ | |
| try: | |
| # Handle empty collection | |
| if self.get_count() == 0: | |
| return [] | |
| # Perform similarity search | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=min(top_k, self.get_count()), | |
| ) | |
| # Format results | |
| formatted_results = [] | |
| if results["ids"] and len(results["ids"][0]) > 0: | |
| for i in range(len(results["ids"][0])): | |
| result = { | |
| "id": results["ids"][0][i], | |
| "document": results["documents"][0][i], | |
| "metadata": results["metadatas"][0][i], | |
| "distance": results["distances"][0][i], | |
| } | |
| formatted_results.append(result) | |
| logging.info(f"Search returned {len(formatted_results)} results") | |
| return formatted_results | |
| except Exception as e: | |
| logging.error(f"Search failed: {e}") | |
| return [] | |
| def get_count(self) -> int: | |
| """Get the number of embeddings in the collection""" | |
| try: | |
| return self.collection.count() | |
| except Exception as e: | |
| logging.error(f"Failed to get count: {e}") | |
| return 0 | |
| def delete_collection(self) -> bool: | |
| """Delete the collection""" | |
| try: | |
| self.client.delete_collection(name=self.collection_name) | |
| logging.info(f"Deleted collection '{self.collection_name}'") | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to delete collection: {e}") | |
| return False | |
| def reset_collection(self) -> bool: | |
| """Reset the collection (delete and recreate)""" | |
| try: | |
| # Delete existing collection | |
| try: | |
| self.client.delete_collection(name=self.collection_name) | |
| except ValueError: | |
| # Collection doesn't exist, that's fine | |
| pass | |
| # Create new collection | |
| self.collection = self.client.create_collection(name=self.collection_name) | |
| logging.info(f"Reset collection '{self.collection_name}'") | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to reset collection: {e}") | |
| return False | |
| def get_embedding_dimension(self) -> int: | |
| """ | |
| Get the embedding dimension from existing data in the collection. | |
| Returns 0 if collection is empty or has no embeddings. | |
| """ | |
| try: | |
| count = self.get_count() | |
| if count == 0: | |
| return 0 | |
| # Retrieve one record to check its embedding dimension | |
| record = self.collection.get( | |
| ids=None, # None returns all records, but we only need one | |
| include=["embeddings"], | |
| limit=1, | |
| ) | |
| if record and "embeddings" in record and record["embeddings"]: | |
| return len(record["embeddings"][0]) | |
| return 0 | |
| except Exception as e: | |
| logging.error(f"Failed to get embedding dimension: {e}") | |
| return 0 | |
| def has_valid_embeddings(self, expected_dimension: int) -> bool: | |
| """ | |
| Check if the collection has embeddings with the expected dimension. | |
| Args: | |
| expected_dimension: The expected embedding dimension | |
| Returns: | |
| True if collection has embeddings with correct dimension, False otherwise | |
| """ | |
| try: | |
| actual_dimension = self.get_embedding_dimension() | |
| return actual_dimension == expected_dimension and actual_dimension > 0 | |
| except Exception as e: | |
| logging.error(f"Failed to validate embeddings: {e}") | |
| return False | |