Tobias Pasquale
style: Fix code formatting and linting issues for CI/CD compliance
7793bb6
raw
history blame
5.37 kB
import logging
from pathlib import Path
from typing import Any, Dict, List
import chromadb
class VectorDatabase:
"""ChromaDB integration for vector storage and similarity search"""
def __init__(self, persist_path: str, collection_name: str):
"""
Initialize the vector database
Args:
persist_path: Path to persist the database
collection_name: Name of the collection to use
"""
self.persist_path = persist_path
self.collection_name = collection_name
# Ensure persist directory exists
Path(persist_path).mkdir(parents=True, exist_ok=True)
# Initialize ChromaDB client with persistence
self.client = chromadb.PersistentClient(path=persist_path)
# Get or create collection
try:
self.collection = self.client.get_collection(name=collection_name)
except ValueError:
# Collection doesn't exist, create it
self.collection = self.client.create_collection(name=collection_name)
logging.info(
f"Initialized VectorDatabase with collection "
f"'{collection_name}' at '{persist_path}'"
)
def get_collection(self):
"""Get the ChromaDB collection"""
return self.collection
def add_embeddings(
self,
embeddings: List[List[float]],
chunk_ids: List[str],
documents: List[str],
metadatas: List[Dict[str, Any]],
) -> bool:
"""
Add embeddings to the vector database
Args:
embeddings: List of embedding vectors
chunk_ids: List of unique chunk IDs
documents: List of document contents
metadatas: List of metadata dictionaries
Returns:
True if successful, False otherwise
"""
try:
# Validate input lengths match
if not (
len(embeddings) == len(chunk_ids) == len(documents) == len(metadatas)
):
raise ValueError("All input lists must have the same length")
# Add to ChromaDB collection
self.collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=chunk_ids,
)
logging.info(
f"Added {len(embeddings)} embeddings to collection "
f"'{self.collection_name}'"
)
return True
except Exception as e:
logging.error(f"Failed to add embeddings: {e}")
raise e
def search(
self, query_embedding: List[float], top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for similar embeddings
Args:
query_embedding: Query vector to search for
top_k: Number of results to return
Returns:
List of search results with metadata
"""
try:
# Handle empty collection
if self.get_count() == 0:
return []
# Perform similarity search
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(top_k, self.get_count()),
)
# Format results
formatted_results = []
if results["ids"] and len(results["ids"][0]) > 0:
for i in range(len(results["ids"][0])):
result = {
"id": results["ids"][0][i],
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
formatted_results.append(result)
logging.info(f"Search returned {len(formatted_results)} results")
return formatted_results
except Exception as e:
logging.error(f"Search failed: {e}")
return []
def get_count(self) -> int:
"""Get the number of embeddings in the collection"""
try:
return self.collection.count()
except Exception as e:
logging.error(f"Failed to get count: {e}")
return 0
def delete_collection(self) -> bool:
"""Delete the collection"""
try:
self.client.delete_collection(name=self.collection_name)
logging.info(f"Deleted collection '{self.collection_name}'")
return True
except Exception as e:
logging.error(f"Failed to delete collection: {e}")
return False
def reset_collection(self) -> bool:
"""Reset the collection (delete and recreate)"""
try:
# Delete existing collection
try:
self.client.delete_collection(name=self.collection_name)
except ValueError:
# Collection doesn't exist, that's fine
pass
# Create new collection
self.collection = self.client.create_collection(name=self.collection_name)
logging.info(f"Reset collection '{self.collection_name}'")
return True
except Exception as e:
logging.error(f"Failed to reset collection: {e}")
return False