Spaces:
Running
on
Zero
Running
on
Zero
liuyang
Update speaker diarization model and refactor WhisperTranscriber alignment process. Introduce align_timestamp method for improved word-level alignment and streamline segment handling. Adjusted print statements for clarity and removed unnecessary comments.
28823e9
| import spaces | |
| import boto3 | |
| from botocore.exceptions import NoCredentialsError, ClientError | |
| from botocore.client import Config | |
| import os, pathlib | |
| CACHE_ROOT = "/home/user/app/cache" # any folder you own | |
| os.environ.update( | |
| TORCH_HOME = f"{CACHE_ROOT}/torch", | |
| XDG_CACHE_HOME = f"{CACHE_ROOT}/xdg", # torch fallback | |
| PYANNOTE_CACHE = f"{CACHE_ROOT}/pyannote", | |
| HF_HOME = f"{CACHE_ROOT}/huggingface", | |
| TRANSFORMERS_CACHE= f"{CACHE_ROOT}/transformers", | |
| MPLCONFIGDIR = f"{CACHE_ROOT}/mpl", | |
| ) | |
| INITIAL_PROMPT = ''' | |
| Use normal punctuation; end sentences properly. | |
| ''' | |
| # make sure the directories exist | |
| for path in os.environ.values(): | |
| pathlib.Path(path).mkdir(parents=True, exist_ok=True) | |
| # ---- make cuDNN libs discoverable before importing torch ---- | |
| import os, pathlib, sys, ctypes | |
| def _cudnn_lib_dir(): | |
| try: | |
| import nvidia.cudnn as _cudnn | |
| except Exception: | |
| return None | |
| # Namespace-safe resolution: prefer __file__, fall back to __path__[0] | |
| base = None | |
| if getattr(_cudnn, "__file__", None): | |
| base = pathlib.Path(_cudnn.__file__).parent | |
| elif getattr(_cudnn, "__path__", None): | |
| base = pathlib.Path(next(iter(_cudnn.__path__))) | |
| if base is None: | |
| return None | |
| libdir = base / "lib" | |
| return str(libdir) if libdir.exists() else None | |
| _cudnn = _cudnn_lib_dir() | |
| if _cudnn: | |
| os.environ["LD_LIBRARY_PATH"] = _cudnn + ":" + os.environ.get("LD_LIBRARY_PATH", "") | |
| # ------------------------------------------------------------- | |
| import torch, ctranslate2, os | |
| print("torch", torch.__version__, "CUDA build:", torch.version.cuda, | |
| "cuDNN:", torch.backends.cudnn.version()) | |
| print("CT2:", ctranslate2.__version__) | |
| print("LD_LIBRARY_PATH has cudnn/lib?", any("cudnn/lib" in p for p in os.environ.get("LD_LIBRARY_PATH","").split(":"))) | |
| def _preload(paths): | |
| for p in paths: | |
| if os.path.exists(p): | |
| ctypes.CDLL(p, mode=ctypes.RTLD_GLOBAL) | |
| if _cudnn: | |
| _preload([ | |
| f"{_cudnn}/libcudnn.so.9", # core (cuDNN 9) | |
| f"{_cudnn}/libcudnn_ops.so.9", | |
| f"{_cudnn}/libcudnn_cnn.so.9", | |
| f"{_cudnn}/libcudnn_adv.so.9", | |
| ]) | |
| import gradio as gr | |
| import torchaudio | |
| import numpy as np | |
| import pandas as pd | |
| import time | |
| import datetime | |
| import re | |
| import subprocess | |
| import os | |
| import tempfile | |
| import spaces | |
| from faster_whisper import WhisperModel, BatchedInferencePipeline | |
| from faster_whisper.vad import VadOptions | |
| import whisperx | |
| import requests | |
| import base64 | |
| from pyannote.audio import Pipeline, Inference, Model | |
| from pyannote.core import Segment | |
| import importlib.util, ctypes, tempfile, wave, math | |
| import json | |
| import webrtcvad | |
| S3_ENDPOINT = os.getenv("S3_ENDPOINT") | |
| S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") | |
| S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") | |
| # Function to upload file to Cloudflare R2 | |
| def upload_data_to_r2(data, bucket_name, object_name, content_type='application/octet-stream'): | |
| """ | |
| Upload data directly to a Cloudflare R2 bucket. | |
| :param data: Data to upload (bytes or string). | |
| :param bucket_name: Name of the R2 bucket. | |
| :param object_name: Name of the object to save in the bucket. | |
| :param content_type: MIME type of the data. | |
| :return: True if data was uploaded, else False. | |
| """ | |
| try: | |
| # Convert string to bytes if necessary | |
| if isinstance(data, str): | |
| data = data.encode('utf-8') | |
| # Initialize a session using Cloudflare R2 credentials | |
| session = boto3.session.Session() | |
| s3 = session.client('s3', | |
| endpoint_url=f'https://{S3_ENDPOINT}', | |
| aws_access_key_id=S3_ACCESS_KEY, | |
| aws_secret_access_key=S3_SECRET_KEY, | |
| config = Config(s3={"addressing_style": "virtual", 'payload_signing_enabled': False}, signature_version='v4', | |
| request_checksum_calculation='when_required', | |
| response_checksum_validation='when_required',), | |
| ) | |
| # Upload the data to R2 bucket | |
| s3.put_object( | |
| Bucket=bucket_name, | |
| Key=object_name, | |
| Body=data, | |
| ContentType=content_type, | |
| ContentLength=len(data), # make length explicit to avoid streaming | |
| ) | |
| print(f"Data uploaded to R2 bucket '{bucket_name}' as '{object_name}'") | |
| return True | |
| except NoCredentialsError: | |
| print("Credentials not available") | |
| return False | |
| except ClientError as e: | |
| print(f"Failed to upload data to R2 bucket: {e}") | |
| return False | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| return False | |
| from huggingface_hub import snapshot_download | |
| # ----------------------------------------------------------------------------- | |
| # Model Management | |
| # ----------------------------------------------------------------------------- | |
| MODELS = { | |
| "large-v3-turbo": { | |
| "whisperx_name": "large-v3-turbo", | |
| }, | |
| "large-v3": { | |
| "whisperx_name": "large-v3", | |
| }, | |
| "large-v2": { | |
| "whisperx_name": "large-v2", | |
| }, | |
| } | |
| DEFAULT_MODEL = "large-v3-turbo" | |
| # Supported languages for alignment models (whisperX) | |
| ALIGN_LANGUAGES = ["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh", "ar", "nl", "tr", "pl", "cs", "sv", "da", "fi", "no", "uk"] | |
| # ----------------------------------------------------------------------------- | |
| # Audio preprocess helper (from input_and_preprocess rule) | |
| # ----------------------------------------------------------------------------- | |
| TRIM_THRESHOLD_MS = 10_000 # 10 seconds | |
| DEFAULT_PAD_MS = 250 # safety context around detected speech | |
| FRAME_MS = 30 # VAD frame | |
| HANG_MS = 240 # hangover (keep speech "on" after silence) | |
| VAD_LEVEL = 2 # 0-3 | |
| def _decode_chunk_to_pcm(task: dict) -> bytes: | |
| """Use ffmpeg to decode the chunk to s16le mono @ 16k PCM bytes.""" | |
| src = task["source_uri"] | |
| ing = task["ingest_recipe"] | |
| seek = task["ffmpeg_seek"] | |
| cmd = [ | |
| "ffmpeg", "-nostdin", "-hide_banner", "-v", "error", | |
| "-ss", f"{max(0.0, float(seek['pre_ss_sec'])):.3f}", | |
| "-i", src, | |
| "-map", "0:a:0", | |
| "-ss", f"{float(seek['post_ss_sec']):.2f}", | |
| "-t", f"{float(seek['t_sec']):.3f}", | |
| ] | |
| # Optional L/R extraction | |
| if ing.get("channel_extract_filter"): | |
| cmd += ["-af", ing["channel_extract_filter"]] | |
| # Force mono 16k s16le to stdout | |
| cmd += ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", "-f", "s16le", "pipe:1"] | |
| p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| pcm, err = p.communicate() | |
| if p.returncode != 0: | |
| raise RuntimeError(f"ffmpeg failed: {err.decode('utf-8', 'ignore')}") | |
| return pcm | |
| def _find_head_tail_speech_ms( | |
| pcm: bytes, | |
| sr: int = 16000, | |
| frame_ms: int = FRAME_MS, | |
| vad_level: int = VAD_LEVEL, | |
| hang_ms: int = HANG_MS, | |
| ): | |
| """Return (first_ms, last_ms) speech boundaries using webrtcvad with hangover.""" | |
| if not pcm: | |
| return None, None | |
| vad = webrtcvad.Vad(int(vad_level)) | |
| bpf = 2 # bytes per sample (s16) | |
| samples_per_ms = sr // 1000 # 16 | |
| bytes_per_frame = samples_per_ms * bpf * frame_ms | |
| n_frames = len(pcm) // bytes_per_frame | |
| if n_frames == 0: | |
| return None, None | |
| first_ms, last_ms = None, None | |
| t_ms = 0 | |
| in_speech = False | |
| silence_run = 0 | |
| view = memoryview(pcm)[: n_frames * bytes_per_frame] | |
| for i in range(n_frames): | |
| frame = view[i * bytes_per_frame : (i + 1) * bytes_per_frame] | |
| if vad.is_speech(frame, sr): | |
| if first_ms is None: | |
| first_ms = t_ms | |
| in_speech = True | |
| silence_run = 0 | |
| else: | |
| if in_speech: | |
| silence_run += frame_ms | |
| if silence_run >= hang_ms: | |
| last_ms = t_ms - (silence_run - hang_ms) | |
| in_speech = False | |
| silence_run = 0 | |
| t_ms += frame_ms | |
| if in_speech: | |
| last_ms = t_ms | |
| return first_ms, last_ms | |
| def _write_wav(path: str, pcm: bytes, sr: int = 16000): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| with wave.open(path, "wb") as w: | |
| w.setnchannels(1) | |
| w.setsampwidth(2) # s16 | |
| w.setframerate(sr) | |
| w.writeframes(pcm) | |
| def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict: | |
| """ | |
| 1) Decode chunk(s) to mono 16k PCM. | |
| 2) Run VAD to locate head/tail silence. | |
| 3) Trim only if head or tail >= 10s. | |
| 4) Save the (possibly trimmed) WAV to local file(s). | |
| 5) Return timing metadata, including 'trimmed_start_ms' to preserve global timestamps. | |
| Args: | |
| task: dict containing either: | |
| - "chunk": single chunk dict, or | |
| - "chunk": list of chunk dicts | |
| out_dir: output directory for WAV files | |
| Returns: | |
| A wrapper dict with general fields (e.g., job_id, channel, sr, filekey) | |
| and a "chunks" array containing metadata dict(s) for each processed chunk. | |
| This structure is returned for both single and multiple chunk inputs. | |
| """ | |
| result = { | |
| "job_id": task.get("job_id", "job"), | |
| "channel": task["channel"], | |
| "sr": 16000, | |
| "options": task.get("options", None), | |
| "filekey": task.get("filekey", None), | |
| } | |
| chunk_result = _process_single_chunk(task, out_dir) | |
| result["chunk"] = chunk_result | |
| return result | |
| def _process_single_chunk(task: dict, out_dir: str) -> dict: | |
| """ | |
| Process a single chunk - extracted from the original prepare_and_save_audio_for_model logic. | |
| 1) Decode chunk to mono 16k PCM. | |
| 2) Run VAD to locate head/tail silence. | |
| 3) Trim only if head or tail >= 10s. | |
| 4) Save the (possibly trimmed) WAV to local file. | |
| 5) Return timing metadata, including 'trimmed_start_ms' to preserve global timestamps. | |
| """ | |
| # 0) Names & constants | |
| sr = 16000 | |
| bpf = 2 | |
| samples_per_ms = sr // 1000 | |
| def bytes_from_ms(ms: int) -> int: | |
| return int(ms * samples_per_ms) * bpf | |
| ch = task["channel"] | |
| ck = task["chunk"] | |
| job = task.get("job_id", "job") | |
| idx = str(ck["idx"]) | |
| # 1) Decode chunk | |
| pcm = _decode_chunk_to_pcm(task) | |
| planned_dur_ms = int(ck["dur_ms"]) | |
| # 2) VAD head/tail detection | |
| first_ms, last_ms = _find_head_tail_speech_ms(pcm, sr=sr) | |
| head_sil_ms = int(first_ms) if first_ms is not None else planned_dur_ms | |
| tail_sil_ms = int(planned_dur_ms - last_ms) if last_ms is not None else planned_dur_ms | |
| # 3) Decide trimming (only if head or tail >= 10s) | |
| trim_applied = False | |
| eff_start_ms = 0 | |
| eff_end_ms = planned_dur_ms | |
| trimmed_pcm = pcm | |
| if (head_sil_ms >= TRIM_THRESHOLD_MS) or (tail_sil_ms >= TRIM_THRESHOLD_MS): | |
| # If no speech found at all, mark skip | |
| if first_ms is None or last_ms is None or last_ms <= first_ms: | |
| out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_nospeech.wav") | |
| _write_wav(out_wav_path, b"", sr) | |
| return { | |
| "out_wav_path": out_wav_path, | |
| "sr": sr, | |
| "trim_applied": False, | |
| "trimmed_start_ms": 0, | |
| "head_silence_ms": head_sil_ms, | |
| "tail_silence_ms": tail_sil_ms, | |
| "effective_start_ms": 0, | |
| "effective_dur_ms": 0, | |
| "abs_start_ms": ck["global_offset_ms"], | |
| "dur_ms": ck["dur_ms"], | |
| "chunk_idx": idx, | |
| "channel": ch, | |
| "skip": True, | |
| } | |
| # Apply padding & slice | |
| start_ms = max(0, int(first_ms) - DEFAULT_PAD_MS) | |
| end_ms = min(planned_dur_ms, int(last_ms) + DEFAULT_PAD_MS) | |
| if end_ms > start_ms: | |
| eff_start_ms = start_ms | |
| eff_end_ms = end_ms | |
| trimmed_pcm = pcm[bytes_from_ms(start_ms) : bytes_from_ms(end_ms)] | |
| trim_applied = True | |
| # 4) Write WAV to local file (trimmed or original) | |
| tag = "trim" if trim_applied else "full" | |
| out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_{tag}.wav") | |
| _write_wav(out_wav_path, trimmed_pcm, sr) | |
| # 5) Return metadata | |
| return { | |
| "out_wav_path": out_wav_path, | |
| "sr": sr, | |
| "trim_applied": trim_applied, | |
| "trimmed_start_ms": eff_start_ms if trim_applied else 0, | |
| "head_silence_ms": head_sil_ms, | |
| "tail_silence_ms": tail_sil_ms, | |
| "effective_start_ms": eff_start_ms, | |
| "effective_dur_ms": eff_end_ms - eff_start_ms, | |
| "abs_start_ms": int(ck["global_offset_ms"]) + eff_start_ms, | |
| "dur_ms": ck["dur_ms"], | |
| "chunk_idx": idx, | |
| "channel": ch, | |
| "job_id": job, | |
| "skip": False if (trim_applied or len(pcm) > 0) else True, | |
| } | |
| # Download once; later runs are instant | |
| # snapshot_download( | |
| # repo_id=MODEL_REPO, | |
| # local_dir=LOCAL_DIR, | |
| # local_dir_use_symlinks=True, # saves disk space | |
| # resume_download=True | |
| # ) | |
| # model_cache_path = LOCAL_DIR # <ββ this is what we pass to WhisperModel | |
| # Lazy global holder ---------------------------------------------------------- | |
| _whipser_x_transcribe_models = {} | |
| _whipser_x_align_models = {} | |
| _faster_whisper_transcribe_models = {} | |
| _faster_whisper_batched_pipelines = {} | |
| _diarizer = None | |
| _embedder = None | |
| # Preload alignment and diarization models at startup (no GPU decorator) | |
| def _preload_alignment_and_diarization_models(): | |
| """Preload WhisperX alignment and diarization models on CUDA device""" | |
| global _whipser_x_align_models, _diarizer | |
| print("Preloading all WhisperX alignment models...") | |
| for lang in ALIGN_LANGUAGES: | |
| try: | |
| print(f"Loading alignment model for language '{lang}'...") | |
| device = "cuda" | |
| align_model, align_metadata = whisperx.load_align_model( | |
| language_code=lang, | |
| device=device, | |
| model_dir=CACHE_ROOT | |
| ) | |
| _whipser_x_align_models[lang] = { | |
| "model": align_model, | |
| "metadata": align_metadata | |
| } | |
| print(f"Alignment model for '{lang}' loaded successfully") | |
| except Exception as e: | |
| print(f"Could not load alignment model for '{lang}': {e}") | |
| # Create global diarization pipeline | |
| try: | |
| print("Loading diarization model...") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision('high') | |
| _diarizer = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-community-1", | |
| use_auth_token=os.getenv("HF_TOKEN"), | |
| ).to(torch.device("cuda")) | |
| print("Diarization model loaded successfully") | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"Could not load diarization model: {e}") | |
| _diarizer = None | |
| print("WhisperX alignment and diarization models preloaded successfully!") | |
| # Call preload function at startup | |
| _preload_alignment_and_diarization_models() | |
| # Preload WhisperX transcribe models with GPU decorator | |
| def _preload_whisperx_transcribe_models(): | |
| """Preload all WhisperX transcribe models on GPU""" | |
| global _whipser_x_transcribe_models | |
| print("Preloading all WhisperX transcribe models on GPU...") | |
| for model_name in MODELS.keys(): | |
| try: | |
| print(f"Loading WhisperX transcribe model '{model_name}'...") | |
| whisperx_model_name = MODELS[model_name]["whisperx_name"] | |
| device = "cuda" | |
| compute_type = "float16" | |
| model = whisperx.load_model( | |
| whisperx_model_name, | |
| device=device, | |
| compute_type=compute_type, | |
| download_root=CACHE_ROOT | |
| ) | |
| _whipser_x_transcribe_models[model_name] = model | |
| print(f"WhisperX transcribe model '{model_name}' loaded successfully") | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"Could not load WhisperX transcribe model '{model_name}': {e}") | |
| print("All WhisperX transcribe models preloaded successfully!") | |
| # ----------------------------------------------------------------------------- | |
| class WhisperTranscriber: | |
| def __init__(self): | |
| # do **not** create the models here! | |
| pass | |
| def preprocess_from_task_json(self, task_json: str) -> any: | |
| """Parse task JSON and run prepare_and_save_audio_for_model, returning metadata.""" | |
| try: | |
| task = json.loads(task_json) | |
| except Exception as e: | |
| raise RuntimeError(f"Invalid JSON: {e}") | |
| out_dir = os.path.join("/tmp/gradio", "preprocessed") | |
| os.makedirs(out_dir, exist_ok=True) | |
| meta = None | |
| #task could be a single chunk or a list of chunks | |
| if isinstance(task, list): | |
| meta = [] | |
| for chunk in task: | |
| meta.append(prepare_and_save_audio_for_model(chunk, out_dir)) | |
| else: | |
| meta = prepare_and_save_audio_for_model(task, out_dir) | |
| return meta | |
| # each call gets a GPU slice | |
| def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, engine="whisperx", model_name: str = DEFAULT_MODEL, transcribe_options: dict = None): | |
| """Transcribe the entire audio file using selected engine, then align with WhisperX. | |
| engine: "whisperx" | "faster_whisper" | |
| Always uses WhisperX alignment regardless of transcription engine. | |
| """ | |
| global _whipser_x_transcribe_models, _whipser_x_align_models, _faster_whisper_transcribe_models | |
| start_time = time.time() | |
| # Resolve engine (allow override from transcribe_options) | |
| if transcribe_options and isinstance(transcribe_options, dict) and transcribe_options.get("engine"): | |
| engine = str(transcribe_options.get("engine")).strip().lower() | |
| # Transcribe using the selected engine | |
| initial_segments = [] | |
| detected_language = language if language else "unknown" | |
| audio = whisperx.load_audio(audio_path) | |
| print(audio_path) | |
| if engine == "whisperx": | |
| # Load audio (float32, 16k) once | |
| # Lazy-load WhisperX model on first use | |
| if model_name not in _whipser_x_transcribe_models: | |
| print(f"Loading WhisperX transcribe model '{model_name}' on GPU...") | |
| if model_name not in MODELS: | |
| raise ValueError(f"Model '{model_name}' not found in MODELS registry. Available: {list(MODELS.keys())}") | |
| whisperx_model_name = MODELS[model_name]["whisperx_name"] | |
| device = "cuda" | |
| compute_type = "float16" | |
| whisper_model = whisperx.load_model( | |
| whisperx_model_name, | |
| device=device, | |
| compute_type=compute_type, | |
| download_root=CACHE_ROOT, | |
| asr_options=transcribe_options | |
| ) | |
| _whipser_x_transcribe_models[model_name] = whisper_model | |
| print(f"WhisperX transcribe model '{model_name}' loaded successfully") | |
| else: | |
| whisper_model = _whipser_x_transcribe_models[model_name] | |
| print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...") | |
| result = whisper_model.transcribe( | |
| audio, | |
| language=language, | |
| batch_size=batch_size, | |
| #initial_prompt=prompt, | |
| #task="translate" if translate else "transcribe" | |
| ) | |
| detected_language = result.get("language", detected_language) | |
| initial_segments = result.get("segments", []) | |
| elif engine == "faster_whisper": | |
| # Lazy-load Faster-Whisper model on first use | |
| if model_name not in _faster_whisper_transcribe_models: | |
| print(f"Loading Faster-Whisper transcribe model '{model_name}' on GPU...") | |
| # Use the same name by default; extend MODELS with specific mapping if needed | |
| faster_name = MODELS.get(model_name, {}).get("whisperx_name", model_name) | |
| fw_model = WhisperModel( | |
| faster_name, | |
| device="cuda", | |
| compute_type="float16", | |
| download_root=CACHE_ROOT, | |
| ) | |
| _faster_whisper_transcribe_models[model_name] = fw_model | |
| print(f"Faster-Whisper transcribe model '{model_name}' loaded successfully") | |
| else: | |
| fw_model = _faster_whisper_transcribe_models[model_name] | |
| print(f"Transcribing full audio with Faster-Whisper model '{model_name}' and batch size {batch_size}...") | |
| task = "translate" if translate else "transcribe" | |
| # Build kwargs from transcribe_options for Faster-Whisper's transcribe API | |
| fw_kwargs = {} | |
| if isinstance(transcribe_options, dict): | |
| allowed = { | |
| "log_progress", | |
| "beam_size", | |
| "best_of", | |
| "patience", | |
| "length_penalty", | |
| "repetition_penalty", | |
| "no_repeat_ngram_size", | |
| "temperature", | |
| "compression_ratio_threshold", | |
| "log_prob_threshold", | |
| "no_speech_threshold", | |
| "condition_on_previous_text", | |
| "prompt_reset_on_temperature", | |
| "initial_prompt", | |
| "prefix", | |
| "suppress_blank", | |
| "suppress_tokens", | |
| "without_timestamps", | |
| "max_initial_timestamp", | |
| #"word_timestamps", | |
| #"prepend_punctuations", | |
| #"append_punctuations", | |
| "multilingual", | |
| "vad_filter", | |
| "vad_parameters", | |
| "max_new_tokens", | |
| "chunk_length", | |
| "clip_timestamps", | |
| "hallucination_silence_threshold", | |
| "batch_size", | |
| "hotwords", | |
| "language_detection_threshold", | |
| "language_detection_segments", | |
| } | |
| for k in allowed: | |
| if k in transcribe_options and transcribe_options[k] is not None: | |
| fw_kwargs[k] = transcribe_options[k] | |
| # Ensure sensible defaults and avoid duplicates | |
| if "initial_prompt" not in fw_kwargs and prompt is not None: | |
| fw_kwargs["initial_prompt"] = prompt | |
| if "batch_size" not in fw_kwargs and batch_size is not None: | |
| fw_kwargs["batch_size"] = batch_size | |
| if "vad_filter" not in fw_kwargs: | |
| fw_kwargs["vad_filter"] = False # preserve boundaries for alignment | |
| # language and task are passed explicitly; do not include in fw_kwargs | |
| fw_kwargs.pop("language", None) | |
| fw_kwargs.pop("task", None) | |
| fw_kwargs["prepend_punctuations"] = "\"'βΒΏ([{-" | |
| fw_kwargs["append_punctuations"] = "\"'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ" | |
| fw_kwargs["without_timestamps"] = False #True | |
| fw_kwargs["max_initial_timestamp"] = 1.0 | |
| fw_kwargs["word_timestamps"] = True #False | |
| # Choose between single and batched transcription per docs | |
| effective_bs = int(fw_kwargs.get("batch_size", batch_size if batch_size is not None else 8)) | |
| use_batched = effective_bs > 1 | |
| print(fw_kwargs) | |
| # Note: pass numpy audio | |
| if use_batched: | |
| if model_name not in _faster_whisper_batched_pipelines: | |
| _faster_whisper_batched_pipelines[model_name] = BatchedInferencePipeline(model=fw_model) | |
| batched_model = _faster_whisper_batched_pipelines[model_name] | |
| segments_iter, info = batched_model.transcribe( | |
| audio_path, | |
| language=language, | |
| task=task, | |
| **fw_kwargs, | |
| ) | |
| else: | |
| fw_kwargs.pop("batch_size", None) | |
| segments_iter, info = fw_model.transcribe( | |
| audio_path, | |
| language=language, | |
| task=task, | |
| **fw_kwargs, | |
| ) | |
| detected_language = getattr(info, "language", detected_language) | |
| # Convert to WhisperX-like segment dicts | |
| initial_segments = [{ | |
| "start": float(s.start), | |
| "end": float(s.end), | |
| "text": s.text or "", | |
| } for s in segments_iter] | |
| else: | |
| raise ValueError(f"Unknown engine '{engine}'. Supported: 'whisperx', 'faster_whisper'") | |
| print(f"Detected language: {detected_language}, segments: {len(initial_segments)}, transcribing done in {time.time() - start_time:.2f} seconds") | |
| # Align with centralized alignment method when available | |
| segments = initial_segments | |
| if detected_language in _whipser_x_align_models: | |
| try: | |
| align_out = self.align_timestamp( | |
| audio_url=audio_path, | |
| text=None, | |
| language=detected_language, | |
| engine="whisperx", | |
| options={"segments": initial_segments}, | |
| ) | |
| if isinstance(align_out, dict) and align_out.get("segments"): | |
| segments = align_out["segments"] | |
| except Exception as e: | |
| print(f"Alignment via align_timestamp failed: {e}, using original timestamps") | |
| else: | |
| print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps") | |
| # Process segments into the expected format | |
| results = [] | |
| for seg in segments: | |
| words_list = [] | |
| if "words" in seg: | |
| for word in seg["words"]: | |
| words_list.append({ | |
| "start": float(word.get("start", 0.0)) + float(base_offset_s), | |
| "end": float(word.get("end", 0.0)) + float(base_offset_s), | |
| "word": word.get("word", ""), | |
| "probability": word.get("score", 1.0), | |
| "speaker": "SPEAKER_00" | |
| }) | |
| results.append({ | |
| "start": float(seg.get("start", 0.0)) + float(base_offset_s), | |
| "end": float(seg.get("end", 0.0)) + float(base_offset_s), | |
| "text": seg.get("text", ""), | |
| "speaker": "SPEAKER_00", | |
| "avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0, | |
| "words": words_list, | |
| "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0)), | |
| "language": detected_language, | |
| }) | |
| print(results) | |
| transcription_time = time.time() - start_time | |
| print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}") | |
| return results, detected_language | |
| # alignment requires GPU | |
| def align_timestamp(self, audio_url, text, language, engine="whisperx", options: dict = None): | |
| """Return word-level alignment for the given text/audio using the specified engine. | |
| Args: | |
| audio_url: Path or URL to the audio file. | |
| text: String text to align. If options contains 'segments', this can be None. | |
| language: Language code (e.g., 'en'). Must be supported by WhisperX align models. | |
| engine: Currently only 'whisperx' is supported. | |
| options: Optional dict. Recognized keys: | |
| - 'segments': list of {start, end, text} to align (preferred for segment-aware alignment) | |
| Returns: | |
| dict with keys: | |
| - 'segments': aligned segments including word timings (if available) | |
| - 'words': flat list of aligned words across all segments | |
| """ | |
| global _whipser_x_align_models | |
| if engine != "whisperx": | |
| raise ValueError(f"align_timestamp engine '{engine}' not supported. Only 'whisperx' is supported") | |
| if language not in _whipser_x_align_models: | |
| raise ValueError(f"No WhisperX alignment model available for language '{language}'") | |
| # Resolve audio path (download if URL) | |
| local_path = None | |
| tmp_file = None | |
| try: | |
| if isinstance(audio_url, str) and audio_url.startswith(("http://", "https://")): | |
| resp = requests.get(audio_url, stream=True, timeout=60) | |
| resp.raise_for_status() | |
| tmp_f = tempfile.NamedTemporaryFile(suffix=".audio", delete=False) | |
| for chunk in resp.iter_content(chunk_size=8192): | |
| if chunk: | |
| tmp_f.write(chunk) | |
| tmp_f.flush() | |
| tmp_f.close() | |
| tmp_file = tmp_f.name | |
| local_path = tmp_file | |
| else: | |
| local_path = audio_url | |
| # Load audio and decide segments to align | |
| audio = whisperx.load_audio(local_path) | |
| sr = 16000.0 # whisperx loads at 16k | |
| audio_duration = float(len(audio)) / sr if hasattr(audio, "__len__") else None | |
| segments_to_align = None | |
| if options and isinstance(options, dict) and options.get("segments"): | |
| segments_to_align = options.get("segments") | |
| else: | |
| if not text or not str(text).strip(): | |
| raise ValueError("align_timestamp requires 'text' when 'segments' are not provided in options") | |
| if audio_duration is None: | |
| raise ValueError("Could not determine audio duration for alignment") | |
| segments_to_align = [{ | |
| "text": str(text), | |
| "start": 0.0, | |
| "end": audio_duration, | |
| }] | |
| # Perform alignment | |
| align_info = _whipser_x_align_models[language] | |
| aligned = whisperx.align( | |
| segments_to_align, | |
| align_info["model"], | |
| align_info["metadata"], | |
| audio, | |
| "cuda", | |
| return_char_alignments=False, | |
| ) | |
| aligned_segments = aligned.get("segments", segments_to_align) | |
| words_flat = [] | |
| for seg in aligned_segments: | |
| for w in seg.get("words", []) or []: | |
| words_flat.append({ | |
| "start": float(w.get("start", 0.0)), | |
| "end": float(w.get("end", 0.0)), | |
| "word": w.get("word", ""), | |
| "probability": w.get("score", 1.0) | |
| }) | |
| return {"segments": aligned_segments, "words": words_flat, "language": language} | |
| finally: | |
| if tmp_file: | |
| try: | |
| os.unlink(tmp_file) | |
| except Exception: | |
| pass | |
| # Removed audio cutting; transcription is done once on the full (preprocessed) audio | |
| # each call gets a GPU slice | |
| def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0): | |
| """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings.""" | |
| global _diarizer | |
| if _diarizer is None: | |
| print("Diarization model not available, creating single speaker segment") | |
| # Load audio to get duration | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| duration = waveform.shape[1] / sample_rate | |
| # Try to compute a single-speaker embedding | |
| speaker_embeddings = {} | |
| try: | |
| embedder = self._load_embedder() | |
| # Provide waveform as (channel, time) and pad if too short | |
| min_embed_duration_sec = 1.0 | |
| min_samples = int(min_embed_duration_sec * sample_rate) | |
| if waveform.shape[1] < min_samples: | |
| pad_len = min_samples - waveform.shape[1] | |
| pad = torch.zeros(waveform.shape[0], pad_len, dtype=waveform.dtype, device=waveform.device) | |
| waveform = torch.cat([waveform, pad], dim=1) | |
| emb = embedder({"waveform": waveform, "sample_rate": sample_rate}) | |
| speaker_embeddings["SPEAKER_00"] = emb.squeeze().tolist() | |
| except Exception: | |
| pass | |
| return [{ | |
| "start": 0.0 + float(base_offset_s), | |
| "end": duration + float(base_offset_s), | |
| "speaker": "SPEAKER_00" | |
| }], 1, speaker_embeddings | |
| print("Starting diarization...") | |
| start_time = time.time() | |
| # Load audio for diarization | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Perform diarization | |
| diarization = _diarizer( | |
| {"waveform": waveform, "sample_rate": sample_rate}, | |
| num_speakers=num_speakers, | |
| ) | |
| # Convert to list format | |
| diarize_segments = [] | |
| diarization_list = list(diarization.itertracks(yield_label=True)) | |
| print(diarization_list) | |
| for turn, _, speaker in diarization_list: | |
| diarize_segments.append({ | |
| "start": float(turn.start) + float(base_offset_s), | |
| "end": float(turn.end) + float(base_offset_s), | |
| "speaker": speaker | |
| }) | |
| unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]} | |
| detected_num_speakers = len(unique_speakers) | |
| # Compute per-speaker embeddings by averaging segment embeddings | |
| speaker_embeddings = {} | |
| try: | |
| embedder = self._load_embedder() | |
| spk_to_embs = {spk: [] for spk in unique_speakers} | |
| # Primary path: slice in-memory waveform and zero-pad short segments | |
| min_embed_duration_sec = 3.0 | |
| audio_duration_sec = float(waveform.shape[1]) / float(sample_rate) | |
| for turn, _, speaker in diarization_list: | |
| seg_start = float(turn.start) | |
| seg_end = float(turn.end) | |
| if seg_end <= seg_start: | |
| continue | |
| start_sample = max(0, int(seg_start * sample_rate)) | |
| end_sample = min(waveform.shape[1], int(seg_end * sample_rate)) | |
| if end_sample <= start_sample: | |
| continue | |
| seg_wav = waveform[:, start_sample:end_sample].contiguous() | |
| min_samples = int(min_embed_duration_sec * sample_rate) | |
| if seg_wav.shape[1] < min_samples: | |
| pad_len = min_samples - seg_wav.shape[1] | |
| pad = torch.zeros(seg_wav.shape[0], pad_len, dtype=seg_wav.dtype, device=seg_wav.device) | |
| seg_wav = torch.cat([seg_wav, pad], dim=1) | |
| try: | |
| emb = embedder({"waveform": seg_wav, "sample_rate": sample_rate}) | |
| except Exception: | |
| # Fallback: use crop on the file with expanded window to minimum duration | |
| desired_end = min(seg_start + min_embed_duration_sec, audio_duration_sec) | |
| desired_start = max(0.0, desired_end - min_embed_duration_sec) | |
| emb = embedder.crop(audio_path, Segment(desired_start, desired_end)) | |
| spk_to_embs[speaker].append(emb.squeeze()) | |
| # average | |
| for spk, embs in spk_to_embs.items(): | |
| if len(embs) == 0: | |
| continue | |
| # stack and mean | |
| try: | |
| import torch as _torch | |
| embs_tensor = _torch.stack([_torch.as_tensor(e) for e in embs], dim=0) | |
| centroid = embs_tensor.mean(dim=0) | |
| # L2 normalize | |
| centroid = centroid / (centroid.norm(p=2) + 1e-12) | |
| speaker_embeddings[spk] = centroid.cpu().tolist() | |
| except Exception: | |
| # fallback to first embedding | |
| speaker_embeddings[spk] = embs[0].cpu().tolist() | |
| #print(speaker_embeddings[spk]) | |
| except Exception as e: | |
| print(f"Error during embedding calculation: {e}") | |
| print(f"Diarization segments: {diarize_segments}") | |
| pass | |
| diarization_time = time.time() - start_time | |
| print(f"Diarization completed in {diarization_time:.2f} seconds") | |
| return diarize_segments, detected_num_speakers, speaker_embeddings | |
| def _load_embedder(self): | |
| """Lazy-load speaker embedding inference model on GPU.""" | |
| global _embedder | |
| if _embedder is None: | |
| # window="whole" to compute one embedding per provided chunk | |
| token = os.getenv("HF_TOKEN") | |
| model = Model.from_pretrained("pyannote/embedding", use_auth_token=token) | |
| _embedder = Inference(model, window="whole", device=torch.device("cuda")) | |
| return _embedder | |
| def assign_speakers_to_transcription(self, transcription_results, diarization_segments): | |
| """Assign speakers to words and segments based on overlap with diarization segments. | |
| Also detects diarization segments that do not overlap any transcription segment and | |
| returns them so they can be re-processed (e.g., re-transcribed) later. | |
| """ | |
| if not diarization_segments: | |
| return transcription_results, [] | |
| # Helper: find the diarization speaker active at time t, or closest | |
| def speaker_at(t: float): | |
| for dseg in diarization_segments: | |
| if float(dseg["start"]) <= t < float(dseg["end"]): | |
| return dseg["speaker"] | |
| # if not inside, return closest segment's speaker | |
| closest = None | |
| best_dist = float("inf") | |
| for dseg in diarization_segments: | |
| if t < float(dseg["start"]): | |
| d = float(dseg["start"]) - t | |
| elif t > float(dseg["end"]): | |
| d = t - float(dseg["end"]) | |
| else: | |
| d = 0.0 | |
| if d < best_dist: | |
| best_dist = d | |
| closest = dseg | |
| return closest["speaker"] if closest else "SPEAKER_00" | |
| # Helper: overlap length between two intervals | |
| def interval_overlap(a_start: float, a_end: float, b_start: float, b_end: float) -> float: | |
| return max(0.0, min(a_end, b_end) - max(a_start, b_start)) | |
| # Helper: choose speaker for an interval by maximum overlap with diarization | |
| def best_speaker_for_interval(start_t: float, end_t: float) -> str: | |
| best_spk = None | |
| best_ov = -1.0 | |
| for dseg in diarization_segments: | |
| ov = interval_overlap(float(start_t), float(end_t), float(dseg["start"]), float(dseg["end"])) | |
| if ov > best_ov: | |
| best_ov = ov | |
| best_spk = dseg["speaker"] | |
| if best_ov > 0.0 and best_spk is not None: | |
| return best_spk | |
| # fallback to nearest by midpoint | |
| mid = (float(start_t) + float(end_t)) / 2.0 | |
| return speaker_at(mid) | |
| # First pass: assign speakers to words and apply smoothing | |
| for seg in transcription_results: | |
| if seg.get("words"): | |
| words = seg["words"] | |
| # 1) Initial assignment by overlap | |
| for w in words: | |
| w_start = float(w["start"]) | |
| w_end = float(w["end"]) | |
| w["speaker"] = best_speaker_for_interval(w_start, w_end) | |
| # 2) Small median filter (window=3) to fix isolated outliers | |
| if len(words) >= 3: | |
| smoothed = [words[i]["speaker"] for i in range(len(words))] | |
| for i in range(1, len(words) - 1): | |
| prev_spk = words[i - 1]["speaker"] | |
| curr_spk = words[i]["speaker"] | |
| next_spk = words[i + 1]["speaker"] | |
| if prev_spk == next_spk and curr_spk != prev_spk: | |
| smoothed[i] = prev_spk | |
| for i in range(len(words)): | |
| words[i]["speaker"] = smoothed[i] | |
| else: | |
| # No word timings: choose by overlap with diarization over the whole segment | |
| seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"])) | |
| # Second pass: split segments that have speaker changes within them | |
| split_segments = [] | |
| for seg in transcription_results: | |
| words = seg.get("words", []) | |
| if not words or len(words) <= 1: | |
| # No words or single word - can't split, assign speaker directly | |
| if not words: | |
| seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"])) | |
| else: | |
| seg["speaker"] = words[0].get("speaker", "SPEAKER_00") | |
| split_segments.append(seg) | |
| continue | |
| # Find speaker transition points with minimum duration filter | |
| current_speaker = words[0].get("speaker", "SPEAKER_00") | |
| split_points = [0] # Always start with first word | |
| min_segment_duration = 0.5 # Minimum 0.5 seconds per segment | |
| for i in range(1, len(words)): | |
| word_speaker = words[i].get("speaker", "SPEAKER_00") | |
| if word_speaker != current_speaker: | |
| # Check if this would create a segment that's too short | |
| if split_points: | |
| last_split = split_points[-1] | |
| segment_start_time = float(words[last_split]["start"]) | |
| current_word_time = float(words[i-1]["end"]) | |
| segment_duration = current_word_time - segment_start_time | |
| # Only split if the previous segment would be long enough | |
| if segment_duration >= min_segment_duration: | |
| split_points.append(i) | |
| current_speaker = word_speaker | |
| # If too short, continue without splitting (speaker will be resolved by dominant speaker logic) | |
| else: | |
| split_points.append(i) | |
| current_speaker = word_speaker | |
| split_points.append(len(words)) # End point | |
| # Create sub-segments if we found speaker changes | |
| if len(split_points) <= 2: | |
| # No splits needed - process as single segment | |
| self._assign_dominant_speaker_to_segment(seg, speaker_at, best_speaker_for_interval) | |
| split_segments.append(seg) | |
| else: | |
| # Split into multiple segments | |
| for i in range(len(split_points) - 1): | |
| start_idx = split_points[i] | |
| end_idx = split_points[i + 1] | |
| if end_idx <= start_idx: | |
| continue | |
| subseg_words = words[start_idx:end_idx] | |
| if not subseg_words: | |
| continue | |
| # Calculate segment timing and text from words | |
| subseg_start = float(subseg_words[0]["start"]) | |
| subseg_end = float(subseg_words[-1]["end"]) | |
| subseg_text = " ".join(w.get("word", "").strip() for w in subseg_words if w.get("word", "").strip()) | |
| # Create new sub-segment | |
| new_seg = { | |
| "start": subseg_start, | |
| "end": subseg_end, | |
| "text": subseg_text, | |
| "words": subseg_words, | |
| "duration": subseg_end - subseg_start, | |
| } | |
| # Copy over other fields from original segment if they exist | |
| for key in ["avg_logprob"]: | |
| if key in seg: | |
| new_seg[key] = seg[key] | |
| # Assign dominant speaker to this sub-segment | |
| self._assign_dominant_speaker_to_segment(new_seg, speaker_at, best_speaker_for_interval) | |
| split_segments.append(new_seg) | |
| # Update transcription_results with split segments | |
| transcription_results = split_segments | |
| # Identify diarization segments that have no overlapping transcription segments | |
| unmatched_diarization_segments = [] | |
| for dseg in diarization_segments: | |
| d_start = float(dseg["start"]) | |
| d_end = float(dseg["end"]) | |
| # Calculate total coverage | |
| total_coverage = 0.0 | |
| for s in transcription_results: | |
| overlap = interval_overlap(d_start, d_end, float(s["start"]), float(s["end"])) | |
| total_coverage += overlap | |
| coverage_ratio = total_coverage / (d_end - d_start) | |
| is_well_covered = coverage_ratio >= 0.85 # 85% or more covered | |
| if not is_well_covered and (d_end - d_start)*(1-coverage_ratio) > 1.5: # If poorly covered, add to unmatched list | |
| unmatched_diarization_segments.append({ | |
| "start": d_start, | |
| "end": d_end, | |
| "speaker": dseg["speaker"], | |
| }) | |
| print("unmatched_diarization_segments", unmatched_diarization_segments) | |
| return transcription_results, unmatched_diarization_segments | |
| def _assign_dominant_speaker_to_segment(self, seg, speaker_at_func, best_speaker_for_interval_func): | |
| """Assign dominant speaker to a segment based on word durations and boundary stabilization.""" | |
| words = seg.get("words", []) | |
| if not words: | |
| # No words: use segment-level overlap | |
| seg["speaker"] = best_speaker_for_interval_func(float(seg["start"]), float(seg["end"])) | |
| return | |
| # 1) Determine dominant speaker by summed word durations | |
| speaker_dur = {} | |
| total_word_dur = 0.0 | |
| for w in words: | |
| dur = max(0.0, float(w["end"]) - float(w["start"])) | |
| total_word_dur += dur | |
| spk = w.get("speaker", "SPEAKER_00") | |
| speaker_dur[spk] = speaker_dur.get(spk, 0.0) + dur | |
| if speaker_dur: | |
| dominant_speaker = max(speaker_dur.items(), key=lambda kv: kv[1])[0] | |
| else: | |
| dominant_speaker = speaker_at_func((float(seg["start"]) + float(seg["end"])) / 2.0) | |
| # 2) Boundary stabilization: relabel tiny prefix/suffix runs to dominant | |
| seg_duration = max(1e-6, float(seg["end"]) - float(seg["start"])) | |
| max_boundary_sec = 0.5 # hard cap for how much to relabel at edges | |
| max_boundary_frac = 0.2 # or up to 20% of the segment duration | |
| # prefix | |
| prefix_dur = 0.0 | |
| prefix_count = 0 | |
| for w in words: | |
| if w.get("speaker") == dominant_speaker: | |
| break | |
| prefix_dur += max(0.0, float(w["end"]) - float(w["start"])) | |
| prefix_count += 1 | |
| if prefix_count > 0 and prefix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration): | |
| for i in range(prefix_count): | |
| words[i]["speaker"] = dominant_speaker | |
| # suffix | |
| suffix_dur = 0.0 | |
| suffix_count = 0 | |
| for w in reversed(words): | |
| if w.get("speaker") == dominant_speaker: | |
| break | |
| suffix_dur += max(0.0, float(w["end"]) - float(w["start"])) | |
| suffix_count += 1 | |
| if suffix_count > 0 and suffix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration): | |
| for i in range(len(words) - suffix_count, len(words)): | |
| words[i]["speaker"] = dominant_speaker | |
| # 3) Final segment speaker | |
| seg["speaker"] = dominant_speaker | |
| def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0): | |
| """Group consecutive segments from the same speaker""" | |
| if not segments: | |
| return segments | |
| grouped_segments = [] | |
| current_group = segments[0].copy() | |
| sentence_end_pattern = r"[.!?]+" | |
| for segment in segments[1:]: | |
| time_gap = segment["start"] - current_group["end"] | |
| current_duration = current_group["end"] - current_group["start"] | |
| # Conditions for combining segments | |
| can_combine = ( | |
| segment["speaker"] == current_group["speaker"] and | |
| time_gap <= max_gap and | |
| current_duration < max_duration and | |
| not re.search(sentence_end_pattern, current_group["text"][-1:]) | |
| ) | |
| if can_combine: | |
| # Merge segments | |
| current_group["end"] = segment["end"] | |
| current_group["text"] += " " + segment["text"] | |
| current_group["words"].extend(segment["words"]) | |
| current_group["duration"] = current_group["end"] - current_group["start"] | |
| else: | |
| # Start new group | |
| grouped_segments.append(current_group) | |
| current_group = segment.copy() | |
| grouped_segments.append(current_group) | |
| # Clean up text | |
| for segment in grouped_segments: | |
| segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip() | |
| #segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"]) | |
| return grouped_segments | |
| def process_audio_transcribe(self, task_json, language=None, | |
| translate=False, prompt=None, batch_size=8, model_name: str = DEFAULT_MODEL): | |
| """Main processing function with diarization using task JSON for a single chunk. | |
| Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription. | |
| """ | |
| if not task_json or not str(task_json).strip(): | |
| return {"error": "No JSON provided"} | |
| pre_meta = None | |
| try: | |
| print("Starting new processing pipeline...") | |
| # Step 1: Preprocess per chunk JSON | |
| print("Preprocessing chunk JSON...") | |
| pre_meta = self.preprocess_from_task_json(task_json) | |
| #transcribe_options = pre_meta.get("options", None) | |
| if isinstance(pre_meta, list): | |
| return self.transcribe_segments(pre_meta, language, translate, prompt, batch_size, model_name) | |
| elif isinstance(pre_meta, dict) and "chunk" in pre_meta: | |
| return self.transcribe_chunk(pre_meta, language, translate, prompt, batch_size, model_name) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| def transcribe_chunk(self, pre_meta, language=None, | |
| translate=False, prompt=None, batch_size=8, model_name: str = DEFAULT_MODEL): | |
| """Main processing function with diarization using task JSON for a single chunk. | |
| Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription. | |
| """ | |
| try: | |
| transcribe_options = pre_meta.get("options", None) | |
| print("Transcribing chunk...") | |
| # Step 1: Preprocess per chunk JSON | |
| if pre_meta["chunk"].get("skip"): | |
| return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size} | |
| wav_path = pre_meta["chunk"]["out_wav_path"] | |
| base_offset_s = float(pre_meta["chunk"].get("abs_start_ms", 0)) / 1000.0 | |
| # Step 2: Transcribe full audio once | |
| transcription_results, detected_language = self.transcribe_full_audio( | |
| wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, engine=transcribe_options.get("engine", "whisperx"), model_name=model_name, transcribe_options=transcribe_options | |
| ) | |
| # Step 6: Return results | |
| result = { | |
| "segments": transcription_results, | |
| "language": detected_language, | |
| "batch_size": batch_size, | |
| } | |
| # job_id = pre_meta["job_id"] | |
| # task_id = pre_meta["chunk_idx"] | |
| filekey = pre_meta["filekey"]#f"ai-transcribe/split/{job_id}-{task_id}.json" | |
| ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey) | |
| if ret: | |
| return {"filekey": filekey} | |
| else: | |
| return {"error": "Failed to upload to R2"} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| finally: | |
| # Clean up preprocessed wav | |
| if pre_meta and pre_meta["chunk"].get("out_wav_path") and os.path.exists(pre_meta["chunk"]["out_wav_path"]): | |
| try: | |
| os.unlink(pre_meta["chunk"]["out_wav_path"]) | |
| except Exception: | |
| pass | |
| def transcribe_segments(self, pre_metas, language=None, | |
| translate=False, prompt=None, batch_size=8, model_name: str = DEFAULT_MODEL): | |
| """Main processing function with diarization using task JSON for a single chunk. | |
| Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription. | |
| """ | |
| try: | |
| print("Transcribing segments...") | |
| transcription_results = [] | |
| # Step 1: Preprocess per chunk JSON | |
| for pre_meta in pre_metas: | |
| transcribe_options = pre_meta.get("options", None) | |
| chunk = pre_meta["chunk"] | |
| if chunk.get("skip"): | |
| return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size} | |
| wav_path = chunk["out_wav_path"] | |
| base_offset_s = float(chunk.get("abs_start_ms", 0)) / 1000.0 | |
| # Step 2: Transcribe full audio once | |
| transcription_result, detected_language = self.transcribe_full_audio( | |
| wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, engine=transcribe_options.get("engine", "faster_whisper"), model_name=model_name, transcribe_options=transcribe_options | |
| ) | |
| # Step 6: Return results | |
| result = {} | |
| result.update(chunk) | |
| result["segments"] = transcription_result | |
| result["language"] = detected_language | |
| result["batch_size"] = batch_size | |
| transcription_results.append(result) | |
| # job_id = pre_meta["job_id"] | |
| # task_id = pre_meta["chunk_idx"] | |
| filekey = pre_meta["filekey"]#f"ai-transcribe/split/{job_id}-{task_id}.json" | |
| ret = upload_data_to_r2(json.dumps(transcription_results), "intermediate", filekey) | |
| if ret: | |
| return {"filekey": filekey} | |
| else: | |
| return {"error": "Failed to upload to R2"} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| finally: | |
| # Clean up preprocessed wav | |
| if pre_meta: | |
| for pre_meta in pre_metas: | |
| chunk = pre_meta["chunk"] | |
| if chunk.get("out_wav_path") and os.path.exists(chunk["out_wav_path"]): | |
| try: | |
| pass | |
| #os.unlink(chunk["out_wav_path"]) | |
| except Exception: | |
| pass | |
| # each call gets a GPU slice | |
| def process_audio_diarization(self, task_json, num_speakers=0): | |
| """Process audio for diarization only, returning speaker information. | |
| Args: | |
| task_json: Task JSON containing audio processing information | |
| num_speakers: Number of speakers (0 for auto-detection) | |
| Returns: | |
| str: filekey of uploaded JSON file containing diarization results | |
| """ | |
| if not task_json or not str(task_json).strip(): | |
| return {"error": "No JSON provided"} | |
| pre_meta = None | |
| try: | |
| print("Starting diarization-only pipeline...") | |
| # Step 1: Preprocess from task JSON | |
| print("Preprocessing chunk JSON...") | |
| pre_meta = self.preprocess_from_task_json(task_json) | |
| if pre_meta.get("skip"): | |
| # Return minimal result for skipped audio | |
| task = json.loads(task_json) | |
| job_id = task.get("job_id", "job") | |
| task_id = str(task["chunk"]["idx"]) | |
| result = { | |
| "num_speakers": 0, | |
| "speaker_embeddings": {} | |
| } | |
| filekey = pre_meta["filekey"]#f"ai-transcribe/split/{job_id}-{task_id}-diarization.json" | |
| ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey) | |
| if ret: | |
| return filekey | |
| else: | |
| return {"error": "Failed to upload to R2"} | |
| wav_path = pre_meta["chunk"]["out_wav_path"] | |
| base_offset_s = float(pre_meta["chunk"].get("abs_start_ms", 0)) / 1000.0 | |
| # Step 2: Perform diarization | |
| print("Performing diarization...") | |
| start_time = time.time() | |
| diarization_segments, detected_num_speakers, speaker_embeddings = self.perform_diarization( | |
| wav_path, num_speakers if num_speakers > 0 else None, base_offset_s=base_offset_s | |
| ) | |
| diarization_time = time.time() - start_time | |
| print(f"Diarization completed in {diarization_time:.2f} seconds") | |
| # Step 3: Compose JSON response | |
| result = { | |
| "num_speakers": detected_num_speakers, | |
| "speaker_embeddings": speaker_embeddings, | |
| "diarization_segments": diarization_segments, | |
| "idx": pre_meta["chunk"]["chunk_idx"], | |
| "abs_start_ms": pre_meta["chunk"]["abs_start_ms"], | |
| "dur_ms": pre_meta["chunk"]["dur_ms"], | |
| } | |
| if pre_meta.get("channel", None): | |
| result["channel"] = pre_meta["channel"] | |
| # set channel in each diarization segment | |
| for seg in diarization_segments: | |
| seg["channel"] = pre_meta["channel"] | |
| # Step 4: Upload to R2 | |
| #job_id = pre_meta["job_id"] | |
| #task_id = pre_meta["chunk_idx"] | |
| #filekey = f"ai-transcribe/split/{job_id}-{task_id}-diarization.json" | |
| filekey = pre_meta["filekey"] | |
| ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey) | |
| if ret: | |
| # Step 5: Return filekey | |
| return {"filekey": filekey} | |
| else: | |
| return {"error": "Failed to upload to R2"} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Diarization processing failed: {str(e)}"} | |
| finally: | |
| # Clean up preprocessed wav | |
| if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]): | |
| try: | |
| os.unlink(pre_meta["out_wav_path"]) | |
| except Exception: | |
| pass | |
| # each call gets a GPU slice | |
| def process_audio(self, task_json, num_speakers=None, language=None, | |
| translate=False, prompt=None, group_segments=True, batch_size=8, model_name: str = DEFAULT_MODEL): | |
| """Main processing function with diarization using task JSON for a single chunk. | |
| Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription. | |
| """ | |
| if not task_json or not str(task_json).strip(): | |
| return {"error": "No JSON provided"} | |
| pre_meta = None | |
| try: | |
| print("Starting new processing pipeline...") | |
| # Step 1: Preprocess per chunk JSON | |
| print("Preprocessing chunk JSON...") | |
| pre_meta = self.preprocess_from_task_json(task_json) | |
| if pre_meta.get("skip"): | |
| return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size} | |
| wav_path = pre_meta["out_wav_path"] | |
| base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0 | |
| # Step 3: Perform diarization with global offset | |
| diarization_segments, detected_num_speakers, speaker_embeddings = self.perform_diarization( | |
| wav_path, num_speakers, base_offset_s=base_offset_s | |
| ) | |
| # Convert diarization_segments to clip_timestamps format | |
| # Format: "start,end,start,end,..." with timestamps relative to the file (subtract base_offset_s) | |
| clip_timestamps_list = [] | |
| for seg in diarization_segments: | |
| # Convert global timestamps back to local file timestamps | |
| local_start = max(0.0, float(seg["start"]) - base_offset_s) | |
| local_end = max(local_start, float(seg["end"]) - base_offset_s) | |
| clip_timestamps_list.extend([str(local_start), str(local_end)]) | |
| clip_timestamps = ",".join(clip_timestamps_list) if clip_timestamps_list else None | |
| # Step 2: Transcribe full audio once | |
| transcription_results, detected_language = self.transcribe_full_audio( | |
| wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, clip_timestamps=None, model_name=model_name | |
| ) | |
| unmatched_diarization_segments = [] | |
| # Step 4: Merge diarization into transcription (assign speakers) | |
| transcription_results, unmatched_diarization_segments = self.assign_speakers_to_transcription( | |
| transcription_results, diarization_segments | |
| ) | |
| # Step 4.1: Transcribe diarization-only regions and merge | |
| if unmatched_diarization_segments: | |
| waveform, sample_rate = torchaudio.load(wav_path) | |
| extra_segments = [] | |
| for dseg in unmatched_diarization_segments: | |
| d_start = float(dseg["start"]) # global seconds | |
| d_end = float(dseg["end"]) # global seconds | |
| if d_end <= d_start: | |
| continue | |
| # Map global time to local file time | |
| local_start = max(0.0, d_start - float(base_offset_s)) | |
| local_end = max(local_start, d_end - float(base_offset_s)) | |
| start_sample = max(0, int(local_start * sample_rate)) | |
| end_sample = min(waveform.shape[1], int(local_end * sample_rate)) | |
| if end_sample <= start_sample: | |
| continue | |
| seg_wav = waveform[:, start_sample:end_sample].contiguous() | |
| tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp_path = tmp_f.name | |
| tmp_f.close() | |
| try: | |
| torchaudio.save(tmp_path, seg_wav.cpu(), sample_rate) | |
| seg_transcription, _ = self.transcribe_full_audio( | |
| tmp_path, | |
| language=language if language is not None else None, | |
| translate=translate, | |
| prompt=prompt, | |
| batch_size=batch_size, | |
| base_offset_s=d_start, | |
| model_name=model_name | |
| ) | |
| extra_segments.extend(seg_transcription) | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except Exception: | |
| pass | |
| if extra_segments: | |
| transcription_results.extend(extra_segments) | |
| transcription_results.sort(key=lambda s: float(s.get("start", 0.0))) | |
| # Re-assign speakers on the combined set | |
| transcription_results, _ = self.assign_speakers_to_transcription( | |
| transcription_results, diarization_segments | |
| ) | |
| # Step 5: Group segments if requested | |
| if group_segments: | |
| transcription_results = self.group_segments_by_speaker(transcription_results) | |
| # Step 6: Return results | |
| result = { | |
| "segments": transcription_results, | |
| "language": detected_language, | |
| "num_speakers": detected_num_speakers, | |
| "transcription_method": "diarized_segments_batched", | |
| "batch_size": batch_size, | |
| "speaker_embeddings": speaker_embeddings, | |
| } | |
| job_id = pre_meta["job_id"] | |
| task_id = pre_meta["chunk_idx"] | |
| filekey = f"ai-transcribe/split/{job_id}-{task_id}.json" | |
| ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey) | |
| if ret: | |
| return {"filekey": filekey} | |
| else: | |
| return {"error": "Failed to upload to R2"} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| finally: | |
| # Clean up preprocessed wav | |
| if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]): | |
| try: | |
| os.unlink(pre_meta["out_wav_path"]) | |
| except Exception: | |
| pass | |
| # Initialize transcriber | |
| transcriber = WhisperTranscriber() | |
| def format_segments_for_display(result): | |
| """Format segments for display in Gradio""" | |
| if "error" in result: | |
| return f"β Error: {result['error']}" | |
| segments = result.get("segments", []) | |
| language = result.get("language", "unknown") | |
| num_speakers = result.get("num_speakers", 1) | |
| method = result.get("transcription_method", "unknown") | |
| batch_size = result.get("batch_size", "N/A") | |
| output = f"π― **Detection Results:**\n" | |
| output += f"- Language: {language}\n" | |
| output += f"- Speakers: {num_speakers}\n" | |
| output += f"- Segments: {len(segments)}\n" | |
| output += f"- Method: {method}\n" | |
| output += f"- Batch Size: {batch_size}\n\n" | |
| output += "π **Transcription:**\n\n" | |
| for i, segment in enumerate(segments, 1): | |
| start_time = str(datetime.timedelta(seconds=int(segment["start"]))) | |
| end_time = str(datetime.timedelta(seconds=int(segment["end"]))) | |
| speaker = segment.get("speaker", "SPEAKER_00") | |
| text = segment["text"] | |
| output += f"**{speaker}** ({start_time} β {end_time})\n" | |
| output += f"{text}\n\n" | |
| return output | |
| def audio_diarization_task(task_json, num_speakers): | |
| """Gradio interface function""" | |
| result = transcriber.process_audio_diarization( | |
| task_json=task_json, | |
| num_speakers=num_speakers if num_speakers > 0 else 0, | |
| ) | |
| #formatted_output = format_segments_for_display(result) | |
| return "OK", result | |
| def audio_transcribe_task(task_json, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size, model_name): | |
| """Gradio interface function""" | |
| result = transcriber.process_audio_transcribe( | |
| task_json=task_json, | |
| language=language if language != "auto" else None, | |
| translate=translate, | |
| prompt=prompt if prompt and prompt.strip() else None, | |
| batch_size=batch_size, | |
| model_name=model_name | |
| ) | |
| ''' | |
| result = transcriber.process_audio_transcribe( | |
| task_json=task_json, | |
| language=language if language != "auto" else None, | |
| translate=translate, | |
| prompt=prompt if prompt and prompt.strip() else None, | |
| batch_size=batch_size, | |
| model_name=model_name | |
| ) | |
| ''' | |
| #formatted_output = format_segments_for_display(result) | |
| return "OK", result | |
| # Create Gradio interface | |
| demo = gr.Blocks( | |
| title="ποΈ Whisper Transcription with Speaker Diarization", | |
| theme="default" | |
| ) | |
| with demo: | |
| gr.Markdown(""" | |
| # ποΈ Advanced Audio Transcription & Speaker Diarization | |
| Upload an audio file to get accurate transcription with speaker identification, powered by: | |
| - **Faster-Whisper Large V3 Turbo** with batched inference for optimal performance | |
| - **Pyannote 3.1** for speaker diarization | |
| - **ZeroGPU** acceleration for optimal performance | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| task_json_input = gr.Textbox( | |
| label="π§Ύ Paste Task JSON", | |
| placeholder="Paste the per-chunk task JSON here...", | |
| lines=16, | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| model_name_dropdown = gr.Dropdown( | |
| label="Whisper Model", | |
| choices=list(MODELS.keys()), | |
| value=DEFAULT_MODEL, | |
| info="Select the Whisper model to use for transcription." | |
| ) | |
| use_diarization = gr.Checkbox( | |
| label="Enable Speaker Diarization", | |
| value=True, | |
| info="Uncheck for faster transcription without speaker identification" | |
| ) | |
| batch_size = gr.Slider( | |
| minimum=1, | |
| maximum=128, | |
| value=16, | |
| step=1, | |
| label="Batch Size", | |
| info="Higher values = faster processing but more GPU memory usage. Recommended: 8-24" | |
| ) | |
| num_speakers = gr.Slider( | |
| minimum=0, | |
| maximum=20, | |
| value=0, | |
| step=1, | |
| label="Number of Speakers (0 = auto-detect)", | |
| visible=True | |
| ) | |
| language = gr.Dropdown( | |
| choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
| value="auto", | |
| label="Language" | |
| ) | |
| translate = gr.Checkbox( | |
| label="Translate to English", | |
| value=False | |
| ) | |
| prompt = gr.Textbox( | |
| label="Vocabulary Prompt (names, acronyms, etc.)", | |
| placeholder="Enter names, technical terms, or context...", | |
| lines=2 | |
| ) | |
| group_segments = gr.Checkbox( | |
| label="Group segments by speaker/time", | |
| value=True | |
| ) | |
| process_btn = gr.Button("π Audio Transcribe Task", variant="primary") | |
| process_btn1 = gr.Button("π Audio Diarization Task", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Markdown( | |
| label="π Transcription Results", | |
| value="Paste task JSON and click 'Transcribe Audio' to get started!" | |
| ) | |
| output_json = gr.JSON( | |
| label="π§ Raw Output (JSON)", | |
| visible=False | |
| ) | |
| # Update visibility of num_speakers based on diarization toggle | |
| use_diarization.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=[use_diarization], | |
| outputs=[num_speakers] | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=audio_transcribe_task, | |
| inputs=[ | |
| task_json_input, | |
| num_speakers, | |
| language, | |
| translate, | |
| prompt, | |
| group_segments, | |
| use_diarization, | |
| batch_size, | |
| model_name_dropdown | |
| ], | |
| outputs=[output_text, output_json] | |
| ) | |
| process_btn1.click( | |
| fn=audio_diarization_task, | |
| inputs=[ | |
| task_json_input, | |
| num_speakers | |
| ], | |
| outputs=[output_text, output_json] | |
| ) | |
| # Examples | |
| gr.Markdown("### π Usage Tips:") | |
| gr.Markdown(""" | |
| - Paste a single-chunk task JSON matching the preprocess schema | |
| - Batch Size: Higher values (16-24) = faster but uses more GPU memory | |
| - Speaker diarization: Enable for speaker identification (slower) | |
| - Languages: Supports 100+ languages with auto-detection | |
| - Vocabulary: Add names and technical terms in the prompt for better accuracy | |
| """) | |
| # Note: WhisperX transcribe models are loaded lazily on first use within GPU context | |
| # This is because @spaces.GPU creates separate contexts, so preloading at startup won't work | |
| print("WhisperX transcribe models will be loaded on first use (lazy loading)...") | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |