Spaces:
Sleeping
Sleeping
File size: 7,140 Bytes
dca679b 159faf0 dca679b 159faf0 dca679b 159faf0 dca679b 159faf0 dca679b 159faf0 dca679b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
#!/usr/bin/env python3
"""
Initialize pgvector extension in PostgreSQL database.
This script connects to the database specified by DATABASE_URL environment variable
and enables the pgvector extension if not already installed.
Usage:
python scripts/init_pgvector.py
Environment Variables:
DATABASE_URL: PostgreSQL connection string (required)
Exit Codes:
0: Success - pgvector extension is installed and working
1: Error - connection failed, extension installation failed, or other error
"""
import logging
import os
import sys
import psycopg2 # type: ignore
import psycopg2.extras # type: ignore
def setup_logging() -> logging.Logger:
"""Setup logging configuration."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
return logging.getLogger(__name__)
def get_database_url() -> str:
"""Get DATABASE_URL from environment."""
database_url = os.getenv("DATABASE_URL")
if not database_url:
raise ValueError("DATABASE_URL environment variable is required")
return database_url
def test_connection(connection_string: str, logger: logging.Logger) -> bool:
"""Test database connection."""
try:
with psycopg2.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1;")
result = cur.fetchone()
if result and result[0] == 1:
logger.info("β
Database connection successful")
return True
else:
logger.error("β Unexpected result from connection test")
return False
except Exception as e:
logger.error(f"β Database connection failed: {e}")
return False
def check_postgresql_version(connection_string: str, logger: logging.Logger) -> bool:
"""Check if PostgreSQL version supports pgvector (13+)."""
try:
with psycopg2.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT version();")
result = cur.fetchone()
if not result:
logger.error("β Could not get PostgreSQL version")
return False
version_string = str(result[0])
# Extract major version number
# Format: "PostgreSQL 15.4 on x86_64-pc-linux-gnu..."
version_parts = version_string.split()
if len(version_parts) >= 2:
version_number = version_parts[1].split(".")[0]
major_version = int(version_number)
if major_version >= 13:
logger.info(f"β
PostgreSQL version {major_version} supports pgvector")
return True
else:
logger.error(
"β PostgreSQL version %s is too old (requires 13+)",
major_version,
)
return False
else:
logger.warning(f"β οΈ Could not parse PostgreSQL version: {version_string}")
return True # Proceed anyway
except Exception as e:
logger.error(f"β Failed to check PostgreSQL version: {e}")
return False
def install_pgvector_extension(connection_string: str, logger: logging.Logger) -> bool:
"""Install pgvector extension."""
try:
with psycopg2.connect(connection_string) as conn:
conn.autocommit = True # Required for CREATE EXTENSION
with conn.cursor() as cur:
logger.info("Installing pgvector extension...")
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
logger.info("β
pgvector extension installed successfully")
return True
except psycopg2.errors.InsufficientPrivilege as e:
logger.error("β Insufficient privileges to install extension: %s", str(e))
logger.error("Make sure your database user has CREATE privilege or is a superuser")
return False
except Exception as e:
logger.error(f"β Failed to install pgvector extension: {e}")
return False
def verify_pgvector_installation(connection_string: str, logger: logging.Logger) -> bool:
"""Verify pgvector extension is properly installed."""
try:
with psycopg2.connect(connection_string) as conn:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Check extension is installed
cur.execute("SELECT extname, extversion FROM pg_extension " "WHERE extname = 'vector';")
result = cur.fetchone()
if not result:
logger.error("β pgvector extension not found in pg_extension")
return False
logger.info(f"β
pgvector extension version: {result['extversion']}")
# Test basic vector functionality
cur.execute("SELECT '[1,2,3]'::vector(3);")
vector_result = cur.fetchone()
if vector_result:
logger.info("β
Vector type functioning correctly")
else:
logger.error("β Vector type test failed")
return False
# Test vector operations
cur.execute("SELECT '[1,2,3]'::vector(3) <-> '[1,2,4]'::vector(3);")
distance_result = cur.fetchone()
if distance_result and distance_result[0] == 1.0:
logger.info("β
Vector distance operations working")
return True
else:
logger.error("β Vector distance operations failed")
return False
except Exception as e:
logger.error(f"β Failed to verify pgvector installation: {e}")
return False
def main() -> int:
"""Main function."""
logger = setup_logging()
try:
logger.info("π Starting pgvector initialization...")
# Get database connection string
database_url = get_database_url()
logger.info("π‘ Got DATABASE_URL from environment")
# Test connection
if not test_connection(database_url, logger):
return 1
# Check PostgreSQL version
if not check_postgresql_version(database_url, logger):
return 1
# Install pgvector extension
if not install_pgvector_extension(database_url, logger):
return 1
# Verify installation
if not verify_pgvector_installation(database_url, logger):
return 1
logger.info("π pgvector initialization completed successfully!")
logger.info(" Your PostgreSQL database is now ready for vector operations.")
return 0
except Exception as e:
logger.error(f"β Unexpected error: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())
|