|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import librosa |
|
|
import tempfile |
|
|
import os |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from functools import partial |
|
|
|
|
|
class AudioTranscriber: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.processor = None |
|
|
self.transcription_cache = {} |
|
|
self.max_segment_duration = 5.0 |
|
|
|
|
|
def set_max_segment_duration(self, duration): |
|
|
"""Set the maximum duration for any segment in seconds""" |
|
|
self.max_segment_duration = duration |
|
|
|
|
|
def load_model(self): |
|
|
"""Load a lightweight transcription model""" |
|
|
if self.model is None: |
|
|
with st.spinner("Loading transcription model..."): |
|
|
try: |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
self.model = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model="openai/whisper-small", |
|
|
chunk_length_s=30, |
|
|
device="cpu" |
|
|
) |
|
|
except Exception as e: |
|
|
st.warning(f"Error loading transcription model: {str(e)}. Using fallback method.") |
|
|
self.model = None |
|
|
|
|
|
return self.model |
|
|
|
|
|
def segment_audio(self, audio_file, num_segments=5, min_segment_duration=3.0): |
|
|
"""Segment the audio file into chunks for processing with minimum 3-second and maximum 5-second duration""" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: |
|
|
tmp_file.write(audio_file.getvalue()) |
|
|
audio_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
y, sr = librosa.load(audio_path, sr=None) |
|
|
|
|
|
|
|
|
duration = librosa.get_duration(y=y, sr=sr) |
|
|
|
|
|
|
|
|
|
|
|
ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1) |
|
|
|
|
|
|
|
|
actual_segments = max(ideal_segments, int(duration / min_segment_duration)) |
|
|
|
|
|
|
|
|
segment_duration = min(duration / actual_segments, self.max_segment_duration) |
|
|
|
|
|
|
|
|
segments = [] |
|
|
timestamps = [] |
|
|
|
|
|
|
|
|
current_time = 0 |
|
|
while current_time < duration: |
|
|
start_time = current_time |
|
|
end_time = min(start_time + segment_duration, duration) |
|
|
|
|
|
|
|
|
start_sample = int(start_time * sr) |
|
|
end_sample = int(end_time * sr) |
|
|
|
|
|
|
|
|
segment = y[start_sample:end_sample] |
|
|
segments.append(segment) |
|
|
timestamps.append((start_time, end_time)) |
|
|
|
|
|
current_time = end_time |
|
|
|
|
|
return segments, timestamps |
|
|
|
|
|
except Exception as e: |
|
|
st.warning(f"Error segmenting audio: {str(e)}. Using simplified segmentation.") |
|
|
|
|
|
|
|
|
try: |
|
|
y, sr = sf.read(audio_path) |
|
|
duration = len(y) / sr |
|
|
|
|
|
|
|
|
ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1) |
|
|
|
|
|
|
|
|
actual_segments = max(ideal_segments, int(duration / min_segment_duration)) |
|
|
|
|
|
|
|
|
segment_duration = min(duration / actual_segments, self.max_segment_duration) |
|
|
|
|
|
|
|
|
segments = [] |
|
|
timestamps = [] |
|
|
|
|
|
|
|
|
current_time = 0 |
|
|
while current_time < duration: |
|
|
start_time = current_time |
|
|
end_time = min(start_time + segment_duration, duration) |
|
|
|
|
|
|
|
|
start_sample = int(start_time * sr) |
|
|
end_sample = int(end_time * sr) |
|
|
|
|
|
|
|
|
segment = y[start_sample:end_sample] |
|
|
segments.append(segment) |
|
|
timestamps.append((start_time, end_time)) |
|
|
|
|
|
current_time = end_time |
|
|
|
|
|
return segments, timestamps |
|
|
|
|
|
except Exception as inner_e: |
|
|
st.error(f"Critical error in audio segmentation: {str(inner_e)}") |
|
|
|
|
|
segments = [np.zeros(16000) for _ in range(num_segments)] |
|
|
timestamps = [(i, min(i+1, i+self.max_segment_duration)) for i in range(num_segments)] |
|
|
return segments, timestamps |
|
|
finally: |
|
|
|
|
|
if os.path.exists(audio_path): |
|
|
try: |
|
|
os.unlink(audio_path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
def transcribe_segment(self, segment, sr=16000): |
|
|
"""Transcribe a single audio segment""" |
|
|
|
|
|
import hashlib |
|
|
cache_key = hashlib.md5(segment.tobytes()).hexdigest() |
|
|
|
|
|
|
|
|
if cache_key in self.transcription_cache: |
|
|
return self.transcription_cache[cache_key] |
|
|
|
|
|
try: |
|
|
|
|
|
model = self.load_model() |
|
|
|
|
|
if model is not None: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: |
|
|
sf.write(tmp_file.name, segment, sr) |
|
|
segment_path = tmp_file.name |
|
|
|
|
|
|
|
|
result = model(segment_path) |
|
|
transcription = result["text"] |
|
|
|
|
|
|
|
|
if os.path.exists(segment_path): |
|
|
os.unlink(segment_path) |
|
|
else: |
|
|
|
|
|
transcription = "Audio content" |
|
|
except Exception as e: |
|
|
st.warning(f"Error transcribing segment: {str(e)}. Using fallback method.") |
|
|
|
|
|
transcription = "Audio content" |
|
|
|
|
|
|
|
|
self.transcription_cache[cache_key] = transcription |
|
|
|
|
|
return transcription |
|
|
|
|
|
def transcribe_segments(self, segments, sr=16000, parallel=False, max_workers=4): |
|
|
"""Transcribe multiple audio segments with parallel processing""" |
|
|
if parallel and len(segments) > 1: |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
|
|
|
transcribe_func = partial(self.transcribe_segment, sr=sr) |
|
|
|
|
|
|
|
|
transcriptions = list(executor.map(transcribe_func, segments)) |
|
|
else: |
|
|
|
|
|
transcriptions = [] |
|
|
for segment in segments: |
|
|
transcription = self.transcribe_segment(segment, sr) |
|
|
transcriptions.append(transcription) |
|
|
|
|
|
return transcriptions |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear the transcription cache""" |
|
|
self.transcription_cache = {} |
|
|
return True |
|
|
|