Jordan Klein commited on
Commit
ceda798
Β·
1 Parent(s): 32aff05

updated description

Browse files
Files changed (1) hide show
  1. app.py +49 -48
app.py CHANGED
@@ -32,7 +32,6 @@ def download_file(url, dest_path):
32
  f.write(r.content)
33
  print(f"Saved to {dest_path}")
34
 
35
-
36
  # Download index + docstore
37
  download_file(INDEX_URL, os.path.join(INDEX_DIR, "index.faiss"))
38
  download_file(DOCSTORE_URL, os.path.join(INDEX_DIR, "docstore.pkl"))
@@ -41,12 +40,7 @@ download_file(DOCSTORE_URL, os.path.join(INDEX_DIR, "docstore.pkl"))
41
  # Retriever
42
  # ------------------------------
43
  class Retriever:
44
-
45
- def __init__(
46
- self,
47
- index_dir,
48
- cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2"
49
- ):
50
  index, segments = self._load_index(index_dir)
51
  self.index = index
52
  self.segments = segments
@@ -59,7 +53,7 @@ class Retriever:
59
 
60
  def _load_index(self, index_dir):
61
  index = faiss.read_index(os.path.join(index_dir, "index.faiss"))
62
- with open(os.path.join(index_dir, "docstore.pkl") , "rb") as f:
63
  segments = pickle.load(f)
64
  return index, segments
65
 
@@ -68,29 +62,20 @@ class Retriever:
68
  faiss.normalize_L2(embedding)
69
  return embedding
70
 
71
- def _cosine_similarity(self, a, b):
72
- return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
73
-
74
  def retrieve(self, query, k=50):
75
- """
76
- 1. Retrieve top-k segments using bi-encoder (FAISS)
77
- 2. Re-rank segments using cross-encoder on segment['text']
78
- 3. Re-score each sentence inside chosen segment using cross-encoder
79
- 4. Highlight the best sentence
80
- """
81
- # ---------- Stage 1: Bi-Encoder Retrieval ----------
82
  embedding = self.preprocess_query(query)
83
  D, I = self.index.search(embedding, k)
84
 
85
  candidates = []
86
- ce_pairs_segments = [] # (query, segment_text)
87
 
88
  for idx in I[0]:
89
  seg = self.segments[idx]
90
  candidates.append(seg)
91
  ce_pairs_segments.append([query, seg["text"]])
92
 
93
- # ---------- Stage 2: Cross-Encoder Re-Rank Segments ----------
94
  segment_scores = self.cross.predict(ce_pairs_segments)
95
  best_seg_idx = int(np.argmax(segment_scores))
96
  best_segment = candidates[best_seg_idx]
@@ -98,21 +83,18 @@ class Retriever:
98
  # ---------- Stage 3: Cross-Encoder Sentence Ranking ----------
99
  sentences = best_segment["sentences"]
100
  ce_pairs_sentences = [[query, s] for s in sentences]
101
-
102
  sentence_scores = self.cross.predict(ce_pairs_sentences)
103
- best_sent_idx = int(np.argmax(sentence_scores))
104
 
 
105
  best_sentence = sentences[best_sent_idx].strip()
106
 
107
- # Highlight within full segment
108
  highlighted_text = (
109
  best_segment["text"]
110
  .replace(best_sentence, f"**{best_sentence}**")
111
  .replace("\n", " ")
112
  )
113
 
114
- # ---------- Result ----------
115
- result = {
116
  "text": highlighted_text,
117
  "url": best_segment.get("url"),
118
  "document_id": best_segment.get("document_id"),
@@ -120,46 +102,51 @@ class Retriever:
120
  "sentence_score": float(sentence_scores[best_sent_idx]),
121
  }
122
 
123
- return result
124
-
125
-
126
  # ------------------------------
127
- # Lightweight Generator
128
  # ------------------------------
129
- # Finetuned TinyLlama
130
- generator = pipeline(
131
- "text-generation",
132
- model="LoneWolfgang/tinyllama-for-abalone-RAG",
133
- max_new_tokens=150,
134
- temperature=0.1,
135
- )
 
 
 
 
 
 
136
 
 
137
 
138
  # ------------------------------
139
  # Combined function: retrieve β†’ generate
140
  # ------------------------------
141
- retriever = Retriever(INDEX_DIR)
142
-
143
- def answer_query(query):
144
  doc = retriever.retrieve(query)
145
 
146
  url = doc["url"]
147
  context = doc["text"].replace("\n", " ")
148
 
149
  prompt = f"""
150
- <|system|>
151
  You answer questions strictly using the provided context.
152
- <|user|>
153
  Context: {context}
154
 
155
  Question: {query}
156
- <|assistant|>
157
  """
158
 
159
- result = generator(prompt)[0]["generated_text"]
 
160
 
161
- # Keep only model completion after the assistant token
162
- result = result.split("<|assistant|>")[-1].strip()
 
 
 
 
163
 
164
  return (
165
  f"#### Response\n\n"
@@ -176,12 +163,26 @@ def answer_query(query):
176
  # ------------------------------
177
  demo = gr.Interface(
178
  fn=answer_query,
179
- inputs=gr.Textbox(label="Enter your question"),
 
 
 
 
 
 
 
180
  outputs=gr.Markdown(label="Answer"),
181
  title="Abalone RAG Demo",
182
- description="This RAG system uses SBERT + Cross-Encoders for Retrieval with TinyLlama finetuned on responses from GPT5."
 
 
 
 
 
 
 
 
183
  )
184
 
185
  if __name__ == "__main__":
186
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
187
-
 
32
  f.write(r.content)
33
  print(f"Saved to {dest_path}")
34
 
 
35
  # Download index + docstore
36
  download_file(INDEX_URL, os.path.join(INDEX_DIR, "index.faiss"))
37
  download_file(DOCSTORE_URL, os.path.join(INDEX_DIR, "docstore.pkl"))
 
40
  # Retriever
41
  # ------------------------------
42
  class Retriever:
43
+ def __init__(self, index_dir, cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2"):
 
 
 
 
 
44
  index, segments = self._load_index(index_dir)
45
  self.index = index
46
  self.segments = segments
 
53
 
54
  def _load_index(self, index_dir):
55
  index = faiss.read_index(os.path.join(index_dir, "index.faiss"))
56
+ with open(os.path.join(index_dir, "docstore.pkl"), "rb") as f:
57
  segments = pickle.load(f)
58
  return index, segments
59
 
 
62
  faiss.normalize_L2(embedding)
63
  return embedding
64
 
 
 
 
65
  def retrieve(self, query, k=50):
66
+ # ---------- Stage 1: Bi-Encoder ----------
 
 
 
 
 
 
67
  embedding = self.preprocess_query(query)
68
  D, I = self.index.search(embedding, k)
69
 
70
  candidates = []
71
+ ce_pairs_segments = []
72
 
73
  for idx in I[0]:
74
  seg = self.segments[idx]
75
  candidates.append(seg)
76
  ce_pairs_segments.append([query, seg["text"]])
77
 
78
+ # ---------- Stage 2: Cross-Encoder Re-Rank ----------
79
  segment_scores = self.cross.predict(ce_pairs_segments)
80
  best_seg_idx = int(np.argmax(segment_scores))
81
  best_segment = candidates[best_seg_idx]
 
83
  # ---------- Stage 3: Cross-Encoder Sentence Ranking ----------
84
  sentences = best_segment["sentences"]
85
  ce_pairs_sentences = [[query, s] for s in sentences]
 
86
  sentence_scores = self.cross.predict(ce_pairs_sentences)
 
87
 
88
+ best_sent_idx = int(np.argmax(sentence_scores))
89
  best_sentence = sentences[best_sent_idx].strip()
90
 
 
91
  highlighted_text = (
92
  best_segment["text"]
93
  .replace(best_sentence, f"**{best_sentence}**")
94
  .replace("\n", " ")
95
  )
96
 
97
+ return {
 
98
  "text": highlighted_text,
99
  "url": best_segment.get("url"),
100
  "document_id": best_segment.get("document_id"),
 
102
  "sentence_score": float(sentence_scores[best_sent_idx]),
103
  }
104
 
 
 
 
105
  # ------------------------------
106
+ # Generators (loaded once)
107
  # ------------------------------
108
+ generators = {
109
+ "TinyLlama": pipeline(
110
+ "text-generation",
111
+ model="LoneWolfgang/tinyllama-for-abalone-RAG",
112
+ max_new_tokens=150,
113
+ temperature=0.1,
114
+ ),
115
+ "FLAN-T5": pipeline(
116
+ "text2text-generation",
117
+ model="google/flan-t5-base",
118
+ max_length=200,
119
+ )
120
+ }
121
 
122
+ retriever = Retriever(INDEX_DIR)
123
 
124
  # ------------------------------
125
  # Combined function: retrieve β†’ generate
126
  # ------------------------------
127
+ def answer_query(query, model_choice):
 
 
128
  doc = retriever.retrieve(query)
129
 
130
  url = doc["url"]
131
  context = doc["text"].replace("\n", " ")
132
 
133
  prompt = f"""
 
134
  You answer questions strictly using the provided context.
135
+
136
  Context: {context}
137
 
138
  Question: {query}
 
139
  """
140
 
141
+ # Choose generator
142
+ gen = generators[model_choice]
143
 
144
+ if model_choice == "TinyLlama":
145
+ out = gen(f"<|system|>{prompt}<|assistant|>")[0]["generated_text"]
146
+ result = out.split("<|assistant|>")[-1].strip()
147
+ else:
148
+ # FLAN-T5 returns text in "generated_text"
149
+ result = gen(prompt)[0]["generated_text"]
150
 
151
  return (
152
  f"#### Response\n\n"
 
163
  # ------------------------------
164
  demo = gr.Interface(
165
  fn=answer_query,
166
+ inputs=[
167
+ gr.Textbox(label="Enter your question"),
168
+ gr.Radio(
169
+ ["TinyLlama", "FLAN-T5"],
170
+ label="Choose Model",
171
+ value="FLAN-T5"
172
+ )
173
+ ],
174
  outputs=gr.Markdown(label="Answer"),
175
  title="Abalone RAG Demo",
176
+ description="""This RAG system uses [SBERT](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) for initial retrieval and a [Cross Encoder](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2) for re-ranking and highlighting.
177
+
178
+ Sentence embeddings are computed and [indexed](https://huggingface.co/LoneWolfgang/abalone-index) using FAISS.
179
+
180
+ For generation, you can choose between:
181
+
182
+ - [FLAN-T5](https://huggingface.co/google/flan-t5-base) β€” fast, reliable, and ideal for exploring retrieval quality.
183
+ - [Finetuned TinyLlama](https://huggingface.co/LoneWolfgang/tinyllama-for-abalone-RAG) β€” slower, but more expressive.
184
+ """
185
  )
186
 
187
  if __name__ == "__main__":
188
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)