Video-Fx / transcriber.py
garyuzair's picture
Upload 7 files
b4d330b verified
import streamlit as st
import numpy as np
import soundfile as sf
import librosa
import tempfile
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial
class AudioTranscriber:
def __init__(self):
self.model = None
self.processor = None
self.transcription_cache = {}
self.max_segment_duration = 5.0 # Maximum segment duration in seconds
def set_max_segment_duration(self, duration):
"""Set the maximum duration for any segment in seconds"""
self.max_segment_duration = duration
def load_model(self):
"""Load a lightweight transcription model"""
if self.model is None:
with st.spinner("Loading transcription model..."):
try:
from transformers import pipeline
# Use a small model for transcription to save memory
self.model = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
chunk_length_s=30,
device="cpu"
)
except Exception as e:
st.warning(f"Error loading transcription model: {str(e)}. Using fallback method.")
self.model = None
return self.model
def segment_audio(self, audio_file, num_segments=5, min_segment_duration=3.0):
"""Segment the audio file into chunks for processing with minimum 3-second and maximum 5-second duration"""
# Save the uploaded audio to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(audio_file.getvalue())
audio_path = tmp_file.name
try:
# Load the audio file
y, sr = librosa.load(audio_path, sr=None)
# Get total duration
duration = librosa.get_duration(y=y, sr=sr)
# Calculate ideal number of segments based on max_segment_duration
# We want to create enough segments so that each is <= max_segment_duration
ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1)
# Ensure we don't create segments that are too short
actual_segments = max(ideal_segments, int(duration / min_segment_duration))
# Calculate segment duration
segment_duration = min(duration / actual_segments, self.max_segment_duration)
# Create segments
segments = []
timestamps = []
# Create more segments to ensure each is under max_segment_duration
current_time = 0
while current_time < duration:
start_time = current_time
end_time = min(start_time + segment_duration, duration)
# Convert time to samples
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
# Extract segment
segment = y[start_sample:end_sample]
segments.append(segment)
timestamps.append((start_time, end_time))
current_time = end_time
return segments, timestamps
except Exception as e:
st.warning(f"Error segmenting audio: {str(e)}. Using simplified segmentation.")
# Fallback: Create equal segments
try:
y, sr = sf.read(audio_path)
duration = len(y) / sr
# Calculate ideal number of segments based on max_segment_duration
ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1)
# Ensure we don't create segments that are too short
actual_segments = max(ideal_segments, int(duration / min_segment_duration))
# Calculate segment duration
segment_duration = min(duration / actual_segments, self.max_segment_duration)
# Create segments
segments = []
timestamps = []
# Create more segments to ensure each is under max_segment_duration
current_time = 0
while current_time < duration:
start_time = current_time
end_time = min(start_time + segment_duration, duration)
# Convert time to samples
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
# Extract segment
segment = y[start_sample:end_sample]
segments.append(segment)
timestamps.append((start_time, end_time))
current_time = end_time
return segments, timestamps
except Exception as inner_e:
st.error(f"Critical error in audio segmentation: {str(inner_e)}")
# Last resort: Create dummy segments
segments = [np.zeros(16000) for _ in range(num_segments)] # 1-second silent segments
timestamps = [(i, min(i+1, i+self.max_segment_duration)) for i in range(num_segments)]
return segments, timestamps
finally:
# Clean up temporary file
if os.path.exists(audio_path):
try:
os.unlink(audio_path)
except:
pass
def transcribe_segment(self, segment, sr=16000):
"""Transcribe a single audio segment"""
# Generate a cache key based on the audio data
import hashlib
cache_key = hashlib.md5(segment.tobytes()).hexdigest()
# Check if result is in cache
if cache_key in self.transcription_cache:
return self.transcription_cache[cache_key]
try:
# Load the model if not already loaded
model = self.load_model()
if model is not None:
# Save segment to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, segment, sr)
segment_path = tmp_file.name
# Transcribe using the model
result = model(segment_path)
transcription = result["text"]
# Clean up temporary file
if os.path.exists(segment_path):
os.unlink(segment_path)
else:
# Fallback: Return empty string or placeholder
transcription = "Audio content"
except Exception as e:
st.warning(f"Error transcribing segment: {str(e)}. Using fallback method.")
# Fallback: Return empty string or placeholder
transcription = "Audio content"
# Cache the result
self.transcription_cache[cache_key] = transcription
return transcription
def transcribe_segments(self, segments, sr=16000, parallel=False, max_workers=4):
"""Transcribe multiple audio segments with parallel processing"""
if parallel and len(segments) > 1:
# Process in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Create a partial function with fixed sample rate
transcribe_func = partial(self.transcribe_segment, sr=sr)
# Map and collect results
transcriptions = list(executor.map(transcribe_func, segments))
else:
# Process sequentially
transcriptions = []
for segment in segments:
transcription = self.transcribe_segment(segment, sr)
transcriptions.append(transcription)
return transcriptions
def clear_cache(self):
"""Clear the transcription cache"""
self.transcription_cache = {}
return True