""" Vietnamese Speaker Profiling - Multi-Model Gradio Interface Supports: Vietnamese Wav2Vec2 and PhoWhisper encoders """ import os import torch import torchaudio import gradio as gr from pathlib import Path from safetensors.torch import load_file as load_safetensors # Model configurations MODELS_CONFIG = { "Wav2Vec2 Vietnamese": { "path": "model/vulehuubinh", "encoder_name": "nguyenvulebinh/wav2vec2-base-vi-vlsp2020", "is_whisper": False, "description": "Vietnamese Wav2Vec2 pretrained model - Fast inference" }, "PhoWhisper": { "path": "model/pho", "encoder_name": "vinai/PhoWhisper-base", "is_whisper": True, "description": "Vietnamese Whisper model - Higher accuracy" } } # Labels GENDER_LABELS = ["Male", "Female"] DIALECT_LABELS = ["Northern", "Central", "Southern"] class MultiModelProfiler: """Speaker Profiler supporting multiple encoder models.""" def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.sampling_rate = 16000 self.models = {} self.processors = {} self.current_model = None print(f"Using device: {self.device}") # Pre-load all models self._load_all_models() def _load_all_models(self): """Load all available models.""" for model_name, config in MODELS_CONFIG.items(): model_path = Path(config["path"]) if model_path.exists(): print(f"Loading {model_name}...") self._load_single_model(model_name, config) else: print(f"Model not found: {model_path}") def _load_single_model(self, model_name: str, config: dict): """Load a specific model.""" try: model_path = Path(config["path"]) is_whisper = config["is_whisper"] encoder_name = config["encoder_name"] # Load processor if is_whisper: from transformers import WhisperFeatureExtractor processor = WhisperFeatureExtractor.from_pretrained(encoder_name) else: from transformers import Wav2Vec2FeatureExtractor processor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_name) # Load model - use MultiTaskSpeakerModel from src.models import MultiTaskSpeakerModel model = MultiTaskSpeakerModel( model_name=encoder_name, num_genders=2, num_dialects=3, dropout=0.1, freeze_encoder=True ) # Load checkpoint from safetensors checkpoint_path = model_path / "model.safetensors" if checkpoint_path.exists(): state_dict = load_safetensors(str(checkpoint_path)) model.load_state_dict(state_dict) print(f"Loaded checkpoint: {checkpoint_path}") else: # Try loading from .pt file pt_path = model_path / "best_model.pt" if pt_path.exists(): checkpoint = torch.load(pt_path, map_location=self.device, weights_only=False) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) else: model.load_state_dict(checkpoint) print(f"Loaded checkpoint: {pt_path}") model.to(self.device) model.eval() self.models[model_name] = model self.processors[model_name] = processor if self.current_model is None: self.current_model = model_name print(f"โœ“ {model_name} loaded successfully") except Exception as e: print(f"โœ— Error loading {model_name}: {e}") import traceback traceback.print_exc() def predict(self, audio_path: str, model_name: str): """Predict gender and dialect from audio.""" if model_name not in self.models: available = list(self.models.keys()) if not available: return "No models available", "No models available" model_name = available[0] try: model = self.models[model_name] processor = self.processors[model_name] is_whisper = MODELS_CONFIG[model_name]["is_whisper"] # Load audio waveform, sr = torchaudio.load(audio_path) # Convert to mono if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample if needed if sr != self.sampling_rate: resampler = torchaudio.transforms.Resample(sr, self.sampling_rate) waveform = resampler(waveform) waveform = waveform.squeeze(0).numpy() # Process based on model type if is_whisper: # Whisper requires exactly 30 seconds of audio whisper_length = self.sampling_rate * 30 # 480000 samples if len(waveform) < whisper_length: waveform_padded = torch.nn.functional.pad( torch.tensor(waveform), (0, whisper_length - len(waveform)) ).numpy() else: waveform_padded = waveform[:whisper_length] inputs = processor( waveform_padded, sampling_rate=self.sampling_rate, return_tensors="pt" ) input_tensor = inputs.input_features.to(self.device) else: # Wav2Vec2 uses raw waveform inputs = processor( waveform, sampling_rate=self.sampling_rate, return_tensors="pt", padding=True ) input_tensor = inputs.input_values.to(self.device) # Inference with torch.no_grad(): gender_logits, dialect_logits = model(input_tensor) gender_probs = torch.softmax(gender_logits, dim=-1) dialect_probs = torch.softmax(dialect_logits, dim=-1) gender_idx = gender_probs.argmax(dim=-1).item() dialect_idx = dialect_probs.argmax(dim=-1).item() gender_conf = gender_probs[0, gender_idx].item() * 100 dialect_conf = dialect_probs[0, dialect_idx].item() * 100 gender_result = f"{GENDER_LABELS[gender_idx]} ({gender_conf:.1f}%)" dialect_result = f"{DIALECT_LABELS[dialect_idx]} ({dialect_conf:.1f}%)" return gender_result, dialect_result except Exception as e: import traceback traceback.print_exc() return f"Error: {str(e)}", f"Error: {str(e)}" def get_available_models(self): """Get list of available models.""" return list(self.models.keys()) def create_interface(): """Create Gradio interface with model selection.""" profiler = MultiModelProfiler() available_models = profiler.get_available_models() if not available_models: available_models = ["No models available"] def predict_wrapper(audio, model_name): if audio is None: return "Please upload audio", "Please upload audio" return profiler.predict(audio, model_name) # Create model info text model_info = "" for name, config in MODELS_CONFIG.items(): status = "โœ“" if name in profiler.models else "โœ—" model_info += f"{status} **{name}**: {config['description']}\n" # Use gr.Blocks without theme for compatibility with older Gradio with gr.Blocks(title="Vietnamese Speaker Profiling") as demo: gr.Markdown( """ # ๐ŸŽ™๏ธ Vietnamese Speaker Profiling Analyze Vietnamese speech to predict **Gender** and **Dialect Region**. Supports multiple AI models - choose the one that works best for you! """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“ค Input") audio_input = gr.Audio( label="Upload or Record Audio", type="filepath", sources=["upload", "microphone"] ) model_dropdown = gr.Dropdown( choices=available_models, value=available_models[0] if available_models else None, label="๐Ÿค– Select Model", info="Choose the AI model for analysis" ) submit_btn = gr.Button("๐Ÿ” Analyze", variant="primary", size="lg") gr.Markdown("### โ„น๏ธ Available Models") gr.Markdown(model_info) with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“Š Results") gender_output = gr.Textbox(label="๐Ÿ‘ค Gender", interactive=False) dialect_output = gr.Textbox(label="๐Ÿ—ฃ๏ธ Dialect Region", interactive=False) gr.Markdown( """ ### ๐Ÿ“– Dialect Regions - **Northern**: Hanoi and surrounding areas - **Central**: Huแบฟ, ฤร  Nแบตng, and Central Vietnam - **Southern**: Ho Chi Minh City and Southern Vietnam """ ) submit_btn.click( fn=predict_wrapper, inputs=[audio_input, model_dropdown], outputs=[gender_output, dialect_output] ) gr.Markdown( """ --- *Made with โค๏ธ for Vietnamese Speech Processing Research* """ ) return demo if __name__ == "__main__": demo = create_interface() demo.launch(server_name="0.0.0.0", server_port=7860, share=False)