Spaces:
Sleeping
Sleeping
Tobias Pasquale
commited on
Commit
·
508a7e5
1
Parent(s):
c280a92
Fix: Complete CI/CD formatting compliance
Browse files- Apply black code formatting to 12 files
- Fix import ordering with isort
- Remove unused imports (Union, MagicMock, json, asdict, PromptTemplate)
- Fix undefined variables in test_chat_endpoint.py
- Break long lines in RAG pipeline and response formatter
- Add noqa comments for prompt template strings
- Resolve all 19 flake8 E501 line length violations
- Ensure full pre-commit hook compliance
All code formatting issues resolved for successful pipeline deployment.
- app.py +72 -57
- src/llm/__init__.py +1 -1
- src/llm/context_manager.py +53 -55
- src/llm/llm_service.py +97 -93
- src/llm/prompt_templates.py +48 -47
- src/rag/__init__.py +1 -1
- src/rag/rag_pipeline.py +120 -100
- src/rag/response_formatter.py +83 -69
- tests/test_chat_endpoint.py +172 -111
- tests/test_llm/__init__.py +1 -1
- tests/test_llm/test_llm_service.py +94 -76
- tests/test_rag/__init__.py +1 -1
app.py
CHANGED
|
@@ -168,7 +168,7 @@ def search():
|
|
| 168 |
def chat():
|
| 169 |
"""
|
| 170 |
Endpoint for conversational RAG interactions.
|
| 171 |
-
|
| 172 |
Accepts JSON requests with user messages and returns AI-generated
|
| 173 |
responses based on corporate policy documents.
|
| 174 |
"""
|
|
@@ -176,10 +176,12 @@ def chat():
|
|
| 176 |
# Validate request contains JSON data
|
| 177 |
if not request.is_json:
|
| 178 |
return (
|
| 179 |
-
jsonify(
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
400,
|
| 184 |
)
|
| 185 |
|
|
@@ -189,19 +191,17 @@ def chat():
|
|
| 189 |
message = data.get("message")
|
| 190 |
if message is None:
|
| 191 |
return (
|
| 192 |
-
jsonify(
|
| 193 |
-
"status": "error",
|
| 194 |
-
|
| 195 |
-
}),
|
| 196 |
400,
|
| 197 |
)
|
| 198 |
|
| 199 |
if not isinstance(message, str) or not message.strip():
|
| 200 |
return (
|
| 201 |
-
jsonify(
|
| 202 |
-
"status": "error",
|
| 203 |
-
|
| 204 |
-
}),
|
| 205 |
400,
|
| 206 |
)
|
| 207 |
|
|
@@ -214,96 +214,103 @@ def chat():
|
|
| 214 |
try:
|
| 215 |
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
|
| 216 |
from src.embedding.embedding_service import EmbeddingService
|
| 217 |
-
from src.search.search_service import SearchService
|
| 218 |
-
from src.vector_store.vector_db import VectorDatabase
|
| 219 |
from src.llm.llm_service import LLMService
|
| 220 |
from src.rag.rag_pipeline import RAGPipeline
|
| 221 |
from src.rag.response_formatter import ResponseFormatter
|
|
|
|
|
|
|
| 222 |
|
| 223 |
# Initialize services
|
| 224 |
vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
|
| 225 |
embedding_service = EmbeddingService()
|
| 226 |
search_service = SearchService(vector_db, embedding_service)
|
| 227 |
-
|
| 228 |
# Initialize LLM service from environment
|
| 229 |
llm_service = LLMService.from_environment()
|
| 230 |
-
|
| 231 |
# Initialize RAG pipeline
|
| 232 |
rag_pipeline = RAGPipeline(search_service, llm_service)
|
| 233 |
-
|
| 234 |
# Initialize response formatter
|
| 235 |
formatter = ResponseFormatter()
|
| 236 |
-
|
| 237 |
except ValueError as e:
|
| 238 |
return (
|
| 239 |
-
jsonify(
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
503,
|
| 245 |
)
|
| 246 |
except Exception as e:
|
| 247 |
return (
|
| 248 |
-
jsonify(
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
| 252 |
500,
|
| 253 |
)
|
| 254 |
|
| 255 |
# Generate RAG response
|
| 256 |
rag_response = rag_pipeline.generate_answer(message.strip())
|
| 257 |
-
|
| 258 |
# Format response for API
|
| 259 |
if include_sources:
|
| 260 |
-
formatted_response = formatter.format_api_response(
|
|
|
|
|
|
|
| 261 |
else:
|
| 262 |
formatted_response = formatter.format_chat_response(
|
| 263 |
-
rag_response,
|
| 264 |
-
conversation_id,
|
| 265 |
-
include_sources=False
|
| 266 |
)
|
| 267 |
|
| 268 |
return jsonify(formatted_response)
|
| 269 |
|
| 270 |
except Exception as e:
|
| 271 |
-
return
|
| 272 |
-
"status": "error",
|
| 273 |
-
|
| 274 |
-
|
| 275 |
|
| 276 |
|
| 277 |
@app.route("/chat/health", methods=["GET"])
|
| 278 |
def chat_health():
|
| 279 |
"""
|
| 280 |
Health check endpoint for RAG chat functionality.
|
| 281 |
-
|
| 282 |
Returns the status of all RAG pipeline components.
|
| 283 |
"""
|
| 284 |
try:
|
| 285 |
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
|
| 286 |
from src.embedding.embedding_service import EmbeddingService
|
| 287 |
-
from src.search.search_service import SearchService
|
| 288 |
-
from src.vector_store.vector_db import VectorDatabase
|
| 289 |
from src.llm.llm_service import LLMService
|
| 290 |
from src.rag.rag_pipeline import RAGPipeline
|
| 291 |
from src.rag.response_formatter import ResponseFormatter
|
|
|
|
|
|
|
| 292 |
|
| 293 |
# Initialize services for health check
|
| 294 |
vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
|
| 295 |
embedding_service = EmbeddingService()
|
| 296 |
search_service = SearchService(vector_db, embedding_service)
|
| 297 |
-
|
| 298 |
try:
|
| 299 |
llm_service = LLMService.from_environment()
|
| 300 |
rag_pipeline = RAGPipeline(search_service, llm_service)
|
| 301 |
formatter = ResponseFormatter()
|
| 302 |
-
|
| 303 |
# Perform health check
|
| 304 |
health_data = rag_pipeline.health_check()
|
| 305 |
health_response = formatter.create_health_response(health_data)
|
| 306 |
-
|
| 307 |
# Determine HTTP status based on health
|
| 308 |
if health_data.get("pipeline") == "healthy":
|
| 309 |
return jsonify(health_response), 200
|
|
@@ -311,24 +318,32 @@ def chat_health():
|
|
| 311 |
return jsonify(health_response), 200 # Still functional
|
| 312 |
else:
|
| 313 |
return jsonify(health_response), 503 # Service unavailable
|
| 314 |
-
|
| 315 |
except ValueError as e:
|
| 316 |
-
return
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
}
|
| 324 |
-
|
| 325 |
-
|
|
|
|
| 326 |
|
| 327 |
except Exception as e:
|
| 328 |
-
return
|
| 329 |
-
"status": "error",
|
| 330 |
-
|
| 331 |
-
|
| 332 |
|
| 333 |
|
| 334 |
if __name__ == "__main__":
|
|
|
|
| 168 |
def chat():
|
| 169 |
"""
|
| 170 |
Endpoint for conversational RAG interactions.
|
| 171 |
+
|
| 172 |
Accepts JSON requests with user messages and returns AI-generated
|
| 173 |
responses based on corporate policy documents.
|
| 174 |
"""
|
|
|
|
| 176 |
# Validate request contains JSON data
|
| 177 |
if not request.is_json:
|
| 178 |
return (
|
| 179 |
+
jsonify(
|
| 180 |
+
{
|
| 181 |
+
"status": "error",
|
| 182 |
+
"message": "Content-Type must be application/json",
|
| 183 |
+
}
|
| 184 |
+
),
|
| 185 |
400,
|
| 186 |
)
|
| 187 |
|
|
|
|
| 191 |
message = data.get("message")
|
| 192 |
if message is None:
|
| 193 |
return (
|
| 194 |
+
jsonify(
|
| 195 |
+
{"status": "error", "message": "message parameter is required"}
|
| 196 |
+
),
|
|
|
|
| 197 |
400,
|
| 198 |
)
|
| 199 |
|
| 200 |
if not isinstance(message, str) or not message.strip():
|
| 201 |
return (
|
| 202 |
+
jsonify(
|
| 203 |
+
{"status": "error", "message": "message must be a non-empty string"}
|
| 204 |
+
),
|
|
|
|
| 205 |
400,
|
| 206 |
)
|
| 207 |
|
|
|
|
| 214 |
try:
|
| 215 |
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
|
| 216 |
from src.embedding.embedding_service import EmbeddingService
|
|
|
|
|
|
|
| 217 |
from src.llm.llm_service import LLMService
|
| 218 |
from src.rag.rag_pipeline import RAGPipeline
|
| 219 |
from src.rag.response_formatter import ResponseFormatter
|
| 220 |
+
from src.search.search_service import SearchService
|
| 221 |
+
from src.vector_store.vector_db import VectorDatabase
|
| 222 |
|
| 223 |
# Initialize services
|
| 224 |
vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
|
| 225 |
embedding_service = EmbeddingService()
|
| 226 |
search_service = SearchService(vector_db, embedding_service)
|
| 227 |
+
|
| 228 |
# Initialize LLM service from environment
|
| 229 |
llm_service = LLMService.from_environment()
|
| 230 |
+
|
| 231 |
# Initialize RAG pipeline
|
| 232 |
rag_pipeline = RAGPipeline(search_service, llm_service)
|
| 233 |
+
|
| 234 |
# Initialize response formatter
|
| 235 |
formatter = ResponseFormatter()
|
| 236 |
+
|
| 237 |
except ValueError as e:
|
| 238 |
return (
|
| 239 |
+
jsonify(
|
| 240 |
+
{
|
| 241 |
+
"status": "error",
|
| 242 |
+
"message": f"LLM service configuration error: {str(e)}",
|
| 243 |
+
"details": (
|
| 244 |
+
"Please ensure OPENROUTER_API_KEY or GROQ_API_KEY "
|
| 245 |
+
"environment variables are set"
|
| 246 |
+
),
|
| 247 |
+
}
|
| 248 |
+
),
|
| 249 |
503,
|
| 250 |
)
|
| 251 |
except Exception as e:
|
| 252 |
return (
|
| 253 |
+
jsonify(
|
| 254 |
+
{
|
| 255 |
+
"status": "error",
|
| 256 |
+
"message": f"Service initialization failed: {str(e)}",
|
| 257 |
+
}
|
| 258 |
+
),
|
| 259 |
500,
|
| 260 |
)
|
| 261 |
|
| 262 |
# Generate RAG response
|
| 263 |
rag_response = rag_pipeline.generate_answer(message.strip())
|
| 264 |
+
|
| 265 |
# Format response for API
|
| 266 |
if include_sources:
|
| 267 |
+
formatted_response = formatter.format_api_response(
|
| 268 |
+
rag_response, include_debug
|
| 269 |
+
)
|
| 270 |
else:
|
| 271 |
formatted_response = formatter.format_chat_response(
|
| 272 |
+
rag_response, conversation_id, include_sources=False
|
|
|
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
return jsonify(formatted_response)
|
| 276 |
|
| 277 |
except Exception as e:
|
| 278 |
+
return (
|
| 279 |
+
jsonify({"status": "error", "message": f"Chat request failed: {str(e)}"}),
|
| 280 |
+
500,
|
| 281 |
+
)
|
| 282 |
|
| 283 |
|
| 284 |
@app.route("/chat/health", methods=["GET"])
|
| 285 |
def chat_health():
|
| 286 |
"""
|
| 287 |
Health check endpoint for RAG chat functionality.
|
| 288 |
+
|
| 289 |
Returns the status of all RAG pipeline components.
|
| 290 |
"""
|
| 291 |
try:
|
| 292 |
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
|
| 293 |
from src.embedding.embedding_service import EmbeddingService
|
|
|
|
|
|
|
| 294 |
from src.llm.llm_service import LLMService
|
| 295 |
from src.rag.rag_pipeline import RAGPipeline
|
| 296 |
from src.rag.response_formatter import ResponseFormatter
|
| 297 |
+
from src.search.search_service import SearchService
|
| 298 |
+
from src.vector_store.vector_db import VectorDatabase
|
| 299 |
|
| 300 |
# Initialize services for health check
|
| 301 |
vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
|
| 302 |
embedding_service = EmbeddingService()
|
| 303 |
search_service = SearchService(vector_db, embedding_service)
|
| 304 |
+
|
| 305 |
try:
|
| 306 |
llm_service = LLMService.from_environment()
|
| 307 |
rag_pipeline = RAGPipeline(search_service, llm_service)
|
| 308 |
formatter = ResponseFormatter()
|
| 309 |
+
|
| 310 |
# Perform health check
|
| 311 |
health_data = rag_pipeline.health_check()
|
| 312 |
health_response = formatter.create_health_response(health_data)
|
| 313 |
+
|
| 314 |
# Determine HTTP status based on health
|
| 315 |
if health_data.get("pipeline") == "healthy":
|
| 316 |
return jsonify(health_response), 200
|
|
|
|
| 318 |
return jsonify(health_response), 200 # Still functional
|
| 319 |
else:
|
| 320 |
return jsonify(health_response), 503 # Service unavailable
|
| 321 |
+
|
| 322 |
except ValueError as e:
|
| 323 |
+
return (
|
| 324 |
+
jsonify(
|
| 325 |
+
{
|
| 326 |
+
"status": "error",
|
| 327 |
+
"message": f"LLM configuration error: {str(e)}",
|
| 328 |
+
"health": {
|
| 329 |
+
"pipeline_status": "unhealthy",
|
| 330 |
+
"components": {
|
| 331 |
+
"llm_service": {
|
| 332 |
+
"status": "unconfigured",
|
| 333 |
+
"error": str(e),
|
| 334 |
+
}
|
| 335 |
+
},
|
| 336 |
+
},
|
| 337 |
}
|
| 338 |
+
),
|
| 339 |
+
503,
|
| 340 |
+
)
|
| 341 |
|
| 342 |
except Exception as e:
|
| 343 |
+
return (
|
| 344 |
+
jsonify({"status": "error", "message": f"Health check failed: {str(e)}"}),
|
| 345 |
+
500,
|
| 346 |
+
)
|
| 347 |
|
| 348 |
|
| 349 |
if __name__ == "__main__":
|
src/llm/__init__.py
CHANGED
|
@@ -8,4 +8,4 @@ Classes:
|
|
| 8 |
LLMService: Main service for LLM interactions
|
| 9 |
PromptTemplates: Predefined prompt templates for corporate policy Q&A
|
| 10 |
ContextManager: Manages context retrieval and formatting
|
| 11 |
-
"""
|
|
|
|
| 8 |
LLMService: Main service for LLM interactions
|
| 9 |
PromptTemplates: Predefined prompt templates for corporate policy Q&A
|
| 10 |
ContextManager: Manages context retrieval and formatting
|
| 11 |
+
"""
|
src/llm/context_manager.py
CHANGED
|
@@ -6,8 +6,8 @@ for the RAG pipeline, ensuring optimal context window utilization.
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import logging
|
| 9 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 10 |
from dataclasses import dataclass
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
|
|
| 15 |
@dataclass
|
| 16 |
class ContextConfig:
|
| 17 |
"""Configuration for context management."""
|
|
|
|
| 18 |
max_context_length: int = 3000 # Maximum characters in context
|
| 19 |
max_results: int = 5 # Maximum search results to include
|
| 20 |
min_similarity: float = 0.1 # Minimum similarity threshold
|
|
@@ -24,7 +25,7 @@ class ContextConfig:
|
|
| 24 |
class ContextManager:
|
| 25 |
"""
|
| 26 |
Manages context retrieval and optimization for RAG pipeline.
|
| 27 |
-
|
| 28 |
Handles:
|
| 29 |
- Context length management
|
| 30 |
- Relevance filtering
|
|
@@ -35,7 +36,7 @@ class ContextManager:
|
|
| 35 |
def __init__(self, config: Optional[ContextConfig] = None):
|
| 36 |
"""
|
| 37 |
Initialize ContextManager with configuration.
|
| 38 |
-
|
| 39 |
Args:
|
| 40 |
config: Context configuration, uses defaults if None
|
| 41 |
"""
|
|
@@ -43,17 +44,15 @@ class ContextManager:
|
|
| 43 |
logger.info("ContextManager initialized")
|
| 44 |
|
| 45 |
def prepare_context(
|
| 46 |
-
self,
|
| 47 |
-
search_results: List[Dict[str, Any]],
|
| 48 |
-
query: str
|
| 49 |
) -> Tuple[str, List[Dict[str, Any]]]:
|
| 50 |
"""
|
| 51 |
Prepare optimized context from search results.
|
| 52 |
-
|
| 53 |
Args:
|
| 54 |
search_results: Results from SearchService
|
| 55 |
query: Original user query for context optimization
|
| 56 |
-
|
| 57 |
Returns:
|
| 58 |
Tuple of (formatted_context, filtered_results)
|
| 59 |
"""
|
|
@@ -62,56 +61,58 @@ class ContextManager:
|
|
| 62 |
|
| 63 |
# Filter and rank results
|
| 64 |
filtered_results = self._filter_results(search_results)
|
| 65 |
-
|
| 66 |
# Remove duplicates and optimize for context window
|
| 67 |
optimized_results = self._optimize_context(filtered_results)
|
| 68 |
-
|
| 69 |
# Format for prompt
|
| 70 |
formatted_context = self._format_context(optimized_results)
|
| 71 |
-
|
| 72 |
logger.debug(
|
| 73 |
f"Prepared context from {len(search_results)} results, "
|
| 74 |
f"filtered to {len(optimized_results)} results, "
|
| 75 |
f"{len(formatted_context)} characters"
|
| 76 |
)
|
| 77 |
-
|
| 78 |
return formatted_context, optimized_results
|
| 79 |
|
| 80 |
def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 81 |
"""
|
| 82 |
Filter search results by relevance and quality.
|
| 83 |
-
|
| 84 |
Args:
|
| 85 |
results: Raw search results
|
| 86 |
-
|
| 87 |
Returns:
|
| 88 |
Filtered and sorted results
|
| 89 |
"""
|
| 90 |
filtered = []
|
| 91 |
-
|
| 92 |
for result in results:
|
| 93 |
similarity = result.get("similarity_score", 0.0)
|
| 94 |
content = result.get("content", "").strip()
|
| 95 |
-
|
| 96 |
# Apply filters
|
| 97 |
-
if (
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
filtered.append(result)
|
| 101 |
-
|
| 102 |
# Sort by similarity score (descending)
|
| 103 |
filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True)
|
| 104 |
-
|
| 105 |
# Limit to max results
|
| 106 |
-
return filtered[:self.config.max_results]
|
| 107 |
|
| 108 |
def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 109 |
"""
|
| 110 |
Optimize context to fit within token limits while maximizing relevance.
|
| 111 |
-
|
| 112 |
Args:
|
| 113 |
results: Filtered search results
|
| 114 |
-
|
| 115 |
Returns:
|
| 116 |
Optimized results list
|
| 117 |
"""
|
|
@@ -125,7 +126,7 @@ class ContextManager:
|
|
| 125 |
for result in results:
|
| 126 |
content = result.get("content", "").strip()
|
| 127 |
content_length = len(content)
|
| 128 |
-
|
| 129 |
# Check if adding this result would exceed limit
|
| 130 |
estimated_formatted_length = current_length + content_length + 100 # Buffer
|
| 131 |
if estimated_formatted_length > self.config.max_context_length:
|
|
@@ -137,18 +138,21 @@ class ContextManager:
|
|
| 137 |
result_copy["content"] = truncated_content
|
| 138 |
optimized.append(result_copy)
|
| 139 |
break
|
| 140 |
-
|
| 141 |
# Check for duplicate or highly similar content
|
| 142 |
content_lower = content.lower()
|
| 143 |
is_duplicate = False
|
| 144 |
-
|
| 145 |
for seen in seen_content:
|
| 146 |
# Simple similarity check for duplicates
|
| 147 |
-
if (
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
is_duplicate = True
|
| 150 |
break
|
| 151 |
-
|
| 152 |
if not is_duplicate:
|
| 153 |
optimized.append(result)
|
| 154 |
seen_content.add(content_lower)
|
|
@@ -159,10 +163,10 @@ class ContextManager:
|
|
| 159 |
def _format_context(self, results: List[Dict[str, Any]]) -> str:
|
| 160 |
"""
|
| 161 |
Format optimized results into context string.
|
| 162 |
-
|
| 163 |
Args:
|
| 164 |
results: Optimized search results
|
| 165 |
-
|
| 166 |
Returns:
|
| 167 |
Formatted context string
|
| 168 |
"""
|
|
@@ -170,34 +174,28 @@ class ContextManager:
|
|
| 170 |
return "No relevant information found in corporate policies."
|
| 171 |
|
| 172 |
context_parts = []
|
| 173 |
-
|
| 174 |
for i, result in enumerate(results, 1):
|
| 175 |
metadata = result.get("metadata", {})
|
| 176 |
filename = metadata.get("filename", f"document_{i}")
|
| 177 |
content = result.get("content", "").strip()
|
| 178 |
-
|
| 179 |
# Format with document info
|
| 180 |
-
context_parts.append(
|
| 181 |
-
f"Document: {filename}\n"
|
| 182 |
-
f"Content: {content}"
|
| 183 |
-
)
|
| 184 |
|
| 185 |
return "\n\n---\n\n".join(context_parts)
|
| 186 |
|
| 187 |
def validate_context_quality(
|
| 188 |
-
self,
|
| 189 |
-
context: str,
|
| 190 |
-
query: str,
|
| 191 |
-
min_quality_score: float = 0.3
|
| 192 |
) -> Dict[str, Any]:
|
| 193 |
"""
|
| 194 |
Validate the quality of prepared context for a given query.
|
| 195 |
-
|
| 196 |
Args:
|
| 197 |
context: Formatted context string
|
| 198 |
query: Original user query
|
| 199 |
min_quality_score: Minimum acceptable quality score
|
| 200 |
-
|
| 201 |
Returns:
|
| 202 |
Dictionary with quality metrics and validation result
|
| 203 |
"""
|
|
@@ -206,7 +204,7 @@ class ContextManager:
|
|
| 206 |
"length": len(context),
|
| 207 |
"has_content": bool(context.strip()),
|
| 208 |
"estimated_relevance": 0.0,
|
| 209 |
-
"passes_validation": False
|
| 210 |
}
|
| 211 |
|
| 212 |
if not context.strip():
|
|
@@ -216,7 +214,7 @@ class ContextManager:
|
|
| 216 |
# Estimate relevance based on query term overlap
|
| 217 |
query_terms = set(query.lower().split())
|
| 218 |
context_terms = set(context.lower().split())
|
| 219 |
-
|
| 220 |
if query_terms and context_terms:
|
| 221 |
overlap = len(query_terms & context_terms)
|
| 222 |
relevance = overlap / len(query_terms)
|
|
@@ -230,36 +228,36 @@ class ContextManager:
|
|
| 230 |
def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 231 |
"""
|
| 232 |
Generate summary of sources used in context.
|
| 233 |
-
|
| 234 |
Args:
|
| 235 |
results: Search results used for context
|
| 236 |
-
|
| 237 |
Returns:
|
| 238 |
Summary of sources and their contribution
|
| 239 |
"""
|
| 240 |
sources = {}
|
| 241 |
total_content_length = 0
|
| 242 |
-
|
| 243 |
for result in results:
|
| 244 |
metadata = result.get("metadata", {})
|
| 245 |
filename = metadata.get("filename", "unknown")
|
| 246 |
content_length = len(result.get("content", ""))
|
| 247 |
similarity = result.get("similarity_score", 0.0)
|
| 248 |
-
|
| 249 |
if filename not in sources:
|
| 250 |
sources[filename] = {
|
| 251 |
"chunks": 0,
|
| 252 |
"total_content_length": 0,
|
| 253 |
"max_similarity": 0.0,
|
| 254 |
-
"avg_similarity": 0.0
|
| 255 |
}
|
| 256 |
-
|
| 257 |
sources[filename]["chunks"] += 1
|
| 258 |
sources[filename]["total_content_length"] += content_length
|
| 259 |
sources[filename]["max_similarity"] = max(
|
| 260 |
sources[filename]["max_similarity"], similarity
|
| 261 |
)
|
| 262 |
-
|
| 263 |
total_content_length += content_length
|
| 264 |
|
| 265 |
# Calculate averages and percentages
|
|
@@ -272,5 +270,5 @@ class ContextManager:
|
|
| 272 |
"total_sources": len(sources),
|
| 273 |
"total_chunks": len(results),
|
| 274 |
"total_content_length": total_content_length,
|
| 275 |
-
"sources": sources
|
| 276 |
-
}
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import logging
|
|
|
|
| 9 |
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
|
|
| 15 |
@dataclass
|
| 16 |
class ContextConfig:
|
| 17 |
"""Configuration for context management."""
|
| 18 |
+
|
| 19 |
max_context_length: int = 3000 # Maximum characters in context
|
| 20 |
max_results: int = 5 # Maximum search results to include
|
| 21 |
min_similarity: float = 0.1 # Minimum similarity threshold
|
|
|
|
| 25 |
class ContextManager:
|
| 26 |
"""
|
| 27 |
Manages context retrieval and optimization for RAG pipeline.
|
| 28 |
+
|
| 29 |
Handles:
|
| 30 |
- Context length management
|
| 31 |
- Relevance filtering
|
|
|
|
| 36 |
def __init__(self, config: Optional[ContextConfig] = None):
|
| 37 |
"""
|
| 38 |
Initialize ContextManager with configuration.
|
| 39 |
+
|
| 40 |
Args:
|
| 41 |
config: Context configuration, uses defaults if None
|
| 42 |
"""
|
|
|
|
| 44 |
logger.info("ContextManager initialized")
|
| 45 |
|
| 46 |
def prepare_context(
|
| 47 |
+
self, search_results: List[Dict[str, Any]], query: str
|
|
|
|
|
|
|
| 48 |
) -> Tuple[str, List[Dict[str, Any]]]:
|
| 49 |
"""
|
| 50 |
Prepare optimized context from search results.
|
| 51 |
+
|
| 52 |
Args:
|
| 53 |
search_results: Results from SearchService
|
| 54 |
query: Original user query for context optimization
|
| 55 |
+
|
| 56 |
Returns:
|
| 57 |
Tuple of (formatted_context, filtered_results)
|
| 58 |
"""
|
|
|
|
| 61 |
|
| 62 |
# Filter and rank results
|
| 63 |
filtered_results = self._filter_results(search_results)
|
| 64 |
+
|
| 65 |
# Remove duplicates and optimize for context window
|
| 66 |
optimized_results = self._optimize_context(filtered_results)
|
| 67 |
+
|
| 68 |
# Format for prompt
|
| 69 |
formatted_context = self._format_context(optimized_results)
|
| 70 |
+
|
| 71 |
logger.debug(
|
| 72 |
f"Prepared context from {len(search_results)} results, "
|
| 73 |
f"filtered to {len(optimized_results)} results, "
|
| 74 |
f"{len(formatted_context)} characters"
|
| 75 |
)
|
| 76 |
+
|
| 77 |
return formatted_context, optimized_results
|
| 78 |
|
| 79 |
def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 80 |
"""
|
| 81 |
Filter search results by relevance and quality.
|
| 82 |
+
|
| 83 |
Args:
|
| 84 |
results: Raw search results
|
| 85 |
+
|
| 86 |
Returns:
|
| 87 |
Filtered and sorted results
|
| 88 |
"""
|
| 89 |
filtered = []
|
| 90 |
+
|
| 91 |
for result in results:
|
| 92 |
similarity = result.get("similarity_score", 0.0)
|
| 93 |
content = result.get("content", "").strip()
|
| 94 |
+
|
| 95 |
# Apply filters
|
| 96 |
+
if (
|
| 97 |
+
similarity >= self.config.min_similarity
|
| 98 |
+
and content
|
| 99 |
+
and len(content) > 20
|
| 100 |
+
): # Minimum content length
|
| 101 |
filtered.append(result)
|
| 102 |
+
|
| 103 |
# Sort by similarity score (descending)
|
| 104 |
filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True)
|
| 105 |
+
|
| 106 |
# Limit to max results
|
| 107 |
+
return filtered[: self.config.max_results]
|
| 108 |
|
| 109 |
def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 110 |
"""
|
| 111 |
Optimize context to fit within token limits while maximizing relevance.
|
| 112 |
+
|
| 113 |
Args:
|
| 114 |
results: Filtered search results
|
| 115 |
+
|
| 116 |
Returns:
|
| 117 |
Optimized results list
|
| 118 |
"""
|
|
|
|
| 126 |
for result in results:
|
| 127 |
content = result.get("content", "").strip()
|
| 128 |
content_length = len(content)
|
| 129 |
+
|
| 130 |
# Check if adding this result would exceed limit
|
| 131 |
estimated_formatted_length = current_length + content_length + 100 # Buffer
|
| 132 |
if estimated_formatted_length > self.config.max_context_length:
|
|
|
|
| 138 |
result_copy["content"] = truncated_content
|
| 139 |
optimized.append(result_copy)
|
| 140 |
break
|
| 141 |
+
|
| 142 |
# Check for duplicate or highly similar content
|
| 143 |
content_lower = content.lower()
|
| 144 |
is_duplicate = False
|
| 145 |
+
|
| 146 |
for seen in seen_content:
|
| 147 |
# Simple similarity check for duplicates
|
| 148 |
+
if (
|
| 149 |
+
len(set(content_lower.split()) & set(seen.split()))
|
| 150 |
+
/ max(len(content_lower.split()), len(seen.split()))
|
| 151 |
+
> 0.8
|
| 152 |
+
):
|
| 153 |
is_duplicate = True
|
| 154 |
break
|
| 155 |
+
|
| 156 |
if not is_duplicate:
|
| 157 |
optimized.append(result)
|
| 158 |
seen_content.add(content_lower)
|
|
|
|
| 163 |
def _format_context(self, results: List[Dict[str, Any]]) -> str:
|
| 164 |
"""
|
| 165 |
Format optimized results into context string.
|
| 166 |
+
|
| 167 |
Args:
|
| 168 |
results: Optimized search results
|
| 169 |
+
|
| 170 |
Returns:
|
| 171 |
Formatted context string
|
| 172 |
"""
|
|
|
|
| 174 |
return "No relevant information found in corporate policies."
|
| 175 |
|
| 176 |
context_parts = []
|
| 177 |
+
|
| 178 |
for i, result in enumerate(results, 1):
|
| 179 |
metadata = result.get("metadata", {})
|
| 180 |
filename = metadata.get("filename", f"document_{i}")
|
| 181 |
content = result.get("content", "").strip()
|
| 182 |
+
|
| 183 |
# Format with document info
|
| 184 |
+
context_parts.append(f"Document: {filename}\n" f"Content: {content}")
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
return "\n\n---\n\n".join(context_parts)
|
| 187 |
|
| 188 |
def validate_context_quality(
|
| 189 |
+
self, context: str, query: str, min_quality_score: float = 0.3
|
|
|
|
|
|
|
|
|
|
| 190 |
) -> Dict[str, Any]:
|
| 191 |
"""
|
| 192 |
Validate the quality of prepared context for a given query.
|
| 193 |
+
|
| 194 |
Args:
|
| 195 |
context: Formatted context string
|
| 196 |
query: Original user query
|
| 197 |
min_quality_score: Minimum acceptable quality score
|
| 198 |
+
|
| 199 |
Returns:
|
| 200 |
Dictionary with quality metrics and validation result
|
| 201 |
"""
|
|
|
|
| 204 |
"length": len(context),
|
| 205 |
"has_content": bool(context.strip()),
|
| 206 |
"estimated_relevance": 0.0,
|
| 207 |
+
"passes_validation": False,
|
| 208 |
}
|
| 209 |
|
| 210 |
if not context.strip():
|
|
|
|
| 214 |
# Estimate relevance based on query term overlap
|
| 215 |
query_terms = set(query.lower().split())
|
| 216 |
context_terms = set(context.lower().split())
|
| 217 |
+
|
| 218 |
if query_terms and context_terms:
|
| 219 |
overlap = len(query_terms & context_terms)
|
| 220 |
relevance = overlap / len(query_terms)
|
|
|
|
| 228 |
def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 229 |
"""
|
| 230 |
Generate summary of sources used in context.
|
| 231 |
+
|
| 232 |
Args:
|
| 233 |
results: Search results used for context
|
| 234 |
+
|
| 235 |
Returns:
|
| 236 |
Summary of sources and their contribution
|
| 237 |
"""
|
| 238 |
sources = {}
|
| 239 |
total_content_length = 0
|
| 240 |
+
|
| 241 |
for result in results:
|
| 242 |
metadata = result.get("metadata", {})
|
| 243 |
filename = metadata.get("filename", "unknown")
|
| 244 |
content_length = len(result.get("content", ""))
|
| 245 |
similarity = result.get("similarity_score", 0.0)
|
| 246 |
+
|
| 247 |
if filename not in sources:
|
| 248 |
sources[filename] = {
|
| 249 |
"chunks": 0,
|
| 250 |
"total_content_length": 0,
|
| 251 |
"max_similarity": 0.0,
|
| 252 |
+
"avg_similarity": 0.0,
|
| 253 |
}
|
| 254 |
+
|
| 255 |
sources[filename]["chunks"] += 1
|
| 256 |
sources[filename]["total_content_length"] += content_length
|
| 257 |
sources[filename]["max_similarity"] = max(
|
| 258 |
sources[filename]["max_similarity"], similarity
|
| 259 |
)
|
| 260 |
+
|
| 261 |
total_content_length += content_length
|
| 262 |
|
| 263 |
# Calculate averages and percentages
|
|
|
|
| 270 |
"total_sources": len(sources),
|
| 271 |
"total_chunks": len(results),
|
| 272 |
"total_content_length": total_content_length,
|
| 273 |
+
"sources": sources,
|
| 274 |
+
}
|
src/llm/llm_service.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
LLM Service for RAG Application
|
| 3 |
|
| 4 |
-
This module provides integration with Large Language Models through multiple
|
| 5 |
-
including OpenRouter and Groq, with fallback capabilities and
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import logging
|
| 9 |
import os
|
| 10 |
import time
|
| 11 |
-
from typing import Any, Dict, List, Optional, Union
|
| 12 |
-
import requests
|
| 13 |
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
|
@@ -18,6 +20,7 @@ logger = logging.getLogger(__name__)
|
|
| 18 |
@dataclass
|
| 19 |
class LLMConfig:
|
| 20 |
"""Configuration for LLM providers."""
|
|
|
|
| 21 |
provider: str # "openrouter" or "groq"
|
| 22 |
api_key: str
|
| 23 |
model_name: str
|
|
@@ -30,6 +33,7 @@ class LLMConfig:
|
|
| 30 |
@dataclass
|
| 31 |
class LLMResponse:
|
| 32 |
"""Standardized response from LLM providers."""
|
|
|
|
| 33 |
content: str
|
| 34 |
provider: str
|
| 35 |
model: str
|
|
@@ -42,7 +46,7 @@ class LLMResponse:
|
|
| 42 |
class LLMService:
|
| 43 |
"""
|
| 44 |
Service for interacting with Large Language Models.
|
| 45 |
-
|
| 46 |
Supports multiple providers with automatic fallback and retry logic.
|
| 47 |
Designed for corporate policy Q&A with appropriate guardrails.
|
| 48 |
"""
|
|
@@ -50,108 +54,112 @@ class LLMService:
|
|
| 50 |
def __init__(self, configs: List[LLMConfig]):
|
| 51 |
"""
|
| 52 |
Initialize LLMService with provider configurations.
|
| 53 |
-
|
| 54 |
Args:
|
| 55 |
configs: List of LLMConfig objects for different providers
|
| 56 |
-
|
| 57 |
Raises:
|
| 58 |
ValueError: If no valid configurations provided
|
| 59 |
"""
|
| 60 |
if not configs:
|
| 61 |
raise ValueError("At least one LLM configuration must be provided")
|
| 62 |
-
|
| 63 |
self.configs = configs
|
| 64 |
self.current_config_index = 0
|
| 65 |
logger.info(f"LLMService initialized with {len(configs)} provider(s)")
|
| 66 |
|
| 67 |
@classmethod
|
| 68 |
-
def from_environment(cls) ->
|
| 69 |
"""
|
| 70 |
Create LLMService instance from environment variables.
|
| 71 |
-
|
| 72 |
Expected environment variables:
|
| 73 |
- OPENROUTER_API_KEY: API key for OpenRouter
|
| 74 |
- GROQ_API_KEY: API key for Groq
|
| 75 |
-
|
| 76 |
Returns:
|
| 77 |
LLMService instance with available providers
|
| 78 |
-
|
| 79 |
Raises:
|
| 80 |
ValueError: If no API keys found in environment
|
| 81 |
"""
|
| 82 |
configs = []
|
| 83 |
-
|
| 84 |
# OpenRouter configuration
|
| 85 |
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
| 86 |
if openrouter_key:
|
| 87 |
-
configs.append(
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
groq_key = os.getenv("GROQ_API_KEY")
|
| 98 |
if groq_key:
|
| 99 |
-
configs.append(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
if not configs:
|
| 109 |
raise ValueError(
|
| 110 |
"No LLM API keys found in environment. "
|
| 111 |
"Please set OPENROUTER_API_KEY or GROQ_API_KEY"
|
| 112 |
)
|
| 113 |
-
|
| 114 |
return cls(configs)
|
| 115 |
|
| 116 |
-
def generate_response(
|
| 117 |
-
self,
|
| 118 |
-
prompt: str,
|
| 119 |
-
max_retries: int = 2
|
| 120 |
-
) -> LLMResponse:
|
| 121 |
"""
|
| 122 |
Generate response from LLM with fallback support.
|
| 123 |
-
|
| 124 |
Args:
|
| 125 |
prompt: Input prompt for the LLM
|
| 126 |
max_retries: Maximum retry attempts per provider
|
| 127 |
-
|
| 128 |
Returns:
|
| 129 |
LLMResponse with generated content or error information
|
| 130 |
"""
|
| 131 |
last_error = None
|
| 132 |
-
|
| 133 |
# Try each provider configuration
|
| 134 |
for attempt in range(len(self.configs)):
|
| 135 |
config = self.configs[self.current_config_index]
|
| 136 |
-
|
| 137 |
try:
|
| 138 |
logger.debug(f"Attempting generation with {config.provider}")
|
| 139 |
response = self._call_provider(config, prompt, max_retries)
|
| 140 |
-
|
| 141 |
if response.success:
|
| 142 |
-
logger.info(
|
|
|
|
|
|
|
| 143 |
return response
|
| 144 |
-
|
| 145 |
last_error = response.error_message
|
| 146 |
logger.warning(f"Provider {config.provider} failed: {last_error}")
|
| 147 |
-
|
| 148 |
except Exception as e:
|
| 149 |
last_error = str(e)
|
| 150 |
logger.error(f"Error with provider {config.provider}: {last_error}")
|
| 151 |
-
|
| 152 |
# Move to next provider
|
| 153 |
-
self.current_config_index = (self.current_config_index + 1) % len(
|
| 154 |
-
|
|
|
|
|
|
|
| 155 |
# All providers failed
|
| 156 |
logger.error("All LLM providers failed")
|
| 157 |
return LLMResponse(
|
|
@@ -161,83 +169,79 @@ class LLMService:
|
|
| 161 |
usage={},
|
| 162 |
response_time=0.0,
|
| 163 |
success=False,
|
| 164 |
-
error_message=f"All providers failed. Last error: {last_error}"
|
| 165 |
)
|
| 166 |
|
| 167 |
def _call_provider(
|
| 168 |
-
self,
|
| 169 |
-
config: LLMConfig,
|
| 170 |
-
prompt: str,
|
| 171 |
-
max_retries: int
|
| 172 |
) -> LLMResponse:
|
| 173 |
"""
|
| 174 |
Make API call to specific provider with retry logic.
|
| 175 |
-
|
| 176 |
Args:
|
| 177 |
config: Provider configuration
|
| 178 |
prompt: Input prompt
|
| 179 |
max_retries: Maximum retry attempts
|
| 180 |
-
|
| 181 |
Returns:
|
| 182 |
LLMResponse from the provider
|
| 183 |
"""
|
| 184 |
start_time = time.time()
|
| 185 |
-
|
| 186 |
for attempt in range(max_retries + 1):
|
| 187 |
try:
|
| 188 |
headers = {
|
| 189 |
"Authorization": f"Bearer {config.api_key}",
|
| 190 |
-
"Content-Type": "application/json"
|
| 191 |
}
|
| 192 |
-
|
| 193 |
# Add provider-specific headers
|
| 194 |
if config.provider == "openrouter":
|
| 195 |
-
headers["HTTP-Referer"] =
|
|
|
|
|
|
|
| 196 |
headers["X-Title"] = "MSSE RAG Application"
|
| 197 |
-
|
| 198 |
payload = {
|
| 199 |
"model": config.model_name,
|
| 200 |
-
"messages": [
|
| 201 |
-
{
|
| 202 |
-
"role": "user",
|
| 203 |
-
"content": prompt
|
| 204 |
-
}
|
| 205 |
-
],
|
| 206 |
"max_tokens": config.max_tokens,
|
| 207 |
-
"temperature": config.temperature
|
| 208 |
}
|
| 209 |
-
|
| 210 |
response = requests.post(
|
| 211 |
f"{config.base_url}/chat/completions",
|
| 212 |
headers=headers,
|
| 213 |
json=payload,
|
| 214 |
-
timeout=config.timeout
|
| 215 |
)
|
| 216 |
-
|
| 217 |
response.raise_for_status()
|
| 218 |
data = response.json()
|
| 219 |
-
|
| 220 |
# Extract response content
|
| 221 |
content = data["choices"][0]["message"]["content"]
|
| 222 |
usage = data.get("usage", {})
|
| 223 |
-
|
| 224 |
response_time = time.time() - start_time
|
| 225 |
-
|
| 226 |
return LLMResponse(
|
| 227 |
content=content,
|
| 228 |
provider=config.provider,
|
| 229 |
model=config.model_name,
|
| 230 |
usage=usage,
|
| 231 |
response_time=response_time,
|
| 232 |
-
success=True
|
| 233 |
)
|
| 234 |
-
|
| 235 |
except requests.exceptions.RequestException as e:
|
| 236 |
-
logger.warning(
|
|
|
|
|
|
|
| 237 |
if attempt < max_retries:
|
| 238 |
-
time.sleep(2
|
| 239 |
continue
|
| 240 |
-
|
| 241 |
return LLMResponse(
|
| 242 |
content="",
|
| 243 |
provider=config.provider,
|
|
@@ -245,9 +249,9 @@ class LLMService:
|
|
| 245 |
usage={},
|
| 246 |
response_time=time.time() - start_time,
|
| 247 |
success=False,
|
| 248 |
-
error_message=str(e)
|
| 249 |
)
|
| 250 |
-
|
| 251 |
except Exception as e:
|
| 252 |
logger.error(f"Unexpected error with {config.provider}: {e}")
|
| 253 |
return LLMResponse(
|
|
@@ -257,44 +261,44 @@ class LLMService:
|
|
| 257 |
usage={},
|
| 258 |
response_time=time.time() - start_time,
|
| 259 |
success=False,
|
| 260 |
-
error_message=str(e)
|
| 261 |
)
|
| 262 |
|
| 263 |
def health_check(self) -> Dict[str, Any]:
|
| 264 |
"""
|
| 265 |
Check health status of all configured providers.
|
| 266 |
-
|
| 267 |
Returns:
|
| 268 |
Dictionary with provider health status
|
| 269 |
"""
|
| 270 |
health_status = {}
|
| 271 |
-
|
| 272 |
for config in self.configs:
|
| 273 |
try:
|
| 274 |
# Simple test prompt
|
| 275 |
test_response = self._call_provider(
|
| 276 |
-
config,
|
| 277 |
-
"Hello, this is a test. Please respond with 'OK'.",
|
| 278 |
-
max_retries=1
|
| 279 |
)
|
| 280 |
-
|
| 281 |
health_status[config.provider] = {
|
| 282 |
"status": "healthy" if test_response.success else "unhealthy",
|
| 283 |
"model": config.model_name,
|
| 284 |
"response_time": test_response.response_time,
|
| 285 |
-
"error": test_response.error_message
|
| 286 |
}
|
| 287 |
-
|
| 288 |
except Exception as e:
|
| 289 |
health_status[config.provider] = {
|
| 290 |
"status": "unhealthy",
|
| 291 |
"model": config.model_name,
|
| 292 |
"response_time": 0.0,
|
| 293 |
-
"error": str(e)
|
| 294 |
}
|
| 295 |
-
|
| 296 |
return health_status
|
| 297 |
|
| 298 |
def get_available_providers(self) -> List[str]:
|
| 299 |
"""Get list of available provider names."""
|
| 300 |
-
return [config.provider for config in self.configs]
|
|
|
|
| 1 |
"""
|
| 2 |
LLM Service for RAG Application
|
| 3 |
|
| 4 |
+
This module provides integration with Large Language Models through multiple
|
| 5 |
+
providers including OpenRouter and Groq, with fallback capabilities and
|
| 6 |
+
comprehensive error handling.
|
| 7 |
"""
|
| 8 |
|
| 9 |
import logging
|
| 10 |
import os
|
| 11 |
import time
|
|
|
|
|
|
|
| 12 |
from dataclasses import dataclass
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
|
|
|
| 20 |
@dataclass
|
| 21 |
class LLMConfig:
|
| 22 |
"""Configuration for LLM providers."""
|
| 23 |
+
|
| 24 |
provider: str # "openrouter" or "groq"
|
| 25 |
api_key: str
|
| 26 |
model_name: str
|
|
|
|
| 33 |
@dataclass
|
| 34 |
class LLMResponse:
|
| 35 |
"""Standardized response from LLM providers."""
|
| 36 |
+
|
| 37 |
content: str
|
| 38 |
provider: str
|
| 39 |
model: str
|
|
|
|
| 46 |
class LLMService:
|
| 47 |
"""
|
| 48 |
Service for interacting with Large Language Models.
|
| 49 |
+
|
| 50 |
Supports multiple providers with automatic fallback and retry logic.
|
| 51 |
Designed for corporate policy Q&A with appropriate guardrails.
|
| 52 |
"""
|
|
|
|
| 54 |
def __init__(self, configs: List[LLMConfig]):
|
| 55 |
"""
|
| 56 |
Initialize LLMService with provider configurations.
|
| 57 |
+
|
| 58 |
Args:
|
| 59 |
configs: List of LLMConfig objects for different providers
|
| 60 |
+
|
| 61 |
Raises:
|
| 62 |
ValueError: If no valid configurations provided
|
| 63 |
"""
|
| 64 |
if not configs:
|
| 65 |
raise ValueError("At least one LLM configuration must be provided")
|
| 66 |
+
|
| 67 |
self.configs = configs
|
| 68 |
self.current_config_index = 0
|
| 69 |
logger.info(f"LLMService initialized with {len(configs)} provider(s)")
|
| 70 |
|
| 71 |
@classmethod
|
| 72 |
+
def from_environment(cls) -> "LLMService":
|
| 73 |
"""
|
| 74 |
Create LLMService instance from environment variables.
|
| 75 |
+
|
| 76 |
Expected environment variables:
|
| 77 |
- OPENROUTER_API_KEY: API key for OpenRouter
|
| 78 |
- GROQ_API_KEY: API key for Groq
|
| 79 |
+
|
| 80 |
Returns:
|
| 81 |
LLMService instance with available providers
|
| 82 |
+
|
| 83 |
Raises:
|
| 84 |
ValueError: If no API keys found in environment
|
| 85 |
"""
|
| 86 |
configs = []
|
| 87 |
+
|
| 88 |
# OpenRouter configuration
|
| 89 |
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
| 90 |
if openrouter_key:
|
| 91 |
+
configs.append(
|
| 92 |
+
LLMConfig(
|
| 93 |
+
provider="openrouter",
|
| 94 |
+
api_key=openrouter_key,
|
| 95 |
+
model_name="microsoft/wizardlm-2-8x22b", # Free tier model
|
| 96 |
+
base_url="https://openrouter.ai/api/v1",
|
| 97 |
+
max_tokens=1000,
|
| 98 |
+
temperature=0.1,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Groq configuration
|
| 103 |
groq_key = os.getenv("GROQ_API_KEY")
|
| 104 |
if groq_key:
|
| 105 |
+
configs.append(
|
| 106 |
+
LLMConfig(
|
| 107 |
+
provider="groq",
|
| 108 |
+
api_key=groq_key,
|
| 109 |
+
model_name="llama3-8b-8192", # Free tier model
|
| 110 |
+
base_url="https://api.groq.com/openai/v1",
|
| 111 |
+
max_tokens=1000,
|
| 112 |
+
temperature=0.1,
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
if not configs:
|
| 117 |
raise ValueError(
|
| 118 |
"No LLM API keys found in environment. "
|
| 119 |
"Please set OPENROUTER_API_KEY or GROQ_API_KEY"
|
| 120 |
)
|
| 121 |
+
|
| 122 |
return cls(configs)
|
| 123 |
|
| 124 |
+
def generate_response(self, prompt: str, max_retries: int = 2) -> LLMResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
"""
|
| 126 |
Generate response from LLM with fallback support.
|
| 127 |
+
|
| 128 |
Args:
|
| 129 |
prompt: Input prompt for the LLM
|
| 130 |
max_retries: Maximum retry attempts per provider
|
| 131 |
+
|
| 132 |
Returns:
|
| 133 |
LLMResponse with generated content or error information
|
| 134 |
"""
|
| 135 |
last_error = None
|
| 136 |
+
|
| 137 |
# Try each provider configuration
|
| 138 |
for attempt in range(len(self.configs)):
|
| 139 |
config = self.configs[self.current_config_index]
|
| 140 |
+
|
| 141 |
try:
|
| 142 |
logger.debug(f"Attempting generation with {config.provider}")
|
| 143 |
response = self._call_provider(config, prompt, max_retries)
|
| 144 |
+
|
| 145 |
if response.success:
|
| 146 |
+
logger.info(
|
| 147 |
+
f"Successfully generated response using {config.provider}"
|
| 148 |
+
)
|
| 149 |
return response
|
| 150 |
+
|
| 151 |
last_error = response.error_message
|
| 152 |
logger.warning(f"Provider {config.provider} failed: {last_error}")
|
| 153 |
+
|
| 154 |
except Exception as e:
|
| 155 |
last_error = str(e)
|
| 156 |
logger.error(f"Error with provider {config.provider}: {last_error}")
|
| 157 |
+
|
| 158 |
# Move to next provider
|
| 159 |
+
self.current_config_index = (self.current_config_index + 1) % len(
|
| 160 |
+
self.configs
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
# All providers failed
|
| 164 |
logger.error("All LLM providers failed")
|
| 165 |
return LLMResponse(
|
|
|
|
| 169 |
usage={},
|
| 170 |
response_time=0.0,
|
| 171 |
success=False,
|
| 172 |
+
error_message=f"All providers failed. Last error: {last_error}",
|
| 173 |
)
|
| 174 |
|
| 175 |
def _call_provider(
|
| 176 |
+
self, config: LLMConfig, prompt: str, max_retries: int
|
|
|
|
|
|
|
|
|
|
| 177 |
) -> LLMResponse:
|
| 178 |
"""
|
| 179 |
Make API call to specific provider with retry logic.
|
| 180 |
+
|
| 181 |
Args:
|
| 182 |
config: Provider configuration
|
| 183 |
prompt: Input prompt
|
| 184 |
max_retries: Maximum retry attempts
|
| 185 |
+
|
| 186 |
Returns:
|
| 187 |
LLMResponse from the provider
|
| 188 |
"""
|
| 189 |
start_time = time.time()
|
| 190 |
+
|
| 191 |
for attempt in range(max_retries + 1):
|
| 192 |
try:
|
| 193 |
headers = {
|
| 194 |
"Authorization": f"Bearer {config.api_key}",
|
| 195 |
+
"Content-Type": "application/json",
|
| 196 |
}
|
| 197 |
+
|
| 198 |
# Add provider-specific headers
|
| 199 |
if config.provider == "openrouter":
|
| 200 |
+
headers["HTTP-Referer"] = (
|
| 201 |
+
"https://github.com/sethmcknight/msse-ai-engineering"
|
| 202 |
+
)
|
| 203 |
headers["X-Title"] = "MSSE RAG Application"
|
| 204 |
+
|
| 205 |
payload = {
|
| 206 |
"model": config.model_name,
|
| 207 |
+
"messages": [{"role": "user", "content": prompt}],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
"max_tokens": config.max_tokens,
|
| 209 |
+
"temperature": config.temperature,
|
| 210 |
}
|
| 211 |
+
|
| 212 |
response = requests.post(
|
| 213 |
f"{config.base_url}/chat/completions",
|
| 214 |
headers=headers,
|
| 215 |
json=payload,
|
| 216 |
+
timeout=config.timeout,
|
| 217 |
)
|
| 218 |
+
|
| 219 |
response.raise_for_status()
|
| 220 |
data = response.json()
|
| 221 |
+
|
| 222 |
# Extract response content
|
| 223 |
content = data["choices"][0]["message"]["content"]
|
| 224 |
usage = data.get("usage", {})
|
| 225 |
+
|
| 226 |
response_time = time.time() - start_time
|
| 227 |
+
|
| 228 |
return LLMResponse(
|
| 229 |
content=content,
|
| 230 |
provider=config.provider,
|
| 231 |
model=config.model_name,
|
| 232 |
usage=usage,
|
| 233 |
response_time=response_time,
|
| 234 |
+
success=True,
|
| 235 |
)
|
| 236 |
+
|
| 237 |
except requests.exceptions.RequestException as e:
|
| 238 |
+
logger.warning(
|
| 239 |
+
f"Request failed for {config.provider} (attempt {attempt + 1}): {e}"
|
| 240 |
+
)
|
| 241 |
if attempt < max_retries:
|
| 242 |
+
time.sleep(2**attempt) # Exponential backoff
|
| 243 |
continue
|
| 244 |
+
|
| 245 |
return LLMResponse(
|
| 246 |
content="",
|
| 247 |
provider=config.provider,
|
|
|
|
| 249 |
usage={},
|
| 250 |
response_time=time.time() - start_time,
|
| 251 |
success=False,
|
| 252 |
+
error_message=str(e),
|
| 253 |
)
|
| 254 |
+
|
| 255 |
except Exception as e:
|
| 256 |
logger.error(f"Unexpected error with {config.provider}: {e}")
|
| 257 |
return LLMResponse(
|
|
|
|
| 261 |
usage={},
|
| 262 |
response_time=time.time() - start_time,
|
| 263 |
success=False,
|
| 264 |
+
error_message=str(e),
|
| 265 |
)
|
| 266 |
|
| 267 |
def health_check(self) -> Dict[str, Any]:
|
| 268 |
"""
|
| 269 |
Check health status of all configured providers.
|
| 270 |
+
|
| 271 |
Returns:
|
| 272 |
Dictionary with provider health status
|
| 273 |
"""
|
| 274 |
health_status = {}
|
| 275 |
+
|
| 276 |
for config in self.configs:
|
| 277 |
try:
|
| 278 |
# Simple test prompt
|
| 279 |
test_response = self._call_provider(
|
| 280 |
+
config,
|
| 281 |
+
"Hello, this is a test. Please respond with 'OK'.",
|
| 282 |
+
max_retries=1,
|
| 283 |
)
|
| 284 |
+
|
| 285 |
health_status[config.provider] = {
|
| 286 |
"status": "healthy" if test_response.success else "unhealthy",
|
| 287 |
"model": config.model_name,
|
| 288 |
"response_time": test_response.response_time,
|
| 289 |
+
"error": test_response.error_message,
|
| 290 |
}
|
| 291 |
+
|
| 292 |
except Exception as e:
|
| 293 |
health_status[config.provider] = {
|
| 294 |
"status": "unhealthy",
|
| 295 |
"model": config.model_name,
|
| 296 |
"response_time": 0.0,
|
| 297 |
+
"error": str(e),
|
| 298 |
}
|
| 299 |
+
|
| 300 |
return health_status
|
| 301 |
|
| 302 |
def get_available_providers(self) -> List[str]:
|
| 303 |
"""Get list of available provider names."""
|
| 304 |
+
return [config.provider for config in self.configs]
|
src/llm/prompt_templates.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
Prompt Templates for Corporate Policy Q&A
|
| 3 |
|
| 4 |
-
This module contains predefined prompt templates optimized for
|
| 5 |
corporate policy question-answering with proper citation requirements.
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
from typing import Dict, List
|
| 9 |
from dataclasses import dataclass
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class PromptTemplate:
|
| 14 |
"""Template for generating prompts with context and citations."""
|
|
|
|
| 15 |
system_prompt: str
|
| 16 |
user_template: str
|
| 17 |
citation_format: str
|
|
@@ -20,7 +21,7 @@ class PromptTemplate:
|
|
| 20 |
class PromptTemplates:
|
| 21 |
"""
|
| 22 |
Collection of prompt templates for different types of policy questions.
|
| 23 |
-
|
| 24 |
Templates are designed to ensure:
|
| 25 |
- Accurate responses based on provided context
|
| 26 |
- Proper citation of source documents
|
|
@@ -29,15 +30,15 @@ class PromptTemplates:
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
# System prompt for corporate policy assistant
|
| 32 |
-
SYSTEM_PROMPT = """You are a helpful corporate policy assistant. Your job is to answer questions about company policies based ONLY on the provided context documents.
|
| 33 |
|
| 34 |
IMPORTANT GUIDELINES:
|
| 35 |
1. Answer questions using ONLY the information provided in the context
|
| 36 |
-
2. If the context doesn't contain enough information to answer the question, say so explicitly
|
| 37 |
3. Always cite your sources using the format: [Source: filename.md]
|
| 38 |
4. Be accurate, concise, and professional
|
| 39 |
-
5. If asked about topics not covered in the policies, politely redirect to HR or appropriate department
|
| 40 |
-
6. Do not make assumptions or provide information not explicitly stated in the context
|
| 41 |
|
| 42 |
Your responses should be helpful while staying strictly within the scope of the provided corporate policies."""
|
| 43 |
|
|
@@ -45,26 +46,26 @@ Your responses should be helpful while staying strictly within the scope of the
|
|
| 45 |
def get_policy_qa_template(cls) -> PromptTemplate:
|
| 46 |
"""
|
| 47 |
Get the standard template for policy question-answering.
|
| 48 |
-
|
| 49 |
Returns:
|
| 50 |
PromptTemplate configured for corporate policy Q&A
|
| 51 |
"""
|
| 52 |
return PromptTemplate(
|
| 53 |
system_prompt=cls.SYSTEM_PROMPT,
|
| 54 |
-
user_template="""Based on the following corporate policy documents, please answer this question: {question}
|
| 55 |
|
| 56 |
CONTEXT DOCUMENTS:
|
| 57 |
{context}
|
| 58 |
|
| 59 |
-
Please provide a clear, accurate answer based on the information above. Include citations for all information using the format [Source: filename.md].""",
|
| 60 |
-
citation_format="[Source: {filename}]"
|
| 61 |
)
|
| 62 |
|
| 63 |
@classmethod
|
| 64 |
def get_clarification_template(cls) -> PromptTemplate:
|
| 65 |
"""
|
| 66 |
Get template for when clarification is needed.
|
| 67 |
-
|
| 68 |
Returns:
|
| 69 |
PromptTemplate for clarification requests
|
| 70 |
"""
|
|
@@ -75,19 +76,19 @@ Please provide a clear, accurate answer based on the information above. Include
|
|
| 75 |
CONTEXT DOCUMENTS:
|
| 76 |
{context}
|
| 77 |
|
| 78 |
-
The provided context documents don't contain sufficient information to fully answer this question. Please provide a helpful response that:
|
| 79 |
1. Acknowledges what information is available (if any)
|
| 80 |
2. Clearly states what information is missing
|
| 81 |
3. Suggests appropriate next steps (contact HR, check other resources, etc.)
|
| 82 |
4. Cites any relevant sources using [Source: filename.md] format""",
|
| 83 |
-
citation_format="[Source: {filename}]"
|
| 84 |
)
|
| 85 |
|
| 86 |
@classmethod
|
| 87 |
def get_off_topic_template(cls) -> PromptTemplate:
|
| 88 |
"""
|
| 89 |
Get template for off-topic questions.
|
| 90 |
-
|
| 91 |
Returns:
|
| 92 |
PromptTemplate for redirecting off-topic questions
|
| 93 |
"""
|
|
@@ -95,122 +96,122 @@ The provided context documents don't contain sufficient information to fully ans
|
|
| 95 |
system_prompt=cls.SYSTEM_PROMPT,
|
| 96 |
user_template="""The user asked: {question}
|
| 97 |
|
| 98 |
-
This question appears to be outside the scope of our corporate policies. Please provide a polite response that:
|
| 99 |
1. Acknowledges the question
|
| 100 |
2. Explains that this falls outside corporate policy documentation
|
| 101 |
3. Suggests appropriate resources (HR, IT, management, etc.)
|
| 102 |
4. Offers to help with any policy-related questions instead""",
|
| 103 |
-
citation_format=""
|
| 104 |
)
|
| 105 |
|
| 106 |
@staticmethod
|
| 107 |
def format_context(search_results: List[Dict]) -> str:
|
| 108 |
"""
|
| 109 |
Format search results into context for the prompt.
|
| 110 |
-
|
| 111 |
Args:
|
| 112 |
search_results: List of search results from SearchService
|
| 113 |
-
|
| 114 |
Returns:
|
| 115 |
Formatted context string for the prompt
|
| 116 |
"""
|
| 117 |
if not search_results:
|
| 118 |
return "No relevant policy documents found."
|
| 119 |
-
|
| 120 |
context_parts = []
|
| 121 |
for i, result in enumerate(search_results[:5], 1): # Limit to top 5 results
|
| 122 |
filename = result.get("metadata", {}).get("filename", "unknown")
|
| 123 |
content = result.get("content", "").strip()
|
| 124 |
similarity = result.get("similarity_score", 0.0)
|
| 125 |
-
|
| 126 |
context_parts.append(
|
| 127 |
f"Document {i}: {filename} (relevance: {similarity:.2f})\n"
|
| 128 |
f"Content: {content}\n"
|
| 129 |
)
|
| 130 |
-
|
| 131 |
return "\n---\n".join(context_parts)
|
| 132 |
|
| 133 |
@staticmethod
|
| 134 |
def extract_citations(response: str) -> List[str]:
|
| 135 |
"""
|
| 136 |
Extract citations from LLM response.
|
| 137 |
-
|
| 138 |
Args:
|
| 139 |
response: Generated response text
|
| 140 |
-
|
| 141 |
Returns:
|
| 142 |
List of extracted filenames from citations
|
| 143 |
"""
|
| 144 |
import re
|
| 145 |
-
|
| 146 |
# Pattern to match [Source: filename.md] format
|
| 147 |
-
citation_pattern = r
|
| 148 |
matches = re.findall(citation_pattern, response)
|
| 149 |
-
|
| 150 |
# Clean up filenames
|
| 151 |
citations = []
|
| 152 |
for match in matches:
|
| 153 |
filename = match.strip()
|
| 154 |
if filename and filename not in citations:
|
| 155 |
citations.append(filename)
|
| 156 |
-
|
| 157 |
return citations
|
| 158 |
|
| 159 |
@staticmethod
|
| 160 |
-
def validate_citations(
|
|
|
|
|
|
|
| 161 |
"""
|
| 162 |
Validate that all citations in response refer to available sources.
|
| 163 |
-
|
| 164 |
Args:
|
| 165 |
response: Generated response text
|
| 166 |
available_sources: List of available source filenames
|
| 167 |
-
|
| 168 |
Returns:
|
| 169 |
Dictionary mapping citations to their validity
|
| 170 |
"""
|
| 171 |
citations = PromptTemplates.extract_citations(response)
|
| 172 |
validation = {}
|
| 173 |
-
|
| 174 |
for citation in citations:
|
| 175 |
# Check if citation matches any available source
|
| 176 |
-
valid = any(
|
| 177 |
-
|
|
|
|
| 178 |
validation[citation] = valid
|
| 179 |
-
|
| 180 |
return validation
|
| 181 |
|
| 182 |
@staticmethod
|
| 183 |
-
def add_fallback_citations(
|
| 184 |
-
response: str,
|
| 185 |
-
search_results: List[Dict]
|
| 186 |
-
) -> str:
|
| 187 |
"""
|
| 188 |
Add citations to response if none were provided by LLM.
|
| 189 |
-
|
| 190 |
Args:
|
| 191 |
response: Generated response text
|
| 192 |
search_results: Original search results used for context
|
| 193 |
-
|
| 194 |
Returns:
|
| 195 |
Response with added citations if needed
|
| 196 |
"""
|
| 197 |
existing_citations = PromptTemplates.extract_citations(response)
|
| 198 |
-
|
| 199 |
if existing_citations:
|
| 200 |
return response # Already has citations
|
| 201 |
-
|
| 202 |
if not search_results:
|
| 203 |
return response # No sources to cite
|
| 204 |
-
|
| 205 |
# Add citations from top search results
|
| 206 |
top_sources = []
|
| 207 |
for result in search_results[:3]: # Top 3 sources
|
| 208 |
filename = result.get("metadata", {}).get("filename", "")
|
| 209 |
if filename and filename not in top_sources:
|
| 210 |
top_sources.append(filename)
|
| 211 |
-
|
| 212 |
if top_sources:
|
| 213 |
citation_text = " [Sources: " + ", ".join(top_sources) + "]"
|
| 214 |
return response + citation_text
|
| 215 |
-
|
| 216 |
-
return response
|
|
|
|
| 1 |
"""
|
| 2 |
Prompt Templates for Corporate Policy Q&A
|
| 3 |
|
| 4 |
+
This module contains predefined prompt templates optimized for
|
| 5 |
corporate policy question-answering with proper citation requirements.
|
| 6 |
"""
|
| 7 |
|
|
|
|
| 8 |
from dataclasses import dataclass
|
| 9 |
+
from typing import Dict, List
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class PromptTemplate:
|
| 14 |
"""Template for generating prompts with context and citations."""
|
| 15 |
+
|
| 16 |
system_prompt: str
|
| 17 |
user_template: str
|
| 18 |
citation_format: str
|
|
|
|
| 21 |
class PromptTemplates:
|
| 22 |
"""
|
| 23 |
Collection of prompt templates for different types of policy questions.
|
| 24 |
+
|
| 25 |
Templates are designed to ensure:
|
| 26 |
- Accurate responses based on provided context
|
| 27 |
- Proper citation of source documents
|
|
|
|
| 30 |
"""
|
| 31 |
|
| 32 |
# System prompt for corporate policy assistant
|
| 33 |
+
SYSTEM_PROMPT = """You are a helpful corporate policy assistant. Your job is to answer questions about company policies based ONLY on the provided context documents. # noqa: E501
|
| 34 |
|
| 35 |
IMPORTANT GUIDELINES:
|
| 36 |
1. Answer questions using ONLY the information provided in the context
|
| 37 |
+
2. If the context doesn't contain enough information to answer the question, say so explicitly # noqa: E501
|
| 38 |
3. Always cite your sources using the format: [Source: filename.md]
|
| 39 |
4. Be accurate, concise, and professional
|
| 40 |
+
5. If asked about topics not covered in the policies, politely redirect to HR or appropriate department # noqa: E501
|
| 41 |
+
6. Do not make assumptions or provide information not explicitly stated in the context # noqa: E501
|
| 42 |
|
| 43 |
Your responses should be helpful while staying strictly within the scope of the provided corporate policies."""
|
| 44 |
|
|
|
|
| 46 |
def get_policy_qa_template(cls) -> PromptTemplate:
|
| 47 |
"""
|
| 48 |
Get the standard template for policy question-answering.
|
| 49 |
+
|
| 50 |
Returns:
|
| 51 |
PromptTemplate configured for corporate policy Q&A
|
| 52 |
"""
|
| 53 |
return PromptTemplate(
|
| 54 |
system_prompt=cls.SYSTEM_PROMPT,
|
| 55 |
+
user_template="""Based on the following corporate policy documents, please answer this question: {question} # noqa: E501
|
| 56 |
|
| 57 |
CONTEXT DOCUMENTS:
|
| 58 |
{context}
|
| 59 |
|
| 60 |
+
Please provide a clear, accurate answer based on the information above. Include citations for all information using the format [Source: filename.md].""", # noqa: E501
|
| 61 |
+
citation_format="[Source: {filename}]",
|
| 62 |
)
|
| 63 |
|
| 64 |
@classmethod
|
| 65 |
def get_clarification_template(cls) -> PromptTemplate:
|
| 66 |
"""
|
| 67 |
Get template for when clarification is needed.
|
| 68 |
+
|
| 69 |
Returns:
|
| 70 |
PromptTemplate for clarification requests
|
| 71 |
"""
|
|
|
|
| 76 |
CONTEXT DOCUMENTS:
|
| 77 |
{context}
|
| 78 |
|
| 79 |
+
The provided context documents don't contain sufficient information to fully answer this question. Please provide a helpful response that: # noqa: E501
|
| 80 |
1. Acknowledges what information is available (if any)
|
| 81 |
2. Clearly states what information is missing
|
| 82 |
3. Suggests appropriate next steps (contact HR, check other resources, etc.)
|
| 83 |
4. Cites any relevant sources using [Source: filename.md] format""",
|
| 84 |
+
citation_format="[Source: {filename}]",
|
| 85 |
)
|
| 86 |
|
| 87 |
@classmethod
|
| 88 |
def get_off_topic_template(cls) -> PromptTemplate:
|
| 89 |
"""
|
| 90 |
Get template for off-topic questions.
|
| 91 |
+
|
| 92 |
Returns:
|
| 93 |
PromptTemplate for redirecting off-topic questions
|
| 94 |
"""
|
|
|
|
| 96 |
system_prompt=cls.SYSTEM_PROMPT,
|
| 97 |
user_template="""The user asked: {question}
|
| 98 |
|
| 99 |
+
This question appears to be outside the scope of our corporate policies. Please provide a polite response that: # noqa: E501
|
| 100 |
1. Acknowledges the question
|
| 101 |
2. Explains that this falls outside corporate policy documentation
|
| 102 |
3. Suggests appropriate resources (HR, IT, management, etc.)
|
| 103 |
4. Offers to help with any policy-related questions instead""",
|
| 104 |
+
citation_format="",
|
| 105 |
)
|
| 106 |
|
| 107 |
@staticmethod
|
| 108 |
def format_context(search_results: List[Dict]) -> str:
|
| 109 |
"""
|
| 110 |
Format search results into context for the prompt.
|
| 111 |
+
|
| 112 |
Args:
|
| 113 |
search_results: List of search results from SearchService
|
| 114 |
+
|
| 115 |
Returns:
|
| 116 |
Formatted context string for the prompt
|
| 117 |
"""
|
| 118 |
if not search_results:
|
| 119 |
return "No relevant policy documents found."
|
| 120 |
+
|
| 121 |
context_parts = []
|
| 122 |
for i, result in enumerate(search_results[:5], 1): # Limit to top 5 results
|
| 123 |
filename = result.get("metadata", {}).get("filename", "unknown")
|
| 124 |
content = result.get("content", "").strip()
|
| 125 |
similarity = result.get("similarity_score", 0.0)
|
| 126 |
+
|
| 127 |
context_parts.append(
|
| 128 |
f"Document {i}: {filename} (relevance: {similarity:.2f})\n"
|
| 129 |
f"Content: {content}\n"
|
| 130 |
)
|
| 131 |
+
|
| 132 |
return "\n---\n".join(context_parts)
|
| 133 |
|
| 134 |
@staticmethod
|
| 135 |
def extract_citations(response: str) -> List[str]:
|
| 136 |
"""
|
| 137 |
Extract citations from LLM response.
|
| 138 |
+
|
| 139 |
Args:
|
| 140 |
response: Generated response text
|
| 141 |
+
|
| 142 |
Returns:
|
| 143 |
List of extracted filenames from citations
|
| 144 |
"""
|
| 145 |
import re
|
| 146 |
+
|
| 147 |
# Pattern to match [Source: filename.md] format
|
| 148 |
+
citation_pattern = r"\[Source:\s*([^\]]+)\]"
|
| 149 |
matches = re.findall(citation_pattern, response)
|
| 150 |
+
|
| 151 |
# Clean up filenames
|
| 152 |
citations = []
|
| 153 |
for match in matches:
|
| 154 |
filename = match.strip()
|
| 155 |
if filename and filename not in citations:
|
| 156 |
citations.append(filename)
|
| 157 |
+
|
| 158 |
return citations
|
| 159 |
|
| 160 |
@staticmethod
|
| 161 |
+
def validate_citations(
|
| 162 |
+
response: str, available_sources: List[str]
|
| 163 |
+
) -> Dict[str, bool]:
|
| 164 |
"""
|
| 165 |
Validate that all citations in response refer to available sources.
|
| 166 |
+
|
| 167 |
Args:
|
| 168 |
response: Generated response text
|
| 169 |
available_sources: List of available source filenames
|
| 170 |
+
|
| 171 |
Returns:
|
| 172 |
Dictionary mapping citations to their validity
|
| 173 |
"""
|
| 174 |
citations = PromptTemplates.extract_citations(response)
|
| 175 |
validation = {}
|
| 176 |
+
|
| 177 |
for citation in citations:
|
| 178 |
# Check if citation matches any available source
|
| 179 |
+
valid = any(
|
| 180 |
+
citation in source or source in citation for source in available_sources
|
| 181 |
+
)
|
| 182 |
validation[citation] = valid
|
| 183 |
+
|
| 184 |
return validation
|
| 185 |
|
| 186 |
@staticmethod
|
| 187 |
+
def add_fallback_citations(response: str, search_results: List[Dict]) -> str:
|
|
|
|
|
|
|
|
|
|
| 188 |
"""
|
| 189 |
Add citations to response if none were provided by LLM.
|
| 190 |
+
|
| 191 |
Args:
|
| 192 |
response: Generated response text
|
| 193 |
search_results: Original search results used for context
|
| 194 |
+
|
| 195 |
Returns:
|
| 196 |
Response with added citations if needed
|
| 197 |
"""
|
| 198 |
existing_citations = PromptTemplates.extract_citations(response)
|
| 199 |
+
|
| 200 |
if existing_citations:
|
| 201 |
return response # Already has citations
|
| 202 |
+
|
| 203 |
if not search_results:
|
| 204 |
return response # No sources to cite
|
| 205 |
+
|
| 206 |
# Add citations from top search results
|
| 207 |
top_sources = []
|
| 208 |
for result in search_results[:3]: # Top 3 sources
|
| 209 |
filename = result.get("metadata", {}).get("filename", "")
|
| 210 |
if filename and filename not in top_sources:
|
| 211 |
top_sources.append(filename)
|
| 212 |
+
|
| 213 |
if top_sources:
|
| 214 |
citation_text = " [Sources: " + ", ".join(top_sources) + "]"
|
| 215 |
return response + citation_text
|
| 216 |
+
|
| 217 |
+
return response
|
src/rag/__init__.py
CHANGED
|
@@ -7,4 +7,4 @@ combining semantic search with LLM-based response generation.
|
|
| 7 |
Classes:
|
| 8 |
RAGPipeline: Main RAG orchestration service
|
| 9 |
ResponseFormatter: Formats LLM responses with citations and metadata
|
| 10 |
-
"""
|
|
|
|
| 7 |
Classes:
|
| 8 |
RAGPipeline: Main RAG orchestration service
|
| 9 |
ResponseFormatter: Formats LLM responses with citations and metadata
|
| 10 |
+
"""
|
src/rag/rag_pipeline.py
CHANGED
|
@@ -7,14 +7,15 @@ combining semantic search, context management, and LLM generation.
|
|
| 7 |
|
| 8 |
import logging
|
| 9 |
import time
|
| 10 |
-
from typing import Any, Dict, List, Optional
|
| 11 |
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Import our modules
|
| 14 |
from src.search.search_service import SearchService
|
| 15 |
-
from src.llm.llm_service import LLMService, LLMResponse
|
| 16 |
-
from src.llm.context_manager import ContextManager, ContextConfig
|
| 17 |
-
from src.llm.prompt_templates import PromptTemplates, PromptTemplate
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
@@ -22,6 +23,7 @@ logger = logging.getLogger(__name__)
|
|
| 22 |
@dataclass
|
| 23 |
class RAGConfig:
|
| 24 |
"""Configuration for RAG pipeline."""
|
|
|
|
| 25 |
max_context_length: int = 3000
|
| 26 |
search_top_k: int = 10
|
| 27 |
search_threshold: float = 0.1
|
|
@@ -33,6 +35,7 @@ class RAGConfig:
|
|
| 33 |
@dataclass
|
| 34 |
class RAGResponse:
|
| 35 |
"""Response from RAG pipeline with metadata."""
|
|
|
|
| 36 |
answer: str
|
| 37 |
sources: List[Dict[str, Any]]
|
| 38 |
confidence: float
|
|
@@ -48,7 +51,7 @@ class RAGResponse:
|
|
| 48 |
class RAGPipeline:
|
| 49 |
"""
|
| 50 |
Complete RAG pipeline orchestrating retrieval and generation.
|
| 51 |
-
|
| 52 |
Combines:
|
| 53 |
- Semantic search for context retrieval
|
| 54 |
- Context optimization and management
|
|
@@ -60,84 +63,84 @@ class RAGPipeline:
|
|
| 60 |
self,
|
| 61 |
search_service: SearchService,
|
| 62 |
llm_service: LLMService,
|
| 63 |
-
config: Optional[RAGConfig] = None
|
| 64 |
):
|
| 65 |
"""
|
| 66 |
Initialize RAG pipeline with required services.
|
| 67 |
-
|
| 68 |
Args:
|
| 69 |
search_service: Configured SearchService instance
|
| 70 |
-
llm_service: Configured LLMService instance
|
| 71 |
config: RAG configuration, uses defaults if None
|
| 72 |
"""
|
| 73 |
self.search_service = search_service
|
| 74 |
self.llm_service = llm_service
|
| 75 |
self.config = config or RAGConfig()
|
| 76 |
-
|
| 77 |
# Initialize context manager with matching config
|
| 78 |
context_config = ContextConfig(
|
| 79 |
max_context_length=self.config.max_context_length,
|
| 80 |
max_results=self.config.search_top_k,
|
| 81 |
-
min_similarity=self.config.search_threshold
|
| 82 |
)
|
| 83 |
self.context_manager = ContextManager(context_config)
|
| 84 |
-
|
| 85 |
# Initialize prompt templates
|
| 86 |
self.prompt_templates = PromptTemplates()
|
| 87 |
-
|
| 88 |
logger.info("RAGPipeline initialized successfully")
|
| 89 |
|
| 90 |
def generate_answer(self, question: str) -> RAGResponse:
|
| 91 |
"""
|
| 92 |
Generate answer to question using RAG pipeline.
|
| 93 |
-
|
| 94 |
Args:
|
| 95 |
question: User's question about corporate policies
|
| 96 |
-
|
| 97 |
Returns:
|
| 98 |
RAGResponse with answer and metadata
|
| 99 |
"""
|
| 100 |
start_time = time.time()
|
| 101 |
-
|
| 102 |
try:
|
| 103 |
# Step 1: Retrieve relevant context
|
| 104 |
logger.debug(f"Starting RAG pipeline for question: {question[:100]}...")
|
| 105 |
-
|
| 106 |
search_results = self._retrieve_context(question)
|
| 107 |
-
|
| 108 |
if not search_results:
|
| 109 |
return self._create_no_context_response(question, start_time)
|
| 110 |
-
|
| 111 |
# Step 2: Prepare and optimize context
|
| 112 |
context, filtered_results = self.context_manager.prepare_context(
|
| 113 |
search_results, question
|
| 114 |
)
|
| 115 |
-
|
| 116 |
# Step 3: Check if we have sufficient context
|
| 117 |
quality_metrics = self.context_manager.validate_context_quality(
|
| 118 |
context, question, self.config.min_similarity_for_answer
|
| 119 |
)
|
| 120 |
-
|
| 121 |
if not quality_metrics["passes_validation"]:
|
| 122 |
return self._create_insufficient_context_response(
|
| 123 |
question, filtered_results, start_time
|
| 124 |
)
|
| 125 |
-
|
| 126 |
# Step 4: Generate response using LLM
|
| 127 |
llm_response = self._generate_llm_response(question, context)
|
| 128 |
-
|
| 129 |
if not llm_response.success:
|
| 130 |
return self._create_llm_error_response(
|
| 131 |
question, llm_response.error_message, start_time
|
| 132 |
)
|
| 133 |
-
|
| 134 |
# Step 5: Process and validate response
|
| 135 |
processed_response = self._process_response(
|
| 136 |
llm_response.content, filtered_results
|
| 137 |
)
|
| 138 |
-
|
| 139 |
processing_time = time.time() - start_time
|
| 140 |
-
|
| 141 |
return RAGResponse(
|
| 142 |
answer=processed_response,
|
| 143 |
sources=self._format_sources(filtered_results),
|
|
@@ -147,13 +150,16 @@ class RAGPipeline:
|
|
| 147 |
llm_model=llm_response.model,
|
| 148 |
context_length=len(context),
|
| 149 |
search_results_count=len(search_results),
|
| 150 |
-
success=True
|
| 151 |
)
|
| 152 |
-
|
| 153 |
except Exception as e:
|
| 154 |
logger.error(f"RAG pipeline error: {e}")
|
| 155 |
return RAGResponse(
|
| 156 |
-
answer=
|
|
|
|
|
|
|
|
|
|
| 157 |
sources=[],
|
| 158 |
confidence=0.0,
|
| 159 |
processing_time=time.time() - start_time,
|
|
@@ -162,7 +168,7 @@ class RAGPipeline:
|
|
| 162 |
context_length=0,
|
| 163 |
search_results_count=0,
|
| 164 |
success=False,
|
| 165 |
-
error_message=str(e)
|
| 166 |
)
|
| 167 |
|
| 168 |
def _retrieve_context(self, question: str) -> List[Dict[str, Any]]:
|
|
@@ -171,12 +177,12 @@ class RAGPipeline:
|
|
| 171 |
results = self.search_service.search(
|
| 172 |
query=question,
|
| 173 |
top_k=self.config.search_top_k,
|
| 174 |
-
threshold=self.config.search_threshold
|
| 175 |
)
|
| 176 |
-
|
| 177 |
logger.debug(f"Retrieved {len(results)} search results")
|
| 178 |
return results
|
| 179 |
-
|
| 180 |
except Exception as e:
|
| 181 |
logger.error(f"Context retrieval error: {e}")
|
| 182 |
return []
|
|
@@ -184,95 +190,108 @@ class RAGPipeline:
|
|
| 184 |
def _generate_llm_response(self, question: str, context: str) -> LLMResponse:
|
| 185 |
"""Generate response using LLM with formatted prompt."""
|
| 186 |
template = self.prompt_templates.get_policy_qa_template()
|
| 187 |
-
|
| 188 |
# Format the prompt
|
| 189 |
formatted_prompt = template.user_template.format(
|
| 190 |
-
question=question,
|
| 191 |
-
context=context
|
| 192 |
)
|
| 193 |
-
|
| 194 |
# Add system prompt (if LLM service supports it in future)
|
| 195 |
full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}"
|
| 196 |
-
|
| 197 |
return self.llm_service.generate_response(full_prompt)
|
| 198 |
|
| 199 |
def _process_response(
|
| 200 |
-
self,
|
| 201 |
-
raw_response: str,
|
| 202 |
-
search_results: List[Dict[str, Any]]
|
| 203 |
) -> str:
|
| 204 |
"""Process and validate LLM response."""
|
| 205 |
-
|
| 206 |
# Ensure citations are present
|
| 207 |
response_with_citations = self.prompt_templates.add_fallback_citations(
|
| 208 |
raw_response, search_results
|
| 209 |
)
|
| 210 |
-
|
| 211 |
# Validate citations if enabled
|
| 212 |
if self.config.enable_citation_validation:
|
| 213 |
available_sources = [
|
| 214 |
result.get("metadata", {}).get("filename", "")
|
| 215 |
for result in search_results
|
| 216 |
]
|
| 217 |
-
|
| 218 |
citation_validation = self.prompt_templates.validate_citations(
|
| 219 |
response_with_citations, available_sources
|
| 220 |
)
|
| 221 |
-
|
| 222 |
# Log any invalid citations
|
| 223 |
invalid_citations = [
|
| 224 |
-
citation for citation, valid in citation_validation.items()
|
| 225 |
-
if not valid
|
| 226 |
]
|
| 227 |
-
|
| 228 |
if invalid_citations:
|
| 229 |
logger.warning(f"Invalid citations detected: {invalid_citations}")
|
| 230 |
-
|
| 231 |
# Truncate if too long
|
| 232 |
if len(response_with_citations) > self.config.max_response_length:
|
| 233 |
-
truncated =
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
return truncated
|
| 236 |
-
|
| 237 |
return response_with_citations
|
| 238 |
|
| 239 |
-
def _format_sources(
|
|
|
|
|
|
|
| 240 |
"""Format search results for response metadata."""
|
| 241 |
sources = []
|
| 242 |
-
|
| 243 |
for result in search_results:
|
| 244 |
metadata = result.get("metadata", {})
|
| 245 |
-
sources.append(
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
return sources
|
| 253 |
|
| 254 |
def _calculate_confidence(
|
| 255 |
-
self,
|
| 256 |
-
quality_metrics: Dict[str, Any],
|
| 257 |
-
llm_response: LLMResponse
|
| 258 |
) -> float:
|
| 259 |
"""Calculate confidence score for the response."""
|
| 260 |
-
|
| 261 |
# Base confidence on context quality
|
| 262 |
context_confidence = quality_metrics.get("estimated_relevance", 0.0)
|
| 263 |
-
|
| 264 |
# Adjust based on LLM response time (faster might indicate more confidence)
|
| 265 |
time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0))
|
| 266 |
-
|
| 267 |
# Combine factors
|
| 268 |
confidence = (context_confidence * 0.7) + (time_factor * 0.3)
|
| 269 |
-
|
| 270 |
return min(1.0, max(0.0, confidence))
|
| 271 |
|
| 272 |
-
def _create_no_context_response(
|
|
|
|
|
|
|
| 273 |
"""Create response when no relevant context found."""
|
| 274 |
return RAGResponse(
|
| 275 |
-
answer=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
sources=[],
|
| 277 |
confidence=0.0,
|
| 278 |
processing_time=time.time() - start_time,
|
|
@@ -280,18 +299,19 @@ class RAGPipeline:
|
|
| 280 |
llm_model="none",
|
| 281 |
context_length=0,
|
| 282 |
search_results_count=0,
|
| 283 |
-
success=True # This is a valid "no answer" response
|
| 284 |
)
|
| 285 |
|
| 286 |
def _create_insufficient_context_response(
|
| 287 |
-
self,
|
| 288 |
-
question: str,
|
| 289 |
-
results: List[Dict[str, Any]],
|
| 290 |
-
start_time: float
|
| 291 |
) -> RAGResponse:
|
| 292 |
"""Create response when context quality is insufficient."""
|
| 293 |
return RAGResponse(
|
| 294 |
-
answer=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
sources=self._format_sources(results),
|
| 296 |
confidence=0.2,
|
| 297 |
processing_time=time.time() - start_time,
|
|
@@ -299,18 +319,18 @@ class RAGPipeline:
|
|
| 299 |
llm_model="none",
|
| 300 |
context_length=0,
|
| 301 |
search_results_count=len(results),
|
| 302 |
-
success=True
|
| 303 |
)
|
| 304 |
|
| 305 |
def _create_llm_error_response(
|
| 306 |
-
self,
|
| 307 |
-
question: str,
|
| 308 |
-
error_message: str,
|
| 309 |
-
start_time: float
|
| 310 |
) -> RAGResponse:
|
| 311 |
"""Create response when LLM generation fails."""
|
| 312 |
return RAGResponse(
|
| 313 |
-
answer=
|
|
|
|
|
|
|
|
|
|
| 314 |
sources=[],
|
| 315 |
confidence=0.0,
|
| 316 |
processing_time=time.time() - start_time,
|
|
@@ -319,54 +339,54 @@ class RAGPipeline:
|
|
| 319 |
context_length=0,
|
| 320 |
search_results_count=0,
|
| 321 |
success=False,
|
| 322 |
-
error_message=error_message
|
| 323 |
)
|
| 324 |
|
| 325 |
def health_check(self) -> Dict[str, Any]:
|
| 326 |
"""
|
| 327 |
Perform health check on all pipeline components.
|
| 328 |
-
|
| 329 |
Returns:
|
| 330 |
Dictionary with component health status
|
| 331 |
"""
|
| 332 |
-
health_status = {
|
| 333 |
-
|
| 334 |
-
"components": {}
|
| 335 |
-
}
|
| 336 |
-
|
| 337 |
try:
|
| 338 |
# Check search service
|
| 339 |
-
test_results = self.search_service.search(
|
|
|
|
|
|
|
| 340 |
health_status["components"]["search_service"] = {
|
| 341 |
"status": "healthy",
|
| 342 |
-
"test_results_count": len(test_results)
|
| 343 |
}
|
| 344 |
except Exception as e:
|
| 345 |
health_status["components"]["search_service"] = {
|
| 346 |
"status": "unhealthy",
|
| 347 |
-
"error": str(e)
|
| 348 |
}
|
| 349 |
health_status["pipeline"] = "degraded"
|
| 350 |
-
|
| 351 |
try:
|
| 352 |
# Check LLM service
|
| 353 |
llm_health = self.llm_service.health_check()
|
| 354 |
health_status["components"]["llm_service"] = llm_health
|
| 355 |
-
|
| 356 |
# Pipeline is unhealthy if all LLM providers are down
|
| 357 |
healthy_providers = sum(
|
| 358 |
-
1
|
|
|
|
| 359 |
if provider_status.get("status") == "healthy"
|
| 360 |
)
|
| 361 |
-
|
| 362 |
if healthy_providers == 0:
|
| 363 |
health_status["pipeline"] = "unhealthy"
|
| 364 |
-
|
| 365 |
except Exception as e:
|
| 366 |
health_status["components"]["llm_service"] = {
|
| 367 |
-
"status": "unhealthy",
|
| 368 |
-
"error": str(e)
|
| 369 |
}
|
| 370 |
health_status["pipeline"] = "unhealthy"
|
| 371 |
-
|
| 372 |
-
return health_status
|
|
|
|
| 7 |
|
| 8 |
import logging
|
| 9 |
import time
|
|
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
+
from typing import Any, Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
from src.llm.context_manager import ContextConfig, ContextManager
|
| 14 |
+
from src.llm.llm_service import LLMResponse, LLMService
|
| 15 |
+
from src.llm.prompt_templates import PromptTemplates
|
| 16 |
|
| 17 |
# Import our modules
|
| 18 |
from src.search.search_service import SearchService
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
|
|
|
| 23 |
@dataclass
|
| 24 |
class RAGConfig:
|
| 25 |
"""Configuration for RAG pipeline."""
|
| 26 |
+
|
| 27 |
max_context_length: int = 3000
|
| 28 |
search_top_k: int = 10
|
| 29 |
search_threshold: float = 0.1
|
|
|
|
| 35 |
@dataclass
|
| 36 |
class RAGResponse:
|
| 37 |
"""Response from RAG pipeline with metadata."""
|
| 38 |
+
|
| 39 |
answer: str
|
| 40 |
sources: List[Dict[str, Any]]
|
| 41 |
confidence: float
|
|
|
|
| 51 |
class RAGPipeline:
|
| 52 |
"""
|
| 53 |
Complete RAG pipeline orchestrating retrieval and generation.
|
| 54 |
+
|
| 55 |
Combines:
|
| 56 |
- Semantic search for context retrieval
|
| 57 |
- Context optimization and management
|
|
|
|
| 63 |
self,
|
| 64 |
search_service: SearchService,
|
| 65 |
llm_service: LLMService,
|
| 66 |
+
config: Optional[RAGConfig] = None,
|
| 67 |
):
|
| 68 |
"""
|
| 69 |
Initialize RAG pipeline with required services.
|
| 70 |
+
|
| 71 |
Args:
|
| 72 |
search_service: Configured SearchService instance
|
| 73 |
+
llm_service: Configured LLMService instance
|
| 74 |
config: RAG configuration, uses defaults if None
|
| 75 |
"""
|
| 76 |
self.search_service = search_service
|
| 77 |
self.llm_service = llm_service
|
| 78 |
self.config = config or RAGConfig()
|
| 79 |
+
|
| 80 |
# Initialize context manager with matching config
|
| 81 |
context_config = ContextConfig(
|
| 82 |
max_context_length=self.config.max_context_length,
|
| 83 |
max_results=self.config.search_top_k,
|
| 84 |
+
min_similarity=self.config.search_threshold,
|
| 85 |
)
|
| 86 |
self.context_manager = ContextManager(context_config)
|
| 87 |
+
|
| 88 |
# Initialize prompt templates
|
| 89 |
self.prompt_templates = PromptTemplates()
|
| 90 |
+
|
| 91 |
logger.info("RAGPipeline initialized successfully")
|
| 92 |
|
| 93 |
def generate_answer(self, question: str) -> RAGResponse:
|
| 94 |
"""
|
| 95 |
Generate answer to question using RAG pipeline.
|
| 96 |
+
|
| 97 |
Args:
|
| 98 |
question: User's question about corporate policies
|
| 99 |
+
|
| 100 |
Returns:
|
| 101 |
RAGResponse with answer and metadata
|
| 102 |
"""
|
| 103 |
start_time = time.time()
|
| 104 |
+
|
| 105 |
try:
|
| 106 |
# Step 1: Retrieve relevant context
|
| 107 |
logger.debug(f"Starting RAG pipeline for question: {question[:100]}...")
|
| 108 |
+
|
| 109 |
search_results = self._retrieve_context(question)
|
| 110 |
+
|
| 111 |
if not search_results:
|
| 112 |
return self._create_no_context_response(question, start_time)
|
| 113 |
+
|
| 114 |
# Step 2: Prepare and optimize context
|
| 115 |
context, filtered_results = self.context_manager.prepare_context(
|
| 116 |
search_results, question
|
| 117 |
)
|
| 118 |
+
|
| 119 |
# Step 3: Check if we have sufficient context
|
| 120 |
quality_metrics = self.context_manager.validate_context_quality(
|
| 121 |
context, question, self.config.min_similarity_for_answer
|
| 122 |
)
|
| 123 |
+
|
| 124 |
if not quality_metrics["passes_validation"]:
|
| 125 |
return self._create_insufficient_context_response(
|
| 126 |
question, filtered_results, start_time
|
| 127 |
)
|
| 128 |
+
|
| 129 |
# Step 4: Generate response using LLM
|
| 130 |
llm_response = self._generate_llm_response(question, context)
|
| 131 |
+
|
| 132 |
if not llm_response.success:
|
| 133 |
return self._create_llm_error_response(
|
| 134 |
question, llm_response.error_message, start_time
|
| 135 |
)
|
| 136 |
+
|
| 137 |
# Step 5: Process and validate response
|
| 138 |
processed_response = self._process_response(
|
| 139 |
llm_response.content, filtered_results
|
| 140 |
)
|
| 141 |
+
|
| 142 |
processing_time = time.time() - start_time
|
| 143 |
+
|
| 144 |
return RAGResponse(
|
| 145 |
answer=processed_response,
|
| 146 |
sources=self._format_sources(filtered_results),
|
|
|
|
| 150 |
llm_model=llm_response.model,
|
| 151 |
context_length=len(context),
|
| 152 |
search_results_count=len(search_results),
|
| 153 |
+
success=True,
|
| 154 |
)
|
| 155 |
+
|
| 156 |
except Exception as e:
|
| 157 |
logger.error(f"RAG pipeline error: {e}")
|
| 158 |
return RAGResponse(
|
| 159 |
+
answer=(
|
| 160 |
+
"I apologize, but I encountered an error processing your question. "
|
| 161 |
+
"Please try again or contact support."
|
| 162 |
+
),
|
| 163 |
sources=[],
|
| 164 |
confidence=0.0,
|
| 165 |
processing_time=time.time() - start_time,
|
|
|
|
| 168 |
context_length=0,
|
| 169 |
search_results_count=0,
|
| 170 |
success=False,
|
| 171 |
+
error_message=str(e),
|
| 172 |
)
|
| 173 |
|
| 174 |
def _retrieve_context(self, question: str) -> List[Dict[str, Any]]:
|
|
|
|
| 177 |
results = self.search_service.search(
|
| 178 |
query=question,
|
| 179 |
top_k=self.config.search_top_k,
|
| 180 |
+
threshold=self.config.search_threshold,
|
| 181 |
)
|
| 182 |
+
|
| 183 |
logger.debug(f"Retrieved {len(results)} search results")
|
| 184 |
return results
|
| 185 |
+
|
| 186 |
except Exception as e:
|
| 187 |
logger.error(f"Context retrieval error: {e}")
|
| 188 |
return []
|
|
|
|
| 190 |
def _generate_llm_response(self, question: str, context: str) -> LLMResponse:
|
| 191 |
"""Generate response using LLM with formatted prompt."""
|
| 192 |
template = self.prompt_templates.get_policy_qa_template()
|
| 193 |
+
|
| 194 |
# Format the prompt
|
| 195 |
formatted_prompt = template.user_template.format(
|
| 196 |
+
question=question, context=context
|
|
|
|
| 197 |
)
|
| 198 |
+
|
| 199 |
# Add system prompt (if LLM service supports it in future)
|
| 200 |
full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}"
|
| 201 |
+
|
| 202 |
return self.llm_service.generate_response(full_prompt)
|
| 203 |
|
| 204 |
def _process_response(
|
| 205 |
+
self, raw_response: str, search_results: List[Dict[str, Any]]
|
|
|
|
|
|
|
| 206 |
) -> str:
|
| 207 |
"""Process and validate LLM response."""
|
| 208 |
+
|
| 209 |
# Ensure citations are present
|
| 210 |
response_with_citations = self.prompt_templates.add_fallback_citations(
|
| 211 |
raw_response, search_results
|
| 212 |
)
|
| 213 |
+
|
| 214 |
# Validate citations if enabled
|
| 215 |
if self.config.enable_citation_validation:
|
| 216 |
available_sources = [
|
| 217 |
result.get("metadata", {}).get("filename", "")
|
| 218 |
for result in search_results
|
| 219 |
]
|
| 220 |
+
|
| 221 |
citation_validation = self.prompt_templates.validate_citations(
|
| 222 |
response_with_citations, available_sources
|
| 223 |
)
|
| 224 |
+
|
| 225 |
# Log any invalid citations
|
| 226 |
invalid_citations = [
|
| 227 |
+
citation for citation, valid in citation_validation.items() if not valid
|
|
|
|
| 228 |
]
|
| 229 |
+
|
| 230 |
if invalid_citations:
|
| 231 |
logger.warning(f"Invalid citations detected: {invalid_citations}")
|
| 232 |
+
|
| 233 |
# Truncate if too long
|
| 234 |
if len(response_with_citations) > self.config.max_response_length:
|
| 235 |
+
truncated = (
|
| 236 |
+
response_with_citations[: self.config.max_response_length - 3] + "..."
|
| 237 |
+
)
|
| 238 |
+
logger.warning(
|
| 239 |
+
f"Response truncated from {len(response_with_citations)} "
|
| 240 |
+
f"to {len(truncated)} characters"
|
| 241 |
+
)
|
| 242 |
return truncated
|
| 243 |
+
|
| 244 |
return response_with_citations
|
| 245 |
|
| 246 |
+
def _format_sources(
|
| 247 |
+
self, search_results: List[Dict[str, Any]]
|
| 248 |
+
) -> List[Dict[str, Any]]:
|
| 249 |
"""Format search results for response metadata."""
|
| 250 |
sources = []
|
| 251 |
+
|
| 252 |
for result in search_results:
|
| 253 |
metadata = result.get("metadata", {})
|
| 254 |
+
sources.append(
|
| 255 |
+
{
|
| 256 |
+
"document": metadata.get("filename", "unknown"),
|
| 257 |
+
"chunk_id": result.get("chunk_id", ""),
|
| 258 |
+
"relevance_score": result.get("similarity_score", 0.0),
|
| 259 |
+
"excerpt": (
|
| 260 |
+
result.get("content", "")[:200] + "..."
|
| 261 |
+
if len(result.get("content", "")) > 200
|
| 262 |
+
else result.get("content", "")
|
| 263 |
+
),
|
| 264 |
+
}
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
return sources
|
| 268 |
|
| 269 |
def _calculate_confidence(
|
| 270 |
+
self, quality_metrics: Dict[str, Any], llm_response: LLMResponse
|
|
|
|
|
|
|
| 271 |
) -> float:
|
| 272 |
"""Calculate confidence score for the response."""
|
| 273 |
+
|
| 274 |
# Base confidence on context quality
|
| 275 |
context_confidence = quality_metrics.get("estimated_relevance", 0.0)
|
| 276 |
+
|
| 277 |
# Adjust based on LLM response time (faster might indicate more confidence)
|
| 278 |
time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0))
|
| 279 |
+
|
| 280 |
# Combine factors
|
| 281 |
confidence = (context_confidence * 0.7) + (time_factor * 0.3)
|
| 282 |
+
|
| 283 |
return min(1.0, max(0.0, confidence))
|
| 284 |
|
| 285 |
+
def _create_no_context_response(
|
| 286 |
+
self, question: str, start_time: float
|
| 287 |
+
) -> RAGResponse:
|
| 288 |
"""Create response when no relevant context found."""
|
| 289 |
return RAGResponse(
|
| 290 |
+
answer=(
|
| 291 |
+
"I couldn't find any relevant information in our corporate policies "
|
| 292 |
+
"to answer your question. Please contact HR or check other company "
|
| 293 |
+
"resources for assistance."
|
| 294 |
+
),
|
| 295 |
sources=[],
|
| 296 |
confidence=0.0,
|
| 297 |
processing_time=time.time() - start_time,
|
|
|
|
| 299 |
llm_model="none",
|
| 300 |
context_length=0,
|
| 301 |
search_results_count=0,
|
| 302 |
+
success=True, # This is a valid "no answer" response
|
| 303 |
)
|
| 304 |
|
| 305 |
def _create_insufficient_context_response(
|
| 306 |
+
self, question: str, results: List[Dict[str, Any]], start_time: float
|
|
|
|
|
|
|
|
|
|
| 307 |
) -> RAGResponse:
|
| 308 |
"""Create response when context quality is insufficient."""
|
| 309 |
return RAGResponse(
|
| 310 |
+
answer=(
|
| 311 |
+
"I found some potentially relevant information, but it doesn't provide "
|
| 312 |
+
"enough detail to fully answer your question. Please contact HR for "
|
| 313 |
+
"more specific guidance or rephrase your question."
|
| 314 |
+
),
|
| 315 |
sources=self._format_sources(results),
|
| 316 |
confidence=0.2,
|
| 317 |
processing_time=time.time() - start_time,
|
|
|
|
| 319 |
llm_model="none",
|
| 320 |
context_length=0,
|
| 321 |
search_results_count=len(results),
|
| 322 |
+
success=True,
|
| 323 |
)
|
| 324 |
|
| 325 |
def _create_llm_error_response(
|
| 326 |
+
self, question: str, error_message: str, start_time: float
|
|
|
|
|
|
|
|
|
|
| 327 |
) -> RAGResponse:
|
| 328 |
"""Create response when LLM generation fails."""
|
| 329 |
return RAGResponse(
|
| 330 |
+
answer=(
|
| 331 |
+
"I apologize, but I'm currently unable to generate a response. "
|
| 332 |
+
"Please try again in a moment or contact support if the issue persists."
|
| 333 |
+
),
|
| 334 |
sources=[],
|
| 335 |
confidence=0.0,
|
| 336 |
processing_time=time.time() - start_time,
|
|
|
|
| 339 |
context_length=0,
|
| 340 |
search_results_count=0,
|
| 341 |
success=False,
|
| 342 |
+
error_message=error_message,
|
| 343 |
)
|
| 344 |
|
| 345 |
def health_check(self) -> Dict[str, Any]:
|
| 346 |
"""
|
| 347 |
Perform health check on all pipeline components.
|
| 348 |
+
|
| 349 |
Returns:
|
| 350 |
Dictionary with component health status
|
| 351 |
"""
|
| 352 |
+
health_status = {"pipeline": "healthy", "components": {}}
|
| 353 |
+
|
|
|
|
|
|
|
|
|
|
| 354 |
try:
|
| 355 |
# Check search service
|
| 356 |
+
test_results = self.search_service.search(
|
| 357 |
+
"test query", top_k=1, threshold=0.0
|
| 358 |
+
)
|
| 359 |
health_status["components"]["search_service"] = {
|
| 360 |
"status": "healthy",
|
| 361 |
+
"test_results_count": len(test_results),
|
| 362 |
}
|
| 363 |
except Exception as e:
|
| 364 |
health_status["components"]["search_service"] = {
|
| 365 |
"status": "unhealthy",
|
| 366 |
+
"error": str(e),
|
| 367 |
}
|
| 368 |
health_status["pipeline"] = "degraded"
|
| 369 |
+
|
| 370 |
try:
|
| 371 |
# Check LLM service
|
| 372 |
llm_health = self.llm_service.health_check()
|
| 373 |
health_status["components"]["llm_service"] = llm_health
|
| 374 |
+
|
| 375 |
# Pipeline is unhealthy if all LLM providers are down
|
| 376 |
healthy_providers = sum(
|
| 377 |
+
1
|
| 378 |
+
for provider_status in llm_health.values()
|
| 379 |
if provider_status.get("status") == "healthy"
|
| 380 |
)
|
| 381 |
+
|
| 382 |
if healthy_providers == 0:
|
| 383 |
health_status["pipeline"] = "unhealthy"
|
| 384 |
+
|
| 385 |
except Exception as e:
|
| 386 |
health_status["components"]["llm_service"] = {
|
| 387 |
+
"status": "unhealthy",
|
| 388 |
+
"error": str(e),
|
| 389 |
}
|
| 390 |
health_status["pipeline"] = "unhealthy"
|
| 391 |
+
|
| 392 |
+
return health_status
|
src/rag/response_formatter.py
CHANGED
|
@@ -6,9 +6,8 @@ formatting, metadata inclusion, and consistent response structure.
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import logging
|
|
|
|
| 9 |
from typing import Any, Dict, List, Optional
|
| 10 |
-
from dataclasses import dataclass, asdict
|
| 11 |
-
import json
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
@@ -16,6 +15,7 @@ logger = logging.getLogger(__name__)
|
|
| 16 |
@dataclass
|
| 17 |
class FormattedResponse:
|
| 18 |
"""Standardized formatted response for API endpoints."""
|
|
|
|
| 19 |
status: str
|
| 20 |
answer: str
|
| 21 |
sources: List[Dict[str, Any]]
|
|
@@ -27,7 +27,7 @@ class FormattedResponse:
|
|
| 27 |
class ResponseFormatter:
|
| 28 |
"""
|
| 29 |
Formats RAG pipeline responses for various output formats.
|
| 30 |
-
|
| 31 |
Handles:
|
| 32 |
- API response formatting
|
| 33 |
- Citation formatting
|
|
@@ -40,23 +40,21 @@ class ResponseFormatter:
|
|
| 40 |
logger.info("ResponseFormatter initialized")
|
| 41 |
|
| 42 |
def format_api_response(
|
| 43 |
-
self,
|
| 44 |
-
rag_response: Any, # RAGResponse type
|
| 45 |
-
include_debug: bool = False
|
| 46 |
) -> Dict[str, Any]:
|
| 47 |
"""
|
| 48 |
Format RAG response for API consumption.
|
| 49 |
-
|
| 50 |
Args:
|
| 51 |
rag_response: RAGResponse from RAG pipeline
|
| 52 |
include_debug: Whether to include debug information
|
| 53 |
-
|
| 54 |
Returns:
|
| 55 |
Formatted dictionary for JSON API response
|
| 56 |
"""
|
| 57 |
if not rag_response.success:
|
| 58 |
return self._format_error_response(rag_response)
|
| 59 |
-
|
| 60 |
# Base response structure
|
| 61 |
formatted_response = {
|
| 62 |
"status": "success",
|
|
@@ -66,88 +64,96 @@ class ResponseFormatter:
|
|
| 66 |
"confidence": round(rag_response.confidence, 3),
|
| 67 |
"processing_time_ms": round(rag_response.processing_time * 1000, 1),
|
| 68 |
"source_count": len(rag_response.sources),
|
| 69 |
-
"context_length": rag_response.context_length
|
| 70 |
-
}
|
| 71 |
}
|
| 72 |
-
|
| 73 |
# Add debug information if requested
|
| 74 |
if include_debug:
|
| 75 |
formatted_response["debug"] = {
|
| 76 |
"llm_provider": rag_response.llm_provider,
|
| 77 |
"llm_model": rag_response.llm_model,
|
| 78 |
"search_results_count": rag_response.search_results_count,
|
| 79 |
-
"processing_time_seconds": round(rag_response.processing_time, 3)
|
| 80 |
}
|
| 81 |
-
|
| 82 |
return formatted_response
|
| 83 |
|
| 84 |
def format_chat_response(
|
| 85 |
self,
|
| 86 |
rag_response: Any, # RAGResponse type
|
| 87 |
conversation_id: Optional[str] = None,
|
| 88 |
-
include_sources: bool = True
|
| 89 |
) -> Dict[str, Any]:
|
| 90 |
"""
|
| 91 |
Format RAG response for chat interface.
|
| 92 |
-
|
| 93 |
Args:
|
| 94 |
rag_response: RAGResponse from RAG pipeline
|
| 95 |
conversation_id: Optional conversation ID
|
| 96 |
include_sources: Whether to include source information
|
| 97 |
-
|
| 98 |
Returns:
|
| 99 |
Formatted dictionary for chat interface
|
| 100 |
"""
|
| 101 |
if not rag_response.success:
|
| 102 |
return self._format_chat_error(rag_response, conversation_id)
|
| 103 |
-
|
| 104 |
response = {
|
| 105 |
"message": rag_response.answer,
|
| 106 |
"confidence": round(rag_response.confidence, 2),
|
| 107 |
-
"processing_time_ms": round(rag_response.processing_time * 1000, 1)
|
| 108 |
}
|
| 109 |
-
|
| 110 |
if conversation_id:
|
| 111 |
response["conversation_id"] = conversation_id
|
| 112 |
-
|
| 113 |
if include_sources and rag_response.sources:
|
| 114 |
response["sources"] = self._format_sources_for_chat(rag_response.sources)
|
| 115 |
-
|
| 116 |
return response
|
| 117 |
|
| 118 |
-
def _format_source_list(
|
|
|
|
|
|
|
| 119 |
"""Format source list for API response."""
|
| 120 |
formatted_sources = []
|
| 121 |
-
|
| 122 |
for source in sources:
|
| 123 |
formatted_source = {
|
| 124 |
"document": source.get("document", "unknown"),
|
| 125 |
"relevance_score": round(source.get("relevance_score", 0.0), 3),
|
| 126 |
-
"excerpt": source.get("excerpt", "")
|
| 127 |
}
|
| 128 |
-
|
| 129 |
# Add chunk ID if available
|
| 130 |
chunk_id = source.get("chunk_id", "")
|
| 131 |
if chunk_id:
|
| 132 |
formatted_source["chunk_id"] = chunk_id
|
| 133 |
-
|
| 134 |
formatted_sources.append(formatted_source)
|
| 135 |
-
|
| 136 |
return formatted_sources
|
| 137 |
|
| 138 |
-
def _format_sources_for_chat(
|
|
|
|
|
|
|
| 139 |
"""Format sources for chat interface (more concise)."""
|
| 140 |
formatted_sources = []
|
| 141 |
-
|
| 142 |
for i, source in enumerate(sources[:3], 1): # Limit to top 3 for chat
|
| 143 |
formatted_source = {
|
| 144 |
"id": i,
|
| 145 |
"document": source.get("document", "unknown"),
|
| 146 |
"relevance": f"{source.get('relevance_score', 0.0):.1%}",
|
| 147 |
-
"preview":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
}
|
| 149 |
formatted_sources.append(formatted_source)
|
| 150 |
-
|
| 151 |
return formatted_sources
|
| 152 |
|
| 153 |
def _format_error_response(self, rag_response: Any) -> Dict[str, Any]:
|
|
@@ -157,51 +163,45 @@ class ResponseFormatter:
|
|
| 157 |
"error": {
|
| 158 |
"message": rag_response.answer,
|
| 159 |
"details": rag_response.error_message,
|
| 160 |
-
"processing_time_ms": round(rag_response.processing_time * 1000, 1)
|
| 161 |
},
|
| 162 |
"sources": [],
|
| 163 |
-
"metadata": {
|
| 164 |
-
"confidence": 0.0,
|
| 165 |
-
"source_count": 0,
|
| 166 |
-
"context_length": 0
|
| 167 |
-
}
|
| 168 |
}
|
| 169 |
|
| 170 |
def _format_chat_error(
|
| 171 |
-
self,
|
| 172 |
-
rag_response: Any,
|
| 173 |
-
conversation_id: Optional[str] = None
|
| 174 |
) -> Dict[str, Any]:
|
| 175 |
"""Format error response for chat interface."""
|
| 176 |
response = {
|
| 177 |
"message": rag_response.answer,
|
| 178 |
"error": True,
|
| 179 |
-
"processing_time_ms": round(rag_response.processing_time * 1000, 1)
|
| 180 |
}
|
| 181 |
-
|
| 182 |
if conversation_id:
|
| 183 |
response["conversation_id"] = conversation_id
|
| 184 |
-
|
| 185 |
return response
|
| 186 |
|
| 187 |
def validate_response_format(self, response: Dict[str, Any]) -> bool:
|
| 188 |
"""
|
| 189 |
Validate that response follows expected format.
|
| 190 |
-
|
| 191 |
Args:
|
| 192 |
response: Formatted response dictionary
|
| 193 |
-
|
| 194 |
Returns:
|
| 195 |
True if format is valid, False otherwise
|
| 196 |
"""
|
| 197 |
required_fields = ["status"]
|
| 198 |
-
|
| 199 |
# Check required fields
|
| 200 |
for field in required_fields:
|
| 201 |
if field not in response:
|
| 202 |
logger.error(f"Missing required field: {field}")
|
| 203 |
return False
|
| 204 |
-
|
| 205 |
# Check status-specific requirements
|
| 206 |
if response["status"] == "success":
|
| 207 |
success_fields = ["answer", "sources", "metadata"]
|
|
@@ -209,21 +209,21 @@ class ResponseFormatter:
|
|
| 209 |
if field not in response:
|
| 210 |
logger.error(f"Missing success field: {field}")
|
| 211 |
return False
|
| 212 |
-
|
| 213 |
elif response["status"] == "error":
|
| 214 |
if "error" not in response:
|
| 215 |
logger.error("Missing error field in error response")
|
| 216 |
return False
|
| 217 |
-
|
| 218 |
return True
|
| 219 |
|
| 220 |
def create_health_response(self, health_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 221 |
"""
|
| 222 |
Format health check response.
|
| 223 |
-
|
| 224 |
Args:
|
| 225 |
health_data: Health status from RAG pipeline
|
| 226 |
-
|
| 227 |
Returns:
|
| 228 |
Formatted health response
|
| 229 |
"""
|
|
@@ -232,51 +232,65 @@ class ResponseFormatter:
|
|
| 232 |
"health": {
|
| 233 |
"pipeline_status": health_data.get("pipeline", "unknown"),
|
| 234 |
"components": health_data.get("components", {}),
|
| 235 |
-
"timestamp": self._get_timestamp()
|
| 236 |
-
}
|
| 237 |
}
|
| 238 |
|
| 239 |
-
def create_no_answer_response(
|
|
|
|
|
|
|
| 240 |
"""
|
| 241 |
Create standardized response when no answer can be provided.
|
| 242 |
-
|
| 243 |
Args:
|
| 244 |
question: Original user question
|
| 245 |
reason: Reason for no answer (no_context, insufficient_context, etc.)
|
| 246 |
-
|
| 247 |
Returns:
|
| 248 |
Formatted no-answer response
|
| 249 |
"""
|
| 250 |
messages = {
|
| 251 |
-
"no_context":
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
}
|
| 256 |
-
|
| 257 |
message = messages.get(reason, messages["error"])
|
| 258 |
-
|
| 259 |
return {
|
| 260 |
"status": "no_answer",
|
| 261 |
"message": message,
|
| 262 |
"reason": reason,
|
| 263 |
-
"suggestion":
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
}
|
| 266 |
|
| 267 |
def _get_timestamp(self) -> str:
|
| 268 |
"""Get current timestamp in ISO format."""
|
| 269 |
from datetime import datetime
|
|
|
|
| 270 |
return datetime.utcnow().isoformat() + "Z"
|
| 271 |
|
| 272 |
def format_for_logging(self, rag_response: Any, question: str) -> Dict[str, Any]:
|
| 273 |
"""
|
| 274 |
Format response data for logging purposes.
|
| 275 |
-
|
| 276 |
Args:
|
| 277 |
rag_response: RAGResponse from pipeline
|
| 278 |
question: Original question
|
| 279 |
-
|
| 280 |
Returns:
|
| 281 |
Formatted data for logging
|
| 282 |
"""
|
|
@@ -291,5 +305,5 @@ class ResponseFormatter:
|
|
| 291 |
"source_count": len(rag_response.sources),
|
| 292 |
"context_length": rag_response.context_length,
|
| 293 |
"answer_length": len(rag_response.answer),
|
| 294 |
-
"error": rag_response.error_message
|
| 295 |
-
}
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import logging
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
|
|
| 15 |
@dataclass
|
| 16 |
class FormattedResponse:
|
| 17 |
"""Standardized formatted response for API endpoints."""
|
| 18 |
+
|
| 19 |
status: str
|
| 20 |
answer: str
|
| 21 |
sources: List[Dict[str, Any]]
|
|
|
|
| 27 |
class ResponseFormatter:
|
| 28 |
"""
|
| 29 |
Formats RAG pipeline responses for various output formats.
|
| 30 |
+
|
| 31 |
Handles:
|
| 32 |
- API response formatting
|
| 33 |
- Citation formatting
|
|
|
|
| 40 |
logger.info("ResponseFormatter initialized")
|
| 41 |
|
| 42 |
def format_api_response(
|
| 43 |
+
self, rag_response: Any, include_debug: bool = False # RAGResponse type
|
|
|
|
|
|
|
| 44 |
) -> Dict[str, Any]:
|
| 45 |
"""
|
| 46 |
Format RAG response for API consumption.
|
| 47 |
+
|
| 48 |
Args:
|
| 49 |
rag_response: RAGResponse from RAG pipeline
|
| 50 |
include_debug: Whether to include debug information
|
| 51 |
+
|
| 52 |
Returns:
|
| 53 |
Formatted dictionary for JSON API response
|
| 54 |
"""
|
| 55 |
if not rag_response.success:
|
| 56 |
return self._format_error_response(rag_response)
|
| 57 |
+
|
| 58 |
# Base response structure
|
| 59 |
formatted_response = {
|
| 60 |
"status": "success",
|
|
|
|
| 64 |
"confidence": round(rag_response.confidence, 3),
|
| 65 |
"processing_time_ms": round(rag_response.processing_time * 1000, 1),
|
| 66 |
"source_count": len(rag_response.sources),
|
| 67 |
+
"context_length": rag_response.context_length,
|
| 68 |
+
},
|
| 69 |
}
|
| 70 |
+
|
| 71 |
# Add debug information if requested
|
| 72 |
if include_debug:
|
| 73 |
formatted_response["debug"] = {
|
| 74 |
"llm_provider": rag_response.llm_provider,
|
| 75 |
"llm_model": rag_response.llm_model,
|
| 76 |
"search_results_count": rag_response.search_results_count,
|
| 77 |
+
"processing_time_seconds": round(rag_response.processing_time, 3),
|
| 78 |
}
|
| 79 |
+
|
| 80 |
return formatted_response
|
| 81 |
|
| 82 |
def format_chat_response(
|
| 83 |
self,
|
| 84 |
rag_response: Any, # RAGResponse type
|
| 85 |
conversation_id: Optional[str] = None,
|
| 86 |
+
include_sources: bool = True,
|
| 87 |
) -> Dict[str, Any]:
|
| 88 |
"""
|
| 89 |
Format RAG response for chat interface.
|
| 90 |
+
|
| 91 |
Args:
|
| 92 |
rag_response: RAGResponse from RAG pipeline
|
| 93 |
conversation_id: Optional conversation ID
|
| 94 |
include_sources: Whether to include source information
|
| 95 |
+
|
| 96 |
Returns:
|
| 97 |
Formatted dictionary for chat interface
|
| 98 |
"""
|
| 99 |
if not rag_response.success:
|
| 100 |
return self._format_chat_error(rag_response, conversation_id)
|
| 101 |
+
|
| 102 |
response = {
|
| 103 |
"message": rag_response.answer,
|
| 104 |
"confidence": round(rag_response.confidence, 2),
|
| 105 |
+
"processing_time_ms": round(rag_response.processing_time * 1000, 1),
|
| 106 |
}
|
| 107 |
+
|
| 108 |
if conversation_id:
|
| 109 |
response["conversation_id"] = conversation_id
|
| 110 |
+
|
| 111 |
if include_sources and rag_response.sources:
|
| 112 |
response["sources"] = self._format_sources_for_chat(rag_response.sources)
|
| 113 |
+
|
| 114 |
return response
|
| 115 |
|
| 116 |
+
def _format_source_list(
|
| 117 |
+
self, sources: List[Dict[str, Any]]
|
| 118 |
+
) -> List[Dict[str, Any]]:
|
| 119 |
"""Format source list for API response."""
|
| 120 |
formatted_sources = []
|
| 121 |
+
|
| 122 |
for source in sources:
|
| 123 |
formatted_source = {
|
| 124 |
"document": source.get("document", "unknown"),
|
| 125 |
"relevance_score": round(source.get("relevance_score", 0.0), 3),
|
| 126 |
+
"excerpt": source.get("excerpt", ""),
|
| 127 |
}
|
| 128 |
+
|
| 129 |
# Add chunk ID if available
|
| 130 |
chunk_id = source.get("chunk_id", "")
|
| 131 |
if chunk_id:
|
| 132 |
formatted_source["chunk_id"] = chunk_id
|
| 133 |
+
|
| 134 |
formatted_sources.append(formatted_source)
|
| 135 |
+
|
| 136 |
return formatted_sources
|
| 137 |
|
| 138 |
+
def _format_sources_for_chat(
|
| 139 |
+
self, sources: List[Dict[str, Any]]
|
| 140 |
+
) -> List[Dict[str, Any]]:
|
| 141 |
"""Format sources for chat interface (more concise)."""
|
| 142 |
formatted_sources = []
|
| 143 |
+
|
| 144 |
for i, source in enumerate(sources[:3], 1): # Limit to top 3 for chat
|
| 145 |
formatted_source = {
|
| 146 |
"id": i,
|
| 147 |
"document": source.get("document", "unknown"),
|
| 148 |
"relevance": f"{source.get('relevance_score', 0.0):.1%}",
|
| 149 |
+
"preview": (
|
| 150 |
+
source.get("excerpt", "")[:100] + "..."
|
| 151 |
+
if len(source.get("excerpt", "")) > 100
|
| 152 |
+
else source.get("excerpt", "")
|
| 153 |
+
),
|
| 154 |
}
|
| 155 |
formatted_sources.append(formatted_source)
|
| 156 |
+
|
| 157 |
return formatted_sources
|
| 158 |
|
| 159 |
def _format_error_response(self, rag_response: Any) -> Dict[str, Any]:
|
|
|
|
| 163 |
"error": {
|
| 164 |
"message": rag_response.answer,
|
| 165 |
"details": rag_response.error_message,
|
| 166 |
+
"processing_time_ms": round(rag_response.processing_time * 1000, 1),
|
| 167 |
},
|
| 168 |
"sources": [],
|
| 169 |
+
"metadata": {"confidence": 0.0, "source_count": 0, "context_length": 0},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
}
|
| 171 |
|
| 172 |
def _format_chat_error(
|
| 173 |
+
self, rag_response: Any, conversation_id: Optional[str] = None
|
|
|
|
|
|
|
| 174 |
) -> Dict[str, Any]:
|
| 175 |
"""Format error response for chat interface."""
|
| 176 |
response = {
|
| 177 |
"message": rag_response.answer,
|
| 178 |
"error": True,
|
| 179 |
+
"processing_time_ms": round(rag_response.processing_time * 1000, 1),
|
| 180 |
}
|
| 181 |
+
|
| 182 |
if conversation_id:
|
| 183 |
response["conversation_id"] = conversation_id
|
| 184 |
+
|
| 185 |
return response
|
| 186 |
|
| 187 |
def validate_response_format(self, response: Dict[str, Any]) -> bool:
|
| 188 |
"""
|
| 189 |
Validate that response follows expected format.
|
| 190 |
+
|
| 191 |
Args:
|
| 192 |
response: Formatted response dictionary
|
| 193 |
+
|
| 194 |
Returns:
|
| 195 |
True if format is valid, False otherwise
|
| 196 |
"""
|
| 197 |
required_fields = ["status"]
|
| 198 |
+
|
| 199 |
# Check required fields
|
| 200 |
for field in required_fields:
|
| 201 |
if field not in response:
|
| 202 |
logger.error(f"Missing required field: {field}")
|
| 203 |
return False
|
| 204 |
+
|
| 205 |
# Check status-specific requirements
|
| 206 |
if response["status"] == "success":
|
| 207 |
success_fields = ["answer", "sources", "metadata"]
|
|
|
|
| 209 |
if field not in response:
|
| 210 |
logger.error(f"Missing success field: {field}")
|
| 211 |
return False
|
| 212 |
+
|
| 213 |
elif response["status"] == "error":
|
| 214 |
if "error" not in response:
|
| 215 |
logger.error("Missing error field in error response")
|
| 216 |
return False
|
| 217 |
+
|
| 218 |
return True
|
| 219 |
|
| 220 |
def create_health_response(self, health_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 221 |
"""
|
| 222 |
Format health check response.
|
| 223 |
+
|
| 224 |
Args:
|
| 225 |
health_data: Health status from RAG pipeline
|
| 226 |
+
|
| 227 |
Returns:
|
| 228 |
Formatted health response
|
| 229 |
"""
|
|
|
|
| 232 |
"health": {
|
| 233 |
"pipeline_status": health_data.get("pipeline", "unknown"),
|
| 234 |
"components": health_data.get("components", {}),
|
| 235 |
+
"timestamp": self._get_timestamp(),
|
| 236 |
+
},
|
| 237 |
}
|
| 238 |
|
| 239 |
+
def create_no_answer_response(
|
| 240 |
+
self, question: str, reason: str = "no_context"
|
| 241 |
+
) -> Dict[str, Any]:
|
| 242 |
"""
|
| 243 |
Create standardized response when no answer can be provided.
|
| 244 |
+
|
| 245 |
Args:
|
| 246 |
question: Original user question
|
| 247 |
reason: Reason for no answer (no_context, insufficient_context, etc.)
|
| 248 |
+
|
| 249 |
Returns:
|
| 250 |
Formatted no-answer response
|
| 251 |
"""
|
| 252 |
messages = {
|
| 253 |
+
"no_context": (
|
| 254 |
+
"I couldn't find any relevant information in our corporate "
|
| 255 |
+
"policies to answer your question."
|
| 256 |
+
),
|
| 257 |
+
"insufficient_context": (
|
| 258 |
+
"I found some potentially relevant information, but not "
|
| 259 |
+
"enough to provide a complete answer."
|
| 260 |
+
),
|
| 261 |
+
"off_topic": (
|
| 262 |
+
"This question appears to be outside the scope of our "
|
| 263 |
+
"corporate policies."
|
| 264 |
+
),
|
| 265 |
+
"error": "I encountered an error while processing your question.",
|
| 266 |
}
|
| 267 |
+
|
| 268 |
message = messages.get(reason, messages["error"])
|
| 269 |
+
|
| 270 |
return {
|
| 271 |
"status": "no_answer",
|
| 272 |
"message": message,
|
| 273 |
"reason": reason,
|
| 274 |
+
"suggestion": (
|
| 275 |
+
"Please contact HR or rephrase your question for better results."
|
| 276 |
+
),
|
| 277 |
+
"sources": [],
|
| 278 |
}
|
| 279 |
|
| 280 |
def _get_timestamp(self) -> str:
|
| 281 |
"""Get current timestamp in ISO format."""
|
| 282 |
from datetime import datetime
|
| 283 |
+
|
| 284 |
return datetime.utcnow().isoformat() + "Z"
|
| 285 |
|
| 286 |
def format_for_logging(self, rag_response: Any, question: str) -> Dict[str, Any]:
|
| 287 |
"""
|
| 288 |
Format response data for logging purposes.
|
| 289 |
+
|
| 290 |
Args:
|
| 291 |
rag_response: RAGResponse from pipeline
|
| 292 |
question: Original question
|
| 293 |
+
|
| 294 |
Returns:
|
| 295 |
Formatted data for logging
|
| 296 |
"""
|
|
|
|
| 305 |
"source_count": len(rag_response.sources),
|
| 306 |
"context_length": rag_response.context_length,
|
| 307 |
"answer_length": len(rag_response.answer),
|
| 308 |
+
"error": rag_response.error_message,
|
| 309 |
+
}
|
tests/test_chat_endpoint.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
import pytest
|
| 4 |
-
from unittest.mock import patch, MagicMock
|
| 5 |
|
| 6 |
from app import app as flask_app
|
| 7 |
|
|
@@ -19,100 +20,122 @@ def client(app):
|
|
| 19 |
class TestChatEndpoint:
|
| 20 |
"""Test cases for the /chat endpoint"""
|
| 21 |
|
| 22 |
-
@patch.dict(os.environ, {
|
| 23 |
-
@patch(
|
| 24 |
-
@patch(
|
| 25 |
-
@patch(
|
| 26 |
-
@patch(
|
| 27 |
-
@patch(
|
| 28 |
-
@patch(
|
| 29 |
-
def test_chat_endpoint_valid_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"""Test chat endpoint with valid request"""
|
| 31 |
# Mock the RAG pipeline response
|
| 32 |
mock_response = {
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
}
|
| 39 |
-
|
| 40 |
# Setup mock instances
|
| 41 |
mock_rag_instance = MagicMock()
|
| 42 |
mock_rag_instance.generate_answer.return_value = mock_response
|
| 43 |
mock_rag.return_value = mock_rag_instance
|
| 44 |
-
|
| 45 |
mock_formatter_instance = MagicMock()
|
| 46 |
mock_formatter_instance.format_api_response.return_value = {
|
| 47 |
"status": "success",
|
| 48 |
-
"answer": mock_response[
|
| 49 |
-
"confidence": mock_response[
|
| 50 |
-
"sources": mock_response[
|
| 51 |
-
"citations": mock_response[
|
| 52 |
}
|
| 53 |
mock_formatter.return_value = mock_formatter_instance
|
| 54 |
-
|
| 55 |
# Mock LLMService.from_environment to return a mock instance
|
| 56 |
mock_llm_instance = MagicMock()
|
| 57 |
mock_llm.from_environment.return_value = mock_llm_instance
|
| 58 |
|
| 59 |
request_data = {
|
| 60 |
"message": "What is the remote work policy?",
|
| 61 |
-
"include_sources": True
|
| 62 |
}
|
| 63 |
|
| 64 |
response = client.post(
|
| 65 |
-
"/chat",
|
| 66 |
-
data=json.dumps(request_data),
|
| 67 |
-
content_type="application/json"
|
| 68 |
)
|
| 69 |
|
| 70 |
assert response.status_code == 200
|
| 71 |
data = response.get_json()
|
| 72 |
-
|
| 73 |
assert data["status"] == "success"
|
| 74 |
assert "answer" in data
|
| 75 |
assert "confidence" in data
|
| 76 |
assert "sources" in data
|
| 77 |
assert "citations" in data
|
| 78 |
|
| 79 |
-
@patch.dict(os.environ, {
|
| 80 |
-
@patch(
|
| 81 |
-
@patch(
|
| 82 |
-
@patch(
|
| 83 |
-
@patch(
|
| 84 |
-
@patch(
|
| 85 |
-
@patch(
|
| 86 |
-
def test_chat_endpoint_minimal_request(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
"""Test chat endpoint with minimal request (only message)"""
|
| 88 |
mock_response = {
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
}
|
| 95 |
-
|
| 96 |
# Setup mock instances
|
| 97 |
mock_rag_instance = MagicMock()
|
| 98 |
mock_rag_instance.generate_answer.return_value = mock_response
|
| 99 |
mock_rag.return_value = mock_rag_instance
|
| 100 |
-
|
| 101 |
mock_formatter_instance = MagicMock()
|
| 102 |
mock_formatter_instance.format_api_response.return_value = {
|
| 103 |
"status": "success",
|
| 104 |
-
"answer": mock_response[
|
| 105 |
}
|
| 106 |
mock_formatter.return_value = mock_formatter_instance
|
| 107 |
-
|
| 108 |
mock_llm.from_environment.return_value = MagicMock()
|
| 109 |
|
| 110 |
request_data = {"message": "What are the employee benefits?"}
|
| 111 |
|
| 112 |
response = client.post(
|
| 113 |
-
"/chat",
|
| 114 |
-
data=json.dumps(request_data),
|
| 115 |
-
content_type="application/json"
|
| 116 |
)
|
| 117 |
|
| 118 |
assert response.status_code == 200
|
|
@@ -124,9 +147,7 @@ class TestChatEndpoint:
|
|
| 124 |
request_data = {"include_sources": True}
|
| 125 |
|
| 126 |
response = client.post(
|
| 127 |
-
"/chat",
|
| 128 |
-
data=json.dumps(request_data),
|
| 129 |
-
content_type="application/json"
|
| 130 |
)
|
| 131 |
|
| 132 |
assert response.status_code == 400
|
|
@@ -139,9 +160,7 @@ class TestChatEndpoint:
|
|
| 139 |
request_data = {"message": ""}
|
| 140 |
|
| 141 |
response = client.post(
|
| 142 |
-
"/chat",
|
| 143 |
-
data=json.dumps(request_data),
|
| 144 |
-
content_type="application/json"
|
| 145 |
)
|
| 146 |
|
| 147 |
assert response.status_code == 400
|
|
@@ -154,9 +173,7 @@ class TestChatEndpoint:
|
|
| 154 |
request_data = {"message": 123}
|
| 155 |
|
| 156 |
response = client.post(
|
| 157 |
-
"/chat",
|
| 158 |
-
data=json.dumps(request_data),
|
| 159 |
-
content_type="application/json"
|
| 160 |
)
|
| 161 |
|
| 162 |
assert response.status_code == 400
|
|
@@ -179,9 +196,7 @@ class TestChatEndpoint:
|
|
| 179 |
request_data = {"message": "What is the policy?"}
|
| 180 |
|
| 181 |
response = client.post(
|
| 182 |
-
"/chat",
|
| 183 |
-
data=json.dumps(request_data),
|
| 184 |
-
content_type="application/json"
|
| 185 |
)
|
| 186 |
|
| 187 |
assert response.status_code == 503
|
|
@@ -189,67 +204,110 @@ class TestChatEndpoint:
|
|
| 189 |
assert data["status"] == "error"
|
| 190 |
assert "LLM service configuration error" in data["message"]
|
| 191 |
|
| 192 |
-
@patch.dict(os.environ, {
|
| 193 |
-
@patch(
|
| 194 |
-
@patch(
|
| 195 |
-
@patch(
|
| 196 |
-
@patch(
|
| 197 |
-
@patch(
|
| 198 |
-
@patch(
|
| 199 |
-
def test_chat_endpoint_with_conversation_id(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
"""Test chat endpoint with conversation_id parameter"""
|
| 201 |
mock_response = {
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
}
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
request_data = {
|
| 212 |
"message": "What is the PTO policy?",
|
| 213 |
"conversation_id": "conv_123",
|
| 214 |
-
"include_sources": False
|
| 215 |
}
|
| 216 |
|
| 217 |
response = client.post(
|
| 218 |
-
"/chat",
|
| 219 |
-
data=json.dumps(request_data),
|
| 220 |
-
content_type="application/json"
|
| 221 |
)
|
| 222 |
|
| 223 |
assert response.status_code == 200
|
| 224 |
data = response.get_json()
|
| 225 |
assert data["status"] == "success"
|
| 226 |
|
| 227 |
-
@patch.dict(os.environ, {
|
| 228 |
-
@patch(
|
| 229 |
-
@patch(
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
"""Test chat endpoint with debug information"""
|
| 232 |
mock_response = {
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
}
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
request_data = {
|
| 245 |
"message": "What are the security requirements?",
|
| 246 |
-
"include_debug": True
|
| 247 |
}
|
| 248 |
|
| 249 |
response = client.post(
|
| 250 |
-
"/chat",
|
| 251 |
-
data=json.dumps(request_data),
|
| 252 |
-
content_type="application/json"
|
| 253 |
)
|
| 254 |
|
| 255 |
assert response.status_code == 200
|
|
@@ -260,9 +318,9 @@ class TestChatEndpoint:
|
|
| 260 |
class TestChatHealthEndpoint:
|
| 261 |
"""Test cases for the /chat/health endpoint"""
|
| 262 |
|
| 263 |
-
@patch.dict(os.environ, {
|
| 264 |
-
@patch(
|
| 265 |
-
@patch(
|
| 266 |
def test_chat_health_healthy(self, mock_health_check, mock_llm_service, client):
|
| 267 |
"""Test chat health endpoint when all services are healthy"""
|
| 268 |
mock_health_data = {
|
|
@@ -270,8 +328,8 @@ class TestChatHealthEndpoint:
|
|
| 270 |
"components": {
|
| 271 |
"search_service": {"status": "healthy"},
|
| 272 |
"llm_service": {"status": "healthy"},
|
| 273 |
-
"vector_db": {"status": "healthy"}
|
| 274 |
-
}
|
| 275 |
}
|
| 276 |
mock_health_check.return_value = mock_health_data
|
| 277 |
mock_llm_service.return_value = MagicMock()
|
|
@@ -282,9 +340,9 @@ class TestChatHealthEndpoint:
|
|
| 282 |
data = response.get_json()
|
| 283 |
assert data["status"] == "success"
|
| 284 |
|
| 285 |
-
@patch.dict(os.environ, {
|
| 286 |
-
@patch(
|
| 287 |
-
@patch(
|
| 288 |
def test_chat_health_degraded(self, mock_health_check, mock_llm_service, client):
|
| 289 |
"""Test chat health endpoint when services are degraded"""
|
| 290 |
mock_health_data = {
|
|
@@ -292,8 +350,8 @@ class TestChatHealthEndpoint:
|
|
| 292 |
"components": {
|
| 293 |
"search_service": {"status": "healthy"},
|
| 294 |
"llm_service": {"status": "degraded", "warning": "High latency"},
|
| 295 |
-
"vector_db": {"status": "healthy"}
|
| 296 |
-
}
|
| 297 |
}
|
| 298 |
mock_health_check.return_value = mock_health_data
|
| 299 |
mock_llm_service.return_value = MagicMock()
|
|
@@ -314,18 +372,21 @@ class TestChatHealthEndpoint:
|
|
| 314 |
assert data["status"] == "error"
|
| 315 |
assert "LLM configuration error" in data["message"]
|
| 316 |
|
| 317 |
-
@patch.dict(os.environ, {
|
| 318 |
-
@patch(
|
| 319 |
-
@patch(
|
| 320 |
def test_chat_health_unhealthy(self, mock_health_check, mock_llm_service, client):
|
| 321 |
"""Test chat health endpoint when services are unhealthy"""
|
| 322 |
mock_health_data = {
|
| 323 |
"pipeline": "unhealthy",
|
| 324 |
"components": {
|
| 325 |
-
"search_service": {
|
|
|
|
|
|
|
|
|
|
| 326 |
"llm_service": {"status": "unhealthy", "error": "API unreachable"},
|
| 327 |
-
"vector_db": {"status": "unhealthy"}
|
| 328 |
-
}
|
| 329 |
}
|
| 330 |
mock_health_check.return_value = mock_health_data
|
| 331 |
mock_llm_service.return_value = MagicMock()
|
|
@@ -334,4 +395,4 @@ class TestChatHealthEndpoint:
|
|
| 334 |
|
| 335 |
assert response.status_code == 503
|
| 336 |
data = response.get_json()
|
| 337 |
-
assert data["status"] == "success" # Still returns success, but 503 status code
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
|
| 5 |
import pytest
|
|
|
|
| 6 |
|
| 7 |
from app import app as flask_app
|
| 8 |
|
|
|
|
| 20 |
class TestChatEndpoint:
|
| 21 |
"""Test cases for the /chat endpoint"""
|
| 22 |
|
| 23 |
+
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
|
| 24 |
+
@patch("app.RAGPipeline")
|
| 25 |
+
@patch("app.ResponseFormatter")
|
| 26 |
+
@patch("app.LLMService")
|
| 27 |
+
@patch("app.SearchService")
|
| 28 |
+
@patch("app.VectorDatabase")
|
| 29 |
+
@patch("app.EmbeddingService")
|
| 30 |
+
def test_chat_endpoint_valid_request(
|
| 31 |
+
self,
|
| 32 |
+
mock_embedding,
|
| 33 |
+
mock_vector,
|
| 34 |
+
mock_search,
|
| 35 |
+
mock_llm,
|
| 36 |
+
mock_formatter,
|
| 37 |
+
mock_rag,
|
| 38 |
+
client,
|
| 39 |
+
):
|
| 40 |
"""Test chat endpoint with valid request"""
|
| 41 |
# Mock the RAG pipeline response
|
| 42 |
mock_response = {
|
| 43 |
+
"answer": (
|
| 44 |
+
"Based on the remote work policy, employees can work "
|
| 45 |
+
"remotely up to 3 days per week."
|
| 46 |
+
),
|
| 47 |
+
"confidence": 0.85,
|
| 48 |
+
"sources": [
|
| 49 |
+
{"chunk_id": "123", "content": "Remote work policy content..."}
|
| 50 |
+
],
|
| 51 |
+
"citations": ["remote_work_policy.md"],
|
| 52 |
+
"processing_time_ms": 1500,
|
| 53 |
}
|
| 54 |
+
|
| 55 |
# Setup mock instances
|
| 56 |
mock_rag_instance = MagicMock()
|
| 57 |
mock_rag_instance.generate_answer.return_value = mock_response
|
| 58 |
mock_rag.return_value = mock_rag_instance
|
| 59 |
+
|
| 60 |
mock_formatter_instance = MagicMock()
|
| 61 |
mock_formatter_instance.format_api_response.return_value = {
|
| 62 |
"status": "success",
|
| 63 |
+
"answer": mock_response["answer"],
|
| 64 |
+
"confidence": mock_response["confidence"],
|
| 65 |
+
"sources": mock_response["sources"],
|
| 66 |
+
"citations": mock_response["citations"],
|
| 67 |
}
|
| 68 |
mock_formatter.return_value = mock_formatter_instance
|
| 69 |
+
|
| 70 |
# Mock LLMService.from_environment to return a mock instance
|
| 71 |
mock_llm_instance = MagicMock()
|
| 72 |
mock_llm.from_environment.return_value = mock_llm_instance
|
| 73 |
|
| 74 |
request_data = {
|
| 75 |
"message": "What is the remote work policy?",
|
| 76 |
+
"include_sources": True,
|
| 77 |
}
|
| 78 |
|
| 79 |
response = client.post(
|
| 80 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 81 |
)
|
| 82 |
|
| 83 |
assert response.status_code == 200
|
| 84 |
data = response.get_json()
|
| 85 |
+
|
| 86 |
assert data["status"] == "success"
|
| 87 |
assert "answer" in data
|
| 88 |
assert "confidence" in data
|
| 89 |
assert "sources" in data
|
| 90 |
assert "citations" in data
|
| 91 |
|
| 92 |
+
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
|
| 93 |
+
@patch("app.RAGPipeline")
|
| 94 |
+
@patch("app.ResponseFormatter")
|
| 95 |
+
@patch("app.LLMService")
|
| 96 |
+
@patch("app.SearchService")
|
| 97 |
+
@patch("app.VectorDatabase")
|
| 98 |
+
@patch("app.EmbeddingService")
|
| 99 |
+
def test_chat_endpoint_minimal_request(
|
| 100 |
+
self,
|
| 101 |
+
mock_embedding,
|
| 102 |
+
mock_vector,
|
| 103 |
+
mock_search,
|
| 104 |
+
mock_llm,
|
| 105 |
+
mock_formatter,
|
| 106 |
+
mock_rag,
|
| 107 |
+
client,
|
| 108 |
+
):
|
| 109 |
"""Test chat endpoint with minimal request (only message)"""
|
| 110 |
mock_response = {
|
| 111 |
+
"answer": (
|
| 112 |
+
"Employee benefits include health insurance, "
|
| 113 |
+
"retirement plans, and PTO."
|
| 114 |
+
),
|
| 115 |
+
"confidence": 0.78,
|
| 116 |
+
"sources": [],
|
| 117 |
+
"citations": ["employee_benefits_guide.md"],
|
| 118 |
+
"processing_time_ms": 1200,
|
| 119 |
}
|
| 120 |
+
|
| 121 |
# Setup mock instances
|
| 122 |
mock_rag_instance = MagicMock()
|
| 123 |
mock_rag_instance.generate_answer.return_value = mock_response
|
| 124 |
mock_rag.return_value = mock_rag_instance
|
| 125 |
+
|
| 126 |
mock_formatter_instance = MagicMock()
|
| 127 |
mock_formatter_instance.format_api_response.return_value = {
|
| 128 |
"status": "success",
|
| 129 |
+
"answer": mock_response["answer"],
|
| 130 |
}
|
| 131 |
mock_formatter.return_value = mock_formatter_instance
|
| 132 |
+
|
| 133 |
mock_llm.from_environment.return_value = MagicMock()
|
| 134 |
|
| 135 |
request_data = {"message": "What are the employee benefits?"}
|
| 136 |
|
| 137 |
response = client.post(
|
| 138 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
assert response.status_code == 200
|
|
|
|
| 147 |
request_data = {"include_sources": True}
|
| 148 |
|
| 149 |
response = client.post(
|
| 150 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
assert response.status_code == 400
|
|
|
|
| 160 |
request_data = {"message": ""}
|
| 161 |
|
| 162 |
response = client.post(
|
| 163 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
assert response.status_code == 400
|
|
|
|
| 173 |
request_data = {"message": 123}
|
| 174 |
|
| 175 |
response = client.post(
|
| 176 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
|
| 179 |
assert response.status_code == 400
|
|
|
|
| 196 |
request_data = {"message": "What is the policy?"}
|
| 197 |
|
| 198 |
response = client.post(
|
| 199 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 200 |
)
|
| 201 |
|
| 202 |
assert response.status_code == 503
|
|
|
|
| 204 |
assert data["status"] == "error"
|
| 205 |
assert "LLM service configuration error" in data["message"]
|
| 206 |
|
| 207 |
+
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
|
| 208 |
+
@patch("app.RAGPipeline")
|
| 209 |
+
@patch("app.ResponseFormatter")
|
| 210 |
+
@patch("app.LLMService")
|
| 211 |
+
@patch("app.SearchService")
|
| 212 |
+
@patch("app.VectorDatabase")
|
| 213 |
+
@patch("app.EmbeddingService")
|
| 214 |
+
def test_chat_endpoint_with_conversation_id(
|
| 215 |
+
self,
|
| 216 |
+
mock_embedding,
|
| 217 |
+
mock_vector,
|
| 218 |
+
mock_search,
|
| 219 |
+
mock_llm,
|
| 220 |
+
mock_formatter,
|
| 221 |
+
mock_rag,
|
| 222 |
+
client,
|
| 223 |
+
):
|
| 224 |
"""Test chat endpoint with conversation_id parameter"""
|
| 225 |
mock_response = {
|
| 226 |
+
"answer": "The PTO policy allows 15 days of vacation annually.",
|
| 227 |
+
"confidence": 0.9,
|
| 228 |
+
"sources": [],
|
| 229 |
+
"citations": ["pto_policy.md"],
|
| 230 |
+
"processing_time_ms": 1100,
|
| 231 |
}
|
| 232 |
+
|
| 233 |
+
# Setup mock instances
|
| 234 |
+
mock_rag_instance = MagicMock()
|
| 235 |
+
mock_rag_instance.generate_answer.return_value = mock_response
|
| 236 |
+
mock_rag.return_value = mock_rag_instance
|
| 237 |
+
|
| 238 |
+
mock_formatter_instance = MagicMock()
|
| 239 |
+
mock_formatter_instance.format_chat_response.return_value = {
|
| 240 |
+
"status": "success",
|
| 241 |
+
"answer": mock_response["answer"],
|
| 242 |
+
}
|
| 243 |
+
mock_formatter.return_value = mock_formatter_instance
|
| 244 |
+
|
| 245 |
+
mock_llm.from_environment.return_value = MagicMock()
|
| 246 |
|
| 247 |
request_data = {
|
| 248 |
"message": "What is the PTO policy?",
|
| 249 |
"conversation_id": "conv_123",
|
| 250 |
+
"include_sources": False,
|
| 251 |
}
|
| 252 |
|
| 253 |
response = client.post(
|
| 254 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 255 |
)
|
| 256 |
|
| 257 |
assert response.status_code == 200
|
| 258 |
data = response.get_json()
|
| 259 |
assert data["status"] == "success"
|
| 260 |
|
| 261 |
+
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
|
| 262 |
+
@patch("app.RAGPipeline")
|
| 263 |
+
@patch("app.ResponseFormatter")
|
| 264 |
+
@patch("app.LLMService")
|
| 265 |
+
@patch("app.SearchService")
|
| 266 |
+
@patch("app.VectorDatabase")
|
| 267 |
+
@patch("app.EmbeddingService")
|
| 268 |
+
def test_chat_endpoint_with_debug(
|
| 269 |
+
self,
|
| 270 |
+
mock_embedding,
|
| 271 |
+
mock_vector,
|
| 272 |
+
mock_search,
|
| 273 |
+
mock_llm,
|
| 274 |
+
mock_formatter,
|
| 275 |
+
mock_rag,
|
| 276 |
+
client,
|
| 277 |
+
):
|
| 278 |
"""Test chat endpoint with debug information"""
|
| 279 |
mock_response = {
|
| 280 |
+
"answer": "The security policy requires 2FA authentication.",
|
| 281 |
+
"confidence": 0.95,
|
| 282 |
+
"sources": [{"chunk_id": "456", "content": "Security requirements..."}],
|
| 283 |
+
"citations": ["information_security_policy.md"],
|
| 284 |
+
"processing_time_ms": 1800,
|
| 285 |
+
"search_results_count": 5,
|
| 286 |
+
"context_length": 2048,
|
| 287 |
}
|
| 288 |
+
|
| 289 |
+
# Setup mock instances
|
| 290 |
+
mock_rag_instance = MagicMock()
|
| 291 |
+
mock_rag_instance.generate_answer.return_value = mock_response
|
| 292 |
+
mock_rag.return_value = mock_rag_instance
|
| 293 |
+
|
| 294 |
+
mock_formatter_instance = MagicMock()
|
| 295 |
+
mock_formatter_instance.format_api_response.return_value = {
|
| 296 |
+
"status": "success",
|
| 297 |
+
"answer": mock_response["answer"],
|
| 298 |
+
"debug": {"processing_time": 1800},
|
| 299 |
+
}
|
| 300 |
+
mock_formatter.return_value = mock_formatter_instance
|
| 301 |
+
|
| 302 |
+
mock_llm.from_environment.return_value = MagicMock()
|
| 303 |
|
| 304 |
request_data = {
|
| 305 |
"message": "What are the security requirements?",
|
| 306 |
+
"include_debug": True,
|
| 307 |
}
|
| 308 |
|
| 309 |
response = client.post(
|
| 310 |
+
"/chat", data=json.dumps(request_data), content_type="application/json"
|
|
|
|
|
|
|
| 311 |
)
|
| 312 |
|
| 313 |
assert response.status_code == 200
|
|
|
|
| 318 |
class TestChatHealthEndpoint:
|
| 319 |
"""Test cases for the /chat/health endpoint"""
|
| 320 |
|
| 321 |
+
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
|
| 322 |
+
@patch("src.llm.llm_service.LLMService.from_environment")
|
| 323 |
+
@patch("src.rag.rag_pipeline.RAGPipeline.health_check")
|
| 324 |
def test_chat_health_healthy(self, mock_health_check, mock_llm_service, client):
|
| 325 |
"""Test chat health endpoint when all services are healthy"""
|
| 326 |
mock_health_data = {
|
|
|
|
| 328 |
"components": {
|
| 329 |
"search_service": {"status": "healthy"},
|
| 330 |
"llm_service": {"status": "healthy"},
|
| 331 |
+
"vector_db": {"status": "healthy"},
|
| 332 |
+
},
|
| 333 |
}
|
| 334 |
mock_health_check.return_value = mock_health_data
|
| 335 |
mock_llm_service.return_value = MagicMock()
|
|
|
|
| 340 |
data = response.get_json()
|
| 341 |
assert data["status"] == "success"
|
| 342 |
|
| 343 |
+
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
|
| 344 |
+
@patch("src.llm.llm_service.LLMService.from_environment")
|
| 345 |
+
@patch("src.rag.rag_pipeline.RAGPipeline.health_check")
|
| 346 |
def test_chat_health_degraded(self, mock_health_check, mock_llm_service, client):
|
| 347 |
"""Test chat health endpoint when services are degraded"""
|
| 348 |
mock_health_data = {
|
|
|
|
| 350 |
"components": {
|
| 351 |
"search_service": {"status": "healthy"},
|
| 352 |
"llm_service": {"status": "degraded", "warning": "High latency"},
|
| 353 |
+
"vector_db": {"status": "healthy"},
|
| 354 |
+
},
|
| 355 |
}
|
| 356 |
mock_health_check.return_value = mock_health_data
|
| 357 |
mock_llm_service.return_value = MagicMock()
|
|
|
|
| 372 |
assert data["status"] == "error"
|
| 373 |
assert "LLM configuration error" in data["message"]
|
| 374 |
|
| 375 |
+
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
|
| 376 |
+
@patch("src.llm.llm_service.LLMService.from_environment")
|
| 377 |
+
@patch("src.rag.rag_pipeline.RAGPipeline.health_check")
|
| 378 |
def test_chat_health_unhealthy(self, mock_health_check, mock_llm_service, client):
|
| 379 |
"""Test chat health endpoint when services are unhealthy"""
|
| 380 |
mock_health_data = {
|
| 381 |
"pipeline": "unhealthy",
|
| 382 |
"components": {
|
| 383 |
+
"search_service": {
|
| 384 |
+
"status": "unhealthy",
|
| 385 |
+
"error": "Database connection failed",
|
| 386 |
+
},
|
| 387 |
"llm_service": {"status": "unhealthy", "error": "API unreachable"},
|
| 388 |
+
"vector_db": {"status": "unhealthy"},
|
| 389 |
+
},
|
| 390 |
}
|
| 391 |
mock_health_check.return_value = mock_health_data
|
| 392 |
mock_llm_service.return_value = MagicMock()
|
|
|
|
| 395 |
|
| 396 |
assert response.status_code == 503
|
| 397 |
data = response.get_json()
|
| 398 |
+
assert data["status"] == "success" # Still returns success, but 503 status code
|
tests/test_llm/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# LLM Service Tests
|
|
|
|
| 1 |
+
# LLM Service Tests
|
tests/test_llm/test_llm_service.py
CHANGED
|
@@ -4,10 +4,12 @@ Test LLM Service
|
|
| 4 |
Tests for LLM integration and service functionality.
|
| 5 |
"""
|
| 6 |
|
|
|
|
|
|
|
| 7 |
import pytest
|
| 8 |
-
from unittest.mock import Mock, patch, MagicMock
|
| 9 |
import requests
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class TestLLMConfig:
|
|
@@ -19,9 +21,9 @@ class TestLLMConfig:
|
|
| 19 |
provider="openrouter",
|
| 20 |
api_key="test-key",
|
| 21 |
model_name="test-model",
|
| 22 |
-
base_url="https://test.com"
|
| 23 |
)
|
| 24 |
-
|
| 25 |
assert config.provider == "openrouter"
|
| 26 |
assert config.api_key == "test-key"
|
| 27 |
assert config.model_name == "test-model"
|
|
@@ -41,9 +43,9 @@ class TestLLMResponse:
|
|
| 41 |
model="test-model",
|
| 42 |
usage={"tokens": 100},
|
| 43 |
response_time=1.5,
|
| 44 |
-
success=True
|
| 45 |
)
|
| 46 |
-
|
| 47 |
assert response.content == "Test response"
|
| 48 |
assert response.provider == "openrouter"
|
| 49 |
assert response.model == "test-model"
|
|
@@ -62,84 +64,83 @@ class TestLLMService:
|
|
| 62 |
provider="openrouter",
|
| 63 |
api_key="test-key",
|
| 64 |
model_name="test-model",
|
| 65 |
-
base_url="https://test.com"
|
| 66 |
)
|
| 67 |
-
|
| 68 |
service = LLMService([config])
|
| 69 |
-
|
| 70 |
assert len(service.configs) == 1
|
| 71 |
assert service.configs[0] == config
|
| 72 |
assert service.current_config_index == 0
|
| 73 |
|
| 74 |
def test_initialization_empty_configs_raises_error(self):
|
| 75 |
"""Test that empty configs raise ValueError."""
|
| 76 |
-
with pytest.raises(
|
|
|
|
|
|
|
| 77 |
LLMService([])
|
| 78 |
|
| 79 |
-
@patch.dict(
|
| 80 |
def test_from_environment_with_openrouter_key(self):
|
| 81 |
"""Test creating service from environment with OpenRouter key."""
|
| 82 |
service = LLMService.from_environment()
|
| 83 |
-
|
| 84 |
assert len(service.configs) >= 1
|
| 85 |
openrouter_config = next(
|
| 86 |
(config for config in service.configs if config.provider == "openrouter"),
|
| 87 |
-
None
|
| 88 |
)
|
| 89 |
assert openrouter_config is not None
|
| 90 |
assert openrouter_config.api_key == "test-openrouter-key"
|
| 91 |
|
| 92 |
-
@patch.dict(
|
| 93 |
def test_from_environment_with_groq_key(self):
|
| 94 |
"""Test creating service from environment with Groq key."""
|
| 95 |
service = LLMService.from_environment()
|
| 96 |
-
|
| 97 |
assert len(service.configs) >= 1
|
| 98 |
groq_config = next(
|
| 99 |
-
(config for config in service.configs if config.provider == "groq"),
|
| 100 |
-
None
|
| 101 |
)
|
| 102 |
assert groq_config is not None
|
| 103 |
assert groq_config.api_key == "test-groq-key"
|
| 104 |
|
| 105 |
-
@patch.dict(
|
| 106 |
def test_from_environment_no_keys_raises_error(self):
|
| 107 |
"""Test that no environment keys raise ValueError."""
|
| 108 |
with pytest.raises(ValueError, match="No LLM API keys found in environment"):
|
| 109 |
LLMService.from_environment()
|
| 110 |
|
| 111 |
-
@patch(
|
| 112 |
def test_successful_response_generation(self, mock_post):
|
| 113 |
"""Test successful response generation."""
|
| 114 |
# Mock successful API response
|
| 115 |
mock_response = Mock()
|
| 116 |
mock_response.status_code = 200
|
| 117 |
mock_response.json.return_value = {
|
| 118 |
-
"choices": [
|
| 119 |
-
|
| 120 |
-
],
|
| 121 |
-
"usage": {"prompt_tokens": 50, "completion_tokens": 20}
|
| 122 |
}
|
| 123 |
mock_response.raise_for_status = Mock()
|
| 124 |
mock_post.return_value = mock_response
|
| 125 |
-
|
| 126 |
config = LLMConfig(
|
| 127 |
provider="openrouter",
|
| 128 |
api_key="test-key",
|
| 129 |
model_name="test-model",
|
| 130 |
-
base_url="https://api.openrouter.ai/api/v1"
|
| 131 |
)
|
| 132 |
service = LLMService([config])
|
| 133 |
-
|
| 134 |
result = service.generate_response("Test prompt")
|
| 135 |
-
|
| 136 |
assert result.success is True
|
| 137 |
assert result.content == "Test response content"
|
| 138 |
assert result.provider == "openrouter"
|
| 139 |
assert result.model == "test-model"
|
| 140 |
assert result.usage == {"prompt_tokens": 50, "completion_tokens": 20}
|
| 141 |
assert result.response_time > 0
|
| 142 |
-
|
| 143 |
# Verify API call
|
| 144 |
mock_post.assert_called_once()
|
| 145 |
args, kwargs = mock_post.call_args
|
|
@@ -147,125 +148,139 @@ class TestLLMService:
|
|
| 147 |
assert kwargs["json"]["model"] == "test-model"
|
| 148 |
assert kwargs["json"]["messages"][0]["content"] == "Test prompt"
|
| 149 |
|
| 150 |
-
@patch(
|
| 151 |
def test_api_error_handling(self, mock_post):
|
| 152 |
"""Test handling of API errors."""
|
| 153 |
# Mock API error
|
| 154 |
mock_post.side_effect = requests.exceptions.RequestException("API Error")
|
| 155 |
-
|
| 156 |
config = LLMConfig(
|
| 157 |
provider="openrouter",
|
| 158 |
api_key="test-key",
|
| 159 |
model_name="test-model",
|
| 160 |
-
base_url="https://api.openrouter.ai/api/v1"
|
| 161 |
)
|
| 162 |
service = LLMService([config])
|
| 163 |
-
|
| 164 |
result = service.generate_response("Test prompt")
|
| 165 |
-
|
| 166 |
assert result.success is False
|
| 167 |
assert "API Error" in result.error_message
|
| 168 |
assert result.content == ""
|
| 169 |
assert result.provider == "openrouter"
|
| 170 |
|
| 171 |
-
@patch(
|
| 172 |
def test_fallback_to_second_provider(self, mock_post):
|
| 173 |
"""Test fallback to second provider when first fails."""
|
| 174 |
# Mock first provider failing, second succeeding
|
| 175 |
first_call = Mock()
|
| 176 |
-
first_call.side_effect = requests.exceptions.RequestException(
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
second_call = Mock()
|
| 179 |
second_response = Mock()
|
| 180 |
second_response.status_code = 200
|
| 181 |
second_response.json.return_value = {
|
| 182 |
"choices": [{"message": {"content": "Second provider response"}}],
|
| 183 |
-
"usage": {}
|
| 184 |
}
|
| 185 |
second_response.raise_for_status = Mock()
|
| 186 |
second_call.return_value = second_response
|
| 187 |
-
|
| 188 |
mock_post.side_effect = [first_call.side_effect, second_response]
|
| 189 |
-
|
| 190 |
config1 = LLMConfig(
|
| 191 |
provider="openrouter",
|
| 192 |
api_key="key1",
|
| 193 |
model_name="model1",
|
| 194 |
-
base_url="https://api1.com"
|
| 195 |
)
|
| 196 |
config2 = LLMConfig(
|
| 197 |
provider="groq",
|
| 198 |
api_key="key2",
|
| 199 |
model_name="model2",
|
| 200 |
-
base_url="https://api2.com"
|
| 201 |
)
|
| 202 |
-
|
| 203 |
service = LLMService([config1, config2])
|
| 204 |
result = service.generate_response("Test prompt")
|
| 205 |
-
|
| 206 |
assert result.success is True
|
| 207 |
assert result.content == "Second provider response"
|
| 208 |
assert result.provider == "groq"
|
| 209 |
assert mock_post.call_count == 2
|
| 210 |
|
| 211 |
-
@patch(
|
| 212 |
def test_all_providers_fail(self, mock_post):
|
| 213 |
"""Test when all providers fail."""
|
| 214 |
-
mock_post.side_effect = requests.exceptions.RequestException(
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
service = LLMService([config1, config2])
|
| 220 |
result = service.generate_response("Test prompt")
|
| 221 |
-
|
| 222 |
assert result.success is False
|
| 223 |
assert "All providers failed" in result.error_message
|
| 224 |
assert result.provider == "none"
|
| 225 |
assert result.model == "none"
|
| 226 |
|
| 227 |
-
@patch(
|
| 228 |
def test_retry_logic(self, mock_post):
|
| 229 |
"""Test retry logic for failed requests."""
|
| 230 |
# First call fails, second succeeds
|
| 231 |
first_response = Mock()
|
| 232 |
-
first_response.side_effect = requests.exceptions.RequestException(
|
| 233 |
-
|
|
|
|
|
|
|
| 234 |
second_response = Mock()
|
| 235 |
second_response.status_code = 200
|
| 236 |
second_response.json.return_value = {
|
| 237 |
"choices": [{"message": {"content": "Success after retry"}}],
|
| 238 |
-
"usage": {}
|
| 239 |
}
|
| 240 |
second_response.raise_for_status = Mock()
|
| 241 |
-
|
| 242 |
mock_post.side_effect = [first_response.side_effect, second_response]
|
| 243 |
-
|
| 244 |
config = LLMConfig(
|
| 245 |
provider="openrouter",
|
| 246 |
api_key="test-key",
|
| 247 |
model_name="test-model",
|
| 248 |
-
base_url="https://api.openrouter.ai/api/v1"
|
| 249 |
)
|
| 250 |
service = LLMService([config])
|
| 251 |
-
|
| 252 |
result = service.generate_response("Test prompt", max_retries=1)
|
| 253 |
-
|
| 254 |
assert result.success is True
|
| 255 |
assert result.content == "Success after retry"
|
| 256 |
assert mock_post.call_count == 2
|
| 257 |
|
| 258 |
def test_get_available_providers(self):
|
| 259 |
"""Test getting list of available providers."""
|
| 260 |
-
config1 = LLMConfig(
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
service = LLMService([config1, config2])
|
| 264 |
providers = service.get_available_providers()
|
| 265 |
-
|
| 266 |
assert providers == ["openrouter", "groq"]
|
| 267 |
|
| 268 |
-
@patch(
|
| 269 |
def test_health_check(self, mock_post):
|
| 270 |
"""Test health check functionality."""
|
| 271 |
# Mock successful health check
|
|
@@ -273,51 +288,54 @@ class TestLLMService:
|
|
| 273 |
mock_response.status_code = 200
|
| 274 |
mock_response.json.return_value = {
|
| 275 |
"choices": [{"message": {"content": "OK"}}],
|
| 276 |
-
"usage": {}
|
| 277 |
}
|
| 278 |
mock_response.raise_for_status = Mock()
|
| 279 |
mock_post.return_value = mock_response
|
| 280 |
-
|
| 281 |
config = LLMConfig(
|
| 282 |
provider="openrouter",
|
| 283 |
api_key="test-key",
|
| 284 |
model_name="test-model",
|
| 285 |
-
base_url="https://api.openrouter.ai/api/v1"
|
| 286 |
)
|
| 287 |
service = LLMService([config])
|
| 288 |
-
|
| 289 |
health_status = service.health_check()
|
| 290 |
-
|
| 291 |
assert "openrouter" in health_status
|
| 292 |
assert health_status["openrouter"]["status"] == "healthy"
|
| 293 |
assert health_status["openrouter"]["model"] == "test-model"
|
| 294 |
assert health_status["openrouter"]["response_time"] > 0
|
| 295 |
|
| 296 |
-
@patch(
|
| 297 |
def test_openrouter_specific_headers(self, mock_post):
|
| 298 |
"""Test that OpenRouter-specific headers are added."""
|
| 299 |
mock_response = Mock()
|
| 300 |
mock_response.status_code = 200
|
| 301 |
mock_response.json.return_value = {
|
| 302 |
"choices": [{"message": {"content": "Test"}}],
|
| 303 |
-
"usage": {}
|
| 304 |
}
|
| 305 |
mock_response.raise_for_status = Mock()
|
| 306 |
mock_post.return_value = mock_response
|
| 307 |
-
|
| 308 |
config = LLMConfig(
|
| 309 |
provider="openrouter",
|
| 310 |
api_key="test-key",
|
| 311 |
model_name="test-model",
|
| 312 |
-
base_url="https://api.openrouter.ai/api/v1"
|
| 313 |
)
|
| 314 |
service = LLMService([config])
|
| 315 |
-
|
| 316 |
service.generate_response("Test")
|
| 317 |
-
|
| 318 |
# Check headers
|
| 319 |
args, kwargs = mock_post.call_args
|
| 320 |
headers = kwargs["headers"]
|
| 321 |
assert "HTTP-Referer" in headers
|
| 322 |
assert "X-Title" in headers
|
| 323 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
Tests for LLM integration and service functionality.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
from unittest.mock import Mock, patch
|
| 8 |
+
|
| 9 |
import pytest
|
|
|
|
| 10 |
import requests
|
| 11 |
+
|
| 12 |
+
from src.llm.llm_service import LLMConfig, LLMResponse, LLMService
|
| 13 |
|
| 14 |
|
| 15 |
class TestLLMConfig:
|
|
|
|
| 21 |
provider="openrouter",
|
| 22 |
api_key="test-key",
|
| 23 |
model_name="test-model",
|
| 24 |
+
base_url="https://test.com",
|
| 25 |
)
|
| 26 |
+
|
| 27 |
assert config.provider == "openrouter"
|
| 28 |
assert config.api_key == "test-key"
|
| 29 |
assert config.model_name == "test-model"
|
|
|
|
| 43 |
model="test-model",
|
| 44 |
usage={"tokens": 100},
|
| 45 |
response_time=1.5,
|
| 46 |
+
success=True,
|
| 47 |
)
|
| 48 |
+
|
| 49 |
assert response.content == "Test response"
|
| 50 |
assert response.provider == "openrouter"
|
| 51 |
assert response.model == "test-model"
|
|
|
|
| 64 |
provider="openrouter",
|
| 65 |
api_key="test-key",
|
| 66 |
model_name="test-model",
|
| 67 |
+
base_url="https://test.com",
|
| 68 |
)
|
| 69 |
+
|
| 70 |
service = LLMService([config])
|
| 71 |
+
|
| 72 |
assert len(service.configs) == 1
|
| 73 |
assert service.configs[0] == config
|
| 74 |
assert service.current_config_index == 0
|
| 75 |
|
| 76 |
def test_initialization_empty_configs_raises_error(self):
|
| 77 |
"""Test that empty configs raise ValueError."""
|
| 78 |
+
with pytest.raises(
|
| 79 |
+
ValueError, match="At least one LLM configuration must be provided"
|
| 80 |
+
):
|
| 81 |
LLMService([])
|
| 82 |
|
| 83 |
+
@patch.dict("os.environ", {"OPENROUTER_API_KEY": "test-openrouter-key"})
|
| 84 |
def test_from_environment_with_openrouter_key(self):
|
| 85 |
"""Test creating service from environment with OpenRouter key."""
|
| 86 |
service = LLMService.from_environment()
|
| 87 |
+
|
| 88 |
assert len(service.configs) >= 1
|
| 89 |
openrouter_config = next(
|
| 90 |
(config for config in service.configs if config.provider == "openrouter"),
|
| 91 |
+
None,
|
| 92 |
)
|
| 93 |
assert openrouter_config is not None
|
| 94 |
assert openrouter_config.api_key == "test-openrouter-key"
|
| 95 |
|
| 96 |
+
@patch.dict("os.environ", {"GROQ_API_KEY": "test-groq-key"})
|
| 97 |
def test_from_environment_with_groq_key(self):
|
| 98 |
"""Test creating service from environment with Groq key."""
|
| 99 |
service = LLMService.from_environment()
|
| 100 |
+
|
| 101 |
assert len(service.configs) >= 1
|
| 102 |
groq_config = next(
|
| 103 |
+
(config for config in service.configs if config.provider == "groq"), None
|
|
|
|
| 104 |
)
|
| 105 |
assert groq_config is not None
|
| 106 |
assert groq_config.api_key == "test-groq-key"
|
| 107 |
|
| 108 |
+
@patch.dict("os.environ", {}, clear=True)
|
| 109 |
def test_from_environment_no_keys_raises_error(self):
|
| 110 |
"""Test that no environment keys raise ValueError."""
|
| 111 |
with pytest.raises(ValueError, match="No LLM API keys found in environment"):
|
| 112 |
LLMService.from_environment()
|
| 113 |
|
| 114 |
+
@patch("requests.post")
|
| 115 |
def test_successful_response_generation(self, mock_post):
|
| 116 |
"""Test successful response generation."""
|
| 117 |
# Mock successful API response
|
| 118 |
mock_response = Mock()
|
| 119 |
mock_response.status_code = 200
|
| 120 |
mock_response.json.return_value = {
|
| 121 |
+
"choices": [{"message": {"content": "Test response content"}}],
|
| 122 |
+
"usage": {"prompt_tokens": 50, "completion_tokens": 20},
|
|
|
|
|
|
|
| 123 |
}
|
| 124 |
mock_response.raise_for_status = Mock()
|
| 125 |
mock_post.return_value = mock_response
|
| 126 |
+
|
| 127 |
config = LLMConfig(
|
| 128 |
provider="openrouter",
|
| 129 |
api_key="test-key",
|
| 130 |
model_name="test-model",
|
| 131 |
+
base_url="https://api.openrouter.ai/api/v1",
|
| 132 |
)
|
| 133 |
service = LLMService([config])
|
| 134 |
+
|
| 135 |
result = service.generate_response("Test prompt")
|
| 136 |
+
|
| 137 |
assert result.success is True
|
| 138 |
assert result.content == "Test response content"
|
| 139 |
assert result.provider == "openrouter"
|
| 140 |
assert result.model == "test-model"
|
| 141 |
assert result.usage == {"prompt_tokens": 50, "completion_tokens": 20}
|
| 142 |
assert result.response_time > 0
|
| 143 |
+
|
| 144 |
# Verify API call
|
| 145 |
mock_post.assert_called_once()
|
| 146 |
args, kwargs = mock_post.call_args
|
|
|
|
| 148 |
assert kwargs["json"]["model"] == "test-model"
|
| 149 |
assert kwargs["json"]["messages"][0]["content"] == "Test prompt"
|
| 150 |
|
| 151 |
+
@patch("requests.post")
|
| 152 |
def test_api_error_handling(self, mock_post):
|
| 153 |
"""Test handling of API errors."""
|
| 154 |
# Mock API error
|
| 155 |
mock_post.side_effect = requests.exceptions.RequestException("API Error")
|
| 156 |
+
|
| 157 |
config = LLMConfig(
|
| 158 |
provider="openrouter",
|
| 159 |
api_key="test-key",
|
| 160 |
model_name="test-model",
|
| 161 |
+
base_url="https://api.openrouter.ai/api/v1",
|
| 162 |
)
|
| 163 |
service = LLMService([config])
|
| 164 |
+
|
| 165 |
result = service.generate_response("Test prompt")
|
| 166 |
+
|
| 167 |
assert result.success is False
|
| 168 |
assert "API Error" in result.error_message
|
| 169 |
assert result.content == ""
|
| 170 |
assert result.provider == "openrouter"
|
| 171 |
|
| 172 |
+
@patch("requests.post")
|
| 173 |
def test_fallback_to_second_provider(self, mock_post):
|
| 174 |
"""Test fallback to second provider when first fails."""
|
| 175 |
# Mock first provider failing, second succeeding
|
| 176 |
first_call = Mock()
|
| 177 |
+
first_call.side_effect = requests.exceptions.RequestException(
|
| 178 |
+
"First provider error"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
second_call = Mock()
|
| 182 |
second_response = Mock()
|
| 183 |
second_response.status_code = 200
|
| 184 |
second_response.json.return_value = {
|
| 185 |
"choices": [{"message": {"content": "Second provider response"}}],
|
| 186 |
+
"usage": {},
|
| 187 |
}
|
| 188 |
second_response.raise_for_status = Mock()
|
| 189 |
second_call.return_value = second_response
|
| 190 |
+
|
| 191 |
mock_post.side_effect = [first_call.side_effect, second_response]
|
| 192 |
+
|
| 193 |
config1 = LLMConfig(
|
| 194 |
provider="openrouter",
|
| 195 |
api_key="key1",
|
| 196 |
model_name="model1",
|
| 197 |
+
base_url="https://api1.com",
|
| 198 |
)
|
| 199 |
config2 = LLMConfig(
|
| 200 |
provider="groq",
|
| 201 |
api_key="key2",
|
| 202 |
model_name="model2",
|
| 203 |
+
base_url="https://api2.com",
|
| 204 |
)
|
| 205 |
+
|
| 206 |
service = LLMService([config1, config2])
|
| 207 |
result = service.generate_response("Test prompt")
|
| 208 |
+
|
| 209 |
assert result.success is True
|
| 210 |
assert result.content == "Second provider response"
|
| 211 |
assert result.provider == "groq"
|
| 212 |
assert mock_post.call_count == 2
|
| 213 |
|
| 214 |
+
@patch("requests.post")
|
| 215 |
def test_all_providers_fail(self, mock_post):
|
| 216 |
"""Test when all providers fail."""
|
| 217 |
+
mock_post.side_effect = requests.exceptions.RequestException(
|
| 218 |
+
"All providers down"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
config1 = LLMConfig(
|
| 222 |
+
provider="provider1", api_key="key1", model_name="model1", base_url="url1"
|
| 223 |
+
)
|
| 224 |
+
config2 = LLMConfig(
|
| 225 |
+
provider="provider2", api_key="key2", model_name="model2", base_url="url2"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
service = LLMService([config1, config2])
|
| 229 |
result = service.generate_response("Test prompt")
|
| 230 |
+
|
| 231 |
assert result.success is False
|
| 232 |
assert "All providers failed" in result.error_message
|
| 233 |
assert result.provider == "none"
|
| 234 |
assert result.model == "none"
|
| 235 |
|
| 236 |
+
@patch("requests.post")
|
| 237 |
def test_retry_logic(self, mock_post):
|
| 238 |
"""Test retry logic for failed requests."""
|
| 239 |
# First call fails, second succeeds
|
| 240 |
first_response = Mock()
|
| 241 |
+
first_response.side_effect = requests.exceptions.RequestException(
|
| 242 |
+
"Temporary error"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
second_response = Mock()
|
| 246 |
second_response.status_code = 200
|
| 247 |
second_response.json.return_value = {
|
| 248 |
"choices": [{"message": {"content": "Success after retry"}}],
|
| 249 |
+
"usage": {},
|
| 250 |
}
|
| 251 |
second_response.raise_for_status = Mock()
|
| 252 |
+
|
| 253 |
mock_post.side_effect = [first_response.side_effect, second_response]
|
| 254 |
+
|
| 255 |
config = LLMConfig(
|
| 256 |
provider="openrouter",
|
| 257 |
api_key="test-key",
|
| 258 |
model_name="test-model",
|
| 259 |
+
base_url="https://api.openrouter.ai/api/v1",
|
| 260 |
)
|
| 261 |
service = LLMService([config])
|
| 262 |
+
|
| 263 |
result = service.generate_response("Test prompt", max_retries=1)
|
| 264 |
+
|
| 265 |
assert result.success is True
|
| 266 |
assert result.content == "Success after retry"
|
| 267 |
assert mock_post.call_count == 2
|
| 268 |
|
| 269 |
def test_get_available_providers(self):
|
| 270 |
"""Test getting list of available providers."""
|
| 271 |
+
config1 = LLMConfig(
|
| 272 |
+
provider="openrouter", api_key="key1", model_name="model1", base_url="url1"
|
| 273 |
+
)
|
| 274 |
+
config2 = LLMConfig(
|
| 275 |
+
provider="groq", api_key="key2", model_name="model2", base_url="url2"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
service = LLMService([config1, config2])
|
| 279 |
providers = service.get_available_providers()
|
| 280 |
+
|
| 281 |
assert providers == ["openrouter", "groq"]
|
| 282 |
|
| 283 |
+
@patch("requests.post")
|
| 284 |
def test_health_check(self, mock_post):
|
| 285 |
"""Test health check functionality."""
|
| 286 |
# Mock successful health check
|
|
|
|
| 288 |
mock_response.status_code = 200
|
| 289 |
mock_response.json.return_value = {
|
| 290 |
"choices": [{"message": {"content": "OK"}}],
|
| 291 |
+
"usage": {},
|
| 292 |
}
|
| 293 |
mock_response.raise_for_status = Mock()
|
| 294 |
mock_post.return_value = mock_response
|
| 295 |
+
|
| 296 |
config = LLMConfig(
|
| 297 |
provider="openrouter",
|
| 298 |
api_key="test-key",
|
| 299 |
model_name="test-model",
|
| 300 |
+
base_url="https://api.openrouter.ai/api/v1",
|
| 301 |
)
|
| 302 |
service = LLMService([config])
|
| 303 |
+
|
| 304 |
health_status = service.health_check()
|
| 305 |
+
|
| 306 |
assert "openrouter" in health_status
|
| 307 |
assert health_status["openrouter"]["status"] == "healthy"
|
| 308 |
assert health_status["openrouter"]["model"] == "test-model"
|
| 309 |
assert health_status["openrouter"]["response_time"] > 0
|
| 310 |
|
| 311 |
+
@patch("requests.post")
|
| 312 |
def test_openrouter_specific_headers(self, mock_post):
|
| 313 |
"""Test that OpenRouter-specific headers are added."""
|
| 314 |
mock_response = Mock()
|
| 315 |
mock_response.status_code = 200
|
| 316 |
mock_response.json.return_value = {
|
| 317 |
"choices": [{"message": {"content": "Test"}}],
|
| 318 |
+
"usage": {},
|
| 319 |
}
|
| 320 |
mock_response.raise_for_status = Mock()
|
| 321 |
mock_post.return_value = mock_response
|
| 322 |
+
|
| 323 |
config = LLMConfig(
|
| 324 |
provider="openrouter",
|
| 325 |
api_key="test-key",
|
| 326 |
model_name="test-model",
|
| 327 |
+
base_url="https://api.openrouter.ai/api/v1",
|
| 328 |
)
|
| 329 |
service = LLMService([config])
|
| 330 |
+
|
| 331 |
service.generate_response("Test")
|
| 332 |
+
|
| 333 |
# Check headers
|
| 334 |
args, kwargs = mock_post.call_args
|
| 335 |
headers = kwargs["headers"]
|
| 336 |
assert "HTTP-Referer" in headers
|
| 337 |
assert "X-Title" in headers
|
| 338 |
+
assert (
|
| 339 |
+
headers["HTTP-Referer"]
|
| 340 |
+
== "https://github.com/sethmcknight/msse-ai-engineering"
|
| 341 |
+
)
|
tests/test_rag/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# RAG Pipeline Tests
|
|
|
|
| 1 |
+
# RAG Pipeline Tests
|