#!/usr/bin/env python3 """ Initialize pgvector extension in PostgreSQL database. This script connects to the database specified by DATABASE_URL environment variable and enables the pgvector extension if not already installed. Usage: python scripts/init_pgvector.py Environment Variables: DATABASE_URL: PostgreSQL connection string (required) Exit Codes: 0: Success - pgvector extension is installed and working 1: Error - connection failed, extension installation failed, or other error """ import logging import os import sys import psycopg2 # type: ignore import psycopg2.extras # type: ignore def setup_logging() -> logging.Logger: """Setup logging configuration.""" logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) return logging.getLogger(__name__) def get_database_url() -> str: """Get DATABASE_URL from environment.""" database_url = os.getenv("DATABASE_URL") if not database_url: raise ValueError("DATABASE_URL environment variable is required") return database_url def test_connection(connection_string: str, logger: logging.Logger) -> bool: """Test database connection.""" try: with psycopg2.connect(connection_string) as conn: with conn.cursor() as cur: cur.execute("SELECT 1;") result = cur.fetchone() if result and result[0] == 1: logger.info("✅ Database connection successful") return True else: logger.error("❌ Unexpected result from connection test") return False except Exception as e: logger.error(f"❌ Database connection failed: {e}") return False def check_postgresql_version(connection_string: str, logger: logging.Logger) -> bool: """Check if PostgreSQL version supports pgvector (13+).""" try: with psycopg2.connect(connection_string) as conn: with conn.cursor() as cur: cur.execute("SELECT version();") result = cur.fetchone() if not result: logger.error("❌ Could not get PostgreSQL version") return False version_string = str(result[0]) # Extract major version number # Format: "PostgreSQL 15.4 on x86_64-pc-linux-gnu..." version_parts = version_string.split() if len(version_parts) >= 2: version_number = version_parts[1].split(".")[0] major_version = int(version_number) if major_version >= 13: logger.info(f"✅ PostgreSQL version {major_version} supports pgvector") return True else: logger.error( "❌ PostgreSQL version %s is too old (requires 13+)", major_version, ) return False else: logger.warning(f"⚠️ Could not parse PostgreSQL version: {version_string}") return True # Proceed anyway except Exception as e: logger.error(f"❌ Failed to check PostgreSQL version: {e}") return False def install_pgvector_extension(connection_string: str, logger: logging.Logger) -> bool: """Install pgvector extension.""" try: with psycopg2.connect(connection_string) as conn: conn.autocommit = True # Required for CREATE EXTENSION with conn.cursor() as cur: logger.info("Installing pgvector extension...") cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") logger.info("✅ pgvector extension installed successfully") return True except psycopg2.errors.InsufficientPrivilege as e: logger.error("❌ Insufficient privileges to install extension: %s", str(e)) logger.error("Make sure your database user has CREATE privilege or is a superuser") return False except Exception as e: logger.error(f"❌ Failed to install pgvector extension: {e}") return False def verify_pgvector_installation(connection_string: str, logger: logging.Logger) -> bool: """Verify pgvector extension is properly installed.""" try: with psycopg2.connect(connection_string) as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: # Check extension is installed cur.execute("SELECT extname, extversion FROM pg_extension " "WHERE extname = 'vector';") result = cur.fetchone() if not result: logger.error("❌ pgvector extension not found in pg_extension") return False logger.info(f"✅ pgvector extension version: {result['extversion']}") # Test basic vector functionality cur.execute("SELECT '[1,2,3]'::vector(3);") vector_result = cur.fetchone() if vector_result: logger.info("✅ Vector type functioning correctly") else: logger.error("❌ Vector type test failed") return False # Test vector operations cur.execute("SELECT '[1,2,3]'::vector(3) <-> '[1,2,4]'::vector(3);") distance_result = cur.fetchone() if distance_result and distance_result[0] == 1.0: logger.info("✅ Vector distance operations working") return True else: logger.error("❌ Vector distance operations failed") return False except Exception as e: logger.error(f"❌ Failed to verify pgvector installation: {e}") return False def main() -> int: """Main function.""" logger = setup_logging() try: logger.info("🚀 Starting pgvector initialization...") # Get database connection string database_url = get_database_url() logger.info("📡 Got DATABASE_URL from environment") # Test connection if not test_connection(database_url, logger): return 1 # Check PostgreSQL version if not check_postgresql_version(database_url, logger): return 1 # Install pgvector extension if not install_pgvector_extension(database_url, logger): return 1 # Verify installation if not verify_pgvector_installation(database_url, logger): return 1 logger.info("🎉 pgvector initialization completed successfully!") logger.info(" Your PostgreSQL database is now ready for vector operations.") return 0 except Exception as e: logger.error(f"❌ Unexpected error: {e}") return 1 if __name__ == "__main__": sys.exit(main())