Tobias Pasquale commited on
Commit
1300b38
·
1 Parent(s): 923405c

style: Fix line length in search service tests

Browse files
tests/test_search/test_search_service.py CHANGED
@@ -53,7 +53,9 @@ class TestSearchFunctionality:
53
  self.mock_vector_db = Mock(spec=VectorDatabase)
54
  self.mock_embedding_service = Mock(spec=EmbeddingService)
55
  self.search_service = SearchService(
56
- vector_db=self.mock_vector_db, embedding_service=self.mock_embedding_service
 
 
57
  )
58
 
59
  def test_search_with_valid_query(self):
@@ -330,4 +332,85 @@ class TestIntegrationWithRealComponents:
330
  # Basic validation
331
  assert len(results) > 0
332
  assert results[0]["chunk_id"] == "test_doc"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  assert 0.0 <= results[0]["similarity_score"] <= 1.0
 
53
  self.mock_vector_db = Mock(spec=VectorDatabase)
54
  self.mock_embedding_service = Mock(spec=EmbeddingService)
55
  self.search_service = SearchService(
56
+ vector_db=self.mock_vector_db,
57
+ embedding_service=self.mock_embedding_service,
58
+ enable_query_expansion=False, # Disable for unit tests
59
  )
60
 
61
  def test_search_with_valid_query(self):
 
332
  # Basic validation
333
  assert len(results) > 0
334
  assert results[0]["chunk_id"] == "test_doc"
335
+
336
+
337
+ class TestQueryExpansion:
338
+ """Test query expansion functionality."""
339
+
340
+ def setup_method(self):
341
+ """Set up test fixtures for query expansion tests."""
342
+ self.mock_vector_db = Mock(spec=VectorDatabase)
343
+ self.mock_embedding_service = Mock(spec=EmbeddingService)
344
+ # Enable query expansion for these tests
345
+ self.search_service = SearchService(
346
+ vector_db=self.mock_vector_db,
347
+ embedding_service=self.mock_embedding_service,
348
+ enable_query_expansion=True,
349
+ )
350
+
351
+ def test_query_expansion_enabled(self):
352
+ """Test that query expansion works when enabled."""
353
+ # Mock embedding generation
354
+ mock_embedding = [0.1, 0.2, 0.3, 0.4]
355
+ self.mock_embedding_service.embed_text.return_value = mock_embedding
356
+
357
+ # Mock vector database search results
358
+ mock_raw_results = [
359
+ {
360
+ "id": "doc_1",
361
+ "document": "Remote work policy content...",
362
+ "distance": 0.15,
363
+ "metadata": {"filename": "remote_work_policy.md", "chunk_index": 0},
364
+ }
365
+ ]
366
+ self.mock_vector_db.search.return_value = mock_raw_results
367
+
368
+ # Perform search with query that should be expanded
369
+ results = self.search_service.search("work from home", top_k=1)
370
+
371
+ # Verify that the query was expanded (should contain more than original query)
372
+ actual_call = self.mock_embedding_service.embed_text.call_args[0][0]
373
+ assert "work from home" in actual_call
374
+ # Check that expansion terms were added
375
+ assert any(
376
+ term in actual_call for term in ["remote work", "telecommuting", "WFH"]
377
+ )
378
+
379
+ # Verify results are still returned correctly
380
+ assert len(results) == 1
381
+ assert results[0]["chunk_id"] == "doc_1"
382
+
383
+ def test_query_expansion_disabled(self):
384
+ """Test that query expansion can be disabled."""
385
+ # Create search service with expansion disabled
386
+ search_service_no_expansion = SearchService(
387
+ vector_db=self.mock_vector_db,
388
+ embedding_service=self.mock_embedding_service,
389
+ enable_query_expansion=False,
390
+ )
391
+
392
+ # Mock embedding generation
393
+ mock_embedding = [0.1, 0.2, 0.3, 0.4]
394
+ self.mock_embedding_service.embed_text.return_value = mock_embedding
395
+
396
+ # Mock vector database search results
397
+ mock_raw_results = [
398
+ {
399
+ "id": "doc_1",
400
+ "document": "Content...",
401
+ "distance": 0.15,
402
+ "metadata": {"filename": "test.md", "chunk_index": 0},
403
+ }
404
+ ]
405
+ self.mock_vector_db.search.return_value = mock_raw_results
406
+
407
+ # Perform search
408
+ original_query = "work from home"
409
+ results = search_service_no_expansion.search(original_query, top_k=1)
410
+
411
+ # Verify that the original query was used without expansion
412
+ self.mock_embedding_service.embed_text.assert_called_with(original_query)
413
+
414
+ # Verify results are returned
415
+ assert len(results) == 1
416
  assert 0.0 <= results[0]["similarity_score"] <= 1.0