sanatan_ai / copy_chromadb.py
vikramvasudevan's picture
Upload folder using huggingface_hub
74c37c0 verified
raw
history blame
2.96 kB
import argparse
import chromadb
from tqdm import tqdm # Optional: For progress bar
db_config = {
"youtube_db": {
"source_db_path": "../youtube_surfer_ai_agent/youtube_db",
"source_collection_name": "yt_metadata",
"destination_collection_name": "yt_metadata",
},
"divya_prabandham": {
"source_db_path": "../uveda_analyzer/chromadb_store",
"source_collection_name": "divya_prabandham",
"destination_collection_name": "divya_prabandham",
},
}
parser = argparse.ArgumentParser(description="My app with database parameter")
parser.add_argument(
"--db",
type=str,
required=True,
choices=list(db_config.keys()),
help=f"Id of the database to use. allowed_values : {', '.join(db_config.keys())}",
)
args = parser.parse_args()
db_id = args.db
if db_id is None:
raise Exception(f"No db provided!")
if db_id not in db_config:
raise Exception(f"db with id {db_id} not found!")
# Connect to source and destination local persistent clients
source_client = chromadb.PersistentClient(path=db_config[db_id]["source_db_path"])
destination_client = chromadb.PersistentClient(path="./chromadb-store")
source_collection_name = db_config[db_id]["source_collection_name"]
destination_collection_name = db_config[db_id]["destination_collection_name"]
# Get the source collection
source_collection = source_client.get_collection(source_collection_name)
# Retrieve all data from the source collection
source_data = source_collection.get(include=["documents", "metadatas", "embeddings"])
# Create or get the destination collection
if destination_client.get_or_create_collection(destination_collection_name):
print("Deleting existing collection", destination_collection_name)
destination_client.delete_collection(destination_collection_name)
destination_collection = destination_client.get_or_create_collection(
destination_collection_name,
metadata=source_collection.metadata, # Copy metadata if needed
)
# Add data to the destination collection in batches
BATCH_SIZE = 500
total_records = len(source_data["ids"])
print(f"Copying {total_records} records in batches of {BATCH_SIZE}...")
for i in tqdm(range(0, total_records, BATCH_SIZE)):
batch_ids = source_data["ids"][i : i + BATCH_SIZE]
batch_docs = source_data["documents"][i : i + BATCH_SIZE]
batch_metas = source_data["metadatas"][i : i + BATCH_SIZE]
batch_embeds = (
source_data["embeddings"][i : i + BATCH_SIZE]
if "embeddings" in source_data and source_data["embeddings"] is not None
else None
)
destination_collection.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metas,
embeddings=batch_embeds,
)
print("✅ Collection copied successfully!")
print("Total records in source collection = ", source_collection.count())
print("Total records in destination collection = ", destination_collection.count())