saikamal1108 commited on
Commit
8ca0243
·
verified ·
1 Parent(s): 5438852

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -23
app.py CHANGED
@@ -1,27 +1,63 @@
1
- # app.py
2
  import streamlit as st
3
  import soundfile as sf
4
  import torch
5
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
- from train_pipeline import train_model
7
-
8
- st.title("Telugu ASR - XLS-R Training + Inference")
9
-
10
- if st.button("Start Training"):
11
- st.write("🚀 Training started…")
12
- train_model()
13
- st.success("Training complete!")
14
-
15
- uploaded = st.file_uploader("Upload WAV audio", type=["wav"])
16
- if uploaded:
17
- audio, sr = sf.read(uploaded)
18
- processor = Wav2Vec2Processor.from_pretrained("./model")
19
- model = Wav2Vec2ForCTC.from_pretrained("./model").to("cpu")
20
-
21
- inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
22
- with torch.no_grad():
23
- logits = model(inputs.input_values).logits
24
-
25
- pred_ids = torch.argmax(logits, dim=-1)
26
- st.subheader("Transcription")
27
- st.write(processor.batch_decode(pred_ids)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import soundfile as sf
3
  import torch
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
+ import os
6
+ import subprocess
7
+
8
+ st.set_page_config(page_title="Telugu ASR Training", layout="wide")
9
+
10
+ st.title("Telugu ASR - Train XLS-R + Run Inference")
11
+
12
+ # ============================================================
13
+ # SECTION 1 — TRAIN BUTTON
14
+ # ============================================================
15
+
16
+ st.header("Train the Model")
17
+
18
+ st.write("""
19
+ Click the button below to start training the Wav2Vec2-XLSR model on Telugu dataset.
20
+ This will run train_pipeline.py inside this Space.
21
+ """)
22
+
23
+ if st.button("🚀 Start Training"):
24
+ st.info("Training started... this may take several minutes to hours depending on GPU.")
25
+
26
+ # run training script as a subprocess so Streamlit initializes properly
27
+ process = subprocess.Popen(
28
+ ["python3", "train_pipeline.py"],
29
+ stdout=subprocess.PIPE,
30
+ stderr=subprocess.PIPE,
31
+ text=True
32
+ )
33
+
34
+ st.success("Training script launched. Check the Space logs for progress.")
35
+
36
+ # ============================================================
37
+ # SECTION 2 — INFERENCE
38
+ # ============================================================
39
+
40
+ st.header("Inference (Upload Audio)")
41
+
42
+ if not os.path.exists("./model"):
43
+ st.warning("No model found. Train it first.")
44
+ else:
45
+ uploaded_audio = st.file_uploader("Upload WAV File", type=["wav"])
46
+
47
+ if uploaded_audio is not None:
48
+ audio, sr = sf.read(uploaded_audio)
49
+
50
+ # Load trained model
51
+ processor = Wav2Vec2Processor.from_pretrained("./model")
52
+ model = Wav2Vec2ForCTC.from_pretrained("./model")
53
+
54
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
55
+
56
+ with torch.no_grad():
57
+ logits = model(inputs.input_values).logits
58
+
59
+ pred_ids = torch.argmax(logits, dim=-1)
60
+ transcription = processor.batch_decode(pred_ids)[0]
61
+
62
+ st.subheader("Transcription")
63
+ st.write(transcription)