from transformers import pipeline import streamlit as st import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from llama_index.core.tools import FunctionTool from llama_index.llms.huggingface.base import HuggingFaceLLM from llama_index.core.agent import ReActAgent st.set_page_config(page_title="Simple Text Summarizer", page_icon="🦙") # 1. Load the summarizer model once and cache it for efficiency. @st.cache_resource def get_summarizer(): return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6") # 2. The tool function now uses the cached model instead of reloading it every time. def summarize_text(text: str) -> str: """Summarizes a long piece of text. Use this tool for long documents or articles.""" summarizer = get_summarizer() summary = summarizer(text, max_length=150, min_length=30, do_sample=False) return summary[0]['summary_text'] # 3. Enable caching for the agent to prevent reloading on every interaction. @st.cache_resource def load_agent() -> ReActAgent: """Loads the LlamaIndex agent and caches it.""" summarize_tool = FunctionTool.from_defaults(fn=summarize_text) model_name = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) llm = HuggingFaceLLM( model=model, tokenizer=tokenizer, query_wrapper_prompt="{query_str}", context_window=2048, max_new_tokens=256 ) agent = ReActAgent.from_tools([summarize_tool], llm=llm, verbose=True) return agent def main(): st.title("LlamaIndex Summarizer Agent 🤖") agent = load_agent() if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Ask a question or paste text to summarize..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Thinking..."): response = agent.chat(prompt) st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": str(response)}) if __name__ == "__main__": main()