import streamlit as st import soundfile as sf import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import os import subprocess st.set_page_config(page_title="Telugu ASR Training", layout="wide") st.title("Telugu ASR - Train XLS-R + Run Inference") # ============================================================ # SECTION 1 — TRAIN BUTTON # ============================================================ st.header("Train the Model") st.write(""" Click the button below to start training the Wav2Vec2-XLSR model on Telugu dataset. This will run train_pipeline.py inside this Space. """) if st.button("🚀 Start Training"): st.info("Training started... this may take several minutes to hours depending on GPU.") # run training script as a subprocess so Streamlit initializes properly process = subprocess.Popen( ["python3", "train_pipeline.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) st.success("Training script launched. Check the Space logs for progress.") # ============================================================ # SECTION 2 — INFERENCE # ============================================================ st.header("Inference (Upload Audio)") if not os.path.exists("./model"): st.warning("No model found. Train it first.") else: uploaded_audio = st.file_uploader("Upload WAV File", type=["wav"]) if uploaded_audio is not None: audio, sr = sf.read(uploaded_audio) # Load trained model processor = Wav2Vec2Processor.from_pretrained("./model") model = Wav2Vec2ForCTC.from_pretrained("./model") inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True) with torch.no_grad(): logits = model(inputs.input_values).logits pred_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(pred_ids)[0] st.subheader("Transcription") st.write(transcription)