| import gradio as gr | |
| import plotly.express as px | |
| import pandas as pd | |
| import logging | |
| import whisper | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| from torch.nn.functional import silu | |
| from torch.nn.functional import softplus | |
| from einops import rearrange, repeat, einsum | |
| from transformers import AutoTokenizer, AutoModel | |
| from torch import Tensor | |
| from einops import rearrange | |
| from model import Mamba | |
| logging.basicConfig(level=logging.INFO) | |
| def plotly_plot_text(text): | |
| data = pd.DataFrame() | |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] | |
| data['Probability'] = model.predict_proba([text])[0].tolist() | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| f"π£οΈ Transcription:\n{text}", | |
| f"## π Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" | |
| ) | |
| def transcribe_audio(audio_path): | |
| whisper_model = whisper.load_model("base") | |
| try: | |
| result = whisper_model.transcribe(audio_path, fp16=False) | |
| return result.get('text', '') | |
| except Exception as e: | |
| logging.error(f"Transcription failed: {e}") | |
| return "" | |
| def plotly_plot_audio(audio_path): | |
| data = pd.DataFrame() | |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] | |
| try: | |
| text = transcribe_audio(audio_path) | |
| data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| f"π£οΈ Transcription:\n{text}", | |
| f"## π Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" | |
| ) | |
| except Exception as e: | |
| logging.error(f"Processing failed: {e}") | |
| data['Probability'] = [0] * data.shape[0] | |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") | |
| return ( | |
| p, | |
| "β Error processing audio", | |
| "β οΈ Processing Error" | |
| ) | |
| def create_demo(): | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection") as demo: | |
| gr.Markdown("# Text-based bilingual emotion recognition") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Record or Upload Audio", | |
| format="wav", | |
| interactive=True | |
| ) | |
| with gr.Column(): | |
| text_input = gr.Text(label="Write Text") | |
| with gr.Row(): | |
| top_emotion = gr.Markdown("## π Dominant Emotion: Waiting for input ...", | |
| elem_classes="dominant-emotion") | |
| with gr.Row(): | |
| text_plot = gr.Plot(label="Text Analysis") | |
| transcription = gr.Textbox( | |
| label="π Transcription Results", | |
| placeholder="Transcribed text will appear here...", | |
| lines=3, | |
| max_lines=6 | |
| ) | |
| if text_input is not None: | |
| text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, transcription, top_emotion]) | |
| elif audio_input is not None: | |
| audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion]) | |
| return demo | |
| if __name__ == "__main__": | |
| model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device) | |
| checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| demo = create_demo() | |
| demo.launch() |