|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
from chronos import ChronosPipeline |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from supabase import create_client, Client |
|
|
import os |
|
|
import plotly.express as px |
|
|
|
|
|
|
|
|
SUPABASE_URL = os.getenv("SUPABASE_URL") |
|
|
SUPABASE_KEY = os.getenv("SUPABASE_KEY") |
|
|
if not SUPABASE_URL or not SUPABASE_KEY: |
|
|
raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set as environment variables.") |
|
|
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) |
|
|
|
|
|
|
|
|
chronos_pipeline = ChronosPipeline.from_pretrained( |
|
|
"amazon/chronos-t5-large", |
|
|
device_map="cuda" if torch.cuda.is_available() else "cpu", |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
qwen_tokenizer = AutoTokenizer.from_pretrained("radm/prophet-qwen3-4b-sft") |
|
|
qwen_model = AutoModelForCausalLM.from_pretrained( |
|
|
"radm/prophet-qwen3-4b-sft", |
|
|
device_map="cuda" if torch.cuda.is_available() else "cpu", |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
def fetch_supabase_data(table_name: str = "sentiment_data") -> pd.DataFrame: |
|
|
"""Fetch time series data from Supabase using the provided API key.""" |
|
|
try: |
|
|
response = supabase.table(table_name).select("date, sentiment").order("date", desc=False).execute() |
|
|
if response.data: |
|
|
df = pd.DataFrame(response.data) |
|
|
df['date'] = pd.to_datetime(df['date']) |
|
|
return df |
|
|
else: |
|
|
raise ValueError("No data found in Supabase table.") |
|
|
except Exception as e: |
|
|
raise Exception(f"Error fetching Supabase data: {str(e)}") |
|
|
|
|
|
def forecast_and_report(data_source: str, csv_file=None, prediction_length: int = 30, table_name: str = "sentiment_data"): |
|
|
"""Run forecasting with Chronos-T5-Large and generate Arabic report with Qwen3-4B-SFT.""" |
|
|
try: |
|
|
|
|
|
if data_source == "Supabase": |
|
|
df = fetch_supabase_data(table_name) |
|
|
else: |
|
|
if not csv_file: |
|
|
return {"error": "Please upload a CSV file."}, None, None |
|
|
df = pd.read_csv(csv_file) |
|
|
if "sentiment" not in df.columns or "date" not in df.columns: |
|
|
return {"error": "CSV must contain 'date' and 'sentiment' columns."}, None, None |
|
|
df['date'] = pd.to_datetime(df['date']) |
|
|
|
|
|
|
|
|
context = torch.tensor(df["sentiment"].values, dtype=torch.float32) |
|
|
|
|
|
|
|
|
forecast = chronos_pipeline.predict(context, prediction_length) |
|
|
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) |
|
|
|
|
|
|
|
|
forecast_dates = pd.date_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D") |
|
|
forecast_df = pd.DataFrame({ |
|
|
"date": forecast_dates, |
|
|
"low": low, |
|
|
"median": median, |
|
|
"high": high |
|
|
}) |
|
|
|
|
|
|
|
|
plot_df = forecast_df.copy() |
|
|
fig = px.line(plot_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast") |
|
|
fig.update_traces(line=dict(color="blue"), selector=dict(name="median")) |
|
|
fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low")) |
|
|
fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high")) |
|
|
|
|
|
|
|
|
prompt = ( |
|
|
"اكتب تقريراً رسمياً بالعربية يلخص توقعات المشاعر للأيام الثلاثين القادمة بناءً على البيانات التالية:\n" |
|
|
f"- متوسط التوقعات: {median[:5].tolist()} (أول 5 أيام)...\n" |
|
|
f"- الحد الأدنى (10%): {low[:5].tolist()}...\n" |
|
|
f"- الحد الأعلى (90%): {high[:5].tolist()}...\n" |
|
|
"التقرير يجب أن يكون موجزاً (200-300 كلمة)، يشرح الاتجاهات، ويستخدم لغة رسمية." |
|
|
) |
|
|
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(qwen_model.device) |
|
|
outputs = qwen_model.generate( |
|
|
inputs["input_ids"], |
|
|
max_new_tokens=500, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9 |
|
|
) |
|
|
report = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return forecast_df.to_dict(), fig, report |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"An error occurred: {str(e)}"}, None, None |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Sentiment Forecasting and Arabic Reporting") |
|
|
data_source = gr.Radio(["Supabase", "CSV Upload"], label="Data Source", value="Supabase") |
|
|
csv_file = gr.File(label="Upload CSV (if CSV selected)") |
|
|
table_name = gr.Textbox(label="Supabase Table Name", value="sentiment_data") |
|
|
prediction_length = gr.Slider(1, 60, value=30, step=1, label="Prediction Length (days)") |
|
|
submit = gr.Button("Run Forecast and Generate Report") |
|
|
output = gr.JSON(label="Forecast Results") |
|
|
plot = gr.Plot(label="Forecast Plot") |
|
|
report = gr.Textbox(label="Arabic Report", lines=10) |
|
|
|
|
|
submit.click( |
|
|
fn=forecast_and_report, |
|
|
inputs=[data_source, csv_file, prediction_length, table_name], |
|
|
outputs=[output, plot, report] |
|
|
) |
|
|
|
|
|
demo.launch() |