""" Advanced RAG techniques for improved retrieval and generation Includes: Query Expansion, Reranking, Contextual Compression, Hybrid Search """ from typing import List, Dict, Optional, Tuple import numpy as np from dataclasses import dataclass import re @dataclass class RetrievedDocument: """Document retrieved from vector database""" id: str text: str confidence: float metadata: Dict class AdvancedRAG: """Advanced RAG system with modern techniques""" def __init__(self, embedding_service, qdrant_service): self.embedding_service = embedding_service self.qdrant_service = qdrant_service def expand_query(self, query: str) -> List[str]: """ Expand query with related terms and variations Simple rule-based expansion for Vietnamese queries """ queries = [query] # Add query variations # Remove question words for alternative search question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào', 'tại sao', 'có', 'là', 'được', 'không'] query_lower = query.lower() for qw in question_words: if qw in query_lower: variant = query_lower.replace(qw, '').strip() if variant and variant != query_lower: queries.append(variant) # Extract key nouns/phrases (simple approach) words = query.split() if len(words) > 3: # Take important words (skip first question word) key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3]) if key_phrases not in queries: queries.append(key_phrases) return queries[:3] # Return top 3 variations def multi_query_retrieval( self, query: str, top_k: int = 5, score_threshold: float = 0.5 ) -> List[RetrievedDocument]: """ Retrieve documents using multiple query variations Combines results from all query variations """ expanded_queries = self.expand_query(query) all_results = {} # Use dict to deduplicate by doc_id for q in expanded_queries: # Generate embedding for each query variant query_embedding = self.embedding_service.encode_text(q) # Search in Qdrant results = self.qdrant_service.search( query_embedding=query_embedding, limit=top_k, score_threshold=score_threshold ) # Add to results (keep highest score for duplicates) for result in results: doc_id = result["id"] if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence: all_results[doc_id] = RetrievedDocument( id=doc_id, text=result["metadata"].get("text", ""), confidence=result["confidence"], metadata=result["metadata"] ) # Sort by confidence and return top_k sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True) return sorted_results[:top_k] def rerank_documents( self, query: str, documents: List[RetrievedDocument], use_cross_encoder: bool = False ) -> List[RetrievedDocument]: """ Rerank documents based on semantic similarity Simple reranking using embedding similarity (can be upgraded to cross-encoder) """ if not documents: return documents # Simple reranking: recalculate similarity with original query query_embedding = self.embedding_service.encode_text(query) reranked = [] for doc in documents: # Get document embedding doc_embedding = self.embedding_service.encode_text(doc.text) # Calculate cosine similarity similarity = np.dot(query_embedding.flatten(), doc_embedding.flatten()) # Combine with original confidence (weighted average) new_score = 0.6 * similarity + 0.4 * doc.confidence reranked.append(RetrievedDocument( id=doc.id, text=doc.text, confidence=float(new_score), metadata=doc.metadata )) # Sort by new score reranked.sort(key=lambda x: x.confidence, reverse=True) return reranked def compress_context( self, query: str, documents: List[RetrievedDocument], max_tokens: int = 500 ) -> List[RetrievedDocument]: """ Compress context to most relevant parts Remove redundant information and keep only relevant sentences """ compressed_docs = [] for doc in documents: # Split into sentences sentences = self._split_sentences(doc.text) # Score each sentence based on relevance to query scored_sentences = [] query_words = set(query.lower().split()) for sent in sentences: sent_words = set(sent.lower().split()) # Simple relevance: word overlap overlap = len(query_words & sent_words) if overlap > 0: scored_sentences.append((sent, overlap)) # Sort by relevance and take top sentences scored_sentences.sort(key=lambda x: x[1], reverse=True) # Reconstruct compressed text (up to max_tokens) compressed_text = "" word_count = 0 for sent, score in scored_sentences: sent_words = len(sent.split()) if word_count + sent_words <= max_tokens: compressed_text += sent + " " word_count += sent_words else: break # If nothing selected, take original first part if not compressed_text.strip(): compressed_text = doc.text[:max_tokens * 5] # Rough estimate compressed_docs.append(RetrievedDocument( id=doc.id, text=compressed_text.strip(), confidence=doc.confidence, metadata=doc.metadata )) return compressed_docs def _split_sentences(self, text: str) -> List[str]: """Split text into sentences (Vietnamese-aware)""" # Simple sentence splitter sentences = re.split(r'[.!?]+', text) return [s.strip() for s in sentences if s.strip()] def hybrid_rag_pipeline( self, query: str, top_k: int = 5, score_threshold: float = 0.5, use_reranking: bool = True, use_compression: bool = True, max_context_tokens: int = 500 ) -> Tuple[List[RetrievedDocument], Dict]: """ Complete advanced RAG pipeline 1. Multi-query retrieval 2. Reranking 3. Contextual compression """ stats = { "original_query": query, "expanded_queries": [], "initial_results": 0, "after_rerank": 0, "after_compression": 0 } # Step 1: Multi-query retrieval expanded_queries = self.expand_query(query) stats["expanded_queries"] = expanded_queries documents = self.multi_query_retrieval( query=query, top_k=top_k * 2, # Get more candidates for reranking score_threshold=score_threshold ) stats["initial_results"] = len(documents) # Step 2: Reranking (optional) if use_reranking and documents: documents = self.rerank_documents(query, documents) documents = documents[:top_k] # Keep top_k after reranking stats["after_rerank"] = len(documents) # Step 3: Contextual compression (optional) if use_compression and documents: documents = self.compress_context( query=query, documents=documents, max_tokens=max_context_tokens ) stats["after_compression"] = len(documents) return documents, stats def format_context_for_llm( self, documents: List[RetrievedDocument], include_metadata: bool = True ) -> str: """ Format retrieved documents into context string for LLM Uses better structure for improved LLM understanding """ if not documents: return "" context_parts = ["RELEVANT CONTEXT:\n"] for i, doc in enumerate(documents, 1): context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---") context_parts.append(doc.text) if include_metadata and doc.metadata: # Add useful metadata meta_str = [] for key, value in doc.metadata.items(): if key not in ['text', 'texts'] and value: meta_str.append(f"{key}: {value}") if meta_str: context_parts.append(f"[Metadata: {', '.join(meta_str)}]") context_parts.append("\n--- End of Context ---\n") return "\n".join(context_parts) def build_rag_prompt( self, query: str, context: str, system_message: str = "You are a helpful AI assistant." ) -> str: """ Build optimized RAG prompt for LLM Uses best practices for prompt engineering """ prompt_template = f"""{system_message} {context} INSTRUCTIONS: 1. Answer the user's question using ONLY the information provided in the context above 2. If the context doesn't contain relevant information, say "Tôi không tìm thấy thông tin liên quan trong dữ liệu." 3. Cite relevant parts of the context when answering 4. Be concise and accurate 5. Answer in Vietnamese if the question is in Vietnamese USER QUESTION: {query} YOUR ANSWER:""" return prompt_template