Spaces:
Sleeping
Sleeping
sethmcknight
commited on
Commit
·
fa40b0e
1
Parent(s):
307e1fd
fix(postgres): use psycopg2.sql.Identifier/SQL for table/sequence names to prevent SQL injection and satisfy PR feedback
Browse files
src/vector_db/postgres_vector_service.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional
|
|
| 10 |
|
| 11 |
import psycopg2
|
| 12 |
import psycopg2.extras
|
|
|
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
|
@@ -68,8 +69,9 @@ class PostgresVectorService:
|
|
| 68 |
|
| 69 |
# Create table with initial structure (dimension will be added later)
|
| 70 |
cur.execute(
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
id SERIAL PRIMARY KEY,
|
| 74 |
content TEXT NOT NULL,
|
| 75 |
embedding vector,
|
|
@@ -77,18 +79,18 @@ class PostgresVectorService:
|
|
| 77 |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 78 |
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 79 |
);
|
| 80 |
-
"""
|
| 81 |
-
|
| 82 |
-
)
|
| 83 |
)
|
| 84 |
|
| 85 |
# Create index for text search
|
| 86 |
cur.execute(
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
| 92 |
)
|
| 93 |
)
|
| 94 |
|
|
@@ -123,28 +125,29 @@ class PostgresVectorService:
|
|
| 123 |
if result and ("vector(%s)" % dimension) not in str(result):
|
| 124 |
# Drop existing index if it exists
|
| 125 |
cur.execute(
|
| 126 |
-
"DROP INDEX IF EXISTS
|
| 127 |
-
|
| 128 |
)
|
| 129 |
)
|
| 130 |
|
| 131 |
# Alter column to correct dimension
|
| 132 |
cur.execute(
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
| 136 |
)
|
| 137 |
)
|
| 138 |
|
| 139 |
# Create optimized index for similarity search
|
| 140 |
cur.execute(
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
)
|
| 149 |
)
|
| 150 |
|
|
@@ -192,19 +195,13 @@ class PostgresVectorService:
|
|
| 192 |
with self._get_connection() as conn:
|
| 193 |
with conn.cursor() as cur:
|
| 194 |
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
| 195 |
-
# Insert document and get ID
|
| 196 |
cur.execute(
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
(
|
| 203 |
-
self.table_name,
|
| 204 |
-
text,
|
| 205 |
-
embedding,
|
| 206 |
-
psycopg2.extras.Json(metadata),
|
| 207 |
-
),
|
| 208 |
)
|
| 209 |
|
| 210 |
doc_id = cur.fetchone()[0]
|
|
@@ -254,14 +251,18 @@ class PostgresVectorService:
|
|
| 254 |
if conditions:
|
| 255 |
where_clause = "WHERE " + " AND ".join(conditions)
|
| 256 |
|
| 257 |
-
query
|
|
|
|
|
|
|
|
|
|
| 258 |
SELECT id, content, metadata,
|
| 259 |
1 - (embedding <=> %s) as similarity_score
|
| 260 |
-
FROM {
|
| 261 |
-
{
|
| 262 |
ORDER BY embedding <=> %s
|
| 263 |
LIMIT %s;
|
| 264 |
"""
|
|
|
|
| 265 |
|
| 266 |
with self._get_connection() as conn:
|
| 267 |
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
|
@@ -283,15 +284,18 @@ class PostgresVectorService:
|
|
| 283 |
with self._get_connection() as conn:
|
| 284 |
with conn.cursor() as cur:
|
| 285 |
# Get document count
|
| 286 |
-
cur.execute(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
doc_count = cur.fetchone()[0]
|
| 288 |
|
| 289 |
# Get table size
|
| 290 |
cur.execute(
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
(self.table_name,),
|
| 295 |
)
|
| 296 |
table_size = cur.fetchone()[0]
|
| 297 |
|
|
@@ -335,10 +339,9 @@ class PostgresVectorService:
|
|
| 335 |
int_ids = [int(doc_id) for doc_id in document_ids]
|
| 336 |
|
| 337 |
cur.execute(
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
""",
|
| 342 |
(int_ids,),
|
| 343 |
)
|
| 344 |
|
|
@@ -357,14 +360,22 @@ class PostgresVectorService:
|
|
| 357 |
"""
|
| 358 |
with self._get_connection() as conn:
|
| 359 |
with conn.cursor() as cur:
|
| 360 |
-
cur.execute(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
count_before = cur.fetchone()[0]
|
| 362 |
|
| 363 |
-
cur.execute(
|
|
|
|
|
|
|
| 364 |
|
| 365 |
# Reset the sequence
|
| 366 |
cur.execute(
|
| 367 |
-
"ALTER SEQUENCE
|
|
|
|
|
|
|
| 368 |
)
|
| 369 |
|
| 370 |
conn.commit()
|
|
@@ -411,11 +422,10 @@ class PostgresVectorService:
|
|
| 411 |
updates.append("updated_at = CURRENT_TIMESTAMP")
|
| 412 |
params.append(int(document_id))
|
| 413 |
|
| 414 |
-
query
|
| 415 |
-
|
| 416 |
-
SET
|
| 417 |
-
|
| 418 |
-
"""
|
| 419 |
|
| 420 |
with self._get_connection() as conn:
|
| 421 |
with conn.cursor() as cur:
|
|
@@ -443,12 +453,11 @@ class PostgresVectorService:
|
|
| 443 |
with self._get_connection() as conn:
|
| 444 |
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
| 445 |
cur.execute(
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
(self.table_name, int(document_id)),
|
| 452 |
)
|
| 453 |
|
| 454 |
row = cur.fetchone()
|
|
|
|
| 10 |
|
| 11 |
import psycopg2
|
| 12 |
import psycopg2.extras
|
| 13 |
+
from psycopg2 import sql
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
|
|
|
| 69 |
|
| 70 |
# Create table with initial structure (dimension will be added later)
|
| 71 |
cur.execute(
|
| 72 |
+
sql.SQL(
|
| 73 |
+
"""
|
| 74 |
+
CREATE TABLE IF NOT EXISTS {} (
|
| 75 |
id SERIAL PRIMARY KEY,
|
| 76 |
content TEXT NOT NULL,
|
| 77 |
embedding vector,
|
|
|
|
| 79 |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 80 |
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 81 |
);
|
| 82 |
+
"""
|
| 83 |
+
).format(sql.Identifier(self.table_name))
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
# Create index for text search
|
| 87 |
cur.execute(
|
| 88 |
+
sql.SQL(
|
| 89 |
+
"CREATE INDEX IF NOT EXISTS {} "
|
| 90 |
+
"ON {} USING gin(to_tsvector('english', content));"
|
| 91 |
+
).format(
|
| 92 |
+
sql.Identifier(f"idx_{self.table_name}_content"),
|
| 93 |
+
sql.Identifier(self.table_name),
|
| 94 |
)
|
| 95 |
)
|
| 96 |
|
|
|
|
| 125 |
if result and ("vector(%s)" % dimension) not in str(result):
|
| 126 |
# Drop existing index if it exists
|
| 127 |
cur.execute(
|
| 128 |
+
sql.SQL("DROP INDEX IF EXISTS {}; ").format(
|
| 129 |
+
sql.Identifier(f"idx_{self.table_name}_embedding_cosine")
|
| 130 |
)
|
| 131 |
)
|
| 132 |
|
| 133 |
# Alter column to correct dimension
|
| 134 |
cur.execute(
|
| 135 |
+
sql.SQL(
|
| 136 |
+
"ALTER TABLE {} ALTER COLUMN embedding TYPE vector({});"
|
| 137 |
+
).format(
|
| 138 |
+
sql.Identifier(self.table_name), sql.Literal(dimension)
|
| 139 |
)
|
| 140 |
)
|
| 141 |
|
| 142 |
# Create optimized index for similarity search
|
| 143 |
cur.execute(
|
| 144 |
+
sql.SQL(
|
| 145 |
+
"CREATE INDEX IF NOT EXISTS {} ON {} "
|
| 146 |
+
"USING ivfflat (embedding vector_cosine_ops) "
|
| 147 |
+
"WITH (lists = 100);"
|
| 148 |
+
).format(
|
| 149 |
+
sql.Identifier(f"idx_{self.table_name}_embedding_cosine"),
|
| 150 |
+
sql.Identifier(self.table_name),
|
| 151 |
)
|
| 152 |
)
|
| 153 |
|
|
|
|
| 195 |
with self._get_connection() as conn:
|
| 196 |
with conn.cursor() as cur:
|
| 197 |
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
| 198 |
+
# Insert document and get ID (table name composed safely)
|
| 199 |
cur.execute(
|
| 200 |
+
sql.SQL(
|
| 201 |
+
"INSERT INTO {} (content, embedding, metadata) "
|
| 202 |
+
"VALUES (%s, %s, %s) RETURNING id;"
|
| 203 |
+
).format(sql.Identifier(self.table_name)),
|
| 204 |
+
(text, embedding, psycopg2.extras.Json(metadata)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
)
|
| 206 |
|
| 207 |
doc_id = cur.fetchone()[0]
|
|
|
|
| 251 |
if conditions:
|
| 252 |
where_clause = "WHERE " + " AND ".join(conditions)
|
| 253 |
|
| 254 |
+
# Compose query safely with identifier for table name. where_clause
|
| 255 |
+
# contains only parameter placeholders (%s) and logical operators.
|
| 256 |
+
query = sql.SQL(
|
| 257 |
+
"""
|
| 258 |
SELECT id, content, metadata,
|
| 259 |
1 - (embedding <=> %s) as similarity_score
|
| 260 |
+
FROM {}
|
| 261 |
+
{}
|
| 262 |
ORDER BY embedding <=> %s
|
| 263 |
LIMIT %s;
|
| 264 |
"""
|
| 265 |
+
).format(sql.Identifier(self.table_name), sql.SQL(where_clause))
|
| 266 |
|
| 267 |
with self._get_connection() as conn:
|
| 268 |
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
|
|
|
| 284 |
with self._get_connection() as conn:
|
| 285 |
with conn.cursor() as cur:
|
| 286 |
# Get document count
|
| 287 |
+
cur.execute(
|
| 288 |
+
sql.SQL("SELECT COUNT(*) FROM {};").format(
|
| 289 |
+
sql.Identifier(self.table_name)
|
| 290 |
+
)
|
| 291 |
+
)
|
| 292 |
doc_count = cur.fetchone()[0]
|
| 293 |
|
| 294 |
# Get table size
|
| 295 |
cur.execute(
|
| 296 |
+
sql.SQL(
|
| 297 |
+
"SELECT pg_size_pretty(pg_total_relation_size({})) as size;"
|
| 298 |
+
).format(sql.Identifier(self.table_name))
|
|
|
|
| 299 |
)
|
| 300 |
table_size = cur.fetchone()[0]
|
| 301 |
|
|
|
|
| 339 |
int_ids = [int(doc_id) for doc_id in document_ids]
|
| 340 |
|
| 341 |
cur.execute(
|
| 342 |
+
sql.SQL("DELETE FROM {} WHERE id = ANY(%s);").format(
|
| 343 |
+
sql.Identifier(self.table_name)
|
| 344 |
+
),
|
|
|
|
| 345 |
(int_ids,),
|
| 346 |
)
|
| 347 |
|
|
|
|
| 360 |
"""
|
| 361 |
with self._get_connection() as conn:
|
| 362 |
with conn.cursor() as cur:
|
| 363 |
+
cur.execute(
|
| 364 |
+
sql.SQL("SELECT COUNT(*) FROM {};").format(
|
| 365 |
+
sql.Identifier(self.table_name)
|
| 366 |
+
)
|
| 367 |
+
)
|
| 368 |
count_before = cur.fetchone()[0]
|
| 369 |
|
| 370 |
+
cur.execute(
|
| 371 |
+
sql.SQL("DELETE FROM {};").format(sql.Identifier(self.table_name))
|
| 372 |
+
)
|
| 373 |
|
| 374 |
# Reset the sequence
|
| 375 |
cur.execute(
|
| 376 |
+
sql.SQL("ALTER SEQUENCE {} RESTART WITH 1;").format(
|
| 377 |
+
sql.Identifier(f"{self.table_name}_id_seq")
|
| 378 |
+
)
|
| 379 |
)
|
| 380 |
|
| 381 |
conn.commit()
|
|
|
|
| 422 |
updates.append("updated_at = CURRENT_TIMESTAMP")
|
| 423 |
params.append(int(document_id))
|
| 424 |
|
| 425 |
+
# Compose update query with safe identifier for the table name.
|
| 426 |
+
query = sql.SQL(
|
| 427 |
+
"UPDATE {} SET " + ", ".join(updates) + " WHERE id = %s"
|
| 428 |
+
).format(sql.Identifier(self.table_name))
|
|
|
|
| 429 |
|
| 430 |
with self._get_connection() as conn:
|
| 431 |
with conn.cursor() as cur:
|
|
|
|
| 453 |
with self._get_connection() as conn:
|
| 454 |
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
| 455 |
cur.execute(
|
| 456 |
+
sql.SQL(
|
| 457 |
+
"SELECT id, content, metadata, created_at, "
|
| 458 |
+
"updated_at FROM {} WHERE id = %s;"
|
| 459 |
+
).format(sql.Identifier(self.table_name)),
|
| 460 |
+
(int(document_id),),
|
|
|
|
| 461 |
)
|
| 462 |
|
| 463 |
row = cur.fetchone()
|