import pandas as pd import numpy as np import random from typing import Literal import chromadb import re, unicodedata from config import SanatanConfig from embeddings import get_embedding import logging from pydantic import BaseModel from metadata import MetadataFilter, MetadataWhereClause from modules.db.relevance import validate_relevance_queryresult from tqdm import tqdm logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class SanatanDatabase: def __init__(self) -> None: self.chroma_client = chromadb.PersistentClient(path=SanatanConfig.dbStorePath) def does_data_exist(self, collection_name: str) -> bool: collection = self.chroma_client.get_or_create_collection(name=collection_name) num_rows = collection.count() logger.info("num_rows in %s = %d", collection_name, num_rows) return num_rows > 0 def load(self, collection_name: str, ids, documents, embeddings, metadatas): collection = self.chroma_client.get_or_create_collection(name=collection_name) collection.add( ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas, ) def fetch_random_data( self, collection_name: str, metadata_where_clause: MetadataWhereClause = None, n_results=1, ): # fetch all documents once logger.info( "getting %d random verses from [%s] | metadata_where_clause = %s", n_results, collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) data = collection.get( include=["metadatas", "documents"], where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ) ) docs = data["documents"] # list of all verse texts ids = data["ids"] metas = data["metadatas"] if not docs: logger.warning("No data found! - data=%s", data) return chromadb.QueryResult(ids=[], documents=[], metadatas=[]) # pick k random indices indices = random.sample(range(len(docs)), k=min(n_results, len(docs))) return chromadb.QueryResult( ids=[ids[i] for i in indices], documents=[docs[i] for i in indices], metadatas=[metas[i] for i in indices], ) def fetch_first_match( self, collection_name: str, metadata_where_clause: MetadataWhereClause = None ): """This version is created to support the browse module""" logger.info( "getting first matching verses from [%s] | metadata_where_clause = %s", collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) data = collection.get( limit=1, #hardcoded to 1 by design include=["metadatas", "documents"], where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ) ) docs = data["documents"] # list of all verse texts ids = data["ids"] metas = data["metadatas"] if not docs: logger.warning("No data found! - data=%s", data) return chromadb.GetResult(ids=[], documents=[], metadatas=[]) # pick k random indices return data def search( self, collection_name: str, query: str = None, metadata_where_clause: MetadataWhereClause = None, n_results=2, search_type: Literal["semantic", "literal", "random"] = "semantic", ): logger.info( "Search for [%s] in [%s]| metadata_where_clause=%s | search_type=%s | n_results=%d", query, collection_name, metadata_where_clause, search_type, n_results, ) if search_type == "semantic": return self.search_semantic( collection_name=collection_name, query=query, metadata_where_clause=metadata_where_clause, n_results=n_results, ) elif search_type == "literal": return self.search_for_literal( collection_name=collection_name, literal_to_search_for=query, metadata_where_clause=metadata_where_clause, n_results=n_results, ) else: # random return self.fetch_random_data( collection_name=collection_name, metadata_where_clause=metadata_where_clause, n_results=n_results, ) def fetch_document_by_index(self, collection_name: str, index: int): """ Fetch one document at a time from a ChromaDB collection using pagination (index = 0-based). Args: collection_name: Name of the ChromaDB collection. index: Zero-based index of the document to fetch. Returns: dict: { "document": , : , : , ... } Or a dict with "error" key if something went wrong. """ logger.info("fetching index %d from [%s]", index, collection_name) collection = self.chroma_client.get_or_create_collection(name=collection_name) try: response = collection.get( limit=1, # offset=index, # pagination via offset include=["metadatas", "documents"], where={"_global_index": index}, ) except Exception as e: logger.error("Error fetching document: %s", e, exc_info=True) return {"error": f"There was an error fetching the document: {str(e)}"} documents = response.get("documents", []) metadatas = response.get("metadatas", []) ids = response.get("ids", []) if documents: # merge document text with metadata result = {"document": documents[0]} if metadatas: result.update(metadatas[0]) if ids: result["id"] = ids[0] print("raw data = ", result) return result else: print("No data available") # show a sample data record response1 = collection.get( limit=2, # offset=index, # pagination via offset include=["metadatas", "documents"], ) print("sample data : ", response1) return {"error": "No data available."} def search_semantic( self, collection_name: str, query: str | None = None, metadata_where_clause: MetadataWhereClause | None = None, n_results=2, ): logger.info( "Vector Semantic Search for [%s] in [%s] | metadata_where_clause = %s", query, collection_name, metadata_where_clause, ) collection = self.chroma_client.get_or_create_collection(name=collection_name) try: q = query.strip() if query is not None else "" if not q: # fallback: fetch random verse return self.fetch_random_data( collection_name=collection_name, metadata_where_clause=metadata_where_clause, n_results=n_results, ) else: response = collection.query( query_embeddings=get_embedding( [query], SanatanConfig().get_embedding_for_collection(collection_name), ), # query_texts=[query], n_results=n_results, where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ), include=["metadatas", "documents", "distances"], ) except Exception as e: logger.error("Error in search: %s", e, exc_info=True) return chromadb.QueryResult( documents=[], ids=[], metadatas=[], distances=[], ) validated_response = validate_relevance_queryresult(query, response) logger.info( "status = %s | reason= %s", validated_response.status, validated_response.reason, ) return validated_response.result def search_for_literal( self, collection_name: str, literal_to_search_for: str | None = None, metadata_where_clause: MetadataWhereClause | None = None, n_results=2, ): logger.info( "Searching literally for [%s] in [%s] | metadata_where_clause = %s", literal_to_search_for, collection_name, metadata_where_clause, ) if literal_to_search_for is None or literal_to_search_for.strip() == "": logger.warning("Nothing to search literally.") raise Exception("query cannot be None or empty for a literal search!") # return self.fetch_random_data( # collection_name=collection_name, # ) collection = self.chroma_client.get_or_create_collection(name=collection_name) def normalize(text): return unicodedata.normalize("NFKC", text).lower() # 1. Try native contains response = collection.get( where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ), where_document={"$contains": literal_to_search_for}, limit=n_results, ) if response["documents"] and any(response["documents"]): return chromadb.QueryResult( ids=response["ids"], documents=response["documents"], metadatas=response["metadatas"], ) # 2. Regex fallback (normalized) logger.info("⚠ No luck. Falling back to regex for %s", literal_to_search_for) regex = re.compile(re.escape(normalize(literal_to_search_for))) logger.info("regex = %s", regex) all_docs = collection.get( where=( metadata_where_clause.to_chroma_where() if metadata_where_clause is not None else None ), ) matched_docs = [] for doc_list, metadata_list, doc_id_list in zip( all_docs["documents"], all_docs["metadatas"], all_docs["ids"] ): # Ensure all are lists if isinstance(doc_list, str): doc_list = [doc_list] if isinstance(metadata_list, dict): metadata_list = [metadata_list] if isinstance(doc_id_list, str): doc_id_list = [doc_id_list] for i in range(len(doc_list)): d = doc_list[i] current_metadata = metadata_list[i] current_id = doc_id_list[i] doc_match = regex.search(normalize(d)) metadata_match = False for key, value in current_metadata.items(): if isinstance(value, str) and regex.search(normalize(value)): metadata_match = True break elif isinstance(value, list): if any( isinstance(v, str) and regex.search(normalize(v)) for v in value ): metadata_match = True break if doc_match or metadata_match: matched_docs.append( { "id": current_id, "document": d, "metadata": current_metadata, } ) if len(matched_docs) >= n_results: break if len(matched_docs) >= n_results: break return chromadb.QueryResult( { "documents": [[d["document"] for d in matched_docs]], "ids": [[d["id"] for d in matched_docs]], "metadatas": [[d["metadata"] for d in matched_docs]], } ) def count(self, collection_name: str): collection = self.chroma_client.get_or_create_collection(name=collection_name) total_count = collection.count() logger.info("Total records in [%s] = %d", collection_name, total_count) return total_count def test_sanity(self): for scripture in SanatanConfig().scriptures: count = self.count(collection_name=scripture["collection_name"]) if count == 0: raise Exception(f"No data in collection {scripture["collection_name"]}") def reembed_collection_openai(self, collection_name: str, batch_size: int = 50): """ Deletes and recreates a Chroma collection with OpenAI text-embedding-3-large embeddings. All existing documents are re-embedded and inserted into the new collection. Args: collection_name: The name of the collection to delete/recreate. batch_size: Number of documents to process per batch. """ # Step 1: Fetch old collection data (if exists) try: old_collection = self.chroma_client.get_collection(name=collection_name) old_data = old_collection.get(include=["documents", "metadatas"]) documents = old_data["documents"] metadatas = old_data["metadatas"] ids = old_data["ids"] print(f"Fetched {len(documents)} documents from old collection.") # Step 2: Delete old collection # self.chroma_client.delete_collection(collection_name) # print(f"Deleted old collection '{collection_name}'.") except chromadb.errors.NotFoundError: print(f"No existing collection named '{collection_name}', starting fresh.") documents, metadatas, ids = [], [], [] # Step 3: Create new collection with correct embedding dimension new_collection = self.chroma_client.create_collection( name=f"{collection_name}_openai", embedding_function=None, # embeddings will be provided manually ) print( f"Created new collection '{collection_name}_openai' with embedding_dim=3072." ) # Step 4: Re-embed and insert documents in batches for i in tqdm( range(0, len(documents), batch_size), desc="Re-embedding batches" ): batch_docs = documents[i : i + batch_size] batch_metadatas = metadatas[i : i + batch_size] batch_ids = ids[i : i + batch_size] embeddings = get_embedding(batch_docs, backend="openai") new_collection.add( ids=batch_ids, documents=batch_docs, metadatas=batch_metadatas, embeddings=embeddings, ) print("All documents re-embedded and added to new collection successfully!") def add_unit_index_to_collection(self, collection_name: str, unit_field: str): if collection_name != "yt_metadata": # safeguard just incase return collection = self.chroma_client.get_collection(name=collection_name) # fetch everything in batches (in case your collection is large) batch_size = 100 offset = 0 unit_counter = 1 while True: result = collection.get( limit=batch_size, offset=offset, include=["documents", "metadatas", "embeddings"], ) ids = result["ids"] if not ids: break # no more docs docs = result["documents"] metas = result["metadatas"] embeddings = result["embeddings"] # add unit_index to metadata updated_metas = [] for meta in metas: # ensure meta is not None m = meta.copy() if meta else {} m[unit_field] = unit_counter updated_metas.append(m) unit_counter += 1 # upsert with same IDs (will overwrite metadata but keep same id+doc) collection.upsert( ids=ids, documents=docs, metadatas=updated_metas, embeddings=embeddings, ) offset += batch_size print( f"✅ Finished adding {unit_field} to {unit_counter-1} documents in {collection_name}." ) def get_list_of_values( self, collection_name: str, metadata_field_name: str ) -> list: """ Returns the unique values for a given metadata field in a collection. """ # Get the collection collection = self.chroma_client.get_or_create_collection(name=collection_name) # Fetch all metadata from the collection query_result = collection.get(include=["metadatas"]) values = set() # use a set to automatically deduplicate metadatas = query_result.get("metadatas", []) if metadatas: # Handle both flat list and nested list formats if isinstance(metadatas[0], dict): # flat list of dicts for md in metadatas: if metadata_field_name in md: values.add(md[metadata_field_name]) elif isinstance(metadatas[0], list): # nested list for md_list in metadatas: for md in md_list: if metadata_field_name in md: values.add(md[metadata_field_name]) return sorted(list(values)) def build_global_index_for_all_scriptures(self, force: bool = False): import pandas as pd import numpy as np logger.info("build_global_index_for_all_scriptures: started") config = SanatanConfig() for scripture in config.scriptures: scripture_name = scripture["name"] chapter_order = scripture.get("chapter_order", None) # if scripture_name != "vishnu_sahasranamam": # continue logger.info( "build_global_index_for_all_scriptures:%s: Processing", scripture_name ) collection_name = scripture["collection_name"] collection = self.chroma_client.get_or_create_collection( name=collection_name ) metadata_fields = scripture.get("metadata_fields", []) # Get metadata field names marked as unique unique_fields = [f["name"] for f in metadata_fields if f.get("is_unique")] if not unique_fields: if metadata_fields: unique_fields = [metadata_fields[0]["name"]] else: logger.warning( f"No metadata fields defined for {collection_name}, skipping" ) continue logger.info( "build_global_index_for_all_scriptures:%s:unique fields: %s", scripture_name, unique_fields, ) # Build chapter_order mapping if defined chapter_order_mapping = {} for field in metadata_fields: if callable(chapter_order): chapter_order_mapping = chapter_order() logger.info( "build_global_index_for_all_scriptures:%s:chapter_order_mapping: %s", scripture_name, chapter_order_mapping, ) # Fetch all records (keep embeddings for upsert) try: results = collection.get( include=["metadatas", "documents", "embeddings"] ) except Exception as e: logger.error( "build_global_index_for_all_scriptures:%s Error getting data from chromadb", scripture_name, exc_info=True, ) continue ids = results["ids"] metadatas = results["metadatas"] documents = results["documents"] embeddings = results.get("embeddings", [None] * len(ids)) if not force and metadatas and "_global_index" in metadatas[0]: logger.warning( "build_global_index_for_all_scriptures:%s: global index already available. skipping collection", scripture_name, ) continue # Create a DataFrame for metadata sorting df = pd.DataFrame(metadatas) df["_id"] = ids df["_doc"] = documents # Add sortable columns for each unique field for field_name in unique_fields: if field_name.lower() == "chapter" and chapter_order_mapping: # Map chapter names to their defined order df["_sort_" + field_name] = ( df[field_name].map(chapter_order_mapping).fillna(np.inf) ) else: # Try numeric, fallback to string lowercase def parse_val(v): if v is None: return float("inf") if isinstance(v, int): return v if isinstance(v, str): v = v.strip() return int(v) if v.isdigit() else v.lower() return str(v) df["_sort_" + field_name] = df[field_name].apply(parse_val) sort_cols = ["_sort_" + f for f in unique_fields] df = df.sort_values(by=sort_cols, kind="stable").reset_index(drop=True) # Assign global index df["_global_index"] = range(1, len(df) + 1) logger.info( "build_global_index_for_all_scriptures:%s: updating database", scripture_name, ) # Batch upsert BATCH_SIZE = 5000 # safely below max batch size for i in range(0, len(df), BATCH_SIZE): batch_df = df.iloc[i : i + BATCH_SIZE] batch_ids = batch_df["_id"].tolist() batch_docs = batch_df["_doc"].tolist() batch_metas = [ {k: record[k] for k in metadatas[0].keys() if k in record} | {"_global_index": record["_global_index"]} for record in batch_df.to_dict(orient="records") ] # Use original metadata keys for upsert batch_metas = [ {k: record[k] for k in metadatas[0].keys() if k in record} | {"_global_index": record["_global_index"]} for record in batch_df.to_dict(orient="records") ] batch_embeds = [embeddings[idx] for idx in batch_df.index] collection.update( ids=batch_ids, # documents=batch_docs, metadatas=batch_metas, # embeddings=batch_embeds, ) logger.info( "build_global_index_for_all_scriptures:%s: ✅ Updated with %d records", scripture_name, len(df), )