Abdullah Zaki
		
	commited on
		
		
					Commit 
							
							·
						
						387baae
	
1
								Parent(s):
							
							6d7551e
								
files
Browse files- .env +2 -0
- app.py +124 -0
- requirements.txt +7 -0
    	
        .env
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            SUPABASE_URL=https://hgsdcoqgvdjuxvcscqzn.supabase.co
         | 
| 2 | 
            +
            SUPABASE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imhnc2Rjb3FndmRqdXh2Y3NjcXpuIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkxNTMxNDEsImV4cCI6MjA2NDcyOTE0MX0.pYigfNha5pge2DMj9sMOwQ1RUqwh2Cy_zQws3A5IwRo
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,124 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from chronos import ChronosPipeline
         | 
| 6 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 7 | 
            +
            from supabase import create_client, Client
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import plotly.express as px
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Initialize Supabase client with API key from environment variables
         | 
| 12 | 
            +
            SUPABASE_URL = os.getenv("SUPABASE_URL")
         | 
| 13 | 
            +
            SUPABASE_KEY = os.getenv("SUPABASE_KEY")
         | 
| 14 | 
            +
            if not SUPABASE_URL or not SUPABASE_KEY:
         | 
| 15 | 
            +
                raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set as environment variables.")
         | 
| 16 | 
            +
            supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Initialize Chronos-T5-Large for forecasting
         | 
| 19 | 
            +
            chronos_pipeline = ChronosPipeline.from_pretrained(
         | 
| 20 | 
            +
                "amazon/chronos-t5-large",
         | 
| 21 | 
            +
                device_map="cuda" if torch.cuda.is_available() else "cpu",
         | 
| 22 | 
            +
                torch_dtype=torch.bfloat16
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Initialize Prophet-Qwen3-4B-SFT for Arabic reports
         | 
| 26 | 
            +
            qwen_tokenizer = AutoTokenizer.from_pretrained("radm/prophet-qwen3-4b-sft")
         | 
| 27 | 
            +
            qwen_model = AutoModelForCausalLM.from_pretrained(
         | 
| 28 | 
            +
                "radm/prophet-qwen3-4b-sft",
         | 
| 29 | 
            +
                device_map="cuda" if torch.cuda.is_available() else "cpu",
         | 
| 30 | 
            +
                torch_dtype=torch.bfloat16
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def fetch_supabase_data(table_name: str = "sentiment_data") -> pd.DataFrame:
         | 
| 34 | 
            +
                """Fetch time series data from Supabase using the provided API key."""
         | 
| 35 | 
            +
                try:
         | 
| 36 | 
            +
                    response = supabase.table(table_name).select("date, sentiment").order("date", desc=False).execute()
         | 
| 37 | 
            +
                    if response.data:
         | 
| 38 | 
            +
                        df = pd.DataFrame(response.data)
         | 
| 39 | 
            +
                        df['date'] = pd.to_datetime(df['date'])
         | 
| 40 | 
            +
                        return df
         | 
| 41 | 
            +
                    else:
         | 
| 42 | 
            +
                        raise ValueError("No data found in Supabase table.")
         | 
| 43 | 
            +
                except Exception as e:
         | 
| 44 | 
            +
                    raise Exception(f"Error fetching Supabase data: {str(e)}")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def forecast_and_report(data_source: str, csv_file=None, prediction_length: int = 30, table_name: str = "sentiment_data"):
         | 
| 47 | 
            +
                """Run forecasting with Chronos-T5-Large and generate Arabic report with Qwen3-4B-SFT."""
         | 
| 48 | 
            +
                try:
         | 
| 49 | 
            +
                    # Load data
         | 
| 50 | 
            +
                    if data_source == "Supabase":
         | 
| 51 | 
            +
                        df = fetch_supabase_data(table_name)
         | 
| 52 | 
            +
                    else:
         | 
| 53 | 
            +
                        if not csv_file:
         | 
| 54 | 
            +
                            return {"error": "Please upload a CSV file."}, None, None
         | 
| 55 | 
            +
                        df = pd.read_csv(csv_file)
         | 
| 56 | 
            +
                        if "sentiment" not in df.columns or "date" not in df.columns:
         | 
| 57 | 
            +
                            return {"error": "CSV must contain 'date' and 'sentiment' columns."}, None, None
         | 
| 58 | 
            +
                        df['date'] = pd.to_datetime(df['date'])
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # Prepare time series
         | 
| 61 | 
            +
                    context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    # Run forecast
         | 
| 64 | 
            +
                    forecast = chronos_pipeline.predict(context, prediction_length)
         | 
| 65 | 
            +
                    low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    # Format forecast results
         | 
| 68 | 
            +
                    forecast_dates = pd.date_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D")
         | 
| 69 | 
            +
                    forecast_df = pd.DataFrame({
         | 
| 70 | 
            +
                        "date": forecast_dates,
         | 
| 71 | 
            +
                        "low": low,
         | 
| 72 | 
            +
                        "median": median,
         | 
| 73 | 
            +
                        "high": high
         | 
| 74 | 
            +
                    })
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # Create forecast plot
         | 
| 77 | 
            +
                    plot_df = forecast_df.copy()
         | 
| 78 | 
            +
                    fig = px.line(plot_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
         | 
| 79 | 
            +
                    fig.update_traces(line=dict(color="blue"), selector=dict(name="median"))
         | 
| 80 | 
            +
                    fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
         | 
| 81 | 
            +
                    fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # Generate Arabic report
         | 
| 84 | 
            +
                    prompt = (
         | 
| 85 | 
            +
                        "اكتب تقريراً رسمياً بالعربية يلخص توقعات المشاعر للأيام الثلاثين القادمة بناءً على البيانات التالية:\n"
         | 
| 86 | 
            +
                        f"- متوسط التوقعات: {median[:5].tolist()} (أول 5 أيام)...\n"
         | 
| 87 | 
            +
                        f"- الحد الأدنى (10%): {low[:5].tolist()}...\n"
         | 
| 88 | 
            +
                        f"- الحد الأعلى (90%): {high[:5].tolist()}...\n"
         | 
| 89 | 
            +
                        "التقرير يجب أن يكون موجزاً (200-300 كلمة)، يشرح الاتجاهات، ويستخدم لغة رسمية."
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    inputs = qwen_tokenizer(prompt, return_tensors="pt").to(qwen_model.device)
         | 
| 92 | 
            +
                    outputs = qwen_model.generate(
         | 
| 93 | 
            +
                        inputs["input_ids"],
         | 
| 94 | 
            +
                        max_new_tokens=500,
         | 
| 95 | 
            +
                        do_sample=True,
         | 
| 96 | 
            +
                        temperature=0.7,
         | 
| 97 | 
            +
                        top_p=0.9
         | 
| 98 | 
            +
                    )
         | 
| 99 | 
            +
                    report = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    return forecast_df.to_dict(), fig, report
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                except Exception as e:
         | 
| 104 | 
            +
                    return {"error": f"An error occurred: {str(e)}"}, None, None
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            # Gradio interface
         | 
| 107 | 
            +
            with gr.Blocks() as demo:
         | 
| 108 | 
            +
                gr.Markdown("# Sentiment Forecasting and Arabic Reporting")
         | 
| 109 | 
            +
                data_source = gr.Radio(["Supabase", "CSV Upload"], label="Data Source", value="Supabase")
         | 
| 110 | 
            +
                csv_file = gr.File(label="Upload CSV (if CSV selected)")
         | 
| 111 | 
            +
                table_name = gr.Textbox(label="Supabase Table Name", value="sentiment_data")
         | 
| 112 | 
            +
                prediction_length = gr.Slider(1, 60, value=30, step=1, label="Prediction Length (days)")
         | 
| 113 | 
            +
                submit = gr.Button("Run Forecast and Generate Report")
         | 
| 114 | 
            +
                output = gr.JSON(label="Forecast Results")
         | 
| 115 | 
            +
                plot = gr.Plot(label="Forecast Plot")
         | 
| 116 | 
            +
                report = gr.Textbox(label="Arabic Report", lines=10)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                submit.click(
         | 
| 119 | 
            +
                    fn=forecast_and_report,
         | 
| 120 | 
            +
                    inputs=[data_source, csv_file, prediction_length, table_name],
         | 
| 121 | 
            +
                    outputs=[output, plot, report]
         | 
| 122 | 
            +
                )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            demo.launch()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch>=2.0.0
         | 
| 2 | 
            +
            transformers>=4.35.0
         | 
| 3 | 
            +
            gradio>=4.0.0
         | 
| 4 | 
            +
            pandas>=2.0.0
         | 
| 5 | 
            +
            numpy>=1.24.0
         | 
| 6 | 
            +
            supabase>=2.0.0
         | 
| 7 | 
            +
            git+https://github.com/amazon-science/chronos-forecasting.git
         | 
