Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,955 Bytes
74c37c0 210796c 32cb5b7 210796c 74c37c0 32cb5b7 74c37c0 210796c 32cb5b7 74c37c0 210796c b2bbee4 210796c 74c37c0 210796c b2bbee4 d2bda67 b2bbee4 210796c b2bbee4 32cb5b7 210796c 32cb5b7 74c37c0 32cb5b7 74c37c0 32cb5b7 210796c 32cb5b7 d2bda67 32cb5b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import argparse
import chromadb
from tqdm import tqdm # Optional: For progress bar
db_config = {
"youtube_db": {
"source_db_path": "../youtube_surfer_ai_agent/youtube_db",
"source_collection_name": "yt_metadata",
"destination_collection_name": "yt_metadata",
},
"divya_prabandham": {
"source_db_path": "../uveda_analyzer/chromadb_store",
"source_collection_name": "divya_prabandham",
"destination_collection_name": "divya_prabandham",
},
}
parser = argparse.ArgumentParser(description="My app with database parameter")
parser.add_argument(
"--db",
type=str,
required=True,
choices=list(db_config.keys()),
help=f"Id of the database to use. allowed_values : {', '.join(db_config.keys())}",
)
args = parser.parse_args()
db_id = args.db
if db_id is None:
raise Exception(f"No db provided!")
if db_id not in db_config:
raise Exception(f"db with id {db_id} not found!")
# Connect to source and destination local persistent clients
source_client = chromadb.PersistentClient(path=db_config[db_id]["source_db_path"])
destination_client = chromadb.PersistentClient(path="./chromadb-store")
source_collection_name = db_config[db_id]["source_collection_name"]
destination_collection_name = db_config[db_id]["destination_collection_name"]
# Get the source collection
source_collection = source_client.get_collection(source_collection_name)
# Retrieve all data from the source collection
source_data = source_collection.get(include=["documents", "metadatas", "embeddings"])
# Create or get the destination collection
if destination_client.get_or_create_collection(destination_collection_name):
print("Deleting existing collection", destination_collection_name)
destination_client.delete_collection(destination_collection_name)
destination_collection = destination_client.get_or_create_collection(
destination_collection_name,
metadata=source_collection.metadata, # Copy metadata if needed
)
# Add data to the destination collection in batches
BATCH_SIZE = 500
total_records = len(source_data["ids"])
print(f"Copying {total_records} records in batches of {BATCH_SIZE}...")
for i in tqdm(range(0, total_records, BATCH_SIZE)):
batch_ids = source_data["ids"][i : i + BATCH_SIZE]
batch_docs = source_data["documents"][i : i + BATCH_SIZE]
batch_metas = source_data["metadatas"][i : i + BATCH_SIZE]
batch_embeds = (
source_data["embeddings"][i : i + BATCH_SIZE]
if "embeddings" in source_data and source_data["embeddings"] is not None
else None
)
destination_collection.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metas,
embeddings=batch_embeds,
)
print("✅ Collection copied successfully!")
print("Total records in source collection = ", source_collection.count())
print("Total records in destination collection = ", destination_collection.count())
|