Spaces:
Running
Running
| """ | |
| Gradio app wrapping your diarization + separation + enhancement + transcription pipeline. | |
| """ | |
| import os | |
| import tempfile | |
| import math | |
| import json | |
| import shutil | |
| import time | |
| from datetime import timedelta | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import re | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import noisereduce as nr | |
| import gradio as gr | |
| # Lazy imports (heavy models) will be done inside the worker function | |
| # to keep the app responsive on startup. | |
| # ----------------------- | |
| # Configuration defaults | |
| # ----------------------- | |
| SAMPLE_RATE = 16000 | |
| CHUNK_DURATION = 8.0 | |
| KEYWORDS = ["red", "yellow", "green"] | |
| HF_TOKEN_E = os.environ.get("HF_TOKEN_E") | |
| # ----------------------- | |
| # Helper utilities | |
| # ----------------------- | |
| def time_to_samples(t: float, sr: int) -> int: | |
| return int(round(t * sr)) | |
| def save_wav(path: str, data: np.ndarray, sr: int = SAMPLE_RATE): | |
| sf.write(path, data.astype(np.float32), sr) | |
| # ----------------------- | |
| # Transcription helper | |
| # ----------------------- | |
| def transcribe_audio_array_with_whisper(audio: np.ndarray, sr: int, whisper_model) -> dict: | |
| """Whisper expects a file path; write to temp wav then transcribe.""" | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| try: | |
| sf.write(tmp.name, audio.astype(np.float32), sr) | |
| res = whisper_model.transcribe(tmp.name, task="transcribe", fp16=False, language=None) | |
| return res | |
| except Exception as e: | |
| return {"text": "", "segments": []} | |
| finally: | |
| try: | |
| tmp.close() | |
| os.unlink(tmp.name) | |
| except Exception: | |
| pass | |
| def transcribe_file_with_whisper(wav_path: str, whisper_model) -> dict: | |
| try: | |
| res = whisper_model.transcribe(wav_path, task="transcribe", fp16=False, language=None) | |
| return res | |
| except Exception as e: | |
| return {"text": "", "segments": []} | |
| # ----------------------- | |
| # Keyword finder | |
| # ----------------------- | |
| def find_keywords_in_text(text: str, keywords: List[str]) -> List[Tuple[str, int]]: | |
| found = [] | |
| for kw in keywords: | |
| for match in re.finditer(rf"\b{re.escape(kw)}\b", text, flags=re.IGNORECASE): | |
| found.append((kw, match.start())) | |
| return found | |
| # ----------------------- | |
| # Main pipeline (wrapped for Gradio streaming) | |
| # ----------------------- | |
| def pipeline_worker(video_file_path: str, keywords: List[str]): | |
| """ | |
| Generator function that yields progress logs and finally returns (log, file_list, keyword_log, transcripts_json_path) | |
| The Gradio interface will call this function and stream the logs. | |
| """ | |
| # Prepare temporary output directory per-run | |
| run_dir = tempfile.mkdtemp(prefix="diarize_run_") | |
| out_dir = os.path.join(run_dir, "out") | |
| os.makedirs(out_dir, exist_ok=True) | |
| logs = [] | |
| def emit(message: str): | |
| nonlocal logs | |
| logs.append(message) | |
| yield "\n".join(logs), "", "", "" | |
| # 1) Convert mp4 to wav (use moviepy) | |
| yield from emit(f"Starting run — saving outputs to: {out_dir}") | |
| try: | |
| from moviepy.editor import VideoFileClip | |
| except Exception as e: | |
| yield from emit(f"ERROR: moviepy import failed: {e}") | |
| return | |
| wav_path = os.path.join(run_dir, "input_audio.wav") | |
| try: | |
| yield from emit("Extracting audio from video...") | |
| clip = VideoFileClip(video_file_path) | |
| clip.audio.write_audiofile(wav_path, codec="pcm_s16le") | |
| clip.close() | |
| yield from emit(f"Saved extracted audio: {wav_path}") | |
| except Exception as e: | |
| yield from emit(f"ERROR extracting audio: {e}") | |
| return | |
| # 2) Load audio (librosa) | |
| try: | |
| y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True) | |
| duration = len(y) / sr | |
| yield from emit(f"Loaded audio: {duration:.1f}s @ {sr}Hz") | |
| except Exception as e: | |
| yield from emit(f"ERROR loading audio: {e}") | |
| return | |
| # Lazy-load heavy models | |
| yield from emit("Loading diarization & embedding models (this can take a while)...") | |
| HF_TOKEN = os.environ.get("HF_TOKEN_1") | |
| try: | |
| from pyannote.audio import Pipeline, Model | |
| # diarize_pipeline = Pipeline.from_pretrained("pyannote/[email protected]", use_auth_token=HF_TOKEN) | |
| diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=HF_TOKEN_E) | |
| embedding_model = Model.from_pretrained("pyannote/embedding", use_auth_token=HF_TOKEN_E) | |
| yield from emit("pyannote models loaded.") | |
| except Exception as e: | |
| yield from emit(f"WARNING: pyannote models failed to load: {e}\nDiarization may not work.") | |
| diarize_pipeline = None | |
| embedding_model = None | |
| # Load separation & enhancement (speechbrain) lazily | |
| try: | |
| from speechbrain.pretrained import SepformerSeparation as Sepformer | |
| from speechbrain.pretrained import SpectralMaskEnhancement as Enhancer | |
| sepformer = Sepformer.from_hparams(source="speechbrain/sepformer-whamr", savedir=os.path.join(run_dir, "tmp_speechbrain_sepformer")) | |
| enhancer = Enhancer.from_hparams(source="speechbrain/metricgan-plus-voicebank", savedir=os.path.join(run_dir, "tmp_speechbrain_enh")) | |
| yield from emit("Speechbrain sepformer + enhancer loaded.") | |
| except Exception as e: | |
| yield from emit(f"WARNING: speechbrain models failed to load: {e}\nSeparation/enhancement fallbacks will be used.") | |
| sepformer = None | |
| enhancer = None | |
| # Load whisper model lazily | |
| try: | |
| import whisper | |
| whisper_model = whisper.load_model("large-v3", device="cpu") | |
| yield from emit("Whisper loaded (large-v3) on CPU.") | |
| except Exception as e: | |
| yield from emit(f"ERROR loading Whisper model: {e}") | |
| whisper_model = None | |
| # run diarization | |
| if diarize_pipeline is None: | |
| yield from emit("Skipping diarization (pipeline unavailable). Creating single ""speaker_0"" segment covering full audio.") | |
| diarization = None | |
| speakers = ["SPEAKER_0"] | |
| segments = [ (0.0, duration, "SPEAKER_0") ] | |
| else: | |
| yield from emit("Running diarization... This may take a while.") | |
| try: | |
| diarization = diarize_pipeline({"audio": wav_path}) | |
| speakers = sorted({label for segment, track, label in diarization.itertracks(yield_label=True)}) | |
| yield from emit(f"Detected speakers: {speakers}") | |
| except Exception as e: | |
| yield from emit(f"ERROR during diarization: {e}") | |
| diarization = None | |
| speakers = ["SPEAKER_0"] | |
| # Prepare speaker buffers | |
| speaker_buffers = {sp: [] for sp in speakers} | |
| transcriptions = [] | |
| # Helper to compute embedding from numpy audio (if model available) | |
| def embedding_from_audio(audio_np: np.ndarray): | |
| if embedding_model is None: | |
| return np.zeros((1, 256)) | |
| waveform = audio_np.reshape(1, -1) | |
| try: | |
| emb = embedding_model({'waveform': waveform, 'sample_rate': SAMPLE_RATE}) | |
| return emb.data.numpy().reshape(1, -1) | |
| except Exception: | |
| return np.zeros((1, 256)) | |
| # Iterate through diarized segments (or single fallback) | |
| yield from emit("Processing diarized segments (separation/enhancement/transcription)...") | |
| if diarization is None: | |
| segments_iter = [(0.0, duration, "SPEAKER_0")] | |
| else: | |
| segments_iter = [(seg.start, seg.end, lbl) for seg, _, lbl in diarization.itertracks(yield_label=True)] | |
| for idx, (start, end, label) in enumerate(segments_iter): | |
| seg_dur = end - start | |
| a_samp = time_to_samples(start, sr) | |
| b_samp = time_to_samples(end, sr) | |
| seg_audio = y[a_samp:b_samp] | |
| yield from emit(f"Segment {idx+1}/{len(segments_iter)}: {label} [{start:.2f}-{end:.2f}] ({seg_dur:.2f}s)") | |
| # Detect overlaps (simple check) | |
| is_overlap = False | |
| if diarization is not None: | |
| overlapped_labels = [lbl for s2, _, lbl in diarization.itertracks(yield_label=True) if s2.start < end and s2.end > start and lbl != label] | |
| is_overlap = len(overlapped_labels) > 0 | |
| # Non-overlap & short => enhance and append | |
| if not is_overlap and seg_dur <= CHUNK_DURATION: | |
| # attempt enhancer | |
| try: | |
| if enhancer is not None: | |
| import torch | |
| wav_tensor = torch.tensor(seg_audio).float().unsqueeze(0) | |
| enhanced = enhancer.enhance_batch(wav_tensor).squeeze(0).numpy() | |
| else: | |
| raise Exception("enhancer unavailable") | |
| except Exception: | |
| enhanced = nr.reduce_noise(y=seg_audio, sr=sr) | |
| speaker_buffers[label].append(enhanced.flatten()) | |
| # transcribe | |
| if whisper_model is not None: | |
| try: | |
| res = transcribe_audio_array_with_whisper(enhanced, sr, whisper_model) | |
| transcript_text = res.get("text", "").strip() | |
| except Exception: | |
| transcript_text = "[Transcription failed]" | |
| else: | |
| transcript_text = "[Whisper unavailable]" | |
| transcriptions.append({ | |
| "speaker": label, | |
| "start": float(start), | |
| "end": float(end), | |
| "duration": float(seg_dur), | |
| "text": transcript_text, | |
| }) | |
| else: | |
| # Overlapped or long: chunk, separate, embed, match to prototypes | |
| samples = seg_audio | |
| n_chunks = max(1, math.ceil(len(samples) / int(CHUNK_DURATION * sr))) | |
| chunk_size = int(len(samples) / n_chunks) | |
| for i in range(n_chunks): | |
| a = i * chunk_size | |
| b = min(len(samples), (i + 1) * chunk_size) | |
| chunk = samples[a:b] | |
| if len(chunk) < 100: | |
| continue | |
| # Try sepformer separation | |
| est_sources = None | |
| try: | |
| if sepformer is not None: | |
| # speechbrain sepformer has a separate_file_chunkwise or separate_file; attempt both | |
| try: | |
| est_sources = sepformer.separate_file_chunkwise(batch_audio=chunk, sample_rate=sr) | |
| except Exception: | |
| tmpf = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| sf.write(tmpf.name, chunk, sr) | |
| est = sepformer.separate_file(tmpf.name) | |
| tmpf.close() | |
| os.unlink(tmpf.name) | |
| est_sources = est | |
| except Exception: | |
| est_sources = None | |
| if est_sources is None: | |
| # fallback: attempt simple split into two channels (if mono, duplicate) — conservative fallback | |
| est_sources = [chunk, chunk] | |
| # Compute embeddings | |
| embeddings = [] | |
| for src in est_sources: | |
| try: | |
| emb = embedding_from_audio(np.asarray(src).flatten()) | |
| except Exception: | |
| emb = np.zeros((1, 256)) | |
| embeddings.append(emb) | |
| # Speaker prototypes | |
| speaker_protos = {} | |
| for sp in speakers: | |
| if len(speaker_buffers[sp]) > 0: | |
| ex = np.concatenate([np.asarray(p).flatten() for p in speaker_buffers[sp][:1]]) | |
| speaker_protos[sp] = embedding_from_audio(ex) | |
| else: | |
| speaker_protos[sp] = None | |
| for src_idx, emb in enumerate(embeddings): | |
| best_sp, best_sim = None, -1 | |
| for sp in speakers: | |
| proto = speaker_protos[sp] | |
| if proto is None: | |
| continue | |
| try: | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| sim = cosine_similarity(emb, proto)[0, 0] | |
| except Exception: | |
| sim = -1 | |
| if sim > best_sim: | |
| best_sim = sim | |
| best_sp = sp | |
| assign_to = best_sp if best_sp is not None else speakers[src_idx % len(speakers)] | |
| speaker_buffers[assign_to].append(np.asarray(est_sources[src_idx]).flatten()) | |
| # Transcribe separated chunk | |
| if whisper_model is not None: | |
| try: | |
| res = transcribe_audio_array_with_whisper(np.asarray(est_sources[src_idx]).flatten(), sr, whisper_model) | |
| transcript_text = res.get("text", "").strip() | |
| except Exception: | |
| transcript_text = "[Transcription failed]" | |
| else: | |
| transcript_text = "[Whisper unavailable]" | |
| transcriptions.append({ | |
| "speaker": assign_to, | |
| "start": float(start + a / sr), | |
| "end": float(start + b / sr), | |
| "duration": float((b - a) / sr), | |
| "text": transcript_text, | |
| }) | |
| # Emit progress after each segment | |
| yield from emit(f"Processed segment {idx+1}/{len(segments_iter)}") | |
| # After processing all segments: write per-speaker concatenated wavs | |
| yield from emit("Concatenating speaker buffers and saving speaker wav files...") | |
| generated_files = [] | |
| for sp, pieces in speaker_buffers.items(): | |
| if len(pieces) == 0: | |
| continue | |
| out = np.concatenate([np.asarray(p).flatten() for p in pieces]) | |
| out_path = os.path.join(out_dir, f"{sp}.wav") | |
| save_wav(out_path, out, sr) | |
| generated_files.append(out_path) | |
| yield from emit(f"Saved speaker file: {out_path}") | |
| # Build residual noise track (simple reconstruction) | |
| yield from emit("Building residual noise track...") | |
| recon = np.zeros_like(y) | |
| cursor = 0 | |
| for sp, pieces in speaker_buffers.items(): | |
| if len(pieces) == 0: | |
| continue | |
| recon_piece = np.concatenate([np.asarray(p).flatten() for p in pieces]) | |
| length = min(len(recon_piece), len(recon) - cursor) | |
| if length <= 0: | |
| continue | |
| recon[cursor:cursor+length] += recon_piece[:length] | |
| cursor += length | |
| residual = y - recon | |
| residual_path = os.path.join(out_dir, "noise_residual.wav") | |
| save_wav(residual_path, residual, sr) | |
| generated_files.append(residual_path) | |
| yield from emit(f"Saved residual: {residual_path}") | |
| # Save timestamped transcriptions (from the `transcriptions` built earlier) | |
| transcript_file = os.path.join(out_dir, "timestamped_transcriptions.json") | |
| with open(transcript_file, "w", encoding="utf-8") as f: | |
| json.dump(transcriptions, f, indent=2, ensure_ascii=False) | |
| generated_files.append(transcript_file) | |
| yield from emit(f"Saved timestamped transcriptions: {transcript_file}") | |
| # Run a second pass: run whisper on each speaker file for segments (detailed JSON) | |
| yield from emit("Running final Whisper pass on each speaker file to produce detailed transcripts...") | |
| detailed_paths = [] | |
| for sp in speakers: | |
| sp_wav_path = os.path.join(out_dir, f"{sp}.wav") | |
| if not os.path.exists(sp_wav_path): | |
| continue | |
| if whisper_model is not None: | |
| res = transcribe_file_with_whisper(sp_wav_path, whisper_model) | |
| text = res.get("text", "").strip() | |
| segments = res.get("segments", []) | |
| else: | |
| text = "" | |
| segments = [] | |
| json_path = os.path.join(out_dir, f"{sp}_transcript.json") | |
| with open(json_path, "w", encoding="utf-8") as fj: | |
| json.dump({"speaker": sp, "text": text, "segments": segments}, fj, indent=2, ensure_ascii=False) | |
| detailed_paths.append(json_path) | |
| generated_files.append(json_path) | |
| yield from emit(f"Saved detailed JSON: {json_path}") | |
| # Keyword scanning | |
| yield from emit("Scanning transcripts for keywords...") | |
| keyword_log_lines = [] | |
| for sp in speakers: | |
| json_path = os.path.join(out_dir, f"{sp}_transcript.json") | |
| if not os.path.exists(json_path): | |
| continue | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| text = data.get("text", "") | |
| segments = data.get("segments", []) | |
| if segments: | |
| for seg in segments: | |
| seg_text = seg.get("text", "") | |
| seg_start = seg.get("start", 0) | |
| seg_end = seg.get("end", 0) | |
| hits = find_keywords_in_text(seg_text, keywords) | |
| if hits: | |
| s_td = str(timedelta(seconds=float(seg_start))) | |
| e_td = str(timedelta(seconds=float(seg_end))) | |
| line = f"Speaker: {sp} [{s_td} --> {e_td}] Text: {seg_text.strip()}" | |
| keyword_log_lines.append(line) | |
| else: | |
| hits = find_keywords_in_text(text, keywords) | |
| if hits: | |
| line = f"Speaker: {sp} [No segment timestamps available] Excerpt: {text.strip()[:200]}" | |
| keyword_log_lines.append(line) | |
| if len(keyword_log_lines) == 0: | |
| keyword_log = "No keyword matches found." | |
| else: | |
| keyword_log = "\n".join(keyword_log_lines) | |
| yield from emit("Keyword scan complete.") | |
| # Final return: logs, list of generated files (as newline list), keywords, path to timestamped JSON | |
| file_list_text = "\n".join(generated_files) | |
| yield "\n".join(logs), file_list_text, keyword_log, transcript_file | |
| # # ----------------------- | |
| # # Gradio UI | |
| # # ----------------------- | |
| # def build_interface(): | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("# Voice Analysis (Diarisation and Signal Identification)\nUpload an MP4 and click Run to start analysis.") | |
| # with gr.Row(): | |
| # video_in = gr.Video(label="Input video (.mp4)") | |
| # keywords_in = gr.Textbox(value=",".join(KEYWORDS), label="Keywords (comma separated)") | |
| # run_btn = gr.Button("Run") | |
| # with gr.Row(): | |
| # # logs_out = gr.Textbox(label="Progress logs", lines=20) | |
| # # files_out = gr.Textbox(label="Generated files (saved in temp run folder)", lines=20) | |
| # keywords_out = gr.Textbox(label="Keyword matches (console-style)", lines=5) | |
| # transcript_json_out = gr.Textbox(label="Timestamped transcript JSON path") | |
| # # Loading indicator (spinner) | |
| # with gr.Row(): | |
| # status_msg = gr.Markdown("⏳ *Idle...*") | |
| # # Add a JSON viewer for transcript preview | |
| # with gr.Accordion("📜 View Detailed Transcript JSON", open=False): | |
| # transcript_view = gr.JSON(label="Transcript Data (Timestamps + Text)") | |
| # # Function to open and display transcript JSON file | |
| # def open_transcript_json(json_path): | |
| # if not os.path.exists(json_path): | |
| # return {"error": "File not found"} | |
| # try: | |
| # with open(json_path, "r", encoding="utf-8") as f: | |
| # data = json.load(f) | |
| # return data | |
| # except Exception as e: | |
| # return {"error": str(e)} | |
| # # Button to view JSON file content | |
| # view_btn = gr.Button("Open Transcript JSON") | |
| # view_btn.click(fn=open_transcript_json, inputs=transcript_json_out, outputs=transcript_view) | |
| # def run_and_stream(video_path, keywords_text, progress=gr.Progress(track_tqdm=True)): | |
| # progress(0, desc="Starting analysis...") | |
| # keys = [k.strip() for k in keywords_text.split(",") if k.strip()] | |
| # gen = pipeline_worker(video_path, keys) | |
| # for out in gen: | |
| # yield out | |
| # # Update status to "Processing..." | |
| # yield "Processing...", "", "⏳ **Processing... Please wait.**" | |
| # for out in pipeline_worker(video_path, keys): | |
| # progress(0.5, desc="Running pipeline...") | |
| # yield out, "", "⚙️ **Working...**" | |
| # # Done | |
| # progress(1, desc="Completed!") | |
| # yield "Processing done", "Processing complete", "✅ **Processing done!**" | |
| # # ----------------------- | |
| # # Attach button to function | |
| # # ----------------------- | |
| # run_btn.click( | |
| # fn=run_and_stream, | |
| # inputs=[video_in, keywords_in], | |
| # outputs=[keywords_out, transcript_json_out, status_msg] | |
| # ) | |
| # # def run_and_stream(video_path, keywords_text): | |
| # # keys = [k.strip() for k in keywords_text.split(",") if k.strip()] | |
| # # gen = pipeline_worker(video_path, keys) | |
| # # for out in gen: | |
| # # yield out | |
| # # yield "Processing done", "Output is ready" | |
| # # # run_btn.click(fn=run_and_stream, inputs=[video_in, keywords_in], outputs=[logs_out, files_out, keywords_out, transcript_json_out]) | |
| # # run_btn.click(fn=run_and_stream, inputs=[video_in, keywords_in], outputs=[keywords_out, transcript_json_out]) | |
| # return demo | |
| # ----------------------- | |
| # Gradio UI | |
| # ----------------------- | |
| def build_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Voice Analysis (Diarisation and Signal Identification)\nUpload an MP4 and click Run to start analysis.") | |
| with gr.Row(): | |
| video_in = gr.Video(label="Input video (.mp4)") | |
| keywords_in = gr.Textbox(value=",".join(KEYWORDS), label="Keywords (comma separated)") | |
| run_btn = gr.Button("Run") | |
| with gr.Row(): | |
| logs_out = gr.Textbox(label="Progress logs", lines=20) | |
| files_out = gr.Textbox(label="Generated files (saved in temp run folder)", lines=20) | |
| with gr.Row(): | |
| keywords_out = gr.Textbox(label="Keyword matches (console-style)", lines=5) | |
| transcript_json_out = gr.Textbox(label="Timestamped transcript JSON path") | |
| # Add a JSON viewer for transcript preview | |
| with gr.Accordion("📜 View Detailed Transcript JSON", open=False): | |
| transcript_view = gr.JSON(label="Transcript Data (Timestamps + Text)") | |
| # Function to open and display transcript JSON file | |
| def open_transcript_json(json_path): | |
| if not os.path.exists(json_path): | |
| return {"error": "File not found"} | |
| try: | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| return data | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Button to view JSON file content | |
| view_btn = gr.Button("Open Transcript JSON") | |
| view_btn.click(fn=open_transcript_json, inputs=transcript_json_out, outputs=transcript_view) | |
| def run_and_stream(video_path, keywords_text): | |
| keys = [k.strip() for k in keywords_text.split(",") if k.strip()] | |
| gen = pipeline_worker(video_path, keys) | |
| for out in gen: | |
| yield out | |
| run_btn.click(fn=run_and_stream, inputs=[video_in, keywords_in], outputs=[logs_out, files_out, keywords_out, transcript_json_out]) | |
| # run_btn.click(fn=run_and_stream, inputs=[video_in, keywords_in], outputs=[keywords_out, transcript_json_out]) | |
| return demo | |
| app = build_interface() | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |