Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import hashlib | |
| import os | |
| import aiohttp | |
| import asyncio | |
| import time | |
| from langsmith import traceable | |
| import random | |
| import discord | |
| from transformers import pipeline | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from tqdm import tqdm | |
| import re | |
| import os | |
| st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") | |
| tokenizer = None | |
| model = None | |
| model_name = "teapotai/teapotllm" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| def log_time(func): | |
| async def wrapper(*args, **kwargs): | |
| start_time = time.time() | |
| result = await func(*args, **kwargs) # Make it awaitable | |
| end_time = time.time() | |
| print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds") | |
| return result | |
| return wrapper | |
| API_KEY = os.environ.get("brave_api_key") | |
| async def brave_search(query, count=1): | |
| url = "https://api.search.brave.com/res/v1/web/search" | |
| headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} | |
| params = {"q": query, "count": count} | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url, headers=headers, params=params) as response: | |
| if response.status == 200: | |
| results = await response.json() | |
| print(results) | |
| return [(res["title"], res["description"], res["url"]) for res in results.get("web", {}).get("results", [])] | |
| else: | |
| print(f"Error: {response.status}, {await response.text()}") | |
| return [] | |
| import re | |
| import urllib.request | |
| import html # For decoding HTML escape codes | |
| # Function to extract the first URL from the text and remove others | |
| def extract_first_url(query): | |
| urls = re.findall(r'https?://\S+', query) # Find all URLs | |
| if urls: | |
| # Remove all URLs except the first one | |
| query = re.sub(r'https?://\S+', '', query) # Remove all URLs | |
| first_url = urls[0] | |
| return query, first_url | |
| return query, None | |
| async def extract_text_from_html(url, max_words=250, max_chars=2000): | |
| # Fetch the HTML content asynchronously | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url) as response: | |
| html_content = await response.text() | |
| # Find all text within <p> tags using regular expression | |
| p_tag_content = re.findall(r'<p>(.*?)</p>', html_content, re.DOTALL) | |
| # Remove any HTML tags from the extracted text | |
| clean_text = [re.sub(r'<.*?>', '', p) for p in p_tag_content] | |
| # Decode any HTML escape codes (e.g., < -> <) | |
| decoded_text = [html.unescape(p) for p in clean_text] | |
| # Join all paragraphs into one large string | |
| full_text = ' '.join(decoded_text) | |
| # Split the text into words and get the first `max_words` words | |
| words = full_text.split() | |
| first_words = ' '.join(words[:max_words]) | |
| # Ensure the text does not exceed `max_chars` characters | |
| return first_words[:max_chars] | |
| # pipeline_lock = asyncio.Lock() | |
| async def query_teapot(prompt, context, user_input): | |
| input_text = prompt + "\n" + context + "\n" + user_input | |
| inputs = tokenizer(input_text, return_tensors="pt") | |
| # async with pipeline_lock: # Ensure only one call runs at a time | |
| output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512) | |
| output_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return output_text | |
| async def handle_chat(user_input): | |
| results = [] | |
| ### Handle logic for scraping, search or translation | |
| prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization.""" | |
| # Check if there's a URL and process the input | |
| processed_query, url = extract_first_url(user_input) | |
| # If there's a URL, fetch the context | |
| if url: | |
| search_start_time = time.time() | |
| context = await extract_text_from_html(url) | |
| user_input = processed_query | |
| search_end_time = time.time() | |
| else: | |
| # Custom prompt shims | |
| if "translate" in user_input: | |
| search_start_time = time.time() | |
| context="" | |
| prompt="" | |
| search_end_time = time.time() | |
| else: # Search task | |
| search_start_time = time.time() | |
| if len(user_input)<400 and "context:" not in user_input and "Context:" not in user_input: | |
| results = await brave_search(user_input) | |
| if len(results)==0: # No information | |
| return "I'm sorry but I don't have any information on that.", "" | |
| documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results] | |
| context = "\n".join(documents) | |
| else: | |
| context="" # User provide context | |
| search_end_time = time.time() | |
| generation_start_time = time.time() | |
| response = await query_teapot(prompt, context, user_input) | |
| generation_end_time = time.time() | |
| debug_info = f""" | |
| Prompt: | |
| {prompt} | |
| Context: | |
| {context} | |
| Query: | |
| {user_input} | |
| Search time: {search_end_time - search_start_time:.2f} seconds | |
| Generation time: {generation_end_time - generation_start_time:.2f} seconds | |
| Response: {response} | |
| """ | |
| return response, debug_info | |
| st.write("418 I'm a teapot") | |
| DISCORD_TOKEN = os.environ.get("discord_key") | |
| # Create an instance of Intents and enable the required ones | |
| intents = discord.Intents.default() # Default intents enable basic functionality | |
| intents.messages = True # Enable message-related events | |
| # Create an instance of a client with the intents | |
| client = discord.Client(intents=intents) | |
| # Event when the bot has connected to the server | |
| async def on_ready(): | |
| print(f'Logged in as {client.user}') | |
| # Event when a message is received | |
| async def on_message(message): | |
| # Check if the message is from the bot itself to prevent a loop | |
| if message.author == client.user: | |
| return | |
| # Exit the function if the bot is not mentioned | |
| if f'<@{client.user.id}>' not in message.content: | |
| return | |
| print(message.content) | |
| is_debug = "debug:" in message.content or "Debug:" in message.content | |
| async with message.channel.typing(): | |
| cleaned_message=message.content.replace("debug:", "").replace("Debug:","").replace(f'<@{client.user.id}>',"") | |
| response, debug_info = await handle_chat(cleaned_message) | |
| print(response) | |
| sent_message = await message.reply(response) | |
| # Create a thread from the sent message | |
| if is_debug: | |
| thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message[:80]}'""", auto_archive_duration=60) | |
| # Send a message in the created thread | |
| await thread.send(debug_info) | |
| def initialize(): | |
| st.session_state["initialized"] = True | |
| client.run(DISCORD_TOKEN) | |
| return | |
| initialize() | |