import gradio as gr from transformers import AutoTokenizer, AutoModel from sentence_transformers import CrossEncoder import torch import torch.nn.functional as F from langchain.text_splitter import RecursiveCharacterTextSplitter # --- Constants --- TOP_K_FINAL = 3 RETRIEVAL_CANDIDATE_COUNT = 20 # --- 1. SETUP: Load all necessary models --- print("Loading Qwen3 Embedding Model (Retriever)...") # Using the model you specified embedding_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") embedding_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B") print("Qwen3 Embedding Model loaded.") print("Loading Reranker model (Cross-Encoder)...") reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') print("Reranker model loaded.") # --- 2. CORE FUNCTIONS --- def get_qwen_embeddings_batch(texts): """ A new function to get embeddings for a BATCH of texts using Qwen3. This is much more efficient than one-by-one. """ # Important: `padding=True` and `truncation=True` are key for batching inputs = embedding_tokenizer(texts, padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): outputs = embedding_model(**inputs) # Extract the [CLS] token's embedding for each text in the batch embeddings = outputs.last_hidden_state[:, 0, :] return embeddings def process_and_index_document(source_text): """ This function is triggered by the 'Index Document' button. It chunks the text, creates embeddings, and stores them. """ if not source_text or not source_text.strip(): # Update the UI to show an error and hide the search bar return None, None, "❌ Error: Please provide some source text.", gr.update(visible=False) print("--- Starting document processing ---") # a. Chunk the document text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=50, length_function=len, separators=["\n\n", "\n", " ", ""], ) chunks = text_splitter.split_text(source_text) print(f"Document split into {len(chunks)} chunks.") # b. Vectorize the chunks using Qwen3 print("Vectorizing chunks with Qwen3... (This might take a moment)") embeddings = get_qwen_embeddings_batch(chunks) print("Vectorization complete. Shape:", embeddings.shape) # c. Return the processed data and update UI success_message = f"✅ Document indexed successfully into {len(chunks)} chunks." # The last return value makes the search group visible return chunks, embeddings, success_message, gr.update(visible=True) def search_and_rerank(user_query, document_chunks, document_embeddings): """ The main search logic (retrieval + reranking). This function now takes the chunks and embeddings from the session state. """ if not user_query or not user_query.strip(): return [""] * (TOP_K_FINAL * 2) if document_chunks is None: return ["Please index a document first."] * (TOP_K_FINAL * 2) # --- STAGE 1: RETRIEVAL --- query_embedding = get_qwen_embeddings_batch([user_query]) # Embed the single query # Use PyTorch's cosine similarity similarities = F.cosine_similarity(query_embedding, document_embeddings) # Get the top candidates top_retrieval_indices = torch.topk(similarities, k=min(RETRIEVAL_CANDIDATE_COUNT, len(document_chunks))).indices candidate_chunks = [document_chunks[idx] for idx in top_retrieval_indices] # --- STAGE 2: RERANKING --- reranker_input_pairs = [[user_query, chunk] for chunk in candidate_chunks] rerank_scores = reranker_model.predict(reranker_input_pairs) reranked_results = sorted(zip(rerank_scores, candidate_chunks), key=lambda x: x[0], reverse=True) # --- Prepare final output --- outputs = [] for score, chunk in reranked_results[:TOP_K_FINAL]: outputs.append(f"Rerank Score: {score:.4f}") outputs.append(chunk) while len(outputs) < TOP_K_FINAL * 2: outputs.extend(["", ""]) return outputs # --- 3. GRADIO USER INTERFACE --- with gr.Blocks(theme=gr.themes.Soft()) as iface: gr.Markdown("# 🧠 Dynamic RAG with Qwen3 + Reranker") gr.Markdown("**Step 1:** Paste your source text below and click 'Index Document'.\n" "**Step 2:** Once indexed, use the search bar to ask questions.") # We use gr.State to hold session-specific data (chunks and embeddings) chunks_state = gr.State() embeddings_state = gr.State() with gr.Row(): source_document_input = gr.Textbox( label="Source Document Text", placeholder="Paste the full text of your document here...", lines=15, scale=2 ) index_button = gr.Button("Index Document 🚀") status_display = gr.Markdown("Status: Ready to index a document.") # The search UI is hidden until indexing is complete with gr.Column(visible=False) as search_ui_group: gr.Markdown("---") gr.Markdown("### Step 2: Search Your Document") query_input = gr.Textbox( label="Your Question or Topic", placeholder="e.g., What is the main goal of the project?", lines=1 ) output_components = [] for i in range(TOP_K_FINAL): with gr.Group(): score = gr.Textbox(label=f"Result {i+1} Score", interactive=False) chunk_text = gr.Textbox(label="Retrieved Chunk", interactive=False, lines=4) output_components.extend([score, chunk_text]) # --- Connect UI components to functions --- # When the index button is clicked... index_button.click( fn=process_and_index_document, inputs=[source_document_input], # The outputs are the state variables, the status message, and the search UI group outputs=[chunks_state, embeddings_state, status_display, search_ui_group] ) # When the query input changes (live search)... query_input.change( fn=search_and_rerank, # Inputs must include the state variables inputs=[query_input, chunks_state, embeddings_state], outputs=output_components ) if __name__ == "__main__": print("\nInterface is launching... Go to the printed URL.") iface.launch()