zaki / app.py
Abdullah Zaki
files
387baae
raw
history blame
5.45 kB
import gradio as gr
import pandas as pd
import numpy as np
import torch
from chronos import ChronosPipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from supabase import create_client, Client
import os
import plotly.express as px
# Initialize Supabase client with API key from environment variables
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
if not SUPABASE_URL or not SUPABASE_KEY:
raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set as environment variables.")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# Initialize Chronos-T5-Large for forecasting
chronos_pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-large",
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16
)
# Initialize Prophet-Qwen3-4B-SFT for Arabic reports
qwen_tokenizer = AutoTokenizer.from_pretrained("radm/prophet-qwen3-4b-sft")
qwen_model = AutoModelForCausalLM.from_pretrained(
"radm/prophet-qwen3-4b-sft",
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16
)
def fetch_supabase_data(table_name: str = "sentiment_data") -> pd.DataFrame:
"""Fetch time series data from Supabase using the provided API key."""
try:
response = supabase.table(table_name).select("date, sentiment").order("date", desc=False).execute()
if response.data:
df = pd.DataFrame(response.data)
df['date'] = pd.to_datetime(df['date'])
return df
else:
raise ValueError("No data found in Supabase table.")
except Exception as e:
raise Exception(f"Error fetching Supabase data: {str(e)}")
def forecast_and_report(data_source: str, csv_file=None, prediction_length: int = 30, table_name: str = "sentiment_data"):
"""Run forecasting with Chronos-T5-Large and generate Arabic report with Qwen3-4B-SFT."""
try:
# Load data
if data_source == "Supabase":
df = fetch_supabase_data(table_name)
else:
if not csv_file:
return {"error": "Please upload a CSV file."}, None, None
df = pd.read_csv(csv_file)
if "sentiment" not in df.columns or "date" not in df.columns:
return {"error": "CSV must contain 'date' and 'sentiment' columns."}, None, None
df['date'] = pd.to_datetime(df['date'])
# Prepare time series
context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
# Run forecast
forecast = chronos_pipeline.predict(context, prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
# Format forecast results
forecast_dates = pd.date_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D")
forecast_df = pd.DataFrame({
"date": forecast_dates,
"low": low,
"median": median,
"high": high
})
# Create forecast plot
plot_df = forecast_df.copy()
fig = px.line(plot_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
fig.update_traces(line=dict(color="blue"), selector=dict(name="median"))
fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
# Generate Arabic report
prompt = (
"اكتب تقريراً رسمياً بالعربية يلخص توقعات المشاعر للأيام الثلاثين القادمة بناءً على البيانات التالية:\n"
f"- متوسط التوقعات: {median[:5].tolist()} (أول 5 أيام)...\n"
f"- الحد الأدنى (10%): {low[:5].tolist()}...\n"
f"- الحد الأعلى (90%): {high[:5].tolist()}...\n"
"التقرير يجب أن يكون موجزاً (200-300 كلمة)، يشرح الاتجاهات، ويستخدم لغة رسمية."
)
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(qwen_model.device)
outputs = qwen_model.generate(
inputs["input_ids"],
max_new_tokens=500,
do_sample=True,
temperature=0.7,
top_p=0.9
)
report = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
return forecast_df.to_dict(), fig, report
except Exception as e:
return {"error": f"An error occurred: {str(e)}"}, None, None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Sentiment Forecasting and Arabic Reporting")
data_source = gr.Radio(["Supabase", "CSV Upload"], label="Data Source", value="Supabase")
csv_file = gr.File(label="Upload CSV (if CSV selected)")
table_name = gr.Textbox(label="Supabase Table Name", value="sentiment_data")
prediction_length = gr.Slider(1, 60, value=30, step=1, label="Prediction Length (days)")
submit = gr.Button("Run Forecast and Generate Report")
output = gr.JSON(label="Forecast Results")
plot = gr.Plot(label="Forecast Plot")
report = gr.Textbox(label="Arabic Report", lines=10)
submit.click(
fn=forecast_and_report,
inputs=[data_source, csv_file, prediction_length, table_name],
outputs=[output, plot, report]
)
demo.launch()