msse-ai-engineering / scripts /init_pgvector.py
sethmcknight
Refactor test cases for improved readability and consistency
159faf0
raw
history blame
7.14 kB
#!/usr/bin/env python3
"""
Initialize pgvector extension in PostgreSQL database.
This script connects to the database specified by DATABASE_URL environment variable
and enables the pgvector extension if not already installed.
Usage:
python scripts/init_pgvector.py
Environment Variables:
DATABASE_URL: PostgreSQL connection string (required)
Exit Codes:
0: Success - pgvector extension is installed and working
1: Error - connection failed, extension installation failed, or other error
"""
import logging
import os
import sys
import psycopg2 # type: ignore
import psycopg2.extras # type: ignore
def setup_logging() -> logging.Logger:
"""Setup logging configuration."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
return logging.getLogger(__name__)
def get_database_url() -> str:
"""Get DATABASE_URL from environment."""
database_url = os.getenv("DATABASE_URL")
if not database_url:
raise ValueError("DATABASE_URL environment variable is required")
return database_url
def test_connection(connection_string: str, logger: logging.Logger) -> bool:
"""Test database connection."""
try:
with psycopg2.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1;")
result = cur.fetchone()
if result and result[0] == 1:
logger.info("βœ… Database connection successful")
return True
else:
logger.error("❌ Unexpected result from connection test")
return False
except Exception as e:
logger.error(f"❌ Database connection failed: {e}")
return False
def check_postgresql_version(connection_string: str, logger: logging.Logger) -> bool:
"""Check if PostgreSQL version supports pgvector (13+)."""
try:
with psycopg2.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT version();")
result = cur.fetchone()
if not result:
logger.error("❌ Could not get PostgreSQL version")
return False
version_string = str(result[0])
# Extract major version number
# Format: "PostgreSQL 15.4 on x86_64-pc-linux-gnu..."
version_parts = version_string.split()
if len(version_parts) >= 2:
version_number = version_parts[1].split(".")[0]
major_version = int(version_number)
if major_version >= 13:
logger.info(f"βœ… PostgreSQL version {major_version} supports pgvector")
return True
else:
logger.error(
"❌ PostgreSQL version %s is too old (requires 13+)",
major_version,
)
return False
else:
logger.warning(f"⚠️ Could not parse PostgreSQL version: {version_string}")
return True # Proceed anyway
except Exception as e:
logger.error(f"❌ Failed to check PostgreSQL version: {e}")
return False
def install_pgvector_extension(connection_string: str, logger: logging.Logger) -> bool:
"""Install pgvector extension."""
try:
with psycopg2.connect(connection_string) as conn:
conn.autocommit = True # Required for CREATE EXTENSION
with conn.cursor() as cur:
logger.info("Installing pgvector extension...")
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
logger.info("βœ… pgvector extension installed successfully")
return True
except psycopg2.errors.InsufficientPrivilege as e:
logger.error("❌ Insufficient privileges to install extension: %s", str(e))
logger.error("Make sure your database user has CREATE privilege or is a superuser")
return False
except Exception as e:
logger.error(f"❌ Failed to install pgvector extension: {e}")
return False
def verify_pgvector_installation(connection_string: str, logger: logging.Logger) -> bool:
"""Verify pgvector extension is properly installed."""
try:
with psycopg2.connect(connection_string) as conn:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Check extension is installed
cur.execute("SELECT extname, extversion FROM pg_extension " "WHERE extname = 'vector';")
result = cur.fetchone()
if not result:
logger.error("❌ pgvector extension not found in pg_extension")
return False
logger.info(f"βœ… pgvector extension version: {result['extversion']}")
# Test basic vector functionality
cur.execute("SELECT '[1,2,3]'::vector(3);")
vector_result = cur.fetchone()
if vector_result:
logger.info("βœ… Vector type functioning correctly")
else:
logger.error("❌ Vector type test failed")
return False
# Test vector operations
cur.execute("SELECT '[1,2,3]'::vector(3) <-> '[1,2,4]'::vector(3);")
distance_result = cur.fetchone()
if distance_result and distance_result[0] == 1.0:
logger.info("βœ… Vector distance operations working")
return True
else:
logger.error("❌ Vector distance operations failed")
return False
except Exception as e:
logger.error(f"❌ Failed to verify pgvector installation: {e}")
return False
def main() -> int:
"""Main function."""
logger = setup_logging()
try:
logger.info("πŸš€ Starting pgvector initialization...")
# Get database connection string
database_url = get_database_url()
logger.info("πŸ“‘ Got DATABASE_URL from environment")
# Test connection
if not test_connection(database_url, logger):
return 1
# Check PostgreSQL version
if not check_postgresql_version(database_url, logger):
return 1
# Install pgvector extension
if not install_pgvector_extension(database_url, logger):
return 1
# Verify installation
if not verify_pgvector_installation(database_url, logger):
return 1
logger.info("πŸŽ‰ pgvector initialization completed successfully!")
logger.info(" Your PostgreSQL database is now ready for vector operations.")
return 0
except Exception as e:
logger.error(f"❌ Unexpected error: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())