sethmcknight commited on
Commit
fa40b0e
·
1 Parent(s): 307e1fd

fix(postgres): use psycopg2.sql.Identifier/SQL for table/sequence names to prevent SQL injection and satisfy PR feedback

Browse files
src/vector_db/postgres_vector_service.py CHANGED
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional
10
 
11
  import psycopg2
12
  import psycopg2.extras
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -68,8 +69,9 @@ class PostgresVectorService:
68
 
69
  # Create table with initial structure (dimension will be added later)
70
  cur.execute(
71
- """
72
- CREATE TABLE IF NOT EXISTS {table_name} (
 
73
  id SERIAL PRIMARY KEY,
74
  content TEXT NOT NULL,
75
  embedding vector,
@@ -77,18 +79,18 @@ class PostgresVectorService:
77
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
78
  updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
79
  );
80
- """.format(
81
- table_name=self.table_name
82
- )
83
  )
84
 
85
  # Create index for text search
86
  cur.execute(
87
- """
88
- CREATE INDEX IF NOT EXISTS idx_{table}_content
89
- ON {table} USING gin(to_tsvector('english', content));
90
- """.format(
91
- table=self.table_name
 
92
  )
93
  )
94
 
@@ -123,28 +125,29 @@ class PostgresVectorService:
123
  if result and ("vector(%s)" % dimension) not in str(result):
124
  # Drop existing index if it exists
125
  cur.execute(
126
- "DROP INDEX IF EXISTS idx_{table}_embedding_cosine".format(
127
- table=self.table_name
128
  )
129
  )
130
 
131
  # Alter column to correct dimension
132
  cur.execute(
133
- "ALTER TABLE {table} "
134
- "ALTER COLUMN embedding TYPE vector({dim})".format(
135
- table=self.table_name, dim=dimension
 
136
  )
137
  )
138
 
139
  # Create optimized index for similarity search
140
  cur.execute(
141
- """
142
- CREATE INDEX IF NOT EXISTS idx_{table}_embedding_cosine
143
- ON {table}
144
- USING ivfflat (embedding vector_cosine_ops)
145
- WITH (lists = 100);
146
- """.format(
147
- table=self.table_name
148
  )
149
  )
150
 
@@ -192,19 +195,13 @@ class PostgresVectorService:
192
  with self._get_connection() as conn:
193
  with conn.cursor() as cur:
194
  for text, embedding, metadata in zip(texts, embeddings, metadatas):
195
- # Insert document and get ID
196
  cur.execute(
197
- """
198
- INSERT INTO %s (content, embedding, metadata)
199
- VALUES (%s, %s, %s)
200
- RETURNING id;
201
- """,
202
- (
203
- self.table_name,
204
- text,
205
- embedding,
206
- psycopg2.extras.Json(metadata),
207
- ),
208
  )
209
 
210
  doc_id = cur.fetchone()[0]
@@ -254,14 +251,18 @@ class PostgresVectorService:
254
  if conditions:
255
  where_clause = "WHERE " + " AND ".join(conditions)
256
 
257
- query = f"""
 
 
 
258
  SELECT id, content, metadata,
259
  1 - (embedding <=> %s) as similarity_score
260
- FROM {self.table_name}
261
- {where_clause}
262
  ORDER BY embedding <=> %s
263
  LIMIT %s;
264
  """
 
265
 
266
  with self._get_connection() as conn:
267
  with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
@@ -283,15 +284,18 @@ class PostgresVectorService:
283
  with self._get_connection() as conn:
284
  with conn.cursor() as cur:
285
  # Get document count
286
- cur.execute("SELECT COUNT(*) FROM %s", (self.table_name,))
 
 
 
 
287
  doc_count = cur.fetchone()[0]
288
 
289
  # Get table size
290
  cur.execute(
291
- """
292
- SELECT pg_size_pretty(pg_total_relation_size(%s)) as size;
293
- """,
294
- (self.table_name,),
295
  )
296
  table_size = cur.fetchone()[0]
297
 
@@ -335,10 +339,9 @@ class PostgresVectorService:
335
  int_ids = [int(doc_id) for doc_id in document_ids]
336
 
337
  cur.execute(
338
- f"""
339
- DELETE FROM {self.table_name}
340
- WHERE id = ANY(%s)
341
- """,
342
  (int_ids,),
343
  )
344
 
@@ -357,14 +360,22 @@ class PostgresVectorService:
357
  """
358
  with self._get_connection() as conn:
359
  with conn.cursor() as cur:
360
- cur.execute("SELECT COUNT(*) FROM %s", (self.table_name,))
 
 
 
 
361
  count_before = cur.fetchone()[0]
362
 
363
- cur.execute("DELETE FROM %s", (self.table_name,))
 
 
364
 
365
  # Reset the sequence
366
  cur.execute(
367
- "ALTER SEQUENCE %s_id_seq RESTART WITH 1", (self.table_name,)
 
 
368
  )
369
 
370
  conn.commit()
@@ -411,11 +422,10 @@ class PostgresVectorService:
411
  updates.append("updated_at = CURRENT_TIMESTAMP")
412
  params.append(int(document_id))
413
 
414
- query = f"""
415
- UPDATE {self.table_name}
416
- SET {', '.join(updates)}
417
- WHERE id = %s
418
- """
419
 
420
  with self._get_connection() as conn:
421
  with conn.cursor() as cur:
@@ -443,12 +453,11 @@ class PostgresVectorService:
443
  with self._get_connection() as conn:
444
  with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
445
  cur.execute(
446
- """
447
- SELECT id, content, metadata, created_at, updated_at
448
- FROM %s
449
- WHERE id = %s
450
- """,
451
- (self.table_name, int(document_id)),
452
  )
453
 
454
  row = cur.fetchone()
 
10
 
11
  import psycopg2
12
  import psycopg2.extras
13
+ from psycopg2 import sql
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
69
 
70
  # Create table with initial structure (dimension will be added later)
71
  cur.execute(
72
+ sql.SQL(
73
+ """
74
+ CREATE TABLE IF NOT EXISTS {} (
75
  id SERIAL PRIMARY KEY,
76
  content TEXT NOT NULL,
77
  embedding vector,
 
79
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
80
  updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
81
  );
82
+ """
83
+ ).format(sql.Identifier(self.table_name))
 
84
  )
85
 
86
  # Create index for text search
87
  cur.execute(
88
+ sql.SQL(
89
+ "CREATE INDEX IF NOT EXISTS {} "
90
+ "ON {} USING gin(to_tsvector('english', content));"
91
+ ).format(
92
+ sql.Identifier(f"idx_{self.table_name}_content"),
93
+ sql.Identifier(self.table_name),
94
  )
95
  )
96
 
 
125
  if result and ("vector(%s)" % dimension) not in str(result):
126
  # Drop existing index if it exists
127
  cur.execute(
128
+ sql.SQL("DROP INDEX IF EXISTS {}; ").format(
129
+ sql.Identifier(f"idx_{self.table_name}_embedding_cosine")
130
  )
131
  )
132
 
133
  # Alter column to correct dimension
134
  cur.execute(
135
+ sql.SQL(
136
+ "ALTER TABLE {} ALTER COLUMN embedding TYPE vector({});"
137
+ ).format(
138
+ sql.Identifier(self.table_name), sql.Literal(dimension)
139
  )
140
  )
141
 
142
  # Create optimized index for similarity search
143
  cur.execute(
144
+ sql.SQL(
145
+ "CREATE INDEX IF NOT EXISTS {} ON {} "
146
+ "USING ivfflat (embedding vector_cosine_ops) "
147
+ "WITH (lists = 100);"
148
+ ).format(
149
+ sql.Identifier(f"idx_{self.table_name}_embedding_cosine"),
150
+ sql.Identifier(self.table_name),
151
  )
152
  )
153
 
 
195
  with self._get_connection() as conn:
196
  with conn.cursor() as cur:
197
  for text, embedding, metadata in zip(texts, embeddings, metadatas):
198
+ # Insert document and get ID (table name composed safely)
199
  cur.execute(
200
+ sql.SQL(
201
+ "INSERT INTO {} (content, embedding, metadata) "
202
+ "VALUES (%s, %s, %s) RETURNING id;"
203
+ ).format(sql.Identifier(self.table_name)),
204
+ (text, embedding, psycopg2.extras.Json(metadata)),
 
 
 
 
 
 
205
  )
206
 
207
  doc_id = cur.fetchone()[0]
 
251
  if conditions:
252
  where_clause = "WHERE " + " AND ".join(conditions)
253
 
254
+ # Compose query safely with identifier for table name. where_clause
255
+ # contains only parameter placeholders (%s) and logical operators.
256
+ query = sql.SQL(
257
+ """
258
  SELECT id, content, metadata,
259
  1 - (embedding <=> %s) as similarity_score
260
+ FROM {}
261
+ {}
262
  ORDER BY embedding <=> %s
263
  LIMIT %s;
264
  """
265
+ ).format(sql.Identifier(self.table_name), sql.SQL(where_clause))
266
 
267
  with self._get_connection() as conn:
268
  with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
 
284
  with self._get_connection() as conn:
285
  with conn.cursor() as cur:
286
  # Get document count
287
+ cur.execute(
288
+ sql.SQL("SELECT COUNT(*) FROM {};").format(
289
+ sql.Identifier(self.table_name)
290
+ )
291
+ )
292
  doc_count = cur.fetchone()[0]
293
 
294
  # Get table size
295
  cur.execute(
296
+ sql.SQL(
297
+ "SELECT pg_size_pretty(pg_total_relation_size({})) as size;"
298
+ ).format(sql.Identifier(self.table_name))
 
299
  )
300
  table_size = cur.fetchone()[0]
301
 
 
339
  int_ids = [int(doc_id) for doc_id in document_ids]
340
 
341
  cur.execute(
342
+ sql.SQL("DELETE FROM {} WHERE id = ANY(%s);").format(
343
+ sql.Identifier(self.table_name)
344
+ ),
 
345
  (int_ids,),
346
  )
347
 
 
360
  """
361
  with self._get_connection() as conn:
362
  with conn.cursor() as cur:
363
+ cur.execute(
364
+ sql.SQL("SELECT COUNT(*) FROM {};").format(
365
+ sql.Identifier(self.table_name)
366
+ )
367
+ )
368
  count_before = cur.fetchone()[0]
369
 
370
+ cur.execute(
371
+ sql.SQL("DELETE FROM {};").format(sql.Identifier(self.table_name))
372
+ )
373
 
374
  # Reset the sequence
375
  cur.execute(
376
+ sql.SQL("ALTER SEQUENCE {} RESTART WITH 1;").format(
377
+ sql.Identifier(f"{self.table_name}_id_seq")
378
+ )
379
  )
380
 
381
  conn.commit()
 
422
  updates.append("updated_at = CURRENT_TIMESTAMP")
423
  params.append(int(document_id))
424
 
425
+ # Compose update query with safe identifier for the table name.
426
+ query = sql.SQL(
427
+ "UPDATE {} SET " + ", ".join(updates) + " WHERE id = %s"
428
+ ).format(sql.Identifier(self.table_name))
 
429
 
430
  with self._get_connection() as conn:
431
  with conn.cursor() as cur:
 
453
  with self._get_connection() as conn:
454
  with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
455
  cur.execute(
456
+ sql.SQL(
457
+ "SELECT id, content, metadata, created_at, "
458
+ "updated_at FROM {} WHERE id = %s;"
459
+ ).format(sql.Identifier(self.table_name)),
460
+ (int(document_id),),
 
461
  )
462
 
463
  row = cur.fetchone()