""" RAG Pipeline - Core RAG Functionality This module orchestrates the complete RAG (Retrieval-Augmented Generation) pipeline, combining semantic search, context management, and LLM generation. """ import logging import time from dataclasses import dataclass from typing import Any, Dict, List, Optional from src.llm.context_manager import ContextConfig, ContextManager from src.llm.llm_service import LLMResponse, LLMService from src.llm.prompt_templates import PromptTemplates # Import our modules from src.search.search_service import SearchService logger = logging.getLogger(__name__) @dataclass class RAGConfig: """Configuration for RAG pipeline.""" max_context_length: int = 3000 search_top_k: int = 10 search_threshold: float = 0.0 # No threshold filtering at search level min_similarity_for_answer: float = 0.2 # Threshold for normalized distance similarity max_response_length: int = 1000 enable_citation_validation: bool = True @dataclass class RAGResponse: """Response from RAG pipeline with metadata.""" answer: str sources: List[Dict[str, Any]] confidence: float processing_time: float llm_provider: str llm_model: str context_length: int search_results_count: int success: bool error_message: Optional[str] = None class RAGPipeline: """ Complete RAG pipeline orchestrating retrieval and generation. Combines: - Semantic search for context retrieval - Context optimization and management - LLM-based response generation - Citation validation and formatting """ def __init__( self, search_service: SearchService, llm_service: LLMService, config: Optional[RAGConfig] = None, ): """ Initialize RAG pipeline with required services. Args: search_service: Configured SearchService instance llm_service: Configured LLMService instance config: RAG configuration, uses defaults if None """ self.search_service = search_service self.llm_service = llm_service self.config = config or RAGConfig() # Initialize context manager with matching config context_config = ContextConfig( max_context_length=self.config.max_context_length, max_results=self.config.search_top_k, min_similarity=self.config.search_threshold, ) self.context_manager = ContextManager(context_config) # Initialize prompt templates self.prompt_templates = PromptTemplates() logger.info("RAGPipeline initialized successfully") def generate_answer(self, question: str) -> RAGResponse: """ Generate answer to question using RAG pipeline. Args: question: User's question about corporate policies Returns: RAGResponse with answer and metadata """ start_time = time.time() try: # Step 1: Retrieve relevant context logger.debug(f"Starting RAG pipeline for question: {question[:100]}...") search_results = self._retrieve_context(question) if not search_results: return self._create_no_context_response(question, start_time) # Step 2: Prepare and optimize context context, filtered_results = self.context_manager.prepare_context(search_results, question) # Step 3: Check if we have sufficient context quality_metrics = self.context_manager.validate_context_quality( context, question, self.config.min_similarity_for_answer ) if not quality_metrics["passes_validation"]: return self._create_insufficient_context_response(question, filtered_results, start_time) # Step 4: Generate response using LLM llm_response = self._generate_llm_response(question, context) if not llm_response.success: return self._create_llm_error_response(question, llm_response.error_message, start_time) # Step 5: Process and validate response processed_response = self._process_response(llm_response.content, filtered_results) processing_time = time.time() - start_time return RAGResponse( answer=processed_response, sources=self._format_sources(filtered_results), confidence=self._calculate_confidence(quality_metrics, llm_response), processing_time=processing_time, llm_provider=llm_response.provider, llm_model=llm_response.model, context_length=len(context), search_results_count=len(search_results), success=True, ) except Exception as e: logger.error(f"RAG pipeline error: {e}") return RAGResponse( answer=( "I apologize, but I encountered an error processing your question. " "Please try again or contact support." ), sources=[], confidence=0.0, processing_time=time.time() - start_time, llm_provider="none", llm_model="none", context_length=0, search_results_count=0, success=False, error_message=str(e), ) def _retrieve_context(self, question: str) -> List[Dict[str, Any]]: """Retrieve relevant context using search service.""" try: results = self.search_service.search( query=question, top_k=self.config.search_top_k, threshold=self.config.search_threshold, ) logger.debug(f"Retrieved {len(results)} search results") return results except Exception as e: logger.error(f"Context retrieval error: {e}") return [] def _generate_llm_response(self, question: str, context: str) -> LLMResponse: """Generate response using LLM with formatted prompt.""" template = self.prompt_templates.get_policy_qa_template() # Format the prompt formatted_prompt = template.user_template.format(question=question, context=context) # Add system prompt (if LLM service supports it in future) full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}" return self.llm_service.generate_response(full_prompt) def _process_response(self, raw_response: str, search_results: List[Dict[str, Any]]) -> str: """Process and validate LLM response.""" # Ensure citations are present response_with_citations = self.prompt_templates.add_fallback_citations(raw_response, search_results) # Validate citations if enabled if self.config.enable_citation_validation: available_sources = [result.get("metadata", {}).get("filename", "") for result in search_results] citation_validation = self.prompt_templates.validate_citations(response_with_citations, available_sources) # Log any invalid citations invalid_citations = [citation for citation, valid in citation_validation.items() if not valid] if invalid_citations: logger.warning(f"Invalid citations detected: {invalid_citations}") # Truncate if too long if len(response_with_citations) > self.config.max_response_length: truncated = response_with_citations[: self.config.max_response_length - 3] + "..." logger.warning(f"Response truncated from {len(response_with_citations)} " f"to {len(truncated)} characters") return truncated return response_with_citations def _format_sources(self, search_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Format search results for response metadata.""" sources = [] for result in search_results: metadata = result.get("metadata", {}) sources.append( { "document": metadata.get("filename", "unknown"), "chunk_id": result.get("chunk_id", ""), "relevance_score": result.get("similarity_score", 0.0), "excerpt": ( result.get("content", "")[:200] + "..." if len(result.get("content", "")) > 200 else result.get("content", "") ), } ) return sources def _calculate_confidence(self, quality_metrics: Dict[str, Any], llm_response: LLMResponse) -> float: """Calculate confidence score for the response.""" # Base confidence on context quality context_confidence = quality_metrics.get("estimated_relevance", 0.0) # Adjust based on LLM response time (faster might indicate more confidence) time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0)) # Combine factors confidence = (context_confidence * 0.7) + (time_factor * 0.3) return min(1.0, max(0.0, confidence)) def _create_no_context_response(self, question: str, start_time: float) -> RAGResponse: """Create response when no relevant context found.""" return RAGResponse( answer=( "I couldn't find any relevant information in our corporate policies " "to answer your question. Please contact HR or check other company " "resources for assistance." ), sources=[], confidence=0.0, processing_time=time.time() - start_time, llm_provider="none", llm_model="none", context_length=0, search_results_count=0, success=True, # This is a valid "no answer" response ) def _create_insufficient_context_response( self, question: str, results: List[Dict[str, Any]], start_time: float ) -> RAGResponse: """Create response when context quality is insufficient.""" return RAGResponse( answer=( "I found some potentially relevant information, but it doesn't provide " "enough detail to fully answer your question. Please contact HR for " "more specific guidance or rephrase your question." ), sources=self._format_sources(results), confidence=0.2, processing_time=time.time() - start_time, llm_provider="none", llm_model="none", context_length=0, search_results_count=len(results), success=True, ) def _create_llm_error_response(self, question: str, error_message: str, start_time: float) -> RAGResponse: """Create response when LLM generation fails.""" return RAGResponse( answer=( "I apologize, but I'm currently unable to generate a response. " "Please try again in a moment or contact support if the issue persists." ), sources=[], confidence=0.0, processing_time=time.time() - start_time, llm_provider="error", llm_model="error", context_length=0, search_results_count=0, success=False, error_message=error_message, ) def health_check(self) -> Dict[str, Any]: """ Perform health check on all pipeline components. Returns: Dictionary with component health status """ health_status = {"pipeline": "healthy", "components": {}} try: # Check search service test_results = self.search_service.search("test query", top_k=1, threshold=0.0) health_status["components"]["search_service"] = { "status": "healthy", "test_results_count": len(test_results), } except Exception as e: health_status["components"]["search_service"] = { "status": "unhealthy", "error": str(e), } health_status["pipeline"] = "degraded" try: # Check LLM service llm_health = self.llm_service.health_check() health_status["components"]["llm_service"] = llm_health # Pipeline is unhealthy if all LLM providers are down healthy_providers = sum( 1 for provider_status in llm_health.values() if provider_status.get("status") == "healthy" ) if healthy_providers == 0: health_status["pipeline"] = "unhealthy" except Exception as e: health_status["components"]["llm_service"] = { "status": "unhealthy", "error": str(e), } health_status["pipeline"] = "unhealthy" return health_status