Lyte commited on
Commit
73f3dc0
·
verified ·
1 Parent(s): 4c8f98f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -31
app.py CHANGED
@@ -1,47 +1,167 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
 
3
  import torch
4
  import torch.nn.functional as F
 
5
 
6
- # Load model and tokenizer
7
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
8
- model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
9
 
10
- def get_embedding(text):
11
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  with torch.no_grad():
13
- outputs = model(**inputs)
14
- return outputs.last_hidden_state[:, 0, :] # [CLS] token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def compare_sentences(reference, comparisons):
17
- if len(reference) > 250:
18
- return "❌ Error: Reference exceeds 250 character limit."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- comparison_list = [s.strip() for s in comparisons.strip().split('\n') if s.strip()]
21
- if not comparison_list:
22
- return "❌ Error: No comparison sentences provided."
 
 
 
 
23
 
24
- if any(len(s) > 250 for s in comparison_list):
25
- return " Error: One or more comparison sentences exceed 250 characters."
 
 
26
 
27
- ref_emb = get_embedding(reference)
28
- comp_embs = torch.cat([get_embedding(s) for s in comparison_list], dim=0)
29
 
30
- similarities = F.cosine_similarity(ref_emb, comp_embs).tolist()
31
- results = "\n".join([f"Similarity with: \"{s}\"\n→ {round(score, 4)}" for s, score in zip(comparison_list, similarities)])
 
 
32
 
33
- return results
 
 
34
 
35
- demo = gr.Interface(
36
- fn=compare_sentences,
37
- inputs=[
38
- gr.Textbox(label="Reference Sentence (max 250 characters)", lines=2, placeholder="Type the reference sentence here..."),
39
- gr.Textbox(label="Comparison Sentences (one per line, each max 250 characters)", lines=8, placeholder="Type comparison sentences here, one per line..."),
40
- ],
41
- outputs="text",
42
- title="Qwen3 Embedding Comparison Demo",
43
- description="Enter a reference sentence and multiple comparison sentences (one per line). The model computes the cosine similarity between the reference and each comparison."
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  if __name__ == "__main__":
47
- demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
+ from sentence_transformers import CrossEncoder
4
  import torch
5
  import torch.nn.functional as F
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
 
8
+ # --- Constants ---
9
+ TOP_K_FINAL = 3
10
+ RETRIEVAL_CANDIDATE_COUNT = 20
11
 
12
+ # --- 1. SETUP: Load all necessary models ---
13
+
14
+ print("Loading Qwen3 Embedding Model (Retriever)...")
15
+ # Using the model you specified
16
+ embedding_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
17
+ embedding_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
18
+ print("Qwen3 Embedding Model loaded.")
19
+
20
+ print("Loading Reranker model (Cross-Encoder)...")
21
+ reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
22
+ print("Reranker model loaded.")
23
+
24
+
25
+ # --- 2. CORE FUNCTIONS ---
26
+
27
+ def get_qwen_embeddings_batch(texts):
28
+ """
29
+ A new function to get embeddings for a BATCH of texts using Qwen3.
30
+ This is much more efficient than one-by-one.
31
+ """
32
+ # Important: `padding=True` and `truncation=True` are key for batching
33
+ inputs = embedding_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
34
  with torch.no_grad():
35
+ outputs = embedding_model(**inputs)
36
+ # Extract the [CLS] token's embedding for each text in the batch
37
+ embeddings = outputs.last_hidden_state[:, 0, :]
38
+ return embeddings
39
+
40
+ def process_and_index_document(source_text):
41
+ """
42
+ This function is triggered by the 'Index Document' button.
43
+ It chunks the text, creates embeddings, and stores them.
44
+ """
45
+ if not source_text or not source_text.strip():
46
+ # Update the UI to show an error and hide the search bar
47
+ return None, None, "❌ Error: Please provide some source text.", gr.update(visible=False)
48
+
49
+ print("--- Starting document processing ---")
50
+
51
+ # a. Chunk the document
52
+ text_splitter = RecursiveCharacterTextSplitter(
53
+ chunk_size=500, chunk_overlap=50,
54
+ length_function=len, separators=["\n\n", "\n", " ", ""],
55
+ )
56
+ chunks = text_splitter.split_text(source_text)
57
+ print(f"Document split into {len(chunks)} chunks.")
58
+
59
+ # b. Vectorize the chunks using Qwen3
60
+ print("Vectorizing chunks with Qwen3... (This might take a moment)")
61
+ embeddings = get_qwen_embeddings_batch(chunks)
62
+ print("Vectorization complete. Shape:", embeddings.shape)
63
+
64
+ # c. Return the processed data and update UI
65
+ success_message = f"✅ Document indexed successfully into {len(chunks)} chunks."
66
+ # The last return value makes the search group visible
67
+ return chunks, embeddings, success_message, gr.update(visible=True)
68
+
69
 
70
+ def search_and_rerank(user_query, document_chunks, document_embeddings):
71
+ """
72
+ The main search logic (retrieval + reranking).
73
+ This function now takes the chunks and embeddings from the session state.
74
+ """
75
+ if not user_query or not user_query.strip():
76
+ return [""] * (TOP_K_FINAL * 2)
77
+
78
+ if document_chunks is None:
79
+ return ["Please index a document first."] * (TOP_K_FINAL * 2)
80
+
81
+ # --- STAGE 1: RETRIEVAL ---
82
+ query_embedding = get_qwen_embeddings_batch([user_query]) # Embed the single query
83
+
84
+ # Use PyTorch's cosine similarity
85
+ similarities = F.cosine_similarity(query_embedding, document_embeddings)
86
+
87
+ # Get the top candidates
88
+ top_retrieval_indices = torch.topk(similarities, k=min(RETRIEVAL_CANDIDATE_COUNT, len(document_chunks))).indices
89
+ candidate_chunks = [document_chunks[idx] for idx in top_retrieval_indices]
90
+
91
+ # --- STAGE 2: RERANKING ---
92
+ reranker_input_pairs = [[user_query, chunk] for chunk in candidate_chunks]
93
+ rerank_scores = reranker_model.predict(reranker_input_pairs)
94
 
95
+ reranked_results = sorted(zip(rerank_scores, candidate_chunks), key=lambda x: x[0], reverse=True)
96
+
97
+ # --- Prepare final output ---
98
+ outputs = []
99
+ for score, chunk in reranked_results[:TOP_K_FINAL]:
100
+ outputs.append(f"Rerank Score: {score:.4f}")
101
+ outputs.append(chunk)
102
 
103
+ while len(outputs) < TOP_K_FINAL * 2:
104
+ outputs.extend(["", ""])
105
+
106
+ return outputs
107
 
108
+ # --- 3. GRADIO USER INTERFACE ---
 
109
 
110
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
111
+ gr.Markdown("# 🧠 Dynamic RAG with Qwen3 + Reranker")
112
+ gr.Markdown("**Step 1:** Paste your source text below and click 'Index Document'.\n"
113
+ "**Step 2:** Once indexed, use the search bar to ask questions.")
114
 
115
+ # We use gr.State to hold session-specific data (chunks and embeddings)
116
+ chunks_state = gr.State()
117
+ embeddings_state = gr.State()
118
 
119
+ with gr.Row():
120
+ source_document_input = gr.Textbox(
121
+ label="Source Document Text",
122
+ placeholder="Paste the full text of your document here...",
123
+ lines=15,
124
+ scale=2
125
+ )
126
+
127
+ index_button = gr.Button("Index Document 🚀")
128
+ status_display = gr.Markdown("Status: Ready to index a document.")
129
+
130
+ # The search UI is hidden until indexing is complete
131
+ with gr.Column(visible=False) as search_ui_group:
132
+ gr.Markdown("---")
133
+ gr.Markdown("### Step 2: Search Your Document")
134
+ query_input = gr.Textbox(
135
+ label="Your Question or Topic",
136
+ placeholder="e.g., What is the main goal of the project?",
137
+ lines=1
138
+ )
139
+
140
+ output_components = []
141
+ for i in range(TOP_K_FINAL):
142
+ with gr.Group():
143
+ score = gr.Textbox(label=f"Result {i+1} Score", interactive=False)
144
+ chunk_text = gr.Textbox(label="Retrieved Chunk", interactive=False, lines=4)
145
+ output_components.extend([score, chunk_text])
146
+
147
+ # --- Connect UI components to functions ---
148
+
149
+ # When the index button is clicked...
150
+ index_button.click(
151
+ fn=process_and_index_document,
152
+ inputs=[source_document_input],
153
+ # The outputs are the state variables, the status message, and the search UI group
154
+ outputs=[chunks_state, embeddings_state, status_display, search_ui_group]
155
+ )
156
+
157
+ # When the query input changes (live search)...
158
+ query_input.change(
159
+ fn=search_and_rerank,
160
+ # Inputs must include the state variables
161
+ inputs=[query_input, chunks_state, embeddings_state],
162
+ outputs=output_components
163
+ )
164
 
165
  if __name__ == "__main__":
166
+ print("\nInterface is launching... Go to the printed URL.")
167
+ iface.launch()