Seth McKnight
Postgres vector migration (#83)
dca679b
raw
history blame
10.5 kB
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Protocol, Union
import chromadb
from src.config import VECTOR_STORAGE_TYPE
from src.utils.memory_utils import log_memory_checkpoint, memory_monitor
def create_vector_database(
persist_path: Optional[str] = None, collection_name: Optional[str] = None
):
"""
Factory function to create the appropriate vector database implementation.
Args:
persist_path: Path for persistence (used by ChromaDB)
collection_name: Name of the collection
Returns:
Vector database implementation
"""
if VECTOR_STORAGE_TYPE == "postgres":
from src.vector_db.postgres_adapter import PostgresVectorAdapter
return PostgresVectorAdapter(
table_name=collection_name or "document_embeddings"
)
else:
# Default to ChromaDB
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
return VectorDatabase(
persist_path=persist_path or VECTOR_DB_PERSIST_PATH,
collection_name=collection_name or COLLECTION_NAME,
)
class VectorDatabase:
"""ChromaDB integration for vector storage and similarity search"""
def __init__(
self,
persist_path: str,
collection_name: str,
):
"""
Initialize the vector database
Args:
persist_path: Path to persist the database
collection_name: Name of the collection to use
"""
self.persist_path = persist_path
self.collection_name = collection_name
# Ensure persist directory exists
Path(persist_path).mkdir(parents=True, exist_ok=True)
# Get chroma settings from config for memory optimization
from chromadb.config import Settings
from src.config import CHROMA_SETTINGS
# Convert CHROMA_SETTINGS dict to Settings object
chroma_settings = Settings(**CHROMA_SETTINGS)
# Initialize ChromaDB client with persistence and memory optimization
log_memory_checkpoint("vector_db_before_client_init")
self.client = chromadb.PersistentClient(
path=persist_path, settings=chroma_settings
)
log_memory_checkpoint("vector_db_after_client_init")
# Get or create collection
try:
self.collection = self.client.get_collection(name=collection_name)
except ValueError:
# Collection doesn't exist, create it
self.collection = self.client.create_collection(name=collection_name)
logging.info(
f"Initialized VectorDatabase with collection "
f"'{collection_name}' at '{persist_path}'"
)
def get_collection(self):
"""Get the ChromaDB collection"""
return self.collection
@memory_monitor
def add_embeddings_batch(
self,
batch_embeddings: List[List[List[float]]],
batch_chunk_ids: List[List[str]],
batch_documents: List[List[str]],
batch_metadatas: List[List[Dict[str, Any]]],
) -> int:
"""
Add embeddings in batches to prevent memory issues with large datasets
Args:
batch_embeddings: List of embedding batches
batch_chunk_ids: List of chunk ID batches
batch_documents: List of document batches
batch_metadatas: List of metadata batches
Returns:
Number of embeddings added
"""
total_added = 0
for i, (embeddings, chunk_ids, documents, metadatas) in enumerate(
zip(
batch_embeddings,
batch_chunk_ids,
batch_documents,
batch_metadatas,
)
):
log_memory_checkpoint(f"before_add_batch_{i}")
# add_embeddings may return True on success (or raise on failure)
added = self.add_embeddings(
embeddings=embeddings,
chunk_ids=chunk_ids,
documents=documents,
metadatas=metadatas,
)
# If add_embeddings returns True, treat as all embeddings added
if isinstance(added, bool) and added:
added_count = len(embeddings)
elif isinstance(added, int):
added_count = int(added)
else:
added_count = 0
total_added += added_count
logging.info(f"Added batch {i+1}/{len(batch_embeddings)}")
# Force cleanup after each batch
import gc
gc.collect()
log_memory_checkpoint(f"after_add_batch_{i}")
return total_added
@memory_monitor
def add_embeddings(
self,
embeddings: List[List[float]],
chunk_ids: List[str],
documents: List[str],
metadatas: List[Dict[str, Any]],
) -> int:
"""
Add embeddings to the collection
Args:
embeddings: List of embedding vectors
chunk_ids: List of chunk IDs
documents: List of document texts
metadatas: List of metadata dictionaries
Returns:
Number of embeddings added
"""
# Validate input lengths
n = len(embeddings)
if not (len(chunk_ids) == n and len(documents) == n and len(metadatas) == n):
raise ValueError(
f"Number of embeddings {n} must match number of ids {len(chunk_ids)}"
)
log_memory_checkpoint("before_add_embeddings")
try:
self.collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=chunk_ids,
)
log_memory_checkpoint("after_add_embeddings")
logging.info(f"Added {n} embeddings to collection")
# Return boolean True for API compatibility tests
return True
except Exception as e:
logging.error(f"Failed to add embeddings: {e}")
# Re-raise to allow callers/tests to handle failures explicitly
raise
@memory_monitor
def search(
self, query_embedding: List[float], top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for similar embeddings
Args:
query_embedding: Query vector to search for
top_k: Number of results to return
Returns:
List of search results with metadata
"""
try:
# Handle empty collection
if self.get_count() == 0:
return []
# Perform similarity search
log_memory_checkpoint("vector_db_before_query")
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(top_k, self.get_count()),
)
log_memory_checkpoint("vector_db_after_query")
# Format results
formatted_results = []
if results["ids"] and len(results["ids"][0]) > 0:
for i in range(len(results["ids"][0])):
result = {
"id": results["ids"][0][i],
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
formatted_results.append(result)
logging.info(f"Search returned {len(formatted_results)} results")
return formatted_results
except Exception as e:
logging.error(f"Search failed: {e}")
return []
def get_count(self) -> int:
"""Get the number of embeddings in the collection"""
try:
return self.collection.count()
except Exception as e:
logging.error(f"Failed to get count: {e}")
return 0
def delete_collection(self) -> bool:
"""Delete the collection"""
try:
self.client.delete_collection(name=self.collection_name)
logging.info(f"Deleted collection '{self.collection_name}'")
return True
except Exception as e:
logging.error(f"Failed to delete collection: {e}")
return False
def reset_collection(self) -> bool:
"""Reset the collection (delete and recreate)"""
try:
# Delete existing collection
try:
self.client.delete_collection(name=self.collection_name)
except ValueError:
# Collection doesn't exist, that's fine
pass
# Create new collection
self.collection = self.client.create_collection(name=self.collection_name)
logging.info(f"Reset collection '{self.collection_name}'")
return True
except Exception as e:
logging.error(f"Failed to reset collection: {e}")
return False
def get_embedding_dimension(self) -> int:
"""
Get the embedding dimension from existing data in the collection.
Returns 0 if collection is empty or has no embeddings.
"""
try:
count = self.get_count()
if count == 0:
return 0
# Retrieve one record to check its embedding dimension
record = self.collection.get(
ids=None, # None returns all records, but we only need one
include=["embeddings"],
limit=1,
)
if record and "embeddings" in record and record["embeddings"]:
return len(record["embeddings"][0])
return 0
except Exception as e:
logging.error(f"Failed to get embedding dimension: {e}")
return 0
def has_valid_embeddings(self, expected_dimension: int) -> bool:
"""
Check if the collection has embeddings with the expected dimension.
Args:
expected_dimension: The expected embedding dimension
Returns:
True if collection has embeddings with correct dimension, False otherwise
"""
try:
actual_dimension = self.get_embedding_dimension()
return actual_dimension == expected_dimension and actual_dimension > 0
except Exception as e:
logging.error(f"Failed to validate embeddings: {e}")
return False