Seth McKnight
Comprehensive memory optimizations and embedding service updates (#76)
f75da29
raw
history blame
7.94 kB
import logging
from pathlib import Path
from typing import Any, Dict, List
import chromadb
class VectorDatabase:
"""ChromaDB integration for vector storage and similarity search"""
def __init__(self, persist_path: str, collection_name: str):
"""
Initialize the vector database
Args:
persist_path: Path to persist the database
collection_name: Name of the collection to use
"""
self.persist_path = persist_path
self.collection_name = collection_name
# Ensure persist directory exists
Path(persist_path).mkdir(parents=True, exist_ok=True)
# Initialize ChromaDB client with persistence
self.client = chromadb.PersistentClient(path=persist_path)
# Get or create collection
try:
self.collection = self.client.get_collection(name=collection_name)
except ValueError:
# Collection doesn't exist, create it
self.collection = self.client.create_collection(name=collection_name)
logging.info(
f"Initialized VectorDatabase with collection "
f"'{collection_name}' at '{persist_path}'"
)
def get_collection(self):
"""Get the ChromaDB collection"""
return self.collection
def add_embeddings(
self,
embeddings: List[List[float]],
chunk_ids: List[str],
documents: List[str],
metadatas: List[Dict[str, Any]],
) -> bool:
"""
Add embeddings to the vector database
Args:
embeddings: List of embedding vectors
chunk_ids: List of unique chunk IDs
documents: List of document contents
metadatas: List of metadata dictionaries
Returns:
True if successful, False otherwise
"""
try:
# Validate input lengths match
if not (
len(embeddings) == len(chunk_ids) == len(documents) == len(metadatas)
):
raise ValueError("All input lists must have the same length")
# Check for existing documents to prevent duplicates
try:
existing = self.collection.get(ids=chunk_ids, include=[])
existing_ids = set(existing.get("ids", []))
except Exception:
existing_ids = set()
# Only add documents that don't already exist
new_embeddings = []
new_chunk_ids = []
new_documents = []
new_metadatas = []
for i, chunk_id in enumerate(chunk_ids):
if chunk_id not in existing_ids:
new_embeddings.append(embeddings[i])
new_chunk_ids.append(chunk_id)
new_documents.append(documents[i])
new_metadatas.append(metadatas[i])
if not new_embeddings:
logging.info(
f"All {len(chunk_ids)} documents already exist in collection"
)
return True
# Add to ChromaDB collection
self.collection.add(
embeddings=new_embeddings,
documents=new_documents,
metadatas=new_metadatas,
ids=new_chunk_ids,
)
logging.info(
f"Added {len(new_embeddings)} new embeddings to collection "
f"'{self.collection_name}' "
f"(skipped {len(chunk_ids) - len(new_embeddings)} duplicates)"
)
return True
except Exception as e:
logging.error(f"Failed to add embeddings: {e}")
raise e
def search(
self, query_embedding: List[float], top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for similar embeddings
Args:
query_embedding: Query vector to search for
top_k: Number of results to return
Returns:
List of search results with metadata
"""
try:
# Handle empty collection
if self.get_count() == 0:
return []
# Perform similarity search
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(top_k, self.get_count()),
)
# Format results
formatted_results = []
if results["ids"] and len(results["ids"][0]) > 0:
for i in range(len(results["ids"][0])):
result = {
"id": results["ids"][0][i],
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
formatted_results.append(result)
logging.info(f"Search returned {len(formatted_results)} results")
return formatted_results
except Exception as e:
logging.error(f"Search failed: {e}")
return []
def get_count(self) -> int:
"""Get the number of embeddings in the collection"""
try:
return self.collection.count()
except Exception as e:
logging.error(f"Failed to get count: {e}")
return 0
def delete_collection(self) -> bool:
"""Delete the collection"""
try:
self.client.delete_collection(name=self.collection_name)
logging.info(f"Deleted collection '{self.collection_name}'")
return True
except Exception as e:
logging.error(f"Failed to delete collection: {e}")
return False
def reset_collection(self) -> bool:
"""Reset the collection (delete and recreate)"""
try:
# Delete existing collection
try:
self.client.delete_collection(name=self.collection_name)
except ValueError:
# Collection doesn't exist, that's fine
pass
# Create new collection
self.collection = self.client.create_collection(name=self.collection_name)
logging.info(f"Reset collection '{self.collection_name}'")
return True
except Exception as e:
logging.error(f"Failed to reset collection: {e}")
return False
def get_embedding_dimension(self) -> int:
"""
Get the embedding dimension from existing data in the collection.
Returns 0 if collection is empty or has no embeddings.
"""
try:
count = self.get_count()
if count == 0:
return 0
# Retrieve one record to check its embedding dimension
record = self.collection.get(
ids=None, # None returns all records, but we only need one
include=["embeddings"],
limit=1,
)
if record and "embeddings" in record and record["embeddings"]:
return len(record["embeddings"][0])
return 0
except Exception as e:
logging.error(f"Failed to get embedding dimension: {e}")
return 0
def has_valid_embeddings(self, expected_dimension: int) -> bool:
"""
Check if the collection has embeddings with the expected dimension.
Args:
expected_dimension: The expected embedding dimension
Returns:
True if collection has embeddings with correct dimension, False otherwise
"""
try:
actual_dimension = self.get_embedding_dimension()
return actual_dimension == expected_dimension and actual_dimension > 0
except Exception as e:
logging.error(f"Failed to validate embeddings: {e}")
return False