LoneWolfgang commited on
Commit
01bfc89
·
1 Parent(s): b840334

Added Generation

Browse files
Files changed (1) hide show
  1. app.py +53 -27
app.py CHANGED
@@ -4,9 +4,10 @@ import pickle
4
  import sentence_transformers
5
  import faiss
6
  import gradio as gr
 
7
 
8
  # ------------------------------
9
- # Configuration: URLs to your files
10
  # ------------------------------
11
  INDEX_URL = "https://huggingface.co/LoneWolfgang/abalone-index/resolve/main/index.faiss"
12
  DOCSTORE_URL = "https://huggingface.co/LoneWolfgang/abalone-index/resolve/main/docstore.pkl"
@@ -31,21 +32,18 @@ def download_file(url, dest_path):
31
  else:
32
  print(f"{dest_path} already exists, skipping download.")
33
 
34
- # Download the FAISS index and docstore
35
  download_file(INDEX_URL, os.path.join(INDEX_DIR, "index.faiss"))
36
  download_file(DOCSTORE_URL, os.path.join(INDEX_DIR, "docstore.pkl"))
37
 
38
  # ------------------------------
39
- # Retriever class
40
  # ------------------------------
41
  class Retriever:
42
  def __init__(self, index_dir, sbert_model="all-MiniLM-L12-v2"):
43
- # Load FAISS index
44
  self.index = faiss.read_index(os.path.join(index_dir, "index.faiss"))
45
- # Load docstore
46
  with open(os.path.join(index_dir, "docstore.pkl"), "rb") as f:
47
  self.segments = pickle.load(f)
48
- # Load SentenceTransformer
49
  self.sbert = sentence_transformers.SentenceTransformer(sbert_model)
50
 
51
  def preprocess_query(self, query):
@@ -53,38 +51,66 @@ class Retriever:
53
  faiss.normalize_L2(embedding)
54
  return embedding
55
 
56
- def retrieve(self, query, k=5):
57
  embedding = self.preprocess_query(query)
58
  D, I = self.index.search(embedding, k)
59
- results = []
60
- for rank, (idx, score) in enumerate(zip(I[0], D[0]), start=1):
61
- text = self.segments[idx]
62
- results.append(f"**{rank}. (Score={score:.4f})**\n{text}")
63
- return "\n\n".join(results)
64
 
65
  # ------------------------------
66
- # Instantiate retriever
67
  # ------------------------------
68
- retriever = Retriever(INDEX_DIR)
 
 
 
 
 
 
 
69
 
70
  # ------------------------------
71
- # Gradio interface
72
  # ------------------------------
73
- def search(query, top_k):
74
- return retriever.retrieve(query, k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  demo = gr.Interface(
77
- fn=search,
78
- inputs=[
79
- gr.Textbox(label="Enter your query"),
80
- gr.Slider(1, 10, value=5, step=1, label="Top K")
81
- ],
82
- outputs=gr.Markdown(label="Results"),
83
- title="FAISS Retriever",
84
- description="Semantic search using SentenceTransformers + FAISS.",
85
  theme="soft",
86
- allow_flagging="never"
87
  )
88
 
89
  if __name__ == "__main__":
90
- demo.launch()
 
4
  import sentence_transformers
5
  import faiss
6
  import gradio as gr
7
+ from transformers import pipeline
8
 
9
  # ------------------------------
10
+ # Configuration
11
  # ------------------------------
12
  INDEX_URL = "https://huggingface.co/LoneWolfgang/abalone-index/resolve/main/index.faiss"
13
  DOCSTORE_URL = "https://huggingface.co/LoneWolfgang/abalone-index/resolve/main/docstore.pkl"
 
32
  else:
33
  print(f"{dest_path} already exists, skipping download.")
34
 
35
+ # Download index + docstore
36
  download_file(INDEX_URL, os.path.join(INDEX_DIR, "index.faiss"))
37
  download_file(DOCSTORE_URL, os.path.join(INDEX_DIR, "docstore.pkl"))
38
 
39
  # ------------------------------
40
+ # Retriever
41
  # ------------------------------
42
  class Retriever:
43
  def __init__(self, index_dir, sbert_model="all-MiniLM-L12-v2"):
 
44
  self.index = faiss.read_index(os.path.join(index_dir, "index.faiss"))
 
45
  with open(os.path.join(index_dir, "docstore.pkl"), "rb") as f:
46
  self.segments = pickle.load(f)
 
47
  self.sbert = sentence_transformers.SentenceTransformer(sbert_model)
48
 
49
  def preprocess_query(self, query):
 
51
  faiss.normalize_L2(embedding)
52
  return embedding
53
 
54
+ def retrieve(self, query, k=1):
55
  embedding = self.preprocess_query(query)
56
  D, I = self.index.search(embedding, k)
57
+ top_docs = [self.segments[idx] for idx in I[0]]
58
+ return top_docs, D[0]
 
 
 
59
 
60
  # ------------------------------
61
+ # Lightweight Generator
62
  # ------------------------------
63
+ # FLAN-T5-base is small (~250M) and fast to run on CPU
64
+ generator = pipeline(
65
+ "text2text-generation",
66
+ model="google/flan-t5-base",
67
+ tokenizer="google/flan-t5-base",
68
+ max_new_tokens=150,
69
+ temperature=0.1,
70
+ )
71
 
72
  # ------------------------------
73
+ # Combined function: retrieve → generate
74
  # ------------------------------
75
+ retriever = Retriever(INDEX_DIR)
76
+
77
+ def answer_query(query):
78
+ docs, scores = retriever.retrieve(query, k=1)
79
+ record = docs[0]
80
+ url = record["url"]
81
+ context = record["text"]
82
+
83
+ prompt = (
84
+ f"Answer the following question based on the context.\n\n"
85
+ f"Context:\n{context}\n\n"
86
+ f"Question: {query}\nAnswer:"
87
+ )
88
+ result = generator(prompt)[0]["generated_text"]
89
+
90
+ return f"""
91
+ ### Response
92
+ {result}
93
 
94
+ ---
95
+
96
+ **Context**
97
+
98
+ {context}
99
+
100
+ **[Source]({url})**
101
+ """
102
+
103
+ # ------------------------------
104
+ # Gradio UI
105
+ # ------------------------------
106
  demo = gr.Interface(
107
+ fn=answer_query,
108
+ inputs=gr.Textbox(label="Enter your question"),
109
+ outputs=gr.Markdown(label="Answer"),
110
+ title="RAG Demo",
111
+ description="Retrieves the top 1 passage and generates an answer using FLAN-T5.",
 
 
 
112
  theme="soft",
 
113
  )
114
 
115
  if __name__ == "__main__":
116
+ demo.launch()