import chromadb import re, unicodedata from config import SanatanConfig from embeddings import get_embedding import logging from pydantic import BaseModel from metadata import MetadataFilter, MetadataWhereClause logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class SanatanDatabase: def __init__(self) -> None: self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath) def does_data_exist(self, collection_name: str) -> bool: collection = self.chroma_client.get_or_create_collection(name=collection_name) num_rows = collection.count() logger.info("num_rows in %s = %d", collection_name, num_rows) return num_rows > 0 def load(self, collection_name: str, ids, documents, embeddings, metadatas): collection = self.chroma_client.get_or_create_collection(name=collection_name) collection.add( ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas, ) def search(self, collection_name: str, query: str, n_results=2): logger.info("Vector Semantic Search for [%s] in [%s]", query, collection_name) collection = self.chroma_client.get_or_create_collection(name=collection_name) response = collection.query( query_embeddings=get_embedding( [query], SanatanConfig().get_embedding_for_collection(collection_name) ), # query_texts=[query], n_results=n_results, ) # logger.info("number of matches = %d", len(response["metadatas"])) return response def search_for_literal( self, collection_name: str, literal_to_search_for: str, n_results=2 ): logger.info( "Searching literally for [%s] in [%s]", literal_to_search_for, collection_name, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) def normalize(text): return unicodedata.normalize("NFKC", text).lower() # 1. Try native contains response = collection.query( query_texts=get_embedding( [""], SanatanConfig().get_embedding_for_collection(collection_name) ), where_document={"$contains": literal_to_search_for}, n_results=n_results, ) if response["documents"] and any(response["documents"]): return response # 2. Regex fallback (normalized) logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for) regex = re.compile(re.escape(normalize(literal_to_search_for))) logger.info("regex = %s", regex) all_docs = collection.get() matched_docs = [] for doc_list, metadata_list, doc_id_list in zip( all_docs["documents"], all_docs["metadatas"], all_docs["ids"] ): # Ensure all are lists if isinstance(doc_list, str): doc_list = [doc_list] if isinstance(metadata_list, dict): metadata_list = [metadata_list] if isinstance(doc_id_list, str): doc_id_list = [doc_id_list] for i in range(len(doc_list)): d = doc_list[i] current_metadata = metadata_list[i] current_id = doc_id_list[i] doc_match = regex.search(normalize(d)) metadata_match = False for key, value in current_metadata.items(): if isinstance(value, str) and regex.search(normalize(value)): metadata_match = True break elif isinstance(value, list): if any( isinstance(v, str) and regex.search(normalize(v)) for v in value ): metadata_match = True break if doc_match or metadata_match: matched_docs.append( { "id": current_id, "document": d, "metadata": current_metadata, } ) if len(matched_docs) >= n_results: break if len(matched_docs) >= n_results: break return { "documents": [[d["document"] for d in matched_docs]], "ids": [[d["id"] for d in matched_docs]], "metadatas": [[d["metadata"] for d in matched_docs]], } def search_by_metadata( self, collection_name: str, query: str, metadata_where_clause: MetadataWhereClause, n_results=2, ): """Search by a metadata field inside a specific collection using a specific operator. For instance {"azhwar_name": {"$in": "Thirumangai Azhwar"}}""" logger.info( "Searching by metadata for [%s] in [%s] with metadata_filters=%s", query, collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) response = collection.query( query_embeddings=get_embedding( [query], SanatanConfig().get_embedding_for_collection(collection_name) ), where=metadata_where_clause.to_chroma_where(), # query_texts=[query], n_results=n_results, ) return response def count(self, collection_name: str): collection = self.chroma_client.get_or_create_collection(name=collection_name) total_count = collection.count() logger.info("Total records in [%s] = %d", collection_name, total_count) return total_count def test_sanity(self): for scripture in SanatanConfig().scriptures: count = self.count(collection_name=scripture["collection_name"]) if count == 0: raise Exception(f"No data in collection {scripture["collection_name"]}") if __name__ == "__main__": logging.basicConfig() collection_name = "divya_prabandham" database = SanatanDatabase() print("count = ", database.count(collection_name)) while True: query = input("Search for: ") if query.strip() == "": break # response = database.search_for_literal( # collection_name=collection_name, literal_to_search_for=query, n_results=1 # ) metadata_where_clause = MetadataWhereClause( filters=[ MetadataFilter( metadata_field="prabandham_code", metadata_search_operator="$eq", metadata_value="TVM", ), MetadataFilter( metadata_field="decade", metadata_search_operator="$gte", metadata_value=10, ), ] ) response = database.search_by_metadata( collection_name=collection_name, query=query, metadata_where_clause=metadata_where_clause, n_results=1, ) print("Matches", response) # print("Document: ") # print(response["documents"][0][0]) # print("Metadata: ") # print(response["metadatas"][0][0])