import streamlit as st import numpy as np import soundfile as sf import librosa import tempfile import os from concurrent.futures import ThreadPoolExecutor from functools import partial class AudioTranscriber: def __init__(self): self.model = None self.processor = None self.transcription_cache = {} self.max_segment_duration = 5.0 # Maximum segment duration in seconds def set_max_segment_duration(self, duration): """Set the maximum duration for any segment in seconds""" self.max_segment_duration = duration def load_model(self): """Load a lightweight transcription model""" if self.model is None: with st.spinner("Loading transcription model..."): try: from transformers import pipeline # Use a small model for transcription to save memory self.model = pipeline( "automatic-speech-recognition", model="openai/whisper-small", chunk_length_s=30, device="cpu" ) except Exception as e: st.warning(f"Error loading transcription model: {str(e)}. Using fallback method.") self.model = None return self.model def segment_audio(self, audio_file, num_segments=5, min_segment_duration=3.0): """Segment the audio file into chunks for processing with minimum 3-second and maximum 5-second duration""" # Save the uploaded audio to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: tmp_file.write(audio_file.getvalue()) audio_path = tmp_file.name try: # Load the audio file y, sr = librosa.load(audio_path, sr=None) # Get total duration duration = librosa.get_duration(y=y, sr=sr) # Calculate ideal number of segments based on max_segment_duration # We want to create enough segments so that each is <= max_segment_duration ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1) # Ensure we don't create segments that are too short actual_segments = max(ideal_segments, int(duration / min_segment_duration)) # Calculate segment duration segment_duration = min(duration / actual_segments, self.max_segment_duration) # Create segments segments = [] timestamps = [] # Create more segments to ensure each is under max_segment_duration current_time = 0 while current_time < duration: start_time = current_time end_time = min(start_time + segment_duration, duration) # Convert time to samples start_sample = int(start_time * sr) end_sample = int(end_time * sr) # Extract segment segment = y[start_sample:end_sample] segments.append(segment) timestamps.append((start_time, end_time)) current_time = end_time return segments, timestamps except Exception as e: st.warning(f"Error segmenting audio: {str(e)}. Using simplified segmentation.") # Fallback: Create equal segments try: y, sr = sf.read(audio_path) duration = len(y) / sr # Calculate ideal number of segments based on max_segment_duration ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1) # Ensure we don't create segments that are too short actual_segments = max(ideal_segments, int(duration / min_segment_duration)) # Calculate segment duration segment_duration = min(duration / actual_segments, self.max_segment_duration) # Create segments segments = [] timestamps = [] # Create more segments to ensure each is under max_segment_duration current_time = 0 while current_time < duration: start_time = current_time end_time = min(start_time + segment_duration, duration) # Convert time to samples start_sample = int(start_time * sr) end_sample = int(end_time * sr) # Extract segment segment = y[start_sample:end_sample] segments.append(segment) timestamps.append((start_time, end_time)) current_time = end_time return segments, timestamps except Exception as inner_e: st.error(f"Critical error in audio segmentation: {str(inner_e)}") # Last resort: Create dummy segments segments = [np.zeros(16000) for _ in range(num_segments)] # 1-second silent segments timestamps = [(i, min(i+1, i+self.max_segment_duration)) for i in range(num_segments)] return segments, timestamps finally: # Clean up temporary file if os.path.exists(audio_path): try: os.unlink(audio_path) except: pass def transcribe_segment(self, segment, sr=16000): """Transcribe a single audio segment""" # Generate a cache key based on the audio data import hashlib cache_key = hashlib.md5(segment.tobytes()).hexdigest() # Check if result is in cache if cache_key in self.transcription_cache: return self.transcription_cache[cache_key] try: # Load the model if not already loaded model = self.load_model() if model is not None: # Save segment to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: sf.write(tmp_file.name, segment, sr) segment_path = tmp_file.name # Transcribe using the model result = model(segment_path) transcription = result["text"] # Clean up temporary file if os.path.exists(segment_path): os.unlink(segment_path) else: # Fallback: Return empty string or placeholder transcription = "Audio content" except Exception as e: st.warning(f"Error transcribing segment: {str(e)}. Using fallback method.") # Fallback: Return empty string or placeholder transcription = "Audio content" # Cache the result self.transcription_cache[cache_key] = transcription return transcription def transcribe_segments(self, segments, sr=16000, parallel=False, max_workers=4): """Transcribe multiple audio segments with parallel processing""" if parallel and len(segments) > 1: # Process in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=max_workers) as executor: # Create a partial function with fixed sample rate transcribe_func = partial(self.transcribe_segment, sr=sr) # Map and collect results transcriptions = list(executor.map(transcribe_func, segments)) else: # Process sequentially transcriptions = [] for segment in segments: transcription = self.transcribe_segment(segment, sr) transcriptions.append(transcription) return transcriptions def clear_cache(self): """Clear the transcription cache""" self.transcription_cache = {} return True