Tobias Pasquale commited on
Commit
508a7e5
·
1 Parent(s): c280a92

Fix: Complete CI/CD formatting compliance

Browse files

- Apply black code formatting to 12 files
- Fix import ordering with isort
- Remove unused imports (Union, MagicMock, json, asdict, PromptTemplate)
- Fix undefined variables in test_chat_endpoint.py
- Break long lines in RAG pipeline and response formatter
- Add noqa comments for prompt template strings
- Resolve all 19 flake8 E501 line length violations
- Ensure full pre-commit hook compliance

All code formatting issues resolved for successful pipeline deployment.

app.py CHANGED
@@ -168,7 +168,7 @@ def search():
168
  def chat():
169
  """
170
  Endpoint for conversational RAG interactions.
171
-
172
  Accepts JSON requests with user messages and returns AI-generated
173
  responses based on corporate policy documents.
174
  """
@@ -176,10 +176,12 @@ def chat():
176
  # Validate request contains JSON data
177
  if not request.is_json:
178
  return (
179
- jsonify({
180
- "status": "error",
181
- "message": "Content-Type must be application/json"
182
- }),
 
 
183
  400,
184
  )
185
 
@@ -189,19 +191,17 @@ def chat():
189
  message = data.get("message")
190
  if message is None:
191
  return (
192
- jsonify({
193
- "status": "error",
194
- "message": "message parameter is required"
195
- }),
196
  400,
197
  )
198
 
199
  if not isinstance(message, str) or not message.strip():
200
  return (
201
- jsonify({
202
- "status": "error",
203
- "message": "message must be a non-empty string"
204
- }),
205
  400,
206
  )
207
 
@@ -214,96 +214,103 @@ def chat():
214
  try:
215
  from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
216
  from src.embedding.embedding_service import EmbeddingService
217
- from src.search.search_service import SearchService
218
- from src.vector_store.vector_db import VectorDatabase
219
  from src.llm.llm_service import LLMService
220
  from src.rag.rag_pipeline import RAGPipeline
221
  from src.rag.response_formatter import ResponseFormatter
 
 
222
 
223
  # Initialize services
224
  vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
225
  embedding_service = EmbeddingService()
226
  search_service = SearchService(vector_db, embedding_service)
227
-
228
  # Initialize LLM service from environment
229
  llm_service = LLMService.from_environment()
230
-
231
  # Initialize RAG pipeline
232
  rag_pipeline = RAGPipeline(search_service, llm_service)
233
-
234
  # Initialize response formatter
235
  formatter = ResponseFormatter()
236
-
237
  except ValueError as e:
238
  return (
239
- jsonify({
240
- "status": "error",
241
- "message": f"LLM service configuration error: {str(e)}",
242
- "details": "Please ensure OPENROUTER_API_KEY or GROQ_API_KEY environment variables are set"
243
- }),
 
 
 
 
 
244
  503,
245
  )
246
  except Exception as e:
247
  return (
248
- jsonify({
249
- "status": "error",
250
- "message": f"Service initialization failed: {str(e)}"
251
- }),
 
 
252
  500,
253
  )
254
 
255
  # Generate RAG response
256
  rag_response = rag_pipeline.generate_answer(message.strip())
257
-
258
  # Format response for API
259
  if include_sources:
260
- formatted_response = formatter.format_api_response(rag_response, include_debug)
 
 
261
  else:
262
  formatted_response = formatter.format_chat_response(
263
- rag_response,
264
- conversation_id,
265
- include_sources=False
266
  )
267
 
268
  return jsonify(formatted_response)
269
 
270
  except Exception as e:
271
- return jsonify({
272
- "status": "error",
273
- "message": f"Chat request failed: {str(e)}"
274
- }), 500
275
 
276
 
277
  @app.route("/chat/health", methods=["GET"])
278
  def chat_health():
279
  """
280
  Health check endpoint for RAG chat functionality.
281
-
282
  Returns the status of all RAG pipeline components.
283
  """
284
  try:
285
  from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
286
  from src.embedding.embedding_service import EmbeddingService
287
- from src.search.search_service import SearchService
288
- from src.vector_store.vector_db import VectorDatabase
289
  from src.llm.llm_service import LLMService
290
  from src.rag.rag_pipeline import RAGPipeline
291
  from src.rag.response_formatter import ResponseFormatter
 
 
292
 
293
  # Initialize services for health check
294
  vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
295
  embedding_service = EmbeddingService()
296
  search_service = SearchService(vector_db, embedding_service)
297
-
298
  try:
299
  llm_service = LLMService.from_environment()
300
  rag_pipeline = RAGPipeline(search_service, llm_service)
301
  formatter = ResponseFormatter()
302
-
303
  # Perform health check
304
  health_data = rag_pipeline.health_check()
305
  health_response = formatter.create_health_response(health_data)
306
-
307
  # Determine HTTP status based on health
308
  if health_data.get("pipeline") == "healthy":
309
  return jsonify(health_response), 200
@@ -311,24 +318,32 @@ def chat_health():
311
  return jsonify(health_response), 200 # Still functional
312
  else:
313
  return jsonify(health_response), 503 # Service unavailable
314
-
315
  except ValueError as e:
316
- return jsonify({
317
- "status": "error",
318
- "message": f"LLM configuration error: {str(e)}",
319
- "health": {
320
- "pipeline_status": "unhealthy",
321
- "components": {
322
- "llm_service": {"status": "unconfigured", "error": str(e)}
 
 
 
 
 
 
 
323
  }
324
- }
325
- }), 503
 
326
 
327
  except Exception as e:
328
- return jsonify({
329
- "status": "error",
330
- "message": f"Health check failed: {str(e)}"
331
- }), 500
332
 
333
 
334
  if __name__ == "__main__":
 
168
  def chat():
169
  """
170
  Endpoint for conversational RAG interactions.
171
+
172
  Accepts JSON requests with user messages and returns AI-generated
173
  responses based on corporate policy documents.
174
  """
 
176
  # Validate request contains JSON data
177
  if not request.is_json:
178
  return (
179
+ jsonify(
180
+ {
181
+ "status": "error",
182
+ "message": "Content-Type must be application/json",
183
+ }
184
+ ),
185
  400,
186
  )
187
 
 
191
  message = data.get("message")
192
  if message is None:
193
  return (
194
+ jsonify(
195
+ {"status": "error", "message": "message parameter is required"}
196
+ ),
 
197
  400,
198
  )
199
 
200
  if not isinstance(message, str) or not message.strip():
201
  return (
202
+ jsonify(
203
+ {"status": "error", "message": "message must be a non-empty string"}
204
+ ),
 
205
  400,
206
  )
207
 
 
214
  try:
215
  from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
216
  from src.embedding.embedding_service import EmbeddingService
 
 
217
  from src.llm.llm_service import LLMService
218
  from src.rag.rag_pipeline import RAGPipeline
219
  from src.rag.response_formatter import ResponseFormatter
220
+ from src.search.search_service import SearchService
221
+ from src.vector_store.vector_db import VectorDatabase
222
 
223
  # Initialize services
224
  vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
225
  embedding_service = EmbeddingService()
226
  search_service = SearchService(vector_db, embedding_service)
227
+
228
  # Initialize LLM service from environment
229
  llm_service = LLMService.from_environment()
230
+
231
  # Initialize RAG pipeline
232
  rag_pipeline = RAGPipeline(search_service, llm_service)
233
+
234
  # Initialize response formatter
235
  formatter = ResponseFormatter()
236
+
237
  except ValueError as e:
238
  return (
239
+ jsonify(
240
+ {
241
+ "status": "error",
242
+ "message": f"LLM service configuration error: {str(e)}",
243
+ "details": (
244
+ "Please ensure OPENROUTER_API_KEY or GROQ_API_KEY "
245
+ "environment variables are set"
246
+ ),
247
+ }
248
+ ),
249
  503,
250
  )
251
  except Exception as e:
252
  return (
253
+ jsonify(
254
+ {
255
+ "status": "error",
256
+ "message": f"Service initialization failed: {str(e)}",
257
+ }
258
+ ),
259
  500,
260
  )
261
 
262
  # Generate RAG response
263
  rag_response = rag_pipeline.generate_answer(message.strip())
264
+
265
  # Format response for API
266
  if include_sources:
267
+ formatted_response = formatter.format_api_response(
268
+ rag_response, include_debug
269
+ )
270
  else:
271
  formatted_response = formatter.format_chat_response(
272
+ rag_response, conversation_id, include_sources=False
 
 
273
  )
274
 
275
  return jsonify(formatted_response)
276
 
277
  except Exception as e:
278
+ return (
279
+ jsonify({"status": "error", "message": f"Chat request failed: {str(e)}"}),
280
+ 500,
281
+ )
282
 
283
 
284
  @app.route("/chat/health", methods=["GET"])
285
  def chat_health():
286
  """
287
  Health check endpoint for RAG chat functionality.
288
+
289
  Returns the status of all RAG pipeline components.
290
  """
291
  try:
292
  from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
293
  from src.embedding.embedding_service import EmbeddingService
 
 
294
  from src.llm.llm_service import LLMService
295
  from src.rag.rag_pipeline import RAGPipeline
296
  from src.rag.response_formatter import ResponseFormatter
297
+ from src.search.search_service import SearchService
298
+ from src.vector_store.vector_db import VectorDatabase
299
 
300
  # Initialize services for health check
301
  vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
302
  embedding_service = EmbeddingService()
303
  search_service = SearchService(vector_db, embedding_service)
304
+
305
  try:
306
  llm_service = LLMService.from_environment()
307
  rag_pipeline = RAGPipeline(search_service, llm_service)
308
  formatter = ResponseFormatter()
309
+
310
  # Perform health check
311
  health_data = rag_pipeline.health_check()
312
  health_response = formatter.create_health_response(health_data)
313
+
314
  # Determine HTTP status based on health
315
  if health_data.get("pipeline") == "healthy":
316
  return jsonify(health_response), 200
 
318
  return jsonify(health_response), 200 # Still functional
319
  else:
320
  return jsonify(health_response), 503 # Service unavailable
321
+
322
  except ValueError as e:
323
+ return (
324
+ jsonify(
325
+ {
326
+ "status": "error",
327
+ "message": f"LLM configuration error: {str(e)}",
328
+ "health": {
329
+ "pipeline_status": "unhealthy",
330
+ "components": {
331
+ "llm_service": {
332
+ "status": "unconfigured",
333
+ "error": str(e),
334
+ }
335
+ },
336
+ },
337
  }
338
+ ),
339
+ 503,
340
+ )
341
 
342
  except Exception as e:
343
+ return (
344
+ jsonify({"status": "error", "message": f"Health check failed: {str(e)}"}),
345
+ 500,
346
+ )
347
 
348
 
349
  if __name__ == "__main__":
src/llm/__init__.py CHANGED
@@ -8,4 +8,4 @@ Classes:
8
  LLMService: Main service for LLM interactions
9
  PromptTemplates: Predefined prompt templates for corporate policy Q&A
10
  ContextManager: Manages context retrieval and formatting
11
- """
 
8
  LLMService: Main service for LLM interactions
9
  PromptTemplates: Predefined prompt templates for corporate policy Q&A
10
  ContextManager: Manages context retrieval and formatting
11
+ """
src/llm/context_manager.py CHANGED
@@ -6,8 +6,8 @@ for the RAG pipeline, ensuring optimal context window utilization.
6
  """
7
 
8
  import logging
9
- from typing import Any, Dict, List, Optional, Tuple
10
  from dataclasses import dataclass
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
15
  @dataclass
16
  class ContextConfig:
17
  """Configuration for context management."""
 
18
  max_context_length: int = 3000 # Maximum characters in context
19
  max_results: int = 5 # Maximum search results to include
20
  min_similarity: float = 0.1 # Minimum similarity threshold
@@ -24,7 +25,7 @@ class ContextConfig:
24
  class ContextManager:
25
  """
26
  Manages context retrieval and optimization for RAG pipeline.
27
-
28
  Handles:
29
  - Context length management
30
  - Relevance filtering
@@ -35,7 +36,7 @@ class ContextManager:
35
  def __init__(self, config: Optional[ContextConfig] = None):
36
  """
37
  Initialize ContextManager with configuration.
38
-
39
  Args:
40
  config: Context configuration, uses defaults if None
41
  """
@@ -43,17 +44,15 @@ class ContextManager:
43
  logger.info("ContextManager initialized")
44
 
45
  def prepare_context(
46
- self,
47
- search_results: List[Dict[str, Any]],
48
- query: str
49
  ) -> Tuple[str, List[Dict[str, Any]]]:
50
  """
51
  Prepare optimized context from search results.
52
-
53
  Args:
54
  search_results: Results from SearchService
55
  query: Original user query for context optimization
56
-
57
  Returns:
58
  Tuple of (formatted_context, filtered_results)
59
  """
@@ -62,56 +61,58 @@ class ContextManager:
62
 
63
  # Filter and rank results
64
  filtered_results = self._filter_results(search_results)
65
-
66
  # Remove duplicates and optimize for context window
67
  optimized_results = self._optimize_context(filtered_results)
68
-
69
  # Format for prompt
70
  formatted_context = self._format_context(optimized_results)
71
-
72
  logger.debug(
73
  f"Prepared context from {len(search_results)} results, "
74
  f"filtered to {len(optimized_results)} results, "
75
  f"{len(formatted_context)} characters"
76
  )
77
-
78
  return formatted_context, optimized_results
79
 
80
  def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
81
  """
82
  Filter search results by relevance and quality.
83
-
84
  Args:
85
  results: Raw search results
86
-
87
  Returns:
88
  Filtered and sorted results
89
  """
90
  filtered = []
91
-
92
  for result in results:
93
  similarity = result.get("similarity_score", 0.0)
94
  content = result.get("content", "").strip()
95
-
96
  # Apply filters
97
- if (similarity >= self.config.min_similarity and
98
- content and
99
- len(content) > 20): # Minimum content length
 
 
100
  filtered.append(result)
101
-
102
  # Sort by similarity score (descending)
103
  filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True)
104
-
105
  # Limit to max results
106
- return filtered[:self.config.max_results]
107
 
108
  def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
109
  """
110
  Optimize context to fit within token limits while maximizing relevance.
111
-
112
  Args:
113
  results: Filtered search results
114
-
115
  Returns:
116
  Optimized results list
117
  """
@@ -125,7 +126,7 @@ class ContextManager:
125
  for result in results:
126
  content = result.get("content", "").strip()
127
  content_length = len(content)
128
-
129
  # Check if adding this result would exceed limit
130
  estimated_formatted_length = current_length + content_length + 100 # Buffer
131
  if estimated_formatted_length > self.config.max_context_length:
@@ -137,18 +138,21 @@ class ContextManager:
137
  result_copy["content"] = truncated_content
138
  optimized.append(result_copy)
139
  break
140
-
141
  # Check for duplicate or highly similar content
142
  content_lower = content.lower()
143
  is_duplicate = False
144
-
145
  for seen in seen_content:
146
  # Simple similarity check for duplicates
147
- if (len(set(content_lower.split()) & set(seen.split())) /
148
- max(len(content_lower.split()), len(seen.split())) > 0.8):
 
 
 
149
  is_duplicate = True
150
  break
151
-
152
  if not is_duplicate:
153
  optimized.append(result)
154
  seen_content.add(content_lower)
@@ -159,10 +163,10 @@ class ContextManager:
159
  def _format_context(self, results: List[Dict[str, Any]]) -> str:
160
  """
161
  Format optimized results into context string.
162
-
163
  Args:
164
  results: Optimized search results
165
-
166
  Returns:
167
  Formatted context string
168
  """
@@ -170,34 +174,28 @@ class ContextManager:
170
  return "No relevant information found in corporate policies."
171
 
172
  context_parts = []
173
-
174
  for i, result in enumerate(results, 1):
175
  metadata = result.get("metadata", {})
176
  filename = metadata.get("filename", f"document_{i}")
177
  content = result.get("content", "").strip()
178
-
179
  # Format with document info
180
- context_parts.append(
181
- f"Document: {filename}\n"
182
- f"Content: {content}"
183
- )
184
 
185
  return "\n\n---\n\n".join(context_parts)
186
 
187
  def validate_context_quality(
188
- self,
189
- context: str,
190
- query: str,
191
- min_quality_score: float = 0.3
192
  ) -> Dict[str, Any]:
193
  """
194
  Validate the quality of prepared context for a given query.
195
-
196
  Args:
197
  context: Formatted context string
198
  query: Original user query
199
  min_quality_score: Minimum acceptable quality score
200
-
201
  Returns:
202
  Dictionary with quality metrics and validation result
203
  """
@@ -206,7 +204,7 @@ class ContextManager:
206
  "length": len(context),
207
  "has_content": bool(context.strip()),
208
  "estimated_relevance": 0.0,
209
- "passes_validation": False
210
  }
211
 
212
  if not context.strip():
@@ -216,7 +214,7 @@ class ContextManager:
216
  # Estimate relevance based on query term overlap
217
  query_terms = set(query.lower().split())
218
  context_terms = set(context.lower().split())
219
-
220
  if query_terms and context_terms:
221
  overlap = len(query_terms & context_terms)
222
  relevance = overlap / len(query_terms)
@@ -230,36 +228,36 @@ class ContextManager:
230
  def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
231
  """
232
  Generate summary of sources used in context.
233
-
234
  Args:
235
  results: Search results used for context
236
-
237
  Returns:
238
  Summary of sources and their contribution
239
  """
240
  sources = {}
241
  total_content_length = 0
242
-
243
  for result in results:
244
  metadata = result.get("metadata", {})
245
  filename = metadata.get("filename", "unknown")
246
  content_length = len(result.get("content", ""))
247
  similarity = result.get("similarity_score", 0.0)
248
-
249
  if filename not in sources:
250
  sources[filename] = {
251
  "chunks": 0,
252
  "total_content_length": 0,
253
  "max_similarity": 0.0,
254
- "avg_similarity": 0.0
255
  }
256
-
257
  sources[filename]["chunks"] += 1
258
  sources[filename]["total_content_length"] += content_length
259
  sources[filename]["max_similarity"] = max(
260
  sources[filename]["max_similarity"], similarity
261
  )
262
-
263
  total_content_length += content_length
264
 
265
  # Calculate averages and percentages
@@ -272,5 +270,5 @@ class ContextManager:
272
  "total_sources": len(sources),
273
  "total_chunks": len(results),
274
  "total_content_length": total_content_length,
275
- "sources": sources
276
- }
 
6
  """
7
 
8
  import logging
 
9
  from dataclasses import dataclass
10
+ from typing import Any, Dict, List, Optional, Tuple
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
15
  @dataclass
16
  class ContextConfig:
17
  """Configuration for context management."""
18
+
19
  max_context_length: int = 3000 # Maximum characters in context
20
  max_results: int = 5 # Maximum search results to include
21
  min_similarity: float = 0.1 # Minimum similarity threshold
 
25
  class ContextManager:
26
  """
27
  Manages context retrieval and optimization for RAG pipeline.
28
+
29
  Handles:
30
  - Context length management
31
  - Relevance filtering
 
36
  def __init__(self, config: Optional[ContextConfig] = None):
37
  """
38
  Initialize ContextManager with configuration.
39
+
40
  Args:
41
  config: Context configuration, uses defaults if None
42
  """
 
44
  logger.info("ContextManager initialized")
45
 
46
  def prepare_context(
47
+ self, search_results: List[Dict[str, Any]], query: str
 
 
48
  ) -> Tuple[str, List[Dict[str, Any]]]:
49
  """
50
  Prepare optimized context from search results.
51
+
52
  Args:
53
  search_results: Results from SearchService
54
  query: Original user query for context optimization
55
+
56
  Returns:
57
  Tuple of (formatted_context, filtered_results)
58
  """
 
61
 
62
  # Filter and rank results
63
  filtered_results = self._filter_results(search_results)
64
+
65
  # Remove duplicates and optimize for context window
66
  optimized_results = self._optimize_context(filtered_results)
67
+
68
  # Format for prompt
69
  formatted_context = self._format_context(optimized_results)
70
+
71
  logger.debug(
72
  f"Prepared context from {len(search_results)} results, "
73
  f"filtered to {len(optimized_results)} results, "
74
  f"{len(formatted_context)} characters"
75
  )
76
+
77
  return formatted_context, optimized_results
78
 
79
  def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
80
  """
81
  Filter search results by relevance and quality.
82
+
83
  Args:
84
  results: Raw search results
85
+
86
  Returns:
87
  Filtered and sorted results
88
  """
89
  filtered = []
90
+
91
  for result in results:
92
  similarity = result.get("similarity_score", 0.0)
93
  content = result.get("content", "").strip()
94
+
95
  # Apply filters
96
+ if (
97
+ similarity >= self.config.min_similarity
98
+ and content
99
+ and len(content) > 20
100
+ ): # Minimum content length
101
  filtered.append(result)
102
+
103
  # Sort by similarity score (descending)
104
  filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True)
105
+
106
  # Limit to max results
107
+ return filtered[: self.config.max_results]
108
 
109
  def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
110
  """
111
  Optimize context to fit within token limits while maximizing relevance.
112
+
113
  Args:
114
  results: Filtered search results
115
+
116
  Returns:
117
  Optimized results list
118
  """
 
126
  for result in results:
127
  content = result.get("content", "").strip()
128
  content_length = len(content)
129
+
130
  # Check if adding this result would exceed limit
131
  estimated_formatted_length = current_length + content_length + 100 # Buffer
132
  if estimated_formatted_length > self.config.max_context_length:
 
138
  result_copy["content"] = truncated_content
139
  optimized.append(result_copy)
140
  break
141
+
142
  # Check for duplicate or highly similar content
143
  content_lower = content.lower()
144
  is_duplicate = False
145
+
146
  for seen in seen_content:
147
  # Simple similarity check for duplicates
148
+ if (
149
+ len(set(content_lower.split()) & set(seen.split()))
150
+ / max(len(content_lower.split()), len(seen.split()))
151
+ > 0.8
152
+ ):
153
  is_duplicate = True
154
  break
155
+
156
  if not is_duplicate:
157
  optimized.append(result)
158
  seen_content.add(content_lower)
 
163
  def _format_context(self, results: List[Dict[str, Any]]) -> str:
164
  """
165
  Format optimized results into context string.
166
+
167
  Args:
168
  results: Optimized search results
169
+
170
  Returns:
171
  Formatted context string
172
  """
 
174
  return "No relevant information found in corporate policies."
175
 
176
  context_parts = []
177
+
178
  for i, result in enumerate(results, 1):
179
  metadata = result.get("metadata", {})
180
  filename = metadata.get("filename", f"document_{i}")
181
  content = result.get("content", "").strip()
182
+
183
  # Format with document info
184
+ context_parts.append(f"Document: {filename}\n" f"Content: {content}")
 
 
 
185
 
186
  return "\n\n---\n\n".join(context_parts)
187
 
188
  def validate_context_quality(
189
+ self, context: str, query: str, min_quality_score: float = 0.3
 
 
 
190
  ) -> Dict[str, Any]:
191
  """
192
  Validate the quality of prepared context for a given query.
193
+
194
  Args:
195
  context: Formatted context string
196
  query: Original user query
197
  min_quality_score: Minimum acceptable quality score
198
+
199
  Returns:
200
  Dictionary with quality metrics and validation result
201
  """
 
204
  "length": len(context),
205
  "has_content": bool(context.strip()),
206
  "estimated_relevance": 0.0,
207
+ "passes_validation": False,
208
  }
209
 
210
  if not context.strip():
 
214
  # Estimate relevance based on query term overlap
215
  query_terms = set(query.lower().split())
216
  context_terms = set(context.lower().split())
217
+
218
  if query_terms and context_terms:
219
  overlap = len(query_terms & context_terms)
220
  relevance = overlap / len(query_terms)
 
228
  def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
229
  """
230
  Generate summary of sources used in context.
231
+
232
  Args:
233
  results: Search results used for context
234
+
235
  Returns:
236
  Summary of sources and their contribution
237
  """
238
  sources = {}
239
  total_content_length = 0
240
+
241
  for result in results:
242
  metadata = result.get("metadata", {})
243
  filename = metadata.get("filename", "unknown")
244
  content_length = len(result.get("content", ""))
245
  similarity = result.get("similarity_score", 0.0)
246
+
247
  if filename not in sources:
248
  sources[filename] = {
249
  "chunks": 0,
250
  "total_content_length": 0,
251
  "max_similarity": 0.0,
252
+ "avg_similarity": 0.0,
253
  }
254
+
255
  sources[filename]["chunks"] += 1
256
  sources[filename]["total_content_length"] += content_length
257
  sources[filename]["max_similarity"] = max(
258
  sources[filename]["max_similarity"], similarity
259
  )
260
+
261
  total_content_length += content_length
262
 
263
  # Calculate averages and percentages
 
270
  "total_sources": len(sources),
271
  "total_chunks": len(results),
272
  "total_content_length": total_content_length,
273
+ "sources": sources,
274
+ }
src/llm/llm_service.py CHANGED
@@ -1,16 +1,18 @@
1
  """
2
  LLM Service for RAG Application
3
 
4
- This module provides integration with Large Language Models through multiple providers
5
- including OpenRouter and Groq, with fallback capabilities and comprehensive error handling.
 
6
  """
7
 
8
  import logging
9
  import os
10
  import time
11
- from typing import Any, Dict, List, Optional, Union
12
- import requests
13
  from dataclasses import dataclass
 
 
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -18,6 +20,7 @@ logger = logging.getLogger(__name__)
18
  @dataclass
19
  class LLMConfig:
20
  """Configuration for LLM providers."""
 
21
  provider: str # "openrouter" or "groq"
22
  api_key: str
23
  model_name: str
@@ -30,6 +33,7 @@ class LLMConfig:
30
  @dataclass
31
  class LLMResponse:
32
  """Standardized response from LLM providers."""
 
33
  content: str
34
  provider: str
35
  model: str
@@ -42,7 +46,7 @@ class LLMResponse:
42
  class LLMService:
43
  """
44
  Service for interacting with Large Language Models.
45
-
46
  Supports multiple providers with automatic fallback and retry logic.
47
  Designed for corporate policy Q&A with appropriate guardrails.
48
  """
@@ -50,108 +54,112 @@ class LLMService:
50
  def __init__(self, configs: List[LLMConfig]):
51
  """
52
  Initialize LLMService with provider configurations.
53
-
54
  Args:
55
  configs: List of LLMConfig objects for different providers
56
-
57
  Raises:
58
  ValueError: If no valid configurations provided
59
  """
60
  if not configs:
61
  raise ValueError("At least one LLM configuration must be provided")
62
-
63
  self.configs = configs
64
  self.current_config_index = 0
65
  logger.info(f"LLMService initialized with {len(configs)} provider(s)")
66
 
67
  @classmethod
68
- def from_environment(cls) -> 'LLMService':
69
  """
70
  Create LLMService instance from environment variables.
71
-
72
  Expected environment variables:
73
  - OPENROUTER_API_KEY: API key for OpenRouter
74
  - GROQ_API_KEY: API key for Groq
75
-
76
  Returns:
77
  LLMService instance with available providers
78
-
79
  Raises:
80
  ValueError: If no API keys found in environment
81
  """
82
  configs = []
83
-
84
  # OpenRouter configuration
85
  openrouter_key = os.getenv("OPENROUTER_API_KEY")
86
  if openrouter_key:
87
- configs.append(LLMConfig(
88
- provider="openrouter",
89
- api_key=openrouter_key,
90
- model_name="microsoft/wizardlm-2-8x22b", # Free tier model
91
- base_url="https://openrouter.ai/api/v1",
92
- max_tokens=1000,
93
- temperature=0.1
94
- ))
95
-
96
- # Groq configuration
 
 
97
  groq_key = os.getenv("GROQ_API_KEY")
98
  if groq_key:
99
- configs.append(LLMConfig(
100
- provider="groq",
101
- api_key=groq_key,
102
- model_name="llama3-8b-8192", # Free tier model
103
- base_url="https://api.groq.com/openai/v1",
104
- max_tokens=1000,
105
- temperature=0.1
106
- ))
107
-
 
 
108
  if not configs:
109
  raise ValueError(
110
  "No LLM API keys found in environment. "
111
  "Please set OPENROUTER_API_KEY or GROQ_API_KEY"
112
  )
113
-
114
  return cls(configs)
115
 
116
- def generate_response(
117
- self,
118
- prompt: str,
119
- max_retries: int = 2
120
- ) -> LLMResponse:
121
  """
122
  Generate response from LLM with fallback support.
123
-
124
  Args:
125
  prompt: Input prompt for the LLM
126
  max_retries: Maximum retry attempts per provider
127
-
128
  Returns:
129
  LLMResponse with generated content or error information
130
  """
131
  last_error = None
132
-
133
  # Try each provider configuration
134
  for attempt in range(len(self.configs)):
135
  config = self.configs[self.current_config_index]
136
-
137
  try:
138
  logger.debug(f"Attempting generation with {config.provider}")
139
  response = self._call_provider(config, prompt, max_retries)
140
-
141
  if response.success:
142
- logger.info(f"Successfully generated response using {config.provider}")
 
 
143
  return response
144
-
145
  last_error = response.error_message
146
  logger.warning(f"Provider {config.provider} failed: {last_error}")
147
-
148
  except Exception as e:
149
  last_error = str(e)
150
  logger.error(f"Error with provider {config.provider}: {last_error}")
151
-
152
  # Move to next provider
153
- self.current_config_index = (self.current_config_index + 1) % len(self.configs)
154
-
 
 
155
  # All providers failed
156
  logger.error("All LLM providers failed")
157
  return LLMResponse(
@@ -161,83 +169,79 @@ class LLMService:
161
  usage={},
162
  response_time=0.0,
163
  success=False,
164
- error_message=f"All providers failed. Last error: {last_error}"
165
  )
166
 
167
  def _call_provider(
168
- self,
169
- config: LLMConfig,
170
- prompt: str,
171
- max_retries: int
172
  ) -> LLMResponse:
173
  """
174
  Make API call to specific provider with retry logic.
175
-
176
  Args:
177
  config: Provider configuration
178
  prompt: Input prompt
179
  max_retries: Maximum retry attempts
180
-
181
  Returns:
182
  LLMResponse from the provider
183
  """
184
  start_time = time.time()
185
-
186
  for attempt in range(max_retries + 1):
187
  try:
188
  headers = {
189
  "Authorization": f"Bearer {config.api_key}",
190
- "Content-Type": "application/json"
191
  }
192
-
193
  # Add provider-specific headers
194
  if config.provider == "openrouter":
195
- headers["HTTP-Referer"] = "https://github.com/sethmcknight/msse-ai-engineering"
 
 
196
  headers["X-Title"] = "MSSE RAG Application"
197
-
198
  payload = {
199
  "model": config.model_name,
200
- "messages": [
201
- {
202
- "role": "user",
203
- "content": prompt
204
- }
205
- ],
206
  "max_tokens": config.max_tokens,
207
- "temperature": config.temperature
208
  }
209
-
210
  response = requests.post(
211
  f"{config.base_url}/chat/completions",
212
  headers=headers,
213
  json=payload,
214
- timeout=config.timeout
215
  )
216
-
217
  response.raise_for_status()
218
  data = response.json()
219
-
220
  # Extract response content
221
  content = data["choices"][0]["message"]["content"]
222
  usage = data.get("usage", {})
223
-
224
  response_time = time.time() - start_time
225
-
226
  return LLMResponse(
227
  content=content,
228
  provider=config.provider,
229
  model=config.model_name,
230
  usage=usage,
231
  response_time=response_time,
232
- success=True
233
  )
234
-
235
  except requests.exceptions.RequestException as e:
236
- logger.warning(f"Request failed for {config.provider} (attempt {attempt + 1}): {e}")
 
 
237
  if attempt < max_retries:
238
- time.sleep(2 ** attempt) # Exponential backoff
239
  continue
240
-
241
  return LLMResponse(
242
  content="",
243
  provider=config.provider,
@@ -245,9 +249,9 @@ class LLMService:
245
  usage={},
246
  response_time=time.time() - start_time,
247
  success=False,
248
- error_message=str(e)
249
  )
250
-
251
  except Exception as e:
252
  logger.error(f"Unexpected error with {config.provider}: {e}")
253
  return LLMResponse(
@@ -257,44 +261,44 @@ class LLMService:
257
  usage={},
258
  response_time=time.time() - start_time,
259
  success=False,
260
- error_message=str(e)
261
  )
262
 
263
  def health_check(self) -> Dict[str, Any]:
264
  """
265
  Check health status of all configured providers.
266
-
267
  Returns:
268
  Dictionary with provider health status
269
  """
270
  health_status = {}
271
-
272
  for config in self.configs:
273
  try:
274
  # Simple test prompt
275
  test_response = self._call_provider(
276
- config,
277
- "Hello, this is a test. Please respond with 'OK'.",
278
- max_retries=1
279
  )
280
-
281
  health_status[config.provider] = {
282
  "status": "healthy" if test_response.success else "unhealthy",
283
  "model": config.model_name,
284
  "response_time": test_response.response_time,
285
- "error": test_response.error_message
286
  }
287
-
288
  except Exception as e:
289
  health_status[config.provider] = {
290
  "status": "unhealthy",
291
  "model": config.model_name,
292
  "response_time": 0.0,
293
- "error": str(e)
294
  }
295
-
296
  return health_status
297
 
298
  def get_available_providers(self) -> List[str]:
299
  """Get list of available provider names."""
300
- return [config.provider for config in self.configs]
 
1
  """
2
  LLM Service for RAG Application
3
 
4
+ This module provides integration with Large Language Models through multiple
5
+ providers including OpenRouter and Groq, with fallback capabilities and
6
+ comprehensive error handling.
7
  """
8
 
9
  import logging
10
  import os
11
  import time
 
 
12
  from dataclasses import dataclass
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import requests
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
20
  @dataclass
21
  class LLMConfig:
22
  """Configuration for LLM providers."""
23
+
24
  provider: str # "openrouter" or "groq"
25
  api_key: str
26
  model_name: str
 
33
  @dataclass
34
  class LLMResponse:
35
  """Standardized response from LLM providers."""
36
+
37
  content: str
38
  provider: str
39
  model: str
 
46
  class LLMService:
47
  """
48
  Service for interacting with Large Language Models.
49
+
50
  Supports multiple providers with automatic fallback and retry logic.
51
  Designed for corporate policy Q&A with appropriate guardrails.
52
  """
 
54
  def __init__(self, configs: List[LLMConfig]):
55
  """
56
  Initialize LLMService with provider configurations.
57
+
58
  Args:
59
  configs: List of LLMConfig objects for different providers
60
+
61
  Raises:
62
  ValueError: If no valid configurations provided
63
  """
64
  if not configs:
65
  raise ValueError("At least one LLM configuration must be provided")
66
+
67
  self.configs = configs
68
  self.current_config_index = 0
69
  logger.info(f"LLMService initialized with {len(configs)} provider(s)")
70
 
71
  @classmethod
72
+ def from_environment(cls) -> "LLMService":
73
  """
74
  Create LLMService instance from environment variables.
75
+
76
  Expected environment variables:
77
  - OPENROUTER_API_KEY: API key for OpenRouter
78
  - GROQ_API_KEY: API key for Groq
79
+
80
  Returns:
81
  LLMService instance with available providers
82
+
83
  Raises:
84
  ValueError: If no API keys found in environment
85
  """
86
  configs = []
87
+
88
  # OpenRouter configuration
89
  openrouter_key = os.getenv("OPENROUTER_API_KEY")
90
  if openrouter_key:
91
+ configs.append(
92
+ LLMConfig(
93
+ provider="openrouter",
94
+ api_key=openrouter_key,
95
+ model_name="microsoft/wizardlm-2-8x22b", # Free tier model
96
+ base_url="https://openrouter.ai/api/v1",
97
+ max_tokens=1000,
98
+ temperature=0.1,
99
+ )
100
+ )
101
+
102
+ # Groq configuration
103
  groq_key = os.getenv("GROQ_API_KEY")
104
  if groq_key:
105
+ configs.append(
106
+ LLMConfig(
107
+ provider="groq",
108
+ api_key=groq_key,
109
+ model_name="llama3-8b-8192", # Free tier model
110
+ base_url="https://api.groq.com/openai/v1",
111
+ max_tokens=1000,
112
+ temperature=0.1,
113
+ )
114
+ )
115
+
116
  if not configs:
117
  raise ValueError(
118
  "No LLM API keys found in environment. "
119
  "Please set OPENROUTER_API_KEY or GROQ_API_KEY"
120
  )
121
+
122
  return cls(configs)
123
 
124
+ def generate_response(self, prompt: str, max_retries: int = 2) -> LLMResponse:
 
 
 
 
125
  """
126
  Generate response from LLM with fallback support.
127
+
128
  Args:
129
  prompt: Input prompt for the LLM
130
  max_retries: Maximum retry attempts per provider
131
+
132
  Returns:
133
  LLMResponse with generated content or error information
134
  """
135
  last_error = None
136
+
137
  # Try each provider configuration
138
  for attempt in range(len(self.configs)):
139
  config = self.configs[self.current_config_index]
140
+
141
  try:
142
  logger.debug(f"Attempting generation with {config.provider}")
143
  response = self._call_provider(config, prompt, max_retries)
144
+
145
  if response.success:
146
+ logger.info(
147
+ f"Successfully generated response using {config.provider}"
148
+ )
149
  return response
150
+
151
  last_error = response.error_message
152
  logger.warning(f"Provider {config.provider} failed: {last_error}")
153
+
154
  except Exception as e:
155
  last_error = str(e)
156
  logger.error(f"Error with provider {config.provider}: {last_error}")
157
+
158
  # Move to next provider
159
+ self.current_config_index = (self.current_config_index + 1) % len(
160
+ self.configs
161
+ )
162
+
163
  # All providers failed
164
  logger.error("All LLM providers failed")
165
  return LLMResponse(
 
169
  usage={},
170
  response_time=0.0,
171
  success=False,
172
+ error_message=f"All providers failed. Last error: {last_error}",
173
  )
174
 
175
  def _call_provider(
176
+ self, config: LLMConfig, prompt: str, max_retries: int
 
 
 
177
  ) -> LLMResponse:
178
  """
179
  Make API call to specific provider with retry logic.
180
+
181
  Args:
182
  config: Provider configuration
183
  prompt: Input prompt
184
  max_retries: Maximum retry attempts
185
+
186
  Returns:
187
  LLMResponse from the provider
188
  """
189
  start_time = time.time()
190
+
191
  for attempt in range(max_retries + 1):
192
  try:
193
  headers = {
194
  "Authorization": f"Bearer {config.api_key}",
195
+ "Content-Type": "application/json",
196
  }
197
+
198
  # Add provider-specific headers
199
  if config.provider == "openrouter":
200
+ headers["HTTP-Referer"] = (
201
+ "https://github.com/sethmcknight/msse-ai-engineering"
202
+ )
203
  headers["X-Title"] = "MSSE RAG Application"
204
+
205
  payload = {
206
  "model": config.model_name,
207
+ "messages": [{"role": "user", "content": prompt}],
 
 
 
 
 
208
  "max_tokens": config.max_tokens,
209
+ "temperature": config.temperature,
210
  }
211
+
212
  response = requests.post(
213
  f"{config.base_url}/chat/completions",
214
  headers=headers,
215
  json=payload,
216
+ timeout=config.timeout,
217
  )
218
+
219
  response.raise_for_status()
220
  data = response.json()
221
+
222
  # Extract response content
223
  content = data["choices"][0]["message"]["content"]
224
  usage = data.get("usage", {})
225
+
226
  response_time = time.time() - start_time
227
+
228
  return LLMResponse(
229
  content=content,
230
  provider=config.provider,
231
  model=config.model_name,
232
  usage=usage,
233
  response_time=response_time,
234
+ success=True,
235
  )
236
+
237
  except requests.exceptions.RequestException as e:
238
+ logger.warning(
239
+ f"Request failed for {config.provider} (attempt {attempt + 1}): {e}"
240
+ )
241
  if attempt < max_retries:
242
+ time.sleep(2**attempt) # Exponential backoff
243
  continue
244
+
245
  return LLMResponse(
246
  content="",
247
  provider=config.provider,
 
249
  usage={},
250
  response_time=time.time() - start_time,
251
  success=False,
252
+ error_message=str(e),
253
  )
254
+
255
  except Exception as e:
256
  logger.error(f"Unexpected error with {config.provider}: {e}")
257
  return LLMResponse(
 
261
  usage={},
262
  response_time=time.time() - start_time,
263
  success=False,
264
+ error_message=str(e),
265
  )
266
 
267
  def health_check(self) -> Dict[str, Any]:
268
  """
269
  Check health status of all configured providers.
270
+
271
  Returns:
272
  Dictionary with provider health status
273
  """
274
  health_status = {}
275
+
276
  for config in self.configs:
277
  try:
278
  # Simple test prompt
279
  test_response = self._call_provider(
280
+ config,
281
+ "Hello, this is a test. Please respond with 'OK'.",
282
+ max_retries=1,
283
  )
284
+
285
  health_status[config.provider] = {
286
  "status": "healthy" if test_response.success else "unhealthy",
287
  "model": config.model_name,
288
  "response_time": test_response.response_time,
289
+ "error": test_response.error_message,
290
  }
291
+
292
  except Exception as e:
293
  health_status[config.provider] = {
294
  "status": "unhealthy",
295
  "model": config.model_name,
296
  "response_time": 0.0,
297
+ "error": str(e),
298
  }
299
+
300
  return health_status
301
 
302
  def get_available_providers(self) -> List[str]:
303
  """Get list of available provider names."""
304
+ return [config.provider for config in self.configs]
src/llm/prompt_templates.py CHANGED
@@ -1,17 +1,18 @@
1
  """
2
  Prompt Templates for Corporate Policy Q&A
3
 
4
- This module contains predefined prompt templates optimized for
5
  corporate policy question-answering with proper citation requirements.
6
  """
7
 
8
- from typing import Dict, List
9
  from dataclasses import dataclass
 
10
 
11
 
12
  @dataclass
13
  class PromptTemplate:
14
  """Template for generating prompts with context and citations."""
 
15
  system_prompt: str
16
  user_template: str
17
  citation_format: str
@@ -20,7 +21,7 @@ class PromptTemplate:
20
  class PromptTemplates:
21
  """
22
  Collection of prompt templates for different types of policy questions.
23
-
24
  Templates are designed to ensure:
25
  - Accurate responses based on provided context
26
  - Proper citation of source documents
@@ -29,15 +30,15 @@ class PromptTemplates:
29
  """
30
 
31
  # System prompt for corporate policy assistant
32
- SYSTEM_PROMPT = """You are a helpful corporate policy assistant. Your job is to answer questions about company policies based ONLY on the provided context documents.
33
 
34
  IMPORTANT GUIDELINES:
35
  1. Answer questions using ONLY the information provided in the context
36
- 2. If the context doesn't contain enough information to answer the question, say so explicitly
37
  3. Always cite your sources using the format: [Source: filename.md]
38
  4. Be accurate, concise, and professional
39
- 5. If asked about topics not covered in the policies, politely redirect to HR or appropriate department
40
- 6. Do not make assumptions or provide information not explicitly stated in the context
41
 
42
  Your responses should be helpful while staying strictly within the scope of the provided corporate policies."""
43
 
@@ -45,26 +46,26 @@ Your responses should be helpful while staying strictly within the scope of the
45
  def get_policy_qa_template(cls) -> PromptTemplate:
46
  """
47
  Get the standard template for policy question-answering.
48
-
49
  Returns:
50
  PromptTemplate configured for corporate policy Q&A
51
  """
52
  return PromptTemplate(
53
  system_prompt=cls.SYSTEM_PROMPT,
54
- user_template="""Based on the following corporate policy documents, please answer this question: {question}
55
 
56
  CONTEXT DOCUMENTS:
57
  {context}
58
 
59
- Please provide a clear, accurate answer based on the information above. Include citations for all information using the format [Source: filename.md].""",
60
- citation_format="[Source: {filename}]"
61
  )
62
 
63
  @classmethod
64
  def get_clarification_template(cls) -> PromptTemplate:
65
  """
66
  Get template for when clarification is needed.
67
-
68
  Returns:
69
  PromptTemplate for clarification requests
70
  """
@@ -75,19 +76,19 @@ Please provide a clear, accurate answer based on the information above. Include
75
  CONTEXT DOCUMENTS:
76
  {context}
77
 
78
- The provided context documents don't contain sufficient information to fully answer this question. Please provide a helpful response that:
79
  1. Acknowledges what information is available (if any)
80
  2. Clearly states what information is missing
81
  3. Suggests appropriate next steps (contact HR, check other resources, etc.)
82
  4. Cites any relevant sources using [Source: filename.md] format""",
83
- citation_format="[Source: {filename}]"
84
  )
85
 
86
  @classmethod
87
  def get_off_topic_template(cls) -> PromptTemplate:
88
  """
89
  Get template for off-topic questions.
90
-
91
  Returns:
92
  PromptTemplate for redirecting off-topic questions
93
  """
@@ -95,122 +96,122 @@ The provided context documents don't contain sufficient information to fully ans
95
  system_prompt=cls.SYSTEM_PROMPT,
96
  user_template="""The user asked: {question}
97
 
98
- This question appears to be outside the scope of our corporate policies. Please provide a polite response that:
99
  1. Acknowledges the question
100
  2. Explains that this falls outside corporate policy documentation
101
  3. Suggests appropriate resources (HR, IT, management, etc.)
102
  4. Offers to help with any policy-related questions instead""",
103
- citation_format=""
104
  )
105
 
106
  @staticmethod
107
  def format_context(search_results: List[Dict]) -> str:
108
  """
109
  Format search results into context for the prompt.
110
-
111
  Args:
112
  search_results: List of search results from SearchService
113
-
114
  Returns:
115
  Formatted context string for the prompt
116
  """
117
  if not search_results:
118
  return "No relevant policy documents found."
119
-
120
  context_parts = []
121
  for i, result in enumerate(search_results[:5], 1): # Limit to top 5 results
122
  filename = result.get("metadata", {}).get("filename", "unknown")
123
  content = result.get("content", "").strip()
124
  similarity = result.get("similarity_score", 0.0)
125
-
126
  context_parts.append(
127
  f"Document {i}: {filename} (relevance: {similarity:.2f})\n"
128
  f"Content: {content}\n"
129
  )
130
-
131
  return "\n---\n".join(context_parts)
132
 
133
  @staticmethod
134
  def extract_citations(response: str) -> List[str]:
135
  """
136
  Extract citations from LLM response.
137
-
138
  Args:
139
  response: Generated response text
140
-
141
  Returns:
142
  List of extracted filenames from citations
143
  """
144
  import re
145
-
146
  # Pattern to match [Source: filename.md] format
147
- citation_pattern = r'\[Source:\s*([^\]]+)\]'
148
  matches = re.findall(citation_pattern, response)
149
-
150
  # Clean up filenames
151
  citations = []
152
  for match in matches:
153
  filename = match.strip()
154
  if filename and filename not in citations:
155
  citations.append(filename)
156
-
157
  return citations
158
 
159
  @staticmethod
160
- def validate_citations(response: str, available_sources: List[str]) -> Dict[str, bool]:
 
 
161
  """
162
  Validate that all citations in response refer to available sources.
163
-
164
  Args:
165
  response: Generated response text
166
  available_sources: List of available source filenames
167
-
168
  Returns:
169
  Dictionary mapping citations to their validity
170
  """
171
  citations = PromptTemplates.extract_citations(response)
172
  validation = {}
173
-
174
  for citation in citations:
175
  # Check if citation matches any available source
176
- valid = any(citation in source or source in citation
177
- for source in available_sources)
 
178
  validation[citation] = valid
179
-
180
  return validation
181
 
182
  @staticmethod
183
- def add_fallback_citations(
184
- response: str,
185
- search_results: List[Dict]
186
- ) -> str:
187
  """
188
  Add citations to response if none were provided by LLM.
189
-
190
  Args:
191
  response: Generated response text
192
  search_results: Original search results used for context
193
-
194
  Returns:
195
  Response with added citations if needed
196
  """
197
  existing_citations = PromptTemplates.extract_citations(response)
198
-
199
  if existing_citations:
200
  return response # Already has citations
201
-
202
  if not search_results:
203
  return response # No sources to cite
204
-
205
  # Add citations from top search results
206
  top_sources = []
207
  for result in search_results[:3]: # Top 3 sources
208
  filename = result.get("metadata", {}).get("filename", "")
209
  if filename and filename not in top_sources:
210
  top_sources.append(filename)
211
-
212
  if top_sources:
213
  citation_text = " [Sources: " + ", ".join(top_sources) + "]"
214
  return response + citation_text
215
-
216
- return response
 
1
  """
2
  Prompt Templates for Corporate Policy Q&A
3
 
4
+ This module contains predefined prompt templates optimized for
5
  corporate policy question-answering with proper citation requirements.
6
  """
7
 
 
8
  from dataclasses import dataclass
9
+ from typing import Dict, List
10
 
11
 
12
  @dataclass
13
  class PromptTemplate:
14
  """Template for generating prompts with context and citations."""
15
+
16
  system_prompt: str
17
  user_template: str
18
  citation_format: str
 
21
  class PromptTemplates:
22
  """
23
  Collection of prompt templates for different types of policy questions.
24
+
25
  Templates are designed to ensure:
26
  - Accurate responses based on provided context
27
  - Proper citation of source documents
 
30
  """
31
 
32
  # System prompt for corporate policy assistant
33
+ SYSTEM_PROMPT = """You are a helpful corporate policy assistant. Your job is to answer questions about company policies based ONLY on the provided context documents. # noqa: E501
34
 
35
  IMPORTANT GUIDELINES:
36
  1. Answer questions using ONLY the information provided in the context
37
+ 2. If the context doesn't contain enough information to answer the question, say so explicitly # noqa: E501
38
  3. Always cite your sources using the format: [Source: filename.md]
39
  4. Be accurate, concise, and professional
40
+ 5. If asked about topics not covered in the policies, politely redirect to HR or appropriate department # noqa: E501
41
+ 6. Do not make assumptions or provide information not explicitly stated in the context # noqa: E501
42
 
43
  Your responses should be helpful while staying strictly within the scope of the provided corporate policies."""
44
 
 
46
  def get_policy_qa_template(cls) -> PromptTemplate:
47
  """
48
  Get the standard template for policy question-answering.
49
+
50
  Returns:
51
  PromptTemplate configured for corporate policy Q&A
52
  """
53
  return PromptTemplate(
54
  system_prompt=cls.SYSTEM_PROMPT,
55
+ user_template="""Based on the following corporate policy documents, please answer this question: {question} # noqa: E501
56
 
57
  CONTEXT DOCUMENTS:
58
  {context}
59
 
60
+ Please provide a clear, accurate answer based on the information above. Include citations for all information using the format [Source: filename.md].""", # noqa: E501
61
+ citation_format="[Source: {filename}]",
62
  )
63
 
64
  @classmethod
65
  def get_clarification_template(cls) -> PromptTemplate:
66
  """
67
  Get template for when clarification is needed.
68
+
69
  Returns:
70
  PromptTemplate for clarification requests
71
  """
 
76
  CONTEXT DOCUMENTS:
77
  {context}
78
 
79
+ The provided context documents don't contain sufficient information to fully answer this question. Please provide a helpful response that: # noqa: E501
80
  1. Acknowledges what information is available (if any)
81
  2. Clearly states what information is missing
82
  3. Suggests appropriate next steps (contact HR, check other resources, etc.)
83
  4. Cites any relevant sources using [Source: filename.md] format""",
84
+ citation_format="[Source: {filename}]",
85
  )
86
 
87
  @classmethod
88
  def get_off_topic_template(cls) -> PromptTemplate:
89
  """
90
  Get template for off-topic questions.
91
+
92
  Returns:
93
  PromptTemplate for redirecting off-topic questions
94
  """
 
96
  system_prompt=cls.SYSTEM_PROMPT,
97
  user_template="""The user asked: {question}
98
 
99
+ This question appears to be outside the scope of our corporate policies. Please provide a polite response that: # noqa: E501
100
  1. Acknowledges the question
101
  2. Explains that this falls outside corporate policy documentation
102
  3. Suggests appropriate resources (HR, IT, management, etc.)
103
  4. Offers to help with any policy-related questions instead""",
104
+ citation_format="",
105
  )
106
 
107
  @staticmethod
108
  def format_context(search_results: List[Dict]) -> str:
109
  """
110
  Format search results into context for the prompt.
111
+
112
  Args:
113
  search_results: List of search results from SearchService
114
+
115
  Returns:
116
  Formatted context string for the prompt
117
  """
118
  if not search_results:
119
  return "No relevant policy documents found."
120
+
121
  context_parts = []
122
  for i, result in enumerate(search_results[:5], 1): # Limit to top 5 results
123
  filename = result.get("metadata", {}).get("filename", "unknown")
124
  content = result.get("content", "").strip()
125
  similarity = result.get("similarity_score", 0.0)
126
+
127
  context_parts.append(
128
  f"Document {i}: {filename} (relevance: {similarity:.2f})\n"
129
  f"Content: {content}\n"
130
  )
131
+
132
  return "\n---\n".join(context_parts)
133
 
134
  @staticmethod
135
  def extract_citations(response: str) -> List[str]:
136
  """
137
  Extract citations from LLM response.
138
+
139
  Args:
140
  response: Generated response text
141
+
142
  Returns:
143
  List of extracted filenames from citations
144
  """
145
  import re
146
+
147
  # Pattern to match [Source: filename.md] format
148
+ citation_pattern = r"\[Source:\s*([^\]]+)\]"
149
  matches = re.findall(citation_pattern, response)
150
+
151
  # Clean up filenames
152
  citations = []
153
  for match in matches:
154
  filename = match.strip()
155
  if filename and filename not in citations:
156
  citations.append(filename)
157
+
158
  return citations
159
 
160
  @staticmethod
161
+ def validate_citations(
162
+ response: str, available_sources: List[str]
163
+ ) -> Dict[str, bool]:
164
  """
165
  Validate that all citations in response refer to available sources.
166
+
167
  Args:
168
  response: Generated response text
169
  available_sources: List of available source filenames
170
+
171
  Returns:
172
  Dictionary mapping citations to their validity
173
  """
174
  citations = PromptTemplates.extract_citations(response)
175
  validation = {}
176
+
177
  for citation in citations:
178
  # Check if citation matches any available source
179
+ valid = any(
180
+ citation in source or source in citation for source in available_sources
181
+ )
182
  validation[citation] = valid
183
+
184
  return validation
185
 
186
  @staticmethod
187
+ def add_fallback_citations(response: str, search_results: List[Dict]) -> str:
 
 
 
188
  """
189
  Add citations to response if none were provided by LLM.
190
+
191
  Args:
192
  response: Generated response text
193
  search_results: Original search results used for context
194
+
195
  Returns:
196
  Response with added citations if needed
197
  """
198
  existing_citations = PromptTemplates.extract_citations(response)
199
+
200
  if existing_citations:
201
  return response # Already has citations
202
+
203
  if not search_results:
204
  return response # No sources to cite
205
+
206
  # Add citations from top search results
207
  top_sources = []
208
  for result in search_results[:3]: # Top 3 sources
209
  filename = result.get("metadata", {}).get("filename", "")
210
  if filename and filename not in top_sources:
211
  top_sources.append(filename)
212
+
213
  if top_sources:
214
  citation_text = " [Sources: " + ", ".join(top_sources) + "]"
215
  return response + citation_text
216
+
217
+ return response
src/rag/__init__.py CHANGED
@@ -7,4 +7,4 @@ combining semantic search with LLM-based response generation.
7
  Classes:
8
  RAGPipeline: Main RAG orchestration service
9
  ResponseFormatter: Formats LLM responses with citations and metadata
10
- """
 
7
  Classes:
8
  RAGPipeline: Main RAG orchestration service
9
  ResponseFormatter: Formats LLM responses with citations and metadata
10
+ """
src/rag/rag_pipeline.py CHANGED
@@ -7,14 +7,15 @@ combining semantic search, context management, and LLM generation.
7
 
8
  import logging
9
  import time
10
- from typing import Any, Dict, List, Optional
11
  from dataclasses import dataclass
 
 
 
 
 
12
 
13
  # Import our modules
14
  from src.search.search_service import SearchService
15
- from src.llm.llm_service import LLMService, LLMResponse
16
- from src.llm.context_manager import ContextManager, ContextConfig
17
- from src.llm.prompt_templates import PromptTemplates, PromptTemplate
18
 
19
  logger = logging.getLogger(__name__)
20
 
@@ -22,6 +23,7 @@ logger = logging.getLogger(__name__)
22
  @dataclass
23
  class RAGConfig:
24
  """Configuration for RAG pipeline."""
 
25
  max_context_length: int = 3000
26
  search_top_k: int = 10
27
  search_threshold: float = 0.1
@@ -33,6 +35,7 @@ class RAGConfig:
33
  @dataclass
34
  class RAGResponse:
35
  """Response from RAG pipeline with metadata."""
 
36
  answer: str
37
  sources: List[Dict[str, Any]]
38
  confidence: float
@@ -48,7 +51,7 @@ class RAGResponse:
48
  class RAGPipeline:
49
  """
50
  Complete RAG pipeline orchestrating retrieval and generation.
51
-
52
  Combines:
53
  - Semantic search for context retrieval
54
  - Context optimization and management
@@ -60,84 +63,84 @@ class RAGPipeline:
60
  self,
61
  search_service: SearchService,
62
  llm_service: LLMService,
63
- config: Optional[RAGConfig] = None
64
  ):
65
  """
66
  Initialize RAG pipeline with required services.
67
-
68
  Args:
69
  search_service: Configured SearchService instance
70
- llm_service: Configured LLMService instance
71
  config: RAG configuration, uses defaults if None
72
  """
73
  self.search_service = search_service
74
  self.llm_service = llm_service
75
  self.config = config or RAGConfig()
76
-
77
  # Initialize context manager with matching config
78
  context_config = ContextConfig(
79
  max_context_length=self.config.max_context_length,
80
  max_results=self.config.search_top_k,
81
- min_similarity=self.config.search_threshold
82
  )
83
  self.context_manager = ContextManager(context_config)
84
-
85
  # Initialize prompt templates
86
  self.prompt_templates = PromptTemplates()
87
-
88
  logger.info("RAGPipeline initialized successfully")
89
 
90
  def generate_answer(self, question: str) -> RAGResponse:
91
  """
92
  Generate answer to question using RAG pipeline.
93
-
94
  Args:
95
  question: User's question about corporate policies
96
-
97
  Returns:
98
  RAGResponse with answer and metadata
99
  """
100
  start_time = time.time()
101
-
102
  try:
103
  # Step 1: Retrieve relevant context
104
  logger.debug(f"Starting RAG pipeline for question: {question[:100]}...")
105
-
106
  search_results = self._retrieve_context(question)
107
-
108
  if not search_results:
109
  return self._create_no_context_response(question, start_time)
110
-
111
  # Step 2: Prepare and optimize context
112
  context, filtered_results = self.context_manager.prepare_context(
113
  search_results, question
114
  )
115
-
116
  # Step 3: Check if we have sufficient context
117
  quality_metrics = self.context_manager.validate_context_quality(
118
  context, question, self.config.min_similarity_for_answer
119
  )
120
-
121
  if not quality_metrics["passes_validation"]:
122
  return self._create_insufficient_context_response(
123
  question, filtered_results, start_time
124
  )
125
-
126
  # Step 4: Generate response using LLM
127
  llm_response = self._generate_llm_response(question, context)
128
-
129
  if not llm_response.success:
130
  return self._create_llm_error_response(
131
  question, llm_response.error_message, start_time
132
  )
133
-
134
  # Step 5: Process and validate response
135
  processed_response = self._process_response(
136
  llm_response.content, filtered_results
137
  )
138
-
139
  processing_time = time.time() - start_time
140
-
141
  return RAGResponse(
142
  answer=processed_response,
143
  sources=self._format_sources(filtered_results),
@@ -147,13 +150,16 @@ class RAGPipeline:
147
  llm_model=llm_response.model,
148
  context_length=len(context),
149
  search_results_count=len(search_results),
150
- success=True
151
  )
152
-
153
  except Exception as e:
154
  logger.error(f"RAG pipeline error: {e}")
155
  return RAGResponse(
156
- answer="I apologize, but I encountered an error processing your question. Please try again or contact support.",
 
 
 
157
  sources=[],
158
  confidence=0.0,
159
  processing_time=time.time() - start_time,
@@ -162,7 +168,7 @@ class RAGPipeline:
162
  context_length=0,
163
  search_results_count=0,
164
  success=False,
165
- error_message=str(e)
166
  )
167
 
168
  def _retrieve_context(self, question: str) -> List[Dict[str, Any]]:
@@ -171,12 +177,12 @@ class RAGPipeline:
171
  results = self.search_service.search(
172
  query=question,
173
  top_k=self.config.search_top_k,
174
- threshold=self.config.search_threshold
175
  )
176
-
177
  logger.debug(f"Retrieved {len(results)} search results")
178
  return results
179
-
180
  except Exception as e:
181
  logger.error(f"Context retrieval error: {e}")
182
  return []
@@ -184,95 +190,108 @@ class RAGPipeline:
184
  def _generate_llm_response(self, question: str, context: str) -> LLMResponse:
185
  """Generate response using LLM with formatted prompt."""
186
  template = self.prompt_templates.get_policy_qa_template()
187
-
188
  # Format the prompt
189
  formatted_prompt = template.user_template.format(
190
- question=question,
191
- context=context
192
  )
193
-
194
  # Add system prompt (if LLM service supports it in future)
195
  full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}"
196
-
197
  return self.llm_service.generate_response(full_prompt)
198
 
199
  def _process_response(
200
- self,
201
- raw_response: str,
202
- search_results: List[Dict[str, Any]]
203
  ) -> str:
204
  """Process and validate LLM response."""
205
-
206
  # Ensure citations are present
207
  response_with_citations = self.prompt_templates.add_fallback_citations(
208
  raw_response, search_results
209
  )
210
-
211
  # Validate citations if enabled
212
  if self.config.enable_citation_validation:
213
  available_sources = [
214
  result.get("metadata", {}).get("filename", "")
215
  for result in search_results
216
  ]
217
-
218
  citation_validation = self.prompt_templates.validate_citations(
219
  response_with_citations, available_sources
220
  )
221
-
222
  # Log any invalid citations
223
  invalid_citations = [
224
- citation for citation, valid in citation_validation.items()
225
- if not valid
226
  ]
227
-
228
  if invalid_citations:
229
  logger.warning(f"Invalid citations detected: {invalid_citations}")
230
-
231
  # Truncate if too long
232
  if len(response_with_citations) > self.config.max_response_length:
233
- truncated = response_with_citations[:self.config.max_response_length - 3] + "..."
234
- logger.warning(f"Response truncated from {len(response_with_citations)} to {len(truncated)} characters")
 
 
 
 
 
235
  return truncated
236
-
237
  return response_with_citations
238
 
239
- def _format_sources(self, search_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 
240
  """Format search results for response metadata."""
241
  sources = []
242
-
243
  for result in search_results:
244
  metadata = result.get("metadata", {})
245
- sources.append({
246
- "document": metadata.get("filename", "unknown"),
247
- "chunk_id": result.get("chunk_id", ""),
248
- "relevance_score": result.get("similarity_score", 0.0),
249
- "excerpt": result.get("content", "")[:200] + "..." if len(result.get("content", "")) > 200 else result.get("content", "")
250
- })
251
-
 
 
 
 
 
 
252
  return sources
253
 
254
  def _calculate_confidence(
255
- self,
256
- quality_metrics: Dict[str, Any],
257
- llm_response: LLMResponse
258
  ) -> float:
259
  """Calculate confidence score for the response."""
260
-
261
  # Base confidence on context quality
262
  context_confidence = quality_metrics.get("estimated_relevance", 0.0)
263
-
264
  # Adjust based on LLM response time (faster might indicate more confidence)
265
  time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0))
266
-
267
  # Combine factors
268
  confidence = (context_confidence * 0.7) + (time_factor * 0.3)
269
-
270
  return min(1.0, max(0.0, confidence))
271
 
272
- def _create_no_context_response(self, question: str, start_time: float) -> RAGResponse:
 
 
273
  """Create response when no relevant context found."""
274
  return RAGResponse(
275
- answer="I couldn't find any relevant information in our corporate policies to answer your question. Please contact HR or check other company resources for assistance.",
 
 
 
 
276
  sources=[],
277
  confidence=0.0,
278
  processing_time=time.time() - start_time,
@@ -280,18 +299,19 @@ class RAGPipeline:
280
  llm_model="none",
281
  context_length=0,
282
  search_results_count=0,
283
- success=True # This is a valid "no answer" response
284
  )
285
 
286
  def _create_insufficient_context_response(
287
- self,
288
- question: str,
289
- results: List[Dict[str, Any]],
290
- start_time: float
291
  ) -> RAGResponse:
292
  """Create response when context quality is insufficient."""
293
  return RAGResponse(
294
- answer="I found some potentially relevant information, but it doesn't provide enough detail to fully answer your question. Please contact HR for more specific guidance or rephrase your question.",
 
 
 
 
295
  sources=self._format_sources(results),
296
  confidence=0.2,
297
  processing_time=time.time() - start_time,
@@ -299,18 +319,18 @@ class RAGPipeline:
299
  llm_model="none",
300
  context_length=0,
301
  search_results_count=len(results),
302
- success=True
303
  )
304
 
305
  def _create_llm_error_response(
306
- self,
307
- question: str,
308
- error_message: str,
309
- start_time: float
310
  ) -> RAGResponse:
311
  """Create response when LLM generation fails."""
312
  return RAGResponse(
313
- answer="I apologize, but I'm currently unable to generate a response. Please try again in a moment or contact support if the issue persists.",
 
 
 
314
  sources=[],
315
  confidence=0.0,
316
  processing_time=time.time() - start_time,
@@ -319,54 +339,54 @@ class RAGPipeline:
319
  context_length=0,
320
  search_results_count=0,
321
  success=False,
322
- error_message=error_message
323
  )
324
 
325
  def health_check(self) -> Dict[str, Any]:
326
  """
327
  Perform health check on all pipeline components.
328
-
329
  Returns:
330
  Dictionary with component health status
331
  """
332
- health_status = {
333
- "pipeline": "healthy",
334
- "components": {}
335
- }
336
-
337
  try:
338
  # Check search service
339
- test_results = self.search_service.search("test query", top_k=1, threshold=0.0)
 
 
340
  health_status["components"]["search_service"] = {
341
  "status": "healthy",
342
- "test_results_count": len(test_results)
343
  }
344
  except Exception as e:
345
  health_status["components"]["search_service"] = {
346
  "status": "unhealthy",
347
- "error": str(e)
348
  }
349
  health_status["pipeline"] = "degraded"
350
-
351
  try:
352
  # Check LLM service
353
  llm_health = self.llm_service.health_check()
354
  health_status["components"]["llm_service"] = llm_health
355
-
356
  # Pipeline is unhealthy if all LLM providers are down
357
  healthy_providers = sum(
358
- 1 for provider_status in llm_health.values()
 
359
  if provider_status.get("status") == "healthy"
360
  )
361
-
362
  if healthy_providers == 0:
363
  health_status["pipeline"] = "unhealthy"
364
-
365
  except Exception as e:
366
  health_status["components"]["llm_service"] = {
367
- "status": "unhealthy",
368
- "error": str(e)
369
  }
370
  health_status["pipeline"] = "unhealthy"
371
-
372
- return health_status
 
7
 
8
  import logging
9
  import time
 
10
  from dataclasses import dataclass
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ from src.llm.context_manager import ContextConfig, ContextManager
14
+ from src.llm.llm_service import LLMResponse, LLMService
15
+ from src.llm.prompt_templates import PromptTemplates
16
 
17
  # Import our modules
18
  from src.search.search_service import SearchService
 
 
 
19
 
20
  logger = logging.getLogger(__name__)
21
 
 
23
  @dataclass
24
  class RAGConfig:
25
  """Configuration for RAG pipeline."""
26
+
27
  max_context_length: int = 3000
28
  search_top_k: int = 10
29
  search_threshold: float = 0.1
 
35
  @dataclass
36
  class RAGResponse:
37
  """Response from RAG pipeline with metadata."""
38
+
39
  answer: str
40
  sources: List[Dict[str, Any]]
41
  confidence: float
 
51
  class RAGPipeline:
52
  """
53
  Complete RAG pipeline orchestrating retrieval and generation.
54
+
55
  Combines:
56
  - Semantic search for context retrieval
57
  - Context optimization and management
 
63
  self,
64
  search_service: SearchService,
65
  llm_service: LLMService,
66
+ config: Optional[RAGConfig] = None,
67
  ):
68
  """
69
  Initialize RAG pipeline with required services.
70
+
71
  Args:
72
  search_service: Configured SearchService instance
73
+ llm_service: Configured LLMService instance
74
  config: RAG configuration, uses defaults if None
75
  """
76
  self.search_service = search_service
77
  self.llm_service = llm_service
78
  self.config = config or RAGConfig()
79
+
80
  # Initialize context manager with matching config
81
  context_config = ContextConfig(
82
  max_context_length=self.config.max_context_length,
83
  max_results=self.config.search_top_k,
84
+ min_similarity=self.config.search_threshold,
85
  )
86
  self.context_manager = ContextManager(context_config)
87
+
88
  # Initialize prompt templates
89
  self.prompt_templates = PromptTemplates()
90
+
91
  logger.info("RAGPipeline initialized successfully")
92
 
93
  def generate_answer(self, question: str) -> RAGResponse:
94
  """
95
  Generate answer to question using RAG pipeline.
96
+
97
  Args:
98
  question: User's question about corporate policies
99
+
100
  Returns:
101
  RAGResponse with answer and metadata
102
  """
103
  start_time = time.time()
104
+
105
  try:
106
  # Step 1: Retrieve relevant context
107
  logger.debug(f"Starting RAG pipeline for question: {question[:100]}...")
108
+
109
  search_results = self._retrieve_context(question)
110
+
111
  if not search_results:
112
  return self._create_no_context_response(question, start_time)
113
+
114
  # Step 2: Prepare and optimize context
115
  context, filtered_results = self.context_manager.prepare_context(
116
  search_results, question
117
  )
118
+
119
  # Step 3: Check if we have sufficient context
120
  quality_metrics = self.context_manager.validate_context_quality(
121
  context, question, self.config.min_similarity_for_answer
122
  )
123
+
124
  if not quality_metrics["passes_validation"]:
125
  return self._create_insufficient_context_response(
126
  question, filtered_results, start_time
127
  )
128
+
129
  # Step 4: Generate response using LLM
130
  llm_response = self._generate_llm_response(question, context)
131
+
132
  if not llm_response.success:
133
  return self._create_llm_error_response(
134
  question, llm_response.error_message, start_time
135
  )
136
+
137
  # Step 5: Process and validate response
138
  processed_response = self._process_response(
139
  llm_response.content, filtered_results
140
  )
141
+
142
  processing_time = time.time() - start_time
143
+
144
  return RAGResponse(
145
  answer=processed_response,
146
  sources=self._format_sources(filtered_results),
 
150
  llm_model=llm_response.model,
151
  context_length=len(context),
152
  search_results_count=len(search_results),
153
+ success=True,
154
  )
155
+
156
  except Exception as e:
157
  logger.error(f"RAG pipeline error: {e}")
158
  return RAGResponse(
159
+ answer=(
160
+ "I apologize, but I encountered an error processing your question. "
161
+ "Please try again or contact support."
162
+ ),
163
  sources=[],
164
  confidence=0.0,
165
  processing_time=time.time() - start_time,
 
168
  context_length=0,
169
  search_results_count=0,
170
  success=False,
171
+ error_message=str(e),
172
  )
173
 
174
  def _retrieve_context(self, question: str) -> List[Dict[str, Any]]:
 
177
  results = self.search_service.search(
178
  query=question,
179
  top_k=self.config.search_top_k,
180
+ threshold=self.config.search_threshold,
181
  )
182
+
183
  logger.debug(f"Retrieved {len(results)} search results")
184
  return results
185
+
186
  except Exception as e:
187
  logger.error(f"Context retrieval error: {e}")
188
  return []
 
190
  def _generate_llm_response(self, question: str, context: str) -> LLMResponse:
191
  """Generate response using LLM with formatted prompt."""
192
  template = self.prompt_templates.get_policy_qa_template()
193
+
194
  # Format the prompt
195
  formatted_prompt = template.user_template.format(
196
+ question=question, context=context
 
197
  )
198
+
199
  # Add system prompt (if LLM service supports it in future)
200
  full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}"
201
+
202
  return self.llm_service.generate_response(full_prompt)
203
 
204
  def _process_response(
205
+ self, raw_response: str, search_results: List[Dict[str, Any]]
 
 
206
  ) -> str:
207
  """Process and validate LLM response."""
208
+
209
  # Ensure citations are present
210
  response_with_citations = self.prompt_templates.add_fallback_citations(
211
  raw_response, search_results
212
  )
213
+
214
  # Validate citations if enabled
215
  if self.config.enable_citation_validation:
216
  available_sources = [
217
  result.get("metadata", {}).get("filename", "")
218
  for result in search_results
219
  ]
220
+
221
  citation_validation = self.prompt_templates.validate_citations(
222
  response_with_citations, available_sources
223
  )
224
+
225
  # Log any invalid citations
226
  invalid_citations = [
227
+ citation for citation, valid in citation_validation.items() if not valid
 
228
  ]
229
+
230
  if invalid_citations:
231
  logger.warning(f"Invalid citations detected: {invalid_citations}")
232
+
233
  # Truncate if too long
234
  if len(response_with_citations) > self.config.max_response_length:
235
+ truncated = (
236
+ response_with_citations[: self.config.max_response_length - 3] + "..."
237
+ )
238
+ logger.warning(
239
+ f"Response truncated from {len(response_with_citations)} "
240
+ f"to {len(truncated)} characters"
241
+ )
242
  return truncated
243
+
244
  return response_with_citations
245
 
246
+ def _format_sources(
247
+ self, search_results: List[Dict[str, Any]]
248
+ ) -> List[Dict[str, Any]]:
249
  """Format search results for response metadata."""
250
  sources = []
251
+
252
  for result in search_results:
253
  metadata = result.get("metadata", {})
254
+ sources.append(
255
+ {
256
+ "document": metadata.get("filename", "unknown"),
257
+ "chunk_id": result.get("chunk_id", ""),
258
+ "relevance_score": result.get("similarity_score", 0.0),
259
+ "excerpt": (
260
+ result.get("content", "")[:200] + "..."
261
+ if len(result.get("content", "")) > 200
262
+ else result.get("content", "")
263
+ ),
264
+ }
265
+ )
266
+
267
  return sources
268
 
269
  def _calculate_confidence(
270
+ self, quality_metrics: Dict[str, Any], llm_response: LLMResponse
 
 
271
  ) -> float:
272
  """Calculate confidence score for the response."""
273
+
274
  # Base confidence on context quality
275
  context_confidence = quality_metrics.get("estimated_relevance", 0.0)
276
+
277
  # Adjust based on LLM response time (faster might indicate more confidence)
278
  time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0))
279
+
280
  # Combine factors
281
  confidence = (context_confidence * 0.7) + (time_factor * 0.3)
282
+
283
  return min(1.0, max(0.0, confidence))
284
 
285
+ def _create_no_context_response(
286
+ self, question: str, start_time: float
287
+ ) -> RAGResponse:
288
  """Create response when no relevant context found."""
289
  return RAGResponse(
290
+ answer=(
291
+ "I couldn't find any relevant information in our corporate policies "
292
+ "to answer your question. Please contact HR or check other company "
293
+ "resources for assistance."
294
+ ),
295
  sources=[],
296
  confidence=0.0,
297
  processing_time=time.time() - start_time,
 
299
  llm_model="none",
300
  context_length=0,
301
  search_results_count=0,
302
+ success=True, # This is a valid "no answer" response
303
  )
304
 
305
  def _create_insufficient_context_response(
306
+ self, question: str, results: List[Dict[str, Any]], start_time: float
 
 
 
307
  ) -> RAGResponse:
308
  """Create response when context quality is insufficient."""
309
  return RAGResponse(
310
+ answer=(
311
+ "I found some potentially relevant information, but it doesn't provide "
312
+ "enough detail to fully answer your question. Please contact HR for "
313
+ "more specific guidance or rephrase your question."
314
+ ),
315
  sources=self._format_sources(results),
316
  confidence=0.2,
317
  processing_time=time.time() - start_time,
 
319
  llm_model="none",
320
  context_length=0,
321
  search_results_count=len(results),
322
+ success=True,
323
  )
324
 
325
  def _create_llm_error_response(
326
+ self, question: str, error_message: str, start_time: float
 
 
 
327
  ) -> RAGResponse:
328
  """Create response when LLM generation fails."""
329
  return RAGResponse(
330
+ answer=(
331
+ "I apologize, but I'm currently unable to generate a response. "
332
+ "Please try again in a moment or contact support if the issue persists."
333
+ ),
334
  sources=[],
335
  confidence=0.0,
336
  processing_time=time.time() - start_time,
 
339
  context_length=0,
340
  search_results_count=0,
341
  success=False,
342
+ error_message=error_message,
343
  )
344
 
345
  def health_check(self) -> Dict[str, Any]:
346
  """
347
  Perform health check on all pipeline components.
348
+
349
  Returns:
350
  Dictionary with component health status
351
  """
352
+ health_status = {"pipeline": "healthy", "components": {}}
353
+
 
 
 
354
  try:
355
  # Check search service
356
+ test_results = self.search_service.search(
357
+ "test query", top_k=1, threshold=0.0
358
+ )
359
  health_status["components"]["search_service"] = {
360
  "status": "healthy",
361
+ "test_results_count": len(test_results),
362
  }
363
  except Exception as e:
364
  health_status["components"]["search_service"] = {
365
  "status": "unhealthy",
366
+ "error": str(e),
367
  }
368
  health_status["pipeline"] = "degraded"
369
+
370
  try:
371
  # Check LLM service
372
  llm_health = self.llm_service.health_check()
373
  health_status["components"]["llm_service"] = llm_health
374
+
375
  # Pipeline is unhealthy if all LLM providers are down
376
  healthy_providers = sum(
377
+ 1
378
+ for provider_status in llm_health.values()
379
  if provider_status.get("status") == "healthy"
380
  )
381
+
382
  if healthy_providers == 0:
383
  health_status["pipeline"] = "unhealthy"
384
+
385
  except Exception as e:
386
  health_status["components"]["llm_service"] = {
387
+ "status": "unhealthy",
388
+ "error": str(e),
389
  }
390
  health_status["pipeline"] = "unhealthy"
391
+
392
+ return health_status
src/rag/response_formatter.py CHANGED
@@ -6,9 +6,8 @@ formatting, metadata inclusion, and consistent response structure.
6
  """
7
 
8
  import logging
 
9
  from typing import Any, Dict, List, Optional
10
- from dataclasses import dataclass, asdict
11
- import json
12
 
13
  logger = logging.getLogger(__name__)
14
 
@@ -16,6 +15,7 @@ logger = logging.getLogger(__name__)
16
  @dataclass
17
  class FormattedResponse:
18
  """Standardized formatted response for API endpoints."""
 
19
  status: str
20
  answer: str
21
  sources: List[Dict[str, Any]]
@@ -27,7 +27,7 @@ class FormattedResponse:
27
  class ResponseFormatter:
28
  """
29
  Formats RAG pipeline responses for various output formats.
30
-
31
  Handles:
32
  - API response formatting
33
  - Citation formatting
@@ -40,23 +40,21 @@ class ResponseFormatter:
40
  logger.info("ResponseFormatter initialized")
41
 
42
  def format_api_response(
43
- self,
44
- rag_response: Any, # RAGResponse type
45
- include_debug: bool = False
46
  ) -> Dict[str, Any]:
47
  """
48
  Format RAG response for API consumption.
49
-
50
  Args:
51
  rag_response: RAGResponse from RAG pipeline
52
  include_debug: Whether to include debug information
53
-
54
  Returns:
55
  Formatted dictionary for JSON API response
56
  """
57
  if not rag_response.success:
58
  return self._format_error_response(rag_response)
59
-
60
  # Base response structure
61
  formatted_response = {
62
  "status": "success",
@@ -66,88 +64,96 @@ class ResponseFormatter:
66
  "confidence": round(rag_response.confidence, 3),
67
  "processing_time_ms": round(rag_response.processing_time * 1000, 1),
68
  "source_count": len(rag_response.sources),
69
- "context_length": rag_response.context_length
70
- }
71
  }
72
-
73
  # Add debug information if requested
74
  if include_debug:
75
  formatted_response["debug"] = {
76
  "llm_provider": rag_response.llm_provider,
77
  "llm_model": rag_response.llm_model,
78
  "search_results_count": rag_response.search_results_count,
79
- "processing_time_seconds": round(rag_response.processing_time, 3)
80
  }
81
-
82
  return formatted_response
83
 
84
  def format_chat_response(
85
  self,
86
  rag_response: Any, # RAGResponse type
87
  conversation_id: Optional[str] = None,
88
- include_sources: bool = True
89
  ) -> Dict[str, Any]:
90
  """
91
  Format RAG response for chat interface.
92
-
93
  Args:
94
  rag_response: RAGResponse from RAG pipeline
95
  conversation_id: Optional conversation ID
96
  include_sources: Whether to include source information
97
-
98
  Returns:
99
  Formatted dictionary for chat interface
100
  """
101
  if not rag_response.success:
102
  return self._format_chat_error(rag_response, conversation_id)
103
-
104
  response = {
105
  "message": rag_response.answer,
106
  "confidence": round(rag_response.confidence, 2),
107
- "processing_time_ms": round(rag_response.processing_time * 1000, 1)
108
  }
109
-
110
  if conversation_id:
111
  response["conversation_id"] = conversation_id
112
-
113
  if include_sources and rag_response.sources:
114
  response["sources"] = self._format_sources_for_chat(rag_response.sources)
115
-
116
  return response
117
 
118
- def _format_source_list(self, sources: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 
119
  """Format source list for API response."""
120
  formatted_sources = []
121
-
122
  for source in sources:
123
  formatted_source = {
124
  "document": source.get("document", "unknown"),
125
  "relevance_score": round(source.get("relevance_score", 0.0), 3),
126
- "excerpt": source.get("excerpt", "")
127
  }
128
-
129
  # Add chunk ID if available
130
  chunk_id = source.get("chunk_id", "")
131
  if chunk_id:
132
  formatted_source["chunk_id"] = chunk_id
133
-
134
  formatted_sources.append(formatted_source)
135
-
136
  return formatted_sources
137
 
138
- def _format_sources_for_chat(self, sources: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 
139
  """Format sources for chat interface (more concise)."""
140
  formatted_sources = []
141
-
142
  for i, source in enumerate(sources[:3], 1): # Limit to top 3 for chat
143
  formatted_source = {
144
  "id": i,
145
  "document": source.get("document", "unknown"),
146
  "relevance": f"{source.get('relevance_score', 0.0):.1%}",
147
- "preview": source.get("excerpt", "")[:100] + "..." if len(source.get("excerpt", "")) > 100 else source.get("excerpt", "")
 
 
 
 
148
  }
149
  formatted_sources.append(formatted_source)
150
-
151
  return formatted_sources
152
 
153
  def _format_error_response(self, rag_response: Any) -> Dict[str, Any]:
@@ -157,51 +163,45 @@ class ResponseFormatter:
157
  "error": {
158
  "message": rag_response.answer,
159
  "details": rag_response.error_message,
160
- "processing_time_ms": round(rag_response.processing_time * 1000, 1)
161
  },
162
  "sources": [],
163
- "metadata": {
164
- "confidence": 0.0,
165
- "source_count": 0,
166
- "context_length": 0
167
- }
168
  }
169
 
170
  def _format_chat_error(
171
- self,
172
- rag_response: Any,
173
- conversation_id: Optional[str] = None
174
  ) -> Dict[str, Any]:
175
  """Format error response for chat interface."""
176
  response = {
177
  "message": rag_response.answer,
178
  "error": True,
179
- "processing_time_ms": round(rag_response.processing_time * 1000, 1)
180
  }
181
-
182
  if conversation_id:
183
  response["conversation_id"] = conversation_id
184
-
185
  return response
186
 
187
  def validate_response_format(self, response: Dict[str, Any]) -> bool:
188
  """
189
  Validate that response follows expected format.
190
-
191
  Args:
192
  response: Formatted response dictionary
193
-
194
  Returns:
195
  True if format is valid, False otherwise
196
  """
197
  required_fields = ["status"]
198
-
199
  # Check required fields
200
  for field in required_fields:
201
  if field not in response:
202
  logger.error(f"Missing required field: {field}")
203
  return False
204
-
205
  # Check status-specific requirements
206
  if response["status"] == "success":
207
  success_fields = ["answer", "sources", "metadata"]
@@ -209,21 +209,21 @@ class ResponseFormatter:
209
  if field not in response:
210
  logger.error(f"Missing success field: {field}")
211
  return False
212
-
213
  elif response["status"] == "error":
214
  if "error" not in response:
215
  logger.error("Missing error field in error response")
216
  return False
217
-
218
  return True
219
 
220
  def create_health_response(self, health_data: Dict[str, Any]) -> Dict[str, Any]:
221
  """
222
  Format health check response.
223
-
224
  Args:
225
  health_data: Health status from RAG pipeline
226
-
227
  Returns:
228
  Formatted health response
229
  """
@@ -232,51 +232,65 @@ class ResponseFormatter:
232
  "health": {
233
  "pipeline_status": health_data.get("pipeline", "unknown"),
234
  "components": health_data.get("components", {}),
235
- "timestamp": self._get_timestamp()
236
- }
237
  }
238
 
239
- def create_no_answer_response(self, question: str, reason: str = "no_context") -> Dict[str, Any]:
 
 
240
  """
241
  Create standardized response when no answer can be provided.
242
-
243
  Args:
244
  question: Original user question
245
  reason: Reason for no answer (no_context, insufficient_context, etc.)
246
-
247
  Returns:
248
  Formatted no-answer response
249
  """
250
  messages = {
251
- "no_context": "I couldn't find any relevant information in our corporate policies to answer your question.",
252
- "insufficient_context": "I found some potentially relevant information, but not enough to provide a complete answer.",
253
- "off_topic": "This question appears to be outside the scope of our corporate policies.",
254
- "error": "I encountered an error while processing your question."
 
 
 
 
 
 
 
 
 
255
  }
256
-
257
  message = messages.get(reason, messages["error"])
258
-
259
  return {
260
  "status": "no_answer",
261
  "message": message,
262
  "reason": reason,
263
- "suggestion": "Please contact HR or rephrase your question for better results.",
264
- "sources": []
 
 
265
  }
266
 
267
  def _get_timestamp(self) -> str:
268
  """Get current timestamp in ISO format."""
269
  from datetime import datetime
 
270
  return datetime.utcnow().isoformat() + "Z"
271
 
272
  def format_for_logging(self, rag_response: Any, question: str) -> Dict[str, Any]:
273
  """
274
  Format response data for logging purposes.
275
-
276
  Args:
277
  rag_response: RAGResponse from pipeline
278
  question: Original question
279
-
280
  Returns:
281
  Formatted data for logging
282
  """
@@ -291,5 +305,5 @@ class ResponseFormatter:
291
  "source_count": len(rag_response.sources),
292
  "context_length": rag_response.context_length,
293
  "answer_length": len(rag_response.answer),
294
- "error": rag_response.error_message
295
- }
 
6
  """
7
 
8
  import logging
9
+ from dataclasses import dataclass
10
  from typing import Any, Dict, List, Optional
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
15
  @dataclass
16
  class FormattedResponse:
17
  """Standardized formatted response for API endpoints."""
18
+
19
  status: str
20
  answer: str
21
  sources: List[Dict[str, Any]]
 
27
  class ResponseFormatter:
28
  """
29
  Formats RAG pipeline responses for various output formats.
30
+
31
  Handles:
32
  - API response formatting
33
  - Citation formatting
 
40
  logger.info("ResponseFormatter initialized")
41
 
42
  def format_api_response(
43
+ self, rag_response: Any, include_debug: bool = False # RAGResponse type
 
 
44
  ) -> Dict[str, Any]:
45
  """
46
  Format RAG response for API consumption.
47
+
48
  Args:
49
  rag_response: RAGResponse from RAG pipeline
50
  include_debug: Whether to include debug information
51
+
52
  Returns:
53
  Formatted dictionary for JSON API response
54
  """
55
  if not rag_response.success:
56
  return self._format_error_response(rag_response)
57
+
58
  # Base response structure
59
  formatted_response = {
60
  "status": "success",
 
64
  "confidence": round(rag_response.confidence, 3),
65
  "processing_time_ms": round(rag_response.processing_time * 1000, 1),
66
  "source_count": len(rag_response.sources),
67
+ "context_length": rag_response.context_length,
68
+ },
69
  }
70
+
71
  # Add debug information if requested
72
  if include_debug:
73
  formatted_response["debug"] = {
74
  "llm_provider": rag_response.llm_provider,
75
  "llm_model": rag_response.llm_model,
76
  "search_results_count": rag_response.search_results_count,
77
+ "processing_time_seconds": round(rag_response.processing_time, 3),
78
  }
79
+
80
  return formatted_response
81
 
82
  def format_chat_response(
83
  self,
84
  rag_response: Any, # RAGResponse type
85
  conversation_id: Optional[str] = None,
86
+ include_sources: bool = True,
87
  ) -> Dict[str, Any]:
88
  """
89
  Format RAG response for chat interface.
90
+
91
  Args:
92
  rag_response: RAGResponse from RAG pipeline
93
  conversation_id: Optional conversation ID
94
  include_sources: Whether to include source information
95
+
96
  Returns:
97
  Formatted dictionary for chat interface
98
  """
99
  if not rag_response.success:
100
  return self._format_chat_error(rag_response, conversation_id)
101
+
102
  response = {
103
  "message": rag_response.answer,
104
  "confidence": round(rag_response.confidence, 2),
105
+ "processing_time_ms": round(rag_response.processing_time * 1000, 1),
106
  }
107
+
108
  if conversation_id:
109
  response["conversation_id"] = conversation_id
110
+
111
  if include_sources and rag_response.sources:
112
  response["sources"] = self._format_sources_for_chat(rag_response.sources)
113
+
114
  return response
115
 
116
+ def _format_source_list(
117
+ self, sources: List[Dict[str, Any]]
118
+ ) -> List[Dict[str, Any]]:
119
  """Format source list for API response."""
120
  formatted_sources = []
121
+
122
  for source in sources:
123
  formatted_source = {
124
  "document": source.get("document", "unknown"),
125
  "relevance_score": round(source.get("relevance_score", 0.0), 3),
126
+ "excerpt": source.get("excerpt", ""),
127
  }
128
+
129
  # Add chunk ID if available
130
  chunk_id = source.get("chunk_id", "")
131
  if chunk_id:
132
  formatted_source["chunk_id"] = chunk_id
133
+
134
  formatted_sources.append(formatted_source)
135
+
136
  return formatted_sources
137
 
138
+ def _format_sources_for_chat(
139
+ self, sources: List[Dict[str, Any]]
140
+ ) -> List[Dict[str, Any]]:
141
  """Format sources for chat interface (more concise)."""
142
  formatted_sources = []
143
+
144
  for i, source in enumerate(sources[:3], 1): # Limit to top 3 for chat
145
  formatted_source = {
146
  "id": i,
147
  "document": source.get("document", "unknown"),
148
  "relevance": f"{source.get('relevance_score', 0.0):.1%}",
149
+ "preview": (
150
+ source.get("excerpt", "")[:100] + "..."
151
+ if len(source.get("excerpt", "")) > 100
152
+ else source.get("excerpt", "")
153
+ ),
154
  }
155
  formatted_sources.append(formatted_source)
156
+
157
  return formatted_sources
158
 
159
  def _format_error_response(self, rag_response: Any) -> Dict[str, Any]:
 
163
  "error": {
164
  "message": rag_response.answer,
165
  "details": rag_response.error_message,
166
+ "processing_time_ms": round(rag_response.processing_time * 1000, 1),
167
  },
168
  "sources": [],
169
+ "metadata": {"confidence": 0.0, "source_count": 0, "context_length": 0},
 
 
 
 
170
  }
171
 
172
  def _format_chat_error(
173
+ self, rag_response: Any, conversation_id: Optional[str] = None
 
 
174
  ) -> Dict[str, Any]:
175
  """Format error response for chat interface."""
176
  response = {
177
  "message": rag_response.answer,
178
  "error": True,
179
+ "processing_time_ms": round(rag_response.processing_time * 1000, 1),
180
  }
181
+
182
  if conversation_id:
183
  response["conversation_id"] = conversation_id
184
+
185
  return response
186
 
187
  def validate_response_format(self, response: Dict[str, Any]) -> bool:
188
  """
189
  Validate that response follows expected format.
190
+
191
  Args:
192
  response: Formatted response dictionary
193
+
194
  Returns:
195
  True if format is valid, False otherwise
196
  """
197
  required_fields = ["status"]
198
+
199
  # Check required fields
200
  for field in required_fields:
201
  if field not in response:
202
  logger.error(f"Missing required field: {field}")
203
  return False
204
+
205
  # Check status-specific requirements
206
  if response["status"] == "success":
207
  success_fields = ["answer", "sources", "metadata"]
 
209
  if field not in response:
210
  logger.error(f"Missing success field: {field}")
211
  return False
212
+
213
  elif response["status"] == "error":
214
  if "error" not in response:
215
  logger.error("Missing error field in error response")
216
  return False
217
+
218
  return True
219
 
220
  def create_health_response(self, health_data: Dict[str, Any]) -> Dict[str, Any]:
221
  """
222
  Format health check response.
223
+
224
  Args:
225
  health_data: Health status from RAG pipeline
226
+
227
  Returns:
228
  Formatted health response
229
  """
 
232
  "health": {
233
  "pipeline_status": health_data.get("pipeline", "unknown"),
234
  "components": health_data.get("components", {}),
235
+ "timestamp": self._get_timestamp(),
236
+ },
237
  }
238
 
239
+ def create_no_answer_response(
240
+ self, question: str, reason: str = "no_context"
241
+ ) -> Dict[str, Any]:
242
  """
243
  Create standardized response when no answer can be provided.
244
+
245
  Args:
246
  question: Original user question
247
  reason: Reason for no answer (no_context, insufficient_context, etc.)
248
+
249
  Returns:
250
  Formatted no-answer response
251
  """
252
  messages = {
253
+ "no_context": (
254
+ "I couldn't find any relevant information in our corporate "
255
+ "policies to answer your question."
256
+ ),
257
+ "insufficient_context": (
258
+ "I found some potentially relevant information, but not "
259
+ "enough to provide a complete answer."
260
+ ),
261
+ "off_topic": (
262
+ "This question appears to be outside the scope of our "
263
+ "corporate policies."
264
+ ),
265
+ "error": "I encountered an error while processing your question.",
266
  }
267
+
268
  message = messages.get(reason, messages["error"])
269
+
270
  return {
271
  "status": "no_answer",
272
  "message": message,
273
  "reason": reason,
274
+ "suggestion": (
275
+ "Please contact HR or rephrase your question for better results."
276
+ ),
277
+ "sources": [],
278
  }
279
 
280
  def _get_timestamp(self) -> str:
281
  """Get current timestamp in ISO format."""
282
  from datetime import datetime
283
+
284
  return datetime.utcnow().isoformat() + "Z"
285
 
286
  def format_for_logging(self, rag_response: Any, question: str) -> Dict[str, Any]:
287
  """
288
  Format response data for logging purposes.
289
+
290
  Args:
291
  rag_response: RAGResponse from pipeline
292
  question: Original question
293
+
294
  Returns:
295
  Formatted data for logging
296
  """
 
305
  "source_count": len(rag_response.sources),
306
  "context_length": rag_response.context_length,
307
  "answer_length": len(rag_response.answer),
308
+ "error": rag_response.error_message,
309
+ }
tests/test_chat_endpoint.py CHANGED
@@ -1,7 +1,8 @@
1
  import json
2
  import os
 
 
3
  import pytest
4
- from unittest.mock import patch, MagicMock
5
 
6
  from app import app as flask_app
7
 
@@ -19,100 +20,122 @@ def client(app):
19
  class TestChatEndpoint:
20
  """Test cases for the /chat endpoint"""
21
 
22
- @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
23
- @patch('app.RAGPipeline')
24
- @patch('app.ResponseFormatter')
25
- @patch('app.LLMService')
26
- @patch('app.SearchService')
27
- @patch('app.VectorDatabase')
28
- @patch('app.EmbeddingService')
29
- def test_chat_endpoint_valid_request(self, mock_embedding, mock_vector, mock_search, mock_llm, mock_formatter, mock_rag, client):
 
 
 
 
 
 
 
 
 
30
  """Test chat endpoint with valid request"""
31
  # Mock the RAG pipeline response
32
  mock_response = {
33
- 'answer': 'Based on the remote work policy, employees can work remotely up to 3 days per week.',
34
- 'confidence': 0.85,
35
- 'sources': [{'chunk_id': '123', 'content': 'Remote work policy content...'}],
36
- 'citations': ['remote_work_policy.md'],
37
- 'processing_time_ms': 1500
 
 
 
 
 
38
  }
39
-
40
  # Setup mock instances
41
  mock_rag_instance = MagicMock()
42
  mock_rag_instance.generate_answer.return_value = mock_response
43
  mock_rag.return_value = mock_rag_instance
44
-
45
  mock_formatter_instance = MagicMock()
46
  mock_formatter_instance.format_api_response.return_value = {
47
  "status": "success",
48
- "answer": mock_response['answer'],
49
- "confidence": mock_response['confidence'],
50
- "sources": mock_response['sources'],
51
- "citations": mock_response['citations']
52
  }
53
  mock_formatter.return_value = mock_formatter_instance
54
-
55
  # Mock LLMService.from_environment to return a mock instance
56
  mock_llm_instance = MagicMock()
57
  mock_llm.from_environment.return_value = mock_llm_instance
58
 
59
  request_data = {
60
  "message": "What is the remote work policy?",
61
- "include_sources": True
62
  }
63
 
64
  response = client.post(
65
- "/chat",
66
- data=json.dumps(request_data),
67
- content_type="application/json"
68
  )
69
 
70
  assert response.status_code == 200
71
  data = response.get_json()
72
-
73
  assert data["status"] == "success"
74
  assert "answer" in data
75
  assert "confidence" in data
76
  assert "sources" in data
77
  assert "citations" in data
78
 
79
- @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
80
- @patch('app.RAGPipeline')
81
- @patch('app.ResponseFormatter')
82
- @patch('app.LLMService')
83
- @patch('app.SearchService')
84
- @patch('app.VectorDatabase')
85
- @patch('app.EmbeddingService')
86
- def test_chat_endpoint_minimal_request(self, mock_embedding, mock_vector, mock_search, mock_llm, mock_formatter, mock_rag, client):
 
 
 
 
 
 
 
 
 
87
  """Test chat endpoint with minimal request (only message)"""
88
  mock_response = {
89
- 'answer': 'Employee benefits include health insurance, retirement plans, and PTO.',
90
- 'confidence': 0.78,
91
- 'sources': [],
92
- 'citations': ['employee_benefits_guide.md'],
93
- 'processing_time_ms': 1200
 
 
 
94
  }
95
-
96
  # Setup mock instances
97
  mock_rag_instance = MagicMock()
98
  mock_rag_instance.generate_answer.return_value = mock_response
99
  mock_rag.return_value = mock_rag_instance
100
-
101
  mock_formatter_instance = MagicMock()
102
  mock_formatter_instance.format_api_response.return_value = {
103
  "status": "success",
104
- "answer": mock_response['answer']
105
  }
106
  mock_formatter.return_value = mock_formatter_instance
107
-
108
  mock_llm.from_environment.return_value = MagicMock()
109
 
110
  request_data = {"message": "What are the employee benefits?"}
111
 
112
  response = client.post(
113
- "/chat",
114
- data=json.dumps(request_data),
115
- content_type="application/json"
116
  )
117
 
118
  assert response.status_code == 200
@@ -124,9 +147,7 @@ class TestChatEndpoint:
124
  request_data = {"include_sources": True}
125
 
126
  response = client.post(
127
- "/chat",
128
- data=json.dumps(request_data),
129
- content_type="application/json"
130
  )
131
 
132
  assert response.status_code == 400
@@ -139,9 +160,7 @@ class TestChatEndpoint:
139
  request_data = {"message": ""}
140
 
141
  response = client.post(
142
- "/chat",
143
- data=json.dumps(request_data),
144
- content_type="application/json"
145
  )
146
 
147
  assert response.status_code == 400
@@ -154,9 +173,7 @@ class TestChatEndpoint:
154
  request_data = {"message": 123}
155
 
156
  response = client.post(
157
- "/chat",
158
- data=json.dumps(request_data),
159
- content_type="application/json"
160
  )
161
 
162
  assert response.status_code == 400
@@ -179,9 +196,7 @@ class TestChatEndpoint:
179
  request_data = {"message": "What is the policy?"}
180
 
181
  response = client.post(
182
- "/chat",
183
- data=json.dumps(request_data),
184
- content_type="application/json"
185
  )
186
 
187
  assert response.status_code == 503
@@ -189,67 +204,110 @@ class TestChatEndpoint:
189
  assert data["status"] == "error"
190
  assert "LLM service configuration error" in data["message"]
191
 
192
- @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
193
- @patch('app.RAGPipeline')
194
- @patch('app.ResponseFormatter')
195
- @patch('app.LLMService')
196
- @patch('app.SearchService')
197
- @patch('app.VectorDatabase')
198
- @patch('app.EmbeddingService')
199
- def test_chat_endpoint_with_conversation_id(self, mock_embedding, mock_vector, mock_search, mock_llm, mock_formatter, mock_rag, client):
 
 
 
 
 
 
 
 
 
200
  """Test chat endpoint with conversation_id parameter"""
201
  mock_response = {
202
- 'answer': 'The PTO policy allows 15 days of vacation annually.',
203
- 'confidence': 0.9,
204
- 'sources': [],
205
- 'citations': ['pto_policy.md'],
206
- 'processing_time_ms': 1100
207
  }
208
- mock_generate.return_value = mock_response
209
- mock_llm_service.return_value = MagicMock()
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  request_data = {
212
  "message": "What is the PTO policy?",
213
  "conversation_id": "conv_123",
214
- "include_sources": False
215
  }
216
 
217
  response = client.post(
218
- "/chat",
219
- data=json.dumps(request_data),
220
- content_type="application/json"
221
  )
222
 
223
  assert response.status_code == 200
224
  data = response.get_json()
225
  assert data["status"] == "success"
226
 
227
- @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
228
- @patch('src.llm.llm_service.LLMService.from_environment')
229
- @patch('src.rag.rag_pipeline.RAGPipeline.generate_answer')
230
- def test_chat_endpoint_with_debug(self, mock_generate, mock_llm_service, client):
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  """Test chat endpoint with debug information"""
232
  mock_response = {
233
- 'answer': 'The security policy requires 2FA authentication.',
234
- 'confidence': 0.95,
235
- 'sources': [{'chunk_id': '456', 'content': 'Security requirements...'}],
236
- 'citations': ['information_security_policy.md'],
237
- 'processing_time_ms': 1800,
238
- 'search_results_count': 5,
239
- 'context_length': 2048
240
  }
241
- mock_generate.return_value = mock_response
242
- mock_llm_service.return_value = MagicMock()
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  request_data = {
245
  "message": "What are the security requirements?",
246
- "include_debug": True
247
  }
248
 
249
  response = client.post(
250
- "/chat",
251
- data=json.dumps(request_data),
252
- content_type="application/json"
253
  )
254
 
255
  assert response.status_code == 200
@@ -260,9 +318,9 @@ class TestChatEndpoint:
260
  class TestChatHealthEndpoint:
261
  """Test cases for the /chat/health endpoint"""
262
 
263
- @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
264
- @patch('src.llm.llm_service.LLMService.from_environment')
265
- @patch('src.rag.rag_pipeline.RAGPipeline.health_check')
266
  def test_chat_health_healthy(self, mock_health_check, mock_llm_service, client):
267
  """Test chat health endpoint when all services are healthy"""
268
  mock_health_data = {
@@ -270,8 +328,8 @@ class TestChatHealthEndpoint:
270
  "components": {
271
  "search_service": {"status": "healthy"},
272
  "llm_service": {"status": "healthy"},
273
- "vector_db": {"status": "healthy"}
274
- }
275
  }
276
  mock_health_check.return_value = mock_health_data
277
  mock_llm_service.return_value = MagicMock()
@@ -282,9 +340,9 @@ class TestChatHealthEndpoint:
282
  data = response.get_json()
283
  assert data["status"] == "success"
284
 
285
- @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
286
- @patch('src.llm.llm_service.LLMService.from_environment')
287
- @patch('src.rag.rag_pipeline.RAGPipeline.health_check')
288
  def test_chat_health_degraded(self, mock_health_check, mock_llm_service, client):
289
  """Test chat health endpoint when services are degraded"""
290
  mock_health_data = {
@@ -292,8 +350,8 @@ class TestChatHealthEndpoint:
292
  "components": {
293
  "search_service": {"status": "healthy"},
294
  "llm_service": {"status": "degraded", "warning": "High latency"},
295
- "vector_db": {"status": "healthy"}
296
- }
297
  }
298
  mock_health_check.return_value = mock_health_data
299
  mock_llm_service.return_value = MagicMock()
@@ -314,18 +372,21 @@ class TestChatHealthEndpoint:
314
  assert data["status"] == "error"
315
  assert "LLM configuration error" in data["message"]
316
 
317
- @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
318
- @patch('src.llm.llm_service.LLMService.from_environment')
319
- @patch('src.rag.rag_pipeline.RAGPipeline.health_check')
320
  def test_chat_health_unhealthy(self, mock_health_check, mock_llm_service, client):
321
  """Test chat health endpoint when services are unhealthy"""
322
  mock_health_data = {
323
  "pipeline": "unhealthy",
324
  "components": {
325
- "search_service": {"status": "unhealthy", "error": "Database connection failed"},
 
 
 
326
  "llm_service": {"status": "unhealthy", "error": "API unreachable"},
327
- "vector_db": {"status": "unhealthy"}
328
- }
329
  }
330
  mock_health_check.return_value = mock_health_data
331
  mock_llm_service.return_value = MagicMock()
@@ -334,4 +395,4 @@ class TestChatHealthEndpoint:
334
 
335
  assert response.status_code == 503
336
  data = response.get_json()
337
- assert data["status"] == "success" # Still returns success, but 503 status code
 
1
  import json
2
  import os
3
+ from unittest.mock import MagicMock, patch
4
+
5
  import pytest
 
6
 
7
  from app import app as flask_app
8
 
 
20
  class TestChatEndpoint:
21
  """Test cases for the /chat endpoint"""
22
 
23
+ @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
24
+ @patch("app.RAGPipeline")
25
+ @patch("app.ResponseFormatter")
26
+ @patch("app.LLMService")
27
+ @patch("app.SearchService")
28
+ @patch("app.VectorDatabase")
29
+ @patch("app.EmbeddingService")
30
+ def test_chat_endpoint_valid_request(
31
+ self,
32
+ mock_embedding,
33
+ mock_vector,
34
+ mock_search,
35
+ mock_llm,
36
+ mock_formatter,
37
+ mock_rag,
38
+ client,
39
+ ):
40
  """Test chat endpoint with valid request"""
41
  # Mock the RAG pipeline response
42
  mock_response = {
43
+ "answer": (
44
+ "Based on the remote work policy, employees can work "
45
+ "remotely up to 3 days per week."
46
+ ),
47
+ "confidence": 0.85,
48
+ "sources": [
49
+ {"chunk_id": "123", "content": "Remote work policy content..."}
50
+ ],
51
+ "citations": ["remote_work_policy.md"],
52
+ "processing_time_ms": 1500,
53
  }
54
+
55
  # Setup mock instances
56
  mock_rag_instance = MagicMock()
57
  mock_rag_instance.generate_answer.return_value = mock_response
58
  mock_rag.return_value = mock_rag_instance
59
+
60
  mock_formatter_instance = MagicMock()
61
  mock_formatter_instance.format_api_response.return_value = {
62
  "status": "success",
63
+ "answer": mock_response["answer"],
64
+ "confidence": mock_response["confidence"],
65
+ "sources": mock_response["sources"],
66
+ "citations": mock_response["citations"],
67
  }
68
  mock_formatter.return_value = mock_formatter_instance
69
+
70
  # Mock LLMService.from_environment to return a mock instance
71
  mock_llm_instance = MagicMock()
72
  mock_llm.from_environment.return_value = mock_llm_instance
73
 
74
  request_data = {
75
  "message": "What is the remote work policy?",
76
+ "include_sources": True,
77
  }
78
 
79
  response = client.post(
80
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
81
  )
82
 
83
  assert response.status_code == 200
84
  data = response.get_json()
85
+
86
  assert data["status"] == "success"
87
  assert "answer" in data
88
  assert "confidence" in data
89
  assert "sources" in data
90
  assert "citations" in data
91
 
92
+ @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
93
+ @patch("app.RAGPipeline")
94
+ @patch("app.ResponseFormatter")
95
+ @patch("app.LLMService")
96
+ @patch("app.SearchService")
97
+ @patch("app.VectorDatabase")
98
+ @patch("app.EmbeddingService")
99
+ def test_chat_endpoint_minimal_request(
100
+ self,
101
+ mock_embedding,
102
+ mock_vector,
103
+ mock_search,
104
+ mock_llm,
105
+ mock_formatter,
106
+ mock_rag,
107
+ client,
108
+ ):
109
  """Test chat endpoint with minimal request (only message)"""
110
  mock_response = {
111
+ "answer": (
112
+ "Employee benefits include health insurance, "
113
+ "retirement plans, and PTO."
114
+ ),
115
+ "confidence": 0.78,
116
+ "sources": [],
117
+ "citations": ["employee_benefits_guide.md"],
118
+ "processing_time_ms": 1200,
119
  }
120
+
121
  # Setup mock instances
122
  mock_rag_instance = MagicMock()
123
  mock_rag_instance.generate_answer.return_value = mock_response
124
  mock_rag.return_value = mock_rag_instance
125
+
126
  mock_formatter_instance = MagicMock()
127
  mock_formatter_instance.format_api_response.return_value = {
128
  "status": "success",
129
+ "answer": mock_response["answer"],
130
  }
131
  mock_formatter.return_value = mock_formatter_instance
132
+
133
  mock_llm.from_environment.return_value = MagicMock()
134
 
135
  request_data = {"message": "What are the employee benefits?"}
136
 
137
  response = client.post(
138
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
139
  )
140
 
141
  assert response.status_code == 200
 
147
  request_data = {"include_sources": True}
148
 
149
  response = client.post(
150
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
151
  )
152
 
153
  assert response.status_code == 400
 
160
  request_data = {"message": ""}
161
 
162
  response = client.post(
163
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
164
  )
165
 
166
  assert response.status_code == 400
 
173
  request_data = {"message": 123}
174
 
175
  response = client.post(
176
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
177
  )
178
 
179
  assert response.status_code == 400
 
196
  request_data = {"message": "What is the policy?"}
197
 
198
  response = client.post(
199
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
200
  )
201
 
202
  assert response.status_code == 503
 
204
  assert data["status"] == "error"
205
  assert "LLM service configuration error" in data["message"]
206
 
207
+ @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
208
+ @patch("app.RAGPipeline")
209
+ @patch("app.ResponseFormatter")
210
+ @patch("app.LLMService")
211
+ @patch("app.SearchService")
212
+ @patch("app.VectorDatabase")
213
+ @patch("app.EmbeddingService")
214
+ def test_chat_endpoint_with_conversation_id(
215
+ self,
216
+ mock_embedding,
217
+ mock_vector,
218
+ mock_search,
219
+ mock_llm,
220
+ mock_formatter,
221
+ mock_rag,
222
+ client,
223
+ ):
224
  """Test chat endpoint with conversation_id parameter"""
225
  mock_response = {
226
+ "answer": "The PTO policy allows 15 days of vacation annually.",
227
+ "confidence": 0.9,
228
+ "sources": [],
229
+ "citations": ["pto_policy.md"],
230
+ "processing_time_ms": 1100,
231
  }
232
+
233
+ # Setup mock instances
234
+ mock_rag_instance = MagicMock()
235
+ mock_rag_instance.generate_answer.return_value = mock_response
236
+ mock_rag.return_value = mock_rag_instance
237
+
238
+ mock_formatter_instance = MagicMock()
239
+ mock_formatter_instance.format_chat_response.return_value = {
240
+ "status": "success",
241
+ "answer": mock_response["answer"],
242
+ }
243
+ mock_formatter.return_value = mock_formatter_instance
244
+
245
+ mock_llm.from_environment.return_value = MagicMock()
246
 
247
  request_data = {
248
  "message": "What is the PTO policy?",
249
  "conversation_id": "conv_123",
250
+ "include_sources": False,
251
  }
252
 
253
  response = client.post(
254
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
255
  )
256
 
257
  assert response.status_code == 200
258
  data = response.get_json()
259
  assert data["status"] == "success"
260
 
261
+ @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
262
+ @patch("app.RAGPipeline")
263
+ @patch("app.ResponseFormatter")
264
+ @patch("app.LLMService")
265
+ @patch("app.SearchService")
266
+ @patch("app.VectorDatabase")
267
+ @patch("app.EmbeddingService")
268
+ def test_chat_endpoint_with_debug(
269
+ self,
270
+ mock_embedding,
271
+ mock_vector,
272
+ mock_search,
273
+ mock_llm,
274
+ mock_formatter,
275
+ mock_rag,
276
+ client,
277
+ ):
278
  """Test chat endpoint with debug information"""
279
  mock_response = {
280
+ "answer": "The security policy requires 2FA authentication.",
281
+ "confidence": 0.95,
282
+ "sources": [{"chunk_id": "456", "content": "Security requirements..."}],
283
+ "citations": ["information_security_policy.md"],
284
+ "processing_time_ms": 1800,
285
+ "search_results_count": 5,
286
+ "context_length": 2048,
287
  }
288
+
289
+ # Setup mock instances
290
+ mock_rag_instance = MagicMock()
291
+ mock_rag_instance.generate_answer.return_value = mock_response
292
+ mock_rag.return_value = mock_rag_instance
293
+
294
+ mock_formatter_instance = MagicMock()
295
+ mock_formatter_instance.format_api_response.return_value = {
296
+ "status": "success",
297
+ "answer": mock_response["answer"],
298
+ "debug": {"processing_time": 1800},
299
+ }
300
+ mock_formatter.return_value = mock_formatter_instance
301
+
302
+ mock_llm.from_environment.return_value = MagicMock()
303
 
304
  request_data = {
305
  "message": "What are the security requirements?",
306
+ "include_debug": True,
307
  }
308
 
309
  response = client.post(
310
+ "/chat", data=json.dumps(request_data), content_type="application/json"
 
 
311
  )
312
 
313
  assert response.status_code == 200
 
318
  class TestChatHealthEndpoint:
319
  """Test cases for the /chat/health endpoint"""
320
 
321
+ @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
322
+ @patch("src.llm.llm_service.LLMService.from_environment")
323
+ @patch("src.rag.rag_pipeline.RAGPipeline.health_check")
324
  def test_chat_health_healthy(self, mock_health_check, mock_llm_service, client):
325
  """Test chat health endpoint when all services are healthy"""
326
  mock_health_data = {
 
328
  "components": {
329
  "search_service": {"status": "healthy"},
330
  "llm_service": {"status": "healthy"},
331
+ "vector_db": {"status": "healthy"},
332
+ },
333
  }
334
  mock_health_check.return_value = mock_health_data
335
  mock_llm_service.return_value = MagicMock()
 
340
  data = response.get_json()
341
  assert data["status"] == "success"
342
 
343
+ @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
344
+ @patch("src.llm.llm_service.LLMService.from_environment")
345
+ @patch("src.rag.rag_pipeline.RAGPipeline.health_check")
346
  def test_chat_health_degraded(self, mock_health_check, mock_llm_service, client):
347
  """Test chat health endpoint when services are degraded"""
348
  mock_health_data = {
 
350
  "components": {
351
  "search_service": {"status": "healthy"},
352
  "llm_service": {"status": "degraded", "warning": "High latency"},
353
+ "vector_db": {"status": "healthy"},
354
+ },
355
  }
356
  mock_health_check.return_value = mock_health_data
357
  mock_llm_service.return_value = MagicMock()
 
372
  assert data["status"] == "error"
373
  assert "LLM configuration error" in data["message"]
374
 
375
+ @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
376
+ @patch("src.llm.llm_service.LLMService.from_environment")
377
+ @patch("src.rag.rag_pipeline.RAGPipeline.health_check")
378
  def test_chat_health_unhealthy(self, mock_health_check, mock_llm_service, client):
379
  """Test chat health endpoint when services are unhealthy"""
380
  mock_health_data = {
381
  "pipeline": "unhealthy",
382
  "components": {
383
+ "search_service": {
384
+ "status": "unhealthy",
385
+ "error": "Database connection failed",
386
+ },
387
  "llm_service": {"status": "unhealthy", "error": "API unreachable"},
388
+ "vector_db": {"status": "unhealthy"},
389
+ },
390
  }
391
  mock_health_check.return_value = mock_health_data
392
  mock_llm_service.return_value = MagicMock()
 
395
 
396
  assert response.status_code == 503
397
  data = response.get_json()
398
+ assert data["status"] == "success" # Still returns success, but 503 status code
tests/test_llm/__init__.py CHANGED
@@ -1 +1 @@
1
- # LLM Service Tests
 
1
+ # LLM Service Tests
tests/test_llm/test_llm_service.py CHANGED
@@ -4,10 +4,12 @@ Test LLM Service
4
  Tests for LLM integration and service functionality.
5
  """
6
 
 
 
7
  import pytest
8
- from unittest.mock import Mock, patch, MagicMock
9
  import requests
10
- from src.llm.llm_service import LLMService, LLMConfig, LLMResponse
 
11
 
12
 
13
  class TestLLMConfig:
@@ -19,9 +21,9 @@ class TestLLMConfig:
19
  provider="openrouter",
20
  api_key="test-key",
21
  model_name="test-model",
22
- base_url="https://test.com"
23
  )
24
-
25
  assert config.provider == "openrouter"
26
  assert config.api_key == "test-key"
27
  assert config.model_name == "test-model"
@@ -41,9 +43,9 @@ class TestLLMResponse:
41
  model="test-model",
42
  usage={"tokens": 100},
43
  response_time=1.5,
44
- success=True
45
  )
46
-
47
  assert response.content == "Test response"
48
  assert response.provider == "openrouter"
49
  assert response.model == "test-model"
@@ -62,84 +64,83 @@ class TestLLMService:
62
  provider="openrouter",
63
  api_key="test-key",
64
  model_name="test-model",
65
- base_url="https://test.com"
66
  )
67
-
68
  service = LLMService([config])
69
-
70
  assert len(service.configs) == 1
71
  assert service.configs[0] == config
72
  assert service.current_config_index == 0
73
 
74
  def test_initialization_empty_configs_raises_error(self):
75
  """Test that empty configs raise ValueError."""
76
- with pytest.raises(ValueError, match="At least one LLM configuration must be provided"):
 
 
77
  LLMService([])
78
 
79
- @patch.dict('os.environ', {'OPENROUTER_API_KEY': 'test-openrouter-key'})
80
  def test_from_environment_with_openrouter_key(self):
81
  """Test creating service from environment with OpenRouter key."""
82
  service = LLMService.from_environment()
83
-
84
  assert len(service.configs) >= 1
85
  openrouter_config = next(
86
  (config for config in service.configs if config.provider == "openrouter"),
87
- None
88
  )
89
  assert openrouter_config is not None
90
  assert openrouter_config.api_key == "test-openrouter-key"
91
 
92
- @patch.dict('os.environ', {'GROQ_API_KEY': 'test-groq-key'})
93
  def test_from_environment_with_groq_key(self):
94
  """Test creating service from environment with Groq key."""
95
  service = LLMService.from_environment()
96
-
97
  assert len(service.configs) >= 1
98
  groq_config = next(
99
- (config for config in service.configs if config.provider == "groq"),
100
- None
101
  )
102
  assert groq_config is not None
103
  assert groq_config.api_key == "test-groq-key"
104
 
105
- @patch.dict('os.environ', {}, clear=True)
106
  def test_from_environment_no_keys_raises_error(self):
107
  """Test that no environment keys raise ValueError."""
108
  with pytest.raises(ValueError, match="No LLM API keys found in environment"):
109
  LLMService.from_environment()
110
 
111
- @patch('requests.post')
112
  def test_successful_response_generation(self, mock_post):
113
  """Test successful response generation."""
114
  # Mock successful API response
115
  mock_response = Mock()
116
  mock_response.status_code = 200
117
  mock_response.json.return_value = {
118
- "choices": [
119
- {"message": {"content": "Test response content"}}
120
- ],
121
- "usage": {"prompt_tokens": 50, "completion_tokens": 20}
122
  }
123
  mock_response.raise_for_status = Mock()
124
  mock_post.return_value = mock_response
125
-
126
  config = LLMConfig(
127
  provider="openrouter",
128
  api_key="test-key",
129
  model_name="test-model",
130
- base_url="https://api.openrouter.ai/api/v1"
131
  )
132
  service = LLMService([config])
133
-
134
  result = service.generate_response("Test prompt")
135
-
136
  assert result.success is True
137
  assert result.content == "Test response content"
138
  assert result.provider == "openrouter"
139
  assert result.model == "test-model"
140
  assert result.usage == {"prompt_tokens": 50, "completion_tokens": 20}
141
  assert result.response_time > 0
142
-
143
  # Verify API call
144
  mock_post.assert_called_once()
145
  args, kwargs = mock_post.call_args
@@ -147,125 +148,139 @@ class TestLLMService:
147
  assert kwargs["json"]["model"] == "test-model"
148
  assert kwargs["json"]["messages"][0]["content"] == "Test prompt"
149
 
150
- @patch('requests.post')
151
  def test_api_error_handling(self, mock_post):
152
  """Test handling of API errors."""
153
  # Mock API error
154
  mock_post.side_effect = requests.exceptions.RequestException("API Error")
155
-
156
  config = LLMConfig(
157
  provider="openrouter",
158
  api_key="test-key",
159
  model_name="test-model",
160
- base_url="https://api.openrouter.ai/api/v1"
161
  )
162
  service = LLMService([config])
163
-
164
  result = service.generate_response("Test prompt")
165
-
166
  assert result.success is False
167
  assert "API Error" in result.error_message
168
  assert result.content == ""
169
  assert result.provider == "openrouter"
170
 
171
- @patch('requests.post')
172
  def test_fallback_to_second_provider(self, mock_post):
173
  """Test fallback to second provider when first fails."""
174
  # Mock first provider failing, second succeeding
175
  first_call = Mock()
176
- first_call.side_effect = requests.exceptions.RequestException("First provider error")
177
-
 
 
178
  second_call = Mock()
179
  second_response = Mock()
180
  second_response.status_code = 200
181
  second_response.json.return_value = {
182
  "choices": [{"message": {"content": "Second provider response"}}],
183
- "usage": {}
184
  }
185
  second_response.raise_for_status = Mock()
186
  second_call.return_value = second_response
187
-
188
  mock_post.side_effect = [first_call.side_effect, second_response]
189
-
190
  config1 = LLMConfig(
191
  provider="openrouter",
192
  api_key="key1",
193
  model_name="model1",
194
- base_url="https://api1.com"
195
  )
196
  config2 = LLMConfig(
197
  provider="groq",
198
  api_key="key2",
199
  model_name="model2",
200
- base_url="https://api2.com"
201
  )
202
-
203
  service = LLMService([config1, config2])
204
  result = service.generate_response("Test prompt")
205
-
206
  assert result.success is True
207
  assert result.content == "Second provider response"
208
  assert result.provider == "groq"
209
  assert mock_post.call_count == 2
210
 
211
- @patch('requests.post')
212
  def test_all_providers_fail(self, mock_post):
213
  """Test when all providers fail."""
214
- mock_post.side_effect = requests.exceptions.RequestException("All providers down")
215
-
216
- config1 = LLMConfig(provider="provider1", api_key="key1", model_name="model1", base_url="url1")
217
- config2 = LLMConfig(provider="provider2", api_key="key2", model_name="model2", base_url="url2")
218
-
 
 
 
 
 
 
219
  service = LLMService([config1, config2])
220
  result = service.generate_response("Test prompt")
221
-
222
  assert result.success is False
223
  assert "All providers failed" in result.error_message
224
  assert result.provider == "none"
225
  assert result.model == "none"
226
 
227
- @patch('requests.post')
228
  def test_retry_logic(self, mock_post):
229
  """Test retry logic for failed requests."""
230
  # First call fails, second succeeds
231
  first_response = Mock()
232
- first_response.side_effect = requests.exceptions.RequestException("Temporary error")
233
-
 
 
234
  second_response = Mock()
235
  second_response.status_code = 200
236
  second_response.json.return_value = {
237
  "choices": [{"message": {"content": "Success after retry"}}],
238
- "usage": {}
239
  }
240
  second_response.raise_for_status = Mock()
241
-
242
  mock_post.side_effect = [first_response.side_effect, second_response]
243
-
244
  config = LLMConfig(
245
  provider="openrouter",
246
  api_key="test-key",
247
  model_name="test-model",
248
- base_url="https://api.openrouter.ai/api/v1"
249
  )
250
  service = LLMService([config])
251
-
252
  result = service.generate_response("Test prompt", max_retries=1)
253
-
254
  assert result.success is True
255
  assert result.content == "Success after retry"
256
  assert mock_post.call_count == 2
257
 
258
  def test_get_available_providers(self):
259
  """Test getting list of available providers."""
260
- config1 = LLMConfig(provider="openrouter", api_key="key1", model_name="model1", base_url="url1")
261
- config2 = LLMConfig(provider="groq", api_key="key2", model_name="model2", base_url="url2")
262
-
 
 
 
 
263
  service = LLMService([config1, config2])
264
  providers = service.get_available_providers()
265
-
266
  assert providers == ["openrouter", "groq"]
267
 
268
- @patch('requests.post')
269
  def test_health_check(self, mock_post):
270
  """Test health check functionality."""
271
  # Mock successful health check
@@ -273,51 +288,54 @@ class TestLLMService:
273
  mock_response.status_code = 200
274
  mock_response.json.return_value = {
275
  "choices": [{"message": {"content": "OK"}}],
276
- "usage": {}
277
  }
278
  mock_response.raise_for_status = Mock()
279
  mock_post.return_value = mock_response
280
-
281
  config = LLMConfig(
282
  provider="openrouter",
283
  api_key="test-key",
284
  model_name="test-model",
285
- base_url="https://api.openrouter.ai/api/v1"
286
  )
287
  service = LLMService([config])
288
-
289
  health_status = service.health_check()
290
-
291
  assert "openrouter" in health_status
292
  assert health_status["openrouter"]["status"] == "healthy"
293
  assert health_status["openrouter"]["model"] == "test-model"
294
  assert health_status["openrouter"]["response_time"] > 0
295
 
296
- @patch('requests.post')
297
  def test_openrouter_specific_headers(self, mock_post):
298
  """Test that OpenRouter-specific headers are added."""
299
  mock_response = Mock()
300
  mock_response.status_code = 200
301
  mock_response.json.return_value = {
302
  "choices": [{"message": {"content": "Test"}}],
303
- "usage": {}
304
  }
305
  mock_response.raise_for_status = Mock()
306
  mock_post.return_value = mock_response
307
-
308
  config = LLMConfig(
309
  provider="openrouter",
310
  api_key="test-key",
311
  model_name="test-model",
312
- base_url="https://api.openrouter.ai/api/v1"
313
  )
314
  service = LLMService([config])
315
-
316
  service.generate_response("Test")
317
-
318
  # Check headers
319
  args, kwargs = mock_post.call_args
320
  headers = kwargs["headers"]
321
  assert "HTTP-Referer" in headers
322
  assert "X-Title" in headers
323
- assert headers["HTTP-Referer"] == "https://github.com/sethmcknight/msse-ai-engineering"
 
 
 
 
4
  Tests for LLM integration and service functionality.
5
  """
6
 
7
+ from unittest.mock import Mock, patch
8
+
9
  import pytest
 
10
  import requests
11
+
12
+ from src.llm.llm_service import LLMConfig, LLMResponse, LLMService
13
 
14
 
15
  class TestLLMConfig:
 
21
  provider="openrouter",
22
  api_key="test-key",
23
  model_name="test-model",
24
+ base_url="https://test.com",
25
  )
26
+
27
  assert config.provider == "openrouter"
28
  assert config.api_key == "test-key"
29
  assert config.model_name == "test-model"
 
43
  model="test-model",
44
  usage={"tokens": 100},
45
  response_time=1.5,
46
+ success=True,
47
  )
48
+
49
  assert response.content == "Test response"
50
  assert response.provider == "openrouter"
51
  assert response.model == "test-model"
 
64
  provider="openrouter",
65
  api_key="test-key",
66
  model_name="test-model",
67
+ base_url="https://test.com",
68
  )
69
+
70
  service = LLMService([config])
71
+
72
  assert len(service.configs) == 1
73
  assert service.configs[0] == config
74
  assert service.current_config_index == 0
75
 
76
  def test_initialization_empty_configs_raises_error(self):
77
  """Test that empty configs raise ValueError."""
78
+ with pytest.raises(
79
+ ValueError, match="At least one LLM configuration must be provided"
80
+ ):
81
  LLMService([])
82
 
83
+ @patch.dict("os.environ", {"OPENROUTER_API_KEY": "test-openrouter-key"})
84
  def test_from_environment_with_openrouter_key(self):
85
  """Test creating service from environment with OpenRouter key."""
86
  service = LLMService.from_environment()
87
+
88
  assert len(service.configs) >= 1
89
  openrouter_config = next(
90
  (config for config in service.configs if config.provider == "openrouter"),
91
+ None,
92
  )
93
  assert openrouter_config is not None
94
  assert openrouter_config.api_key == "test-openrouter-key"
95
 
96
+ @patch.dict("os.environ", {"GROQ_API_KEY": "test-groq-key"})
97
  def test_from_environment_with_groq_key(self):
98
  """Test creating service from environment with Groq key."""
99
  service = LLMService.from_environment()
100
+
101
  assert len(service.configs) >= 1
102
  groq_config = next(
103
+ (config for config in service.configs if config.provider == "groq"), None
 
104
  )
105
  assert groq_config is not None
106
  assert groq_config.api_key == "test-groq-key"
107
 
108
+ @patch.dict("os.environ", {}, clear=True)
109
  def test_from_environment_no_keys_raises_error(self):
110
  """Test that no environment keys raise ValueError."""
111
  with pytest.raises(ValueError, match="No LLM API keys found in environment"):
112
  LLMService.from_environment()
113
 
114
+ @patch("requests.post")
115
  def test_successful_response_generation(self, mock_post):
116
  """Test successful response generation."""
117
  # Mock successful API response
118
  mock_response = Mock()
119
  mock_response.status_code = 200
120
  mock_response.json.return_value = {
121
+ "choices": [{"message": {"content": "Test response content"}}],
122
+ "usage": {"prompt_tokens": 50, "completion_tokens": 20},
 
 
123
  }
124
  mock_response.raise_for_status = Mock()
125
  mock_post.return_value = mock_response
126
+
127
  config = LLMConfig(
128
  provider="openrouter",
129
  api_key="test-key",
130
  model_name="test-model",
131
+ base_url="https://api.openrouter.ai/api/v1",
132
  )
133
  service = LLMService([config])
134
+
135
  result = service.generate_response("Test prompt")
136
+
137
  assert result.success is True
138
  assert result.content == "Test response content"
139
  assert result.provider == "openrouter"
140
  assert result.model == "test-model"
141
  assert result.usage == {"prompt_tokens": 50, "completion_tokens": 20}
142
  assert result.response_time > 0
143
+
144
  # Verify API call
145
  mock_post.assert_called_once()
146
  args, kwargs = mock_post.call_args
 
148
  assert kwargs["json"]["model"] == "test-model"
149
  assert kwargs["json"]["messages"][0]["content"] == "Test prompt"
150
 
151
+ @patch("requests.post")
152
  def test_api_error_handling(self, mock_post):
153
  """Test handling of API errors."""
154
  # Mock API error
155
  mock_post.side_effect = requests.exceptions.RequestException("API Error")
156
+
157
  config = LLMConfig(
158
  provider="openrouter",
159
  api_key="test-key",
160
  model_name="test-model",
161
+ base_url="https://api.openrouter.ai/api/v1",
162
  )
163
  service = LLMService([config])
164
+
165
  result = service.generate_response("Test prompt")
166
+
167
  assert result.success is False
168
  assert "API Error" in result.error_message
169
  assert result.content == ""
170
  assert result.provider == "openrouter"
171
 
172
+ @patch("requests.post")
173
  def test_fallback_to_second_provider(self, mock_post):
174
  """Test fallback to second provider when first fails."""
175
  # Mock first provider failing, second succeeding
176
  first_call = Mock()
177
+ first_call.side_effect = requests.exceptions.RequestException(
178
+ "First provider error"
179
+ )
180
+
181
  second_call = Mock()
182
  second_response = Mock()
183
  second_response.status_code = 200
184
  second_response.json.return_value = {
185
  "choices": [{"message": {"content": "Second provider response"}}],
186
+ "usage": {},
187
  }
188
  second_response.raise_for_status = Mock()
189
  second_call.return_value = second_response
190
+
191
  mock_post.side_effect = [first_call.side_effect, second_response]
192
+
193
  config1 = LLMConfig(
194
  provider="openrouter",
195
  api_key="key1",
196
  model_name="model1",
197
+ base_url="https://api1.com",
198
  )
199
  config2 = LLMConfig(
200
  provider="groq",
201
  api_key="key2",
202
  model_name="model2",
203
+ base_url="https://api2.com",
204
  )
205
+
206
  service = LLMService([config1, config2])
207
  result = service.generate_response("Test prompt")
208
+
209
  assert result.success is True
210
  assert result.content == "Second provider response"
211
  assert result.provider == "groq"
212
  assert mock_post.call_count == 2
213
 
214
+ @patch("requests.post")
215
  def test_all_providers_fail(self, mock_post):
216
  """Test when all providers fail."""
217
+ mock_post.side_effect = requests.exceptions.RequestException(
218
+ "All providers down"
219
+ )
220
+
221
+ config1 = LLMConfig(
222
+ provider="provider1", api_key="key1", model_name="model1", base_url="url1"
223
+ )
224
+ config2 = LLMConfig(
225
+ provider="provider2", api_key="key2", model_name="model2", base_url="url2"
226
+ )
227
+
228
  service = LLMService([config1, config2])
229
  result = service.generate_response("Test prompt")
230
+
231
  assert result.success is False
232
  assert "All providers failed" in result.error_message
233
  assert result.provider == "none"
234
  assert result.model == "none"
235
 
236
+ @patch("requests.post")
237
  def test_retry_logic(self, mock_post):
238
  """Test retry logic for failed requests."""
239
  # First call fails, second succeeds
240
  first_response = Mock()
241
+ first_response.side_effect = requests.exceptions.RequestException(
242
+ "Temporary error"
243
+ )
244
+
245
  second_response = Mock()
246
  second_response.status_code = 200
247
  second_response.json.return_value = {
248
  "choices": [{"message": {"content": "Success after retry"}}],
249
+ "usage": {},
250
  }
251
  second_response.raise_for_status = Mock()
252
+
253
  mock_post.side_effect = [first_response.side_effect, second_response]
254
+
255
  config = LLMConfig(
256
  provider="openrouter",
257
  api_key="test-key",
258
  model_name="test-model",
259
+ base_url="https://api.openrouter.ai/api/v1",
260
  )
261
  service = LLMService([config])
262
+
263
  result = service.generate_response("Test prompt", max_retries=1)
264
+
265
  assert result.success is True
266
  assert result.content == "Success after retry"
267
  assert mock_post.call_count == 2
268
 
269
  def test_get_available_providers(self):
270
  """Test getting list of available providers."""
271
+ config1 = LLMConfig(
272
+ provider="openrouter", api_key="key1", model_name="model1", base_url="url1"
273
+ )
274
+ config2 = LLMConfig(
275
+ provider="groq", api_key="key2", model_name="model2", base_url="url2"
276
+ )
277
+
278
  service = LLMService([config1, config2])
279
  providers = service.get_available_providers()
280
+
281
  assert providers == ["openrouter", "groq"]
282
 
283
+ @patch("requests.post")
284
  def test_health_check(self, mock_post):
285
  """Test health check functionality."""
286
  # Mock successful health check
 
288
  mock_response.status_code = 200
289
  mock_response.json.return_value = {
290
  "choices": [{"message": {"content": "OK"}}],
291
+ "usage": {},
292
  }
293
  mock_response.raise_for_status = Mock()
294
  mock_post.return_value = mock_response
295
+
296
  config = LLMConfig(
297
  provider="openrouter",
298
  api_key="test-key",
299
  model_name="test-model",
300
+ base_url="https://api.openrouter.ai/api/v1",
301
  )
302
  service = LLMService([config])
303
+
304
  health_status = service.health_check()
305
+
306
  assert "openrouter" in health_status
307
  assert health_status["openrouter"]["status"] == "healthy"
308
  assert health_status["openrouter"]["model"] == "test-model"
309
  assert health_status["openrouter"]["response_time"] > 0
310
 
311
+ @patch("requests.post")
312
  def test_openrouter_specific_headers(self, mock_post):
313
  """Test that OpenRouter-specific headers are added."""
314
  mock_response = Mock()
315
  mock_response.status_code = 200
316
  mock_response.json.return_value = {
317
  "choices": [{"message": {"content": "Test"}}],
318
+ "usage": {},
319
  }
320
  mock_response.raise_for_status = Mock()
321
  mock_post.return_value = mock_response
322
+
323
  config = LLMConfig(
324
  provider="openrouter",
325
  api_key="test-key",
326
  model_name="test-model",
327
+ base_url="https://api.openrouter.ai/api/v1",
328
  )
329
  service = LLMService([config])
330
+
331
  service.generate_response("Test")
332
+
333
  # Check headers
334
  args, kwargs = mock_post.call_args
335
  headers = kwargs["headers"]
336
  assert "HTTP-Referer" in headers
337
  assert "X-Title" in headers
338
+ assert (
339
+ headers["HTTP-Referer"]
340
+ == "https://github.com/sethmcknight/msse-ai-engineering"
341
+ )
tests/test_rag/__init__.py CHANGED
@@ -1 +1 @@
1
- # RAG Pipeline Tests
 
1
+ # RAG Pipeline Tests