tsissam's picture
Update app.py
18bc323 verified
raw
history blame
2.95 kB
import faiss
import pickle
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.vectorstores import FAISS as LangChainFAISS
from langchain.docstore import InMemoryDocstore
from langchain.schema import Document
from langchain.chains import RetrievalQA
import gradio as gr
# Paths (relative to app root)
vector_path = "vector_store_faiss_chroma/faiss_index.index"
metadata_path = "vector_store_faiss_chroma/metadata.pkl"
#model_path = "HuggingFaceModels/falcon-1b-instruct"
model_path = "tiiuae/Falcon3-1B-Instruct"
# Load the FAISS index
faiss_index = faiss.read_index(f"{vector_path}")
# Load metadata (text chunks)
with open(f"{metadata_path}", "rb") as f:
metadata = pickle.load(f)
# Rebuild LangChain Documents
docs = [Document(page_content=doc["page_content"]) for doc in metadata]
# Link documents to FAISS vectors
docstore = InMemoryDocstore({str(i): docs[i] for i in range(len(docs))})
id_map = {i: str(i) for i in range(len(docs))}
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
# Create a generation pipeline
text_generator_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
return_full_text=False,
max_new_tokens=512,
temperature=0.2
)
# Wrap it as a LangChain LLM
llm = HuggingFacePipeline(pipeline=text_generator_pipeline)
# Re-declare embedding function
embed_fn = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Create vectorstore and retriever
vectorstore_faiss = LangChainFAISS(
index=faiss_index,
docstore=docstore,
index_to_docstore_id=id_map,
embedding_function=embed_fn # Not needed for retrieval only
)
# Create a retriever that returns top-k most relevant chunks
retriever = vectorstore_faiss.as_retriever(search_kwargs={"k": 3})
# Create the RAG pipeline (Retriever + LLM)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
# πŸ” Chatbot function: takes a user question, returns generated answer
def ask_rag(query):
result = qa_chain({"query": query})
answer = result["result"]
# Optional: include sources (limited to 2)
sources = result.get("source_documents", [])
source_texts = "\n\n".join([f"πŸ”Ή Source {i+1}:\n{doc.page_content[:300]}..." for i, doc in enumerate(sources[:2])])
return f"πŸ“˜ Answer:\n{answer}\n\nπŸ“š Sources:\n{source_texts}"
# πŸŽ›οΈ Gradio UI components
gr.Interface(
fn=ask_rag,
inputs=gr.Textbox(lines=2, placeholder="Ask me about UCT admissions, housing, fees..."),
outputs="text",
title="πŸŽ“ University of Cape Town Course Advisor Chatbot",
description="Ask academic questions. Powered by FAISS + Falcon-E-1B + LangChain.",
allow_flagging="never"
).launch()