Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Initialize pgvector extension in PostgreSQL database. | |
| This script connects to the database specified by DATABASE_URL environment variable | |
| and enables the pgvector extension if not already installed. | |
| Usage: | |
| python scripts/init_pgvector.py | |
| Environment Variables: | |
| DATABASE_URL: PostgreSQL connection string (required) | |
| Exit Codes: | |
| 0: Success - pgvector extension is installed and working | |
| 1: Error - connection failed, extension installation failed, or other error | |
| """ | |
| import logging | |
| import os | |
| import sys | |
| import psycopg2 # type: ignore | |
| import psycopg2.extras # type: ignore | |
| def setup_logging() -> logging.Logger: | |
| """Setup logging configuration.""" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| return logging.getLogger(__name__) | |
| def get_database_url() -> str: | |
| """Get DATABASE_URL from environment.""" | |
| database_url = os.getenv("DATABASE_URL") | |
| if not database_url: | |
| raise ValueError("DATABASE_URL environment variable is required") | |
| return database_url | |
| def test_connection(connection_string: str, logger: logging.Logger) -> bool: | |
| """Test database connection.""" | |
| try: | |
| with psycopg2.connect(connection_string) as conn: | |
| with conn.cursor() as cur: | |
| cur.execute("SELECT 1;") | |
| result = cur.fetchone() | |
| if result and result[0] == 1: | |
| logger.info("β Database connection successful") | |
| return True | |
| else: | |
| logger.error("β Unexpected result from connection test") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Database connection failed: {e}") | |
| return False | |
| def check_postgresql_version(connection_string: str, logger: logging.Logger) -> bool: | |
| """Check if PostgreSQL version supports pgvector (13+).""" | |
| try: | |
| with psycopg2.connect(connection_string) as conn: | |
| with conn.cursor() as cur: | |
| cur.execute("SELECT version();") | |
| result = cur.fetchone() | |
| if not result: | |
| logger.error("β Could not get PostgreSQL version") | |
| return False | |
| version_string = str(result[0]) | |
| # Extract major version number | |
| # Format: "PostgreSQL 15.4 on x86_64-pc-linux-gnu..." | |
| version_parts = version_string.split() | |
| if len(version_parts) >= 2: | |
| version_number = version_parts[1].split(".")[0] | |
| major_version = int(version_number) | |
| if major_version >= 13: | |
| logger.info(f"β PostgreSQL version {major_version} supports pgvector") | |
| return True | |
| else: | |
| logger.error( | |
| "β PostgreSQL version %s is too old (requires 13+)", | |
| major_version, | |
| ) | |
| return False | |
| else: | |
| logger.warning(f"β οΈ Could not parse PostgreSQL version: {version_string}") | |
| return True # Proceed anyway | |
| except Exception as e: | |
| logger.error(f"β Failed to check PostgreSQL version: {e}") | |
| return False | |
| def install_pgvector_extension(connection_string: str, logger: logging.Logger) -> bool: | |
| """Install pgvector extension.""" | |
| try: | |
| with psycopg2.connect(connection_string) as conn: | |
| conn.autocommit = True # Required for CREATE EXTENSION | |
| with conn.cursor() as cur: | |
| logger.info("Installing pgvector extension...") | |
| cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") | |
| logger.info("β pgvector extension installed successfully") | |
| return True | |
| except psycopg2.errors.InsufficientPrivilege as e: | |
| logger.error("β Insufficient privileges to install extension: %s", str(e)) | |
| logger.error("Make sure your database user has CREATE privilege or is a superuser") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Failed to install pgvector extension: {e}") | |
| return False | |
| def verify_pgvector_installation(connection_string: str, logger: logging.Logger) -> bool: | |
| """Verify pgvector extension is properly installed.""" | |
| try: | |
| with psycopg2.connect(connection_string) as conn: | |
| with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: | |
| # Check extension is installed | |
| cur.execute("SELECT extname, extversion FROM pg_extension " "WHERE extname = 'vector';") | |
| result = cur.fetchone() | |
| if not result: | |
| logger.error("β pgvector extension not found in pg_extension") | |
| return False | |
| logger.info(f"β pgvector extension version: {result['extversion']}") | |
| # Test basic vector functionality | |
| cur.execute("SELECT '[1,2,3]'::vector(3);") | |
| vector_result = cur.fetchone() | |
| if vector_result: | |
| logger.info("β Vector type functioning correctly") | |
| else: | |
| logger.error("β Vector type test failed") | |
| return False | |
| # Test vector operations | |
| cur.execute("SELECT '[1,2,3]'::vector(3) <-> '[1,2,4]'::vector(3);") | |
| distance_result = cur.fetchone() | |
| if distance_result and distance_result[0] == 1.0: | |
| logger.info("β Vector distance operations working") | |
| return True | |
| else: | |
| logger.error("β Vector distance operations failed") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Failed to verify pgvector installation: {e}") | |
| return False | |
| def main() -> int: | |
| """Main function.""" | |
| logger = setup_logging() | |
| try: | |
| logger.info("π Starting pgvector initialization...") | |
| # Get database connection string | |
| database_url = get_database_url() | |
| logger.info("π‘ Got DATABASE_URL from environment") | |
| # Test connection | |
| if not test_connection(database_url, logger): | |
| return 1 | |
| # Check PostgreSQL version | |
| if not check_postgresql_version(database_url, logger): | |
| return 1 | |
| # Install pgvector extension | |
| if not install_pgvector_extension(database_url, logger): | |
| return 1 | |
| # Verify installation | |
| if not verify_pgvector_installation(database_url, logger): | |
| return 1 | |
| logger.info("π pgvector initialization completed successfully!") | |
| logger.info(" Your PostgreSQL database is now ready for vector operations.") | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"β Unexpected error: {e}") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |