Jordan Klein
updated app
c573d72
import os
import requests
import pickle
import sentence_transformers
import faiss
import gradio as gr
from transformers import pipeline
import numpy as np
from sentence_transformers import CrossEncoder
# ------------------------------
# Configuration
# ------------------------------
INDEX_URL = "https://huggingface.co/LoneWolfgang/abalone-index/resolve/main/index.faiss"
DOCSTORE_URL = "https://huggingface.co/LoneWolfgang/abalone-index/resolve/main/docstore.pkl"
INDEX_DIR = "data/index"
SBERT = "all-MiniLM-L12-v2"
# ------------------------------
# Ensure data folder exists
# ------------------------------
os.makedirs(INDEX_DIR, exist_ok=True)
# ------------------------------
# Download helper
# ------------------------------
def download_file(url, dest_path):
print(f"Downloading {url} ...")
r = requests.get(url)
r.raise_for_status()
with open(dest_path, "wb") as f:
f.write(r.content)
print(f"Saved to {dest_path}")
# Download index + docstore
download_file(INDEX_URL, os.path.join(INDEX_DIR, "index.faiss"))
download_file(DOCSTORE_URL, os.path.join(INDEX_DIR, "docstore.pkl"))
# ------------------------------
# Retriever
# ------------------------------
class Retriever:
def __init__(self, index_dir, cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2"):
index, segments = self._load_index(index_dir)
self.index = index
self.segments = segments
# bi-encoder
self.sbert = sentence_transformers.SentenceTransformer(SBERT)
# cross-encoder
self.cross = CrossEncoder(cross_encoder_model)
def _load_index(self, index_dir):
index = faiss.read_index(os.path.join(index_dir, "index.faiss"))
with open(os.path.join(index_dir, "docstore.pkl"), "rb") as f:
segments = pickle.load(f)
return index, segments
def preprocess_query(self, query):
embedding = self.sbert.encode([query]).astype("float32")
faiss.normalize_L2(embedding)
return embedding
def retrieve(self, query, k=50):
# ---------- Stage 1: Bi-Encoder ----------
embedding = self.preprocess_query(query)
D, I = self.index.search(embedding, k)
candidates = []
ce_pairs_segments = []
for idx in I[0]:
seg = self.segments[idx]
candidates.append(seg)
ce_pairs_segments.append([query, seg["text"]])
# ---------- Stage 2: Cross-Encoder Re-Rank ----------
segment_scores = self.cross.predict(ce_pairs_segments)
best_seg_idx = int(np.argmax(segment_scores))
best_segment = candidates[best_seg_idx]
# ---------- Stage 3: Cross-Encoder Sentence Ranking ----------
sentences = best_segment["sentences"]
ce_pairs_sentences = [[query, s] for s in sentences]
sentence_scores = self.cross.predict(ce_pairs_sentences)
best_sent_idx = int(np.argmax(sentence_scores))
best_sentence = sentences[best_sent_idx].strip()
highlighted_text = (
best_segment["text"]
.replace(best_sentence, f"**{best_sentence}**")
.replace("\n", " ")
)
return {
"text": highlighted_text,
"url": best_segment.get("url"),
"document_id": best_segment.get("document_id"),
"segment_score": float(segment_scores[best_seg_idx]),
"sentence_score": float(sentence_scores[best_sent_idx]),
}
# ------------------------------
# Generators (loaded once)
# ------------------------------
generators = {
"TinyLlama": pipeline(
"text-generation",
model="LoneWolfgang/tinyllama-for-abalone-RAG",
max_new_tokens=150,
temperature=0.1,
),
"FLAN-T5": pipeline(
"text2text-generation",
model="google/flan-t5-base",
max_length=200,
)
}
retriever = Retriever(INDEX_DIR)
# ------------------------------
# Combined function: retrieve β†’ generate
# ------------------------------
def answer_query(query, model_choice):
doc = retriever.retrieve(query)
url = doc["url"]
context = doc["text"].replace("\n", " ")
if model_choice == "No Generation":
# Just return context, no model generation
return (
f"#### Response\n\n"
f"{context}\n\n"
f"---\n"
f"[Source]({url})"
)
else:
prompt = f"""
You answer questions strictly using the provided context.
Context: {context}
Question: {query}
"""
# Choose generator
gen = generators[model_choice]
if model_choice == "TinyLlama":
out = gen(f"<|system|>{prompt}<|assistant|>")[0]["generated_text"]
result = out.split("<|assistant|>")[-1].strip()
else:
# FLAN-T5 returns text in "generated_text"
result = gen(prompt)[0]["generated_text"]
return (
f"#### Response\n\n"
f"{result}\n\n"
f"---\n"
f"#### Context\n\n"
f"{context}\n\n"
f"---\n"
f"[Source]({url})"
)
# ------------------------------
# Gradio UI
# ------------------------------
demo = gr.Interface(
fn=answer_query,
inputs=[
gr.Textbox(label="Enter your question"),
gr.Radio(
["TinyLlama", "FLAN-T5", "No Generation"],
label="Choose Model",
value="No Generation"
)
],
outputs=gr.Markdown(label="Answer"),
title="Abalone RAG Demo",
description="""This RAG system uses [SBERT](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) for initial retrieval and a [Cross Encoder](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2) for re-ranking and highlighting.
Sentence embeddings are computed and [indexed](https://huggingface.co/LoneWolfgang/abalone-index) using FAISS.
For generation, you can choose between:
- [FLAN-T5](https://huggingface.co/google/flan-t5-base) β€” Fast and reliable, the baseline experience.
- [Finetuned TinyLlama](https://huggingface.co/LoneWolfgang/tinyllama-for-abalone-RAG) β€” Slower, but more expressive.
- **No Generation** β€” Only retrieve and highlight relevant context without generating a response. Explore the retrieval quality.
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)