GAIAgent / tools.py
antmontieri's picture
Update tools.py
5653b45 verified
"""
Unified Tools Definition for GAIAgent.
Includes: Text Logic, Web Search, PDF Parsing, and Multimodal Capabilities.
All tools return structured Dictionaries.
"""
import os
import sys
import re
import ast
import io
import time
import base64
import json
import requests
import arxiv
from typing import Dict, Any
from io import StringIO
from bs4 import BeautifulSoup
from pypdf import PdfReader
# External Libraries
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_community.document_loaders import WikipediaLoader
from langchain_tavily import TavilySearch
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_mistralai import ChatMistralAI
from youtube_transcript_api import YouTubeTranscriptApi
# --- CONSTANTS ---
GAIA_DATASET_BASE_URL = "https://huggingface.co/datasets/gaia-benchmark/GAIA/resolve/main/2023/validation"
# =============================================================================
# 0. DOWNLOAD HELPER (Internal Use)
# =============================================================================
def _ensure_local_file_exists(file_name: str) -> str:
"""
Downloads a file from the official GAIA dataset repository if it does not exist locally.
Uses HF_TOKEN for authentication to prevent 401 Unauthorized errors.
Returns:
str: The absolute path to the downloaded file.
Raises:
FileNotFoundError: If the file cannot be downloaded.
"""
if not file_name:
return ""
path = os.path.join(os.getcwd(), file_name)
# 1. Local Cache Check
if os.path.exists(path):
return path
# 2. Setup URL and Authentication
download_url = f"{GAIA_DATASET_BASE_URL}/{file_name}"
# Retrieve token from environment variables (Secrets)
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
# Construct Headers
headers = {
"User-Agent": "Mozilla/5.0 (compatible; GAIAgent/1.0)",
}
# CRITICAL: Inject Authorization header if token is present
if hf_token:
headers["Authorization"] = f"Bearer {hf_token}"
else:
print(f"⚠️ WARNING: HF_TOKEN not found. Downloading '{file_name}' might fail with 401 Unauthorized.")
try:
print(f"📥 Downloading file: {file_name} from {download_url}...")
# Use stream=True for efficient memory usage
response = requests.get(download_url, headers=headers, timeout=30, stream=True)
# 3. Handle 404 (File might be in the 'test' set instead of 'validation')
if response.status_code == 404:
fallback_url = download_url.replace("/validation/", "/test/")
print(f"⚠️ File not found in validation set. Trying test set: {fallback_url}")
response = requests.get(fallback_url, headers=headers, timeout=30, stream=True)
# 4. Handle Auth Errors (401/403) explicitly for better debugging
if response.status_code in [401, 403]:
raise PermissionError(f"Authentication failed (Status {response.status_code}). Check your HF_TOKEN.")
response.raise_for_status() # Raise error for other bad status codes
# 5. Write to Disk
with open(path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
# 6. Integrity Check
if os.path.getsize(path) == 0:
os.remove(path)
raise Exception("Downloaded file is empty (0 bytes).")
print(f"✅ Download complete: {path} ({os.path.getsize(path)} bytes)")
return path
except Exception as e:
# Re-raise as FileNotFoundError so the calling tool handles it gracefully
raise FileNotFoundError(f"Failed to download '{file_name}' from GAIA Dataset: {str(e)}")
# =============================================================================
# 1. LOGIC & DATA TOOLS
# =============================================================================
@tool
def python_repl(code: str, file_name: str = None) -> Dict[str, str]:
"""
Executes Python code in a local environment.
Use this for: Math, Logic Puzzles, List Processing, and Excel/CSV analysis (pandas).
Args:
code (str): Valid Python code.
- Do NOT use markdown backticks.
- Ensure the last line is an expression to see the result.
- Assume local files (like .xlsx) are in the current directory.
file_name (str, optional): If the task involves a file (e.g., 'data.xlsx'), pass its name here.
Returns:
Dict: {"output": str} or {"error": str}
"""
# --- 1. DOWNLOAD ESPLICITO ---
# Se l'agente è stato bravo e ha passato il nome del file
if file_name:
try:
_ensure_local_file_exists(file_name)
except Exception as e:
return {"error": str(e)}
# --- 2. SMART AUTO-DOWNLOAD (NUOVO) ---
# Se l'agente si è dimenticato file_name, cerchiamo UUID nel codice.
# Regex per UUID GAIA standard (es. f918266a-b3e0... .py/.xlsx)
gaia_pattern = r"([a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\.[a-z0-9]+)"
found_files = re.findall(gaia_pattern, code)
for found_file in found_files:
# Evita di riscaricare se già gestito sopra
if file_name and found_file == file_name:
continue
try:
print(f"🕵️‍♂️ Auto-detected file in code: {found_file}. Downloading...")
_ensure_local_file_exists(found_file)
except Exception as e:
print(f"⚠️ Auto-download failed for {found_file}: {e}")
# --- 3. SANITIZZAZIONE ---
code = code.strip().strip("`").replace("python\n", "")
old_stdout = sys.stdout
redirected_output = sys.stdout = StringIO()
# --- 4. PREPARAZIONE AMBIENTE (NUOVO) ---
# Iniettiamo librerie comuni per evitare ImportError banali
import pandas as pd
import numpy as np
import math
import subprocess
local_vars = {
"pd": pd,
"np": np,
"math": math,
"subprocess": subprocess
}
try:
tree = ast.parse(code)
last_node = tree.body[-1] if tree.body else None
exec_result = ""
if isinstance(last_node, ast.Expr):
code_body = tree.body[:-1]
last_expr = last_node.value
if code_body:
exec(compile(ast.Module(body=code_body, type_ignores=[]), filename="<string>", mode="exec"), {}, local_vars)
result = eval(compile(ast.Expression(body=last_expr), filename="<string>", mode="eval"), {}, local_vars)
printed = redirected_output.getvalue()
exec_result = f"{printed}\n{str(result)}".strip() if printed else str(result)
else:
exec(code, {}, local_vars)
exec_result = redirected_output.getvalue()
if not exec_result: exec_result = "Code executed successfully (no output)."
return {"output": exec_result}
except Exception as e:
return {"error": f"Python Execution Error: {repr(e)}"}
finally:
sys.stdout = old_stdout
# =============================================================================
# 2. SEARCH & KNOWLEDGE TOOLS
# =============================================================================
@tool
def web_search(query: str) -> Dict[str, str]:
"""
Performs a high-quality web search using Tavily.
Returns Titles, Snippets, and most importantly, URLs (Links).
Use this to find updated information or to discover the URL of a specific article/paper
that you will subsequently read using the 'scrape_website' tool.
Args:
query (str): The search query. Use specific keywords or 'site:domain.com'.
Example: 'site:universetoday.com Carolyn Collins Petersen'
Returns:
Dict: {"results": str (XML formatted)} or {"error": str}
"""
try:
api_key = os.environ.get("TAVILY_API_KEY")
if not api_key:
return {"error": "TAVILY_API_KEY not found."}
tool = TavilySearch(
max_results=5,
search_depth="advanced" # Fondamentale per GAIA
)
response = tool.invoke({"query": query})
results = []
if isinstance(response, dict) and "results" in response:
results = response["results"]
elif isinstance(response, list):
results = response
elif isinstance(response, str):
try:
parsed = json.loads(response)
results = parsed.get("results", [])
except:
pass
if not results:
return {"results": "No relevant results found."}
formatted = []
for doc in results:
url = doc.get('url', 'No URL')
content = doc.get('content', 'No Content')
title = doc.get('title', 'No Title')
formatted.append(
f'<Result>\n'
f'<Title>{title}</Title>\n'
f'<Source>{url}</Source>\n'
f'<Snippet>{content}</Snippet>\n'
f'</Result>'
)
return {"results": "\n\n".join(formatted)}
except Exception as e:
return {"error": f"Search Error: {str(e)}"}
@tool
def wiki_search(query: str) -> Dict[str, str]:
"""
Searches Wikipedia for a topic.
NOTE: This returns TEXT ONLY. It does not show images or revision history.
Args:
query (str): The topic title (e.g. "Thomas Aquinas").
Returns:
Dict: {"wiki_content": str} or {"error": str}
"""
try:
# Load max 2 docs, limit content to avoid context overflow
loader = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=12000)
docs = loader.load()
if not docs:
return {"wiki_content": "No Wikipedia page found."}
formatted = []
for d in docs:
formatted.append(f"Title: {d.metadata.get('title')}\nContent:\n{d.page_content}")
return {"wiki_content": "\n---\n".join(formatted)}
except Exception as e:
return {"error": f"Wikipedia Error: {str(e)}"}
@tool
def arxiv_search(query: str) -> Dict[str, str]:
"""
Searches ArXiv for scientific papers.
Returns Metadata (Abstract) and the PDF URL.
Does NOT return full text. Use 'scrape_website' on the PDF URL to read it.
Args:
query (str): Paper title, author, or keywords.
Returns:
Dict: {"papers": str (XML formatted with PDF_URL)}
"""
try:
client = arxiv.Client()
search = arxiv.Search(query=query, max_results=3, sort_by=arxiv.SortCriterion.Relevance)
results = []
for r in client.results(search):
# Clean abstract
summary = r.summary.replace("\n", " ")
results.append(
f"<Paper>\nTitle: {r.title}\nDate: {r.published.strftime('%Y-%m-%d')}\n"
f"Summary: {summary}\nPDF_URL: {r.pdf_url}\n</Paper>"
)
if not results:
return {"papers": "No papers found."}
return {"papers": "\n\n".join(results)}
except Exception as e:
return {"error": f"Arxiv Error: {str(e)}"}
# =============================================================================
# 3. DOCUMENT READING (HTML & PDF)
# =============================================================================
@tool
def scrape_website(url: str) -> Dict[str, str]:
"""
Scrapes content from a specific URL.
Supports standard HTML webpages AND PDF files.
Args:
url (str): The full URL starting with http:// or https://.
Returns:
Dict: {"content": str} or {"error": str}
"""
try:
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
response = requests.get(url, headers=headers, timeout=20)
response.raise_for_status()
# 1. PDF DETECTION & PARSING
content_type = response.headers.get('Content-Type', '').lower()
if 'application/pdf' in content_type or url.lower().endswith('.pdf'):
try:
reader = PdfReader(io.BytesIO(response.content))
# Extract first 15 pages (covers Intro, Abstract, Acknowledgments)
pages_text = [p.extract_text() for p in reader.pages[:15] if p.extract_text()]
full_text = "\n".join(pages_text)
return {"content": f"[PDF EXTRACTED]\n{full_text[:30000]}..."}
except Exception as e:
return {"error": f"PDF Parsing Failed: {str(e)}"}
# 2. HTML PARSING
soup = BeautifulSoup(response.content, 'html.parser')
# Remove clutter
for junk in soup(["script", "style", "nav", "footer", "iframe", "aside"]):
junk.extract()
text = soup.get_text(separator=' ', strip=True)
return {"content": text[:30000]} # Limit to ~30k chars
except Exception as e:
return {"error": f"Scrape Error: {str(e)}"}
# =============================================================================
# 4. MULTIMODAL TOOLS (Vision, Audio, Video)
# =============================================================================
@tool
def analyze_image(file_name: str, question: str) -> Dict[str, str]:
"""
Analyzes a local image file using Google Gemini 2.5 Flash (Vision).
Use for: Diagrams, Chess boards, Plots, Maps, Photos.
Args:
file_name (str): The local filename (e.g. "chess.png").
question (str): What to ask about the image.
Returns:
Dict: {"image_analysis": str}
"""
try:
# 1. AUTO-DOWNLOAD
try:
_ensure_local_file_exists(file_name)
except Exception as e:
return {"error": str(e)}
# 2. SETUP GEMINI
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key: return {"error": "GEMINI_API_KEY not set."}
path = os.path.join(os.getcwd(), file_name)
if not os.path.exists(path):
return {"error": f"File '{file_name}' not found."}
# Utilizziamo Gemini 1.5 Pro che è eccellente per la visione
llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
google_api_key=api_key,
temperature=0
)
# 3. PREPARAZIONE IMMAGINE
# Rileviamo l'estensione per il MIME type corretto
mime_type = "image/jpeg"
if file_name.lower().endswith(".png"):
mime_type = "image/png"
with open(path, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
msg = HumanMessage(content=[
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{image_data}"}
}
])
print(f"👁️ Analyzing {file_name} with Gemini 1.5 Pro...")
response = llm.invoke([msg])
return {"image_analysis": response.content}
except Exception as e:
return {"error": f"Vision Error (Gemini): {str(e)}"}
@tool
def transcribe_audio(file_name: str) -> Dict[str, str]:
"""
Transcribes a local MP3/WAV audio file to text.
Use for: Podcasts, Recipes, Voice notes.
Args:
file_name (str): The local filename (e.g. "recipe.mp3").
Returns:
Dict: {"transcript": str}
"""
try:
# AUTO-DOWNLOAD
try:
_ensure_local_file_exists(file_name)
except Exception as e:
return {"error": str(e)}
# Requires: pip install openai-whisper
import whisper
path = os.path.join(os.getcwd(), file_name)
if not os.path.exists(path):
return {"error": f"File '{file_name}' not found."}
# Load base model (auto-downloads if needed)
model = whisper.load_model("base")
result = model.transcribe(path)
return {"transcript": result["text"]}
except ImportError:
return {"error": "Library 'openai-whisper' not installed."}
except Exception as e:
return {"error": f"Audio Transcription Error: {str(e)}"}
@tool
def get_youtube_transcript(video_url: str) -> Dict[str, str]:
"""
Extracts the transcript (subtitles) from a YouTube video.
Use for: "What did X say in the video?", summaries.
Args:
video_url (str): The full YouTube URL.
Returns:
Dict: {"transcript": str}
"""
try:
# Extract Video ID
if "v=" in video_url:
vid = video_url.split("v=")[1].split("&")[0]
elif "youtu.be" in video_url:
vid = video_url.split("/")[-1]
else:
return {"error": "Could not parse Video ID."}
# Fetch Transcript
transcript_list = YouTubeTranscriptApi.get_transcript(vid)
full_text = " ".join([t['text'] for t in transcript_list])
return {"transcript": full_text[:20000]} # Limit length
except Exception as e:
return {"error": f"YouTube Error: {str(e)}"}
# =============================================================================
# EXPORT ALL TOOLS
# =============================================================================
TOOLS = [
python_repl,
web_search,
wiki_search,
arxiv_search,
scrape_website,
analyze_image,
transcribe_audio,
get_youtube_transcript
]