Spaces:
Runtime error
Runtime error
| import logging | |
| import math | |
| import os | |
| import shutil | |
| import time | |
| from datasets import load_dataset | |
| import gradio as gr | |
| import moviepy.editor as mp | |
| import numpy as np | |
| import pysrt | |
| import re | |
| import torch | |
| from transformers import pipeline | |
| import yt_dlp | |
| os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True) | |
| LOG = logging.getLogger(__name__) | |
| BASEDIR = '/tmp/demo' | |
| os.makedirs(BASEDIR, exist_ok=True) | |
| CLIP_SECONDS = 20 | |
| SLICES = 4 | |
| # SLICE_DURATION = CLIP_SECONDS / SLICES | |
| # At most 6 mins | |
| MAX_CHUNKS = 45 | |
| SENTENCE_SPLIT = re.compile(r'([^.?!]*[.?!]+)([^.?!].*|$)') | |
| asr_kwargs = { | |
| "task": "automatic-speech-recognition", | |
| "model": "openai/whisper-medium.en" | |
| } | |
| translator_kwargs = { | |
| "task": "translation_en_to_fr", | |
| "model": "Helsinki-NLP/opus-mt-en-fr" | |
| } | |
| summarizer_kwargs = { | |
| "task": "summarization", | |
| "model": "facebook/bart-large-cnn" | |
| } | |
| if torch.cuda.is_available(): | |
| LOG.info("GPU available") | |
| asr_kwargs['device'] = 'cuda:0' | |
| translator_kwargs['device'] = 'cuda:0' | |
| summarizer_kwargs['device'] = 'cuda:0' | |
| # All three models should fit together on a single T4 GPU | |
| LOG.info("Fetching ASR model from the Hub if not already there") | |
| asr = pipeline(**asr_kwargs) | |
| LOG.info("Fetching translation model from the Hub if not already there") | |
| translator = pipeline(**translator_kwargs) | |
| LOG.info("Fetching summarization model from the Hub if not already there") | |
| summarizer = pipeline(**summarizer_kwargs) | |
| def demo(url: str, translate: bool): | |
| # Free disk space leak | |
| basedir = BASEDIR | |
| LOG.info("Base directory %s", basedir) | |
| video_path, video = download(url, os.path.join(basedir, 'video.mp4')) | |
| audio_clips(video, basedir) | |
| srt_file, full_transcription, summary = process_video(basedir, video.duration, translate) | |
| return summary, srt_file, [video_path, srt_file], full_transcription | |
| def download(url, dst): | |
| LOG.info("Downloading provided url %s", url) | |
| opts = { | |
| 'skip_download': False, | |
| 'overwrites': True, | |
| 'format': 'mp4', | |
| 'outtmpl': {'default': dst} | |
| } | |
| with yt_dlp.YoutubeDL(opts) as dl: | |
| dl.download([url]) | |
| return dst, mp.VideoFileClip(dst) | |
| def audiodir(basedir): | |
| return os.path.join(basedir, 'audio') | |
| def audio_clips(video: mp.VideoFileClip, basedir: str): | |
| LOG.info("Building audio clips") | |
| clips_dir = audiodir(basedir) | |
| shutil.rmtree(clips_dir, ignore_errors=True) | |
| os.makedirs(clips_dir, exist_ok=True) | |
| audio = video.audio | |
| end = audio.duration | |
| digits = int(math.log(end / CLIP_SECONDS, 10)) + 1 | |
| for idx, i in enumerate(range(0, int(end), CLIP_SECONDS)): | |
| sub_end = min(i+CLIP_SECONDS, end) | |
| # print(sub_end) | |
| sub_clip = audio.subclip(t_start=i, t_end=sub_end) | |
| audio_file = os.path.join(clips_dir, f"audio_{idx:0{digits}d}" + ".ogg") | |
| # audio_file = os.path.join(AUDIO_CLIPS, "audio_" + str(idx)) | |
| sub_clip.write_audiofile(audio_file, fps=16000) | |
| def process_video(basedir: str, duration, translate: bool): | |
| audio_dir = audiodir(basedir) | |
| transcriptions = transcription(audio_dir, duration) | |
| subs = translation(transcriptions, translate) | |
| srt_file = build_srt_clips(subs, basedir) | |
| summary = summarize(transcriptions, translate) | |
| return srt_file, ' '.join([s['text'].strip() for s in subs]).strip(), summary | |
| def transcription(audio_dir: str, duration): | |
| LOG.info("Audio transcription") | |
| # Not exact, nvm, doesn't need to be | |
| chunks = int(duration / CLIP_SECONDS + 1) | |
| chunks = min(chunks, MAX_CHUNKS) | |
| LOG.debug("Loading audio clips dataset") | |
| dataset = load_dataset("audiofolder", data_dir=audio_dir) | |
| dataset = dataset['train'] | |
| dataset = dataset['audio'][0:chunks] | |
| start = time.time() | |
| transcriptions = [] | |
| for i, d in enumerate(np.array_split(dataset, 5)): | |
| d = list(d) | |
| LOG.info("ASR batch %d / 5, samples %d", i, len(d)) | |
| t = asr(d, max_new_tokens=10000) | |
| transcriptions.extend(t) | |
| transcriptions = [ | |
| { | |
| 'text': t['text'].strip(), | |
| 'start': i * CLIP_SECONDS * 1000, | |
| 'end': (i + 1) * CLIP_SECONDS * 1000 | |
| } for i, t in enumerate(transcriptions) | |
| ] | |
| if transcriptions: | |
| transcriptions[0]['start'] += 2500 | |
| # Will improve the translation | |
| segments = segments_on_sentence_boundaries(transcriptions) | |
| elapsed = time.time() - start | |
| LOG.info("Transcription done, elapsed %.2f seconds", elapsed) | |
| return segments | |
| def segments_on_sentence_boundaries(segments): | |
| LOG.info("Segmenting along sentence boundaries for better translations") | |
| new_segments = [] | |
| i = 0 | |
| while i < len(segments): | |
| s = segments[i] | |
| text = s['text'].strip() | |
| if not text: | |
| i += 1 | |
| continue | |
| if i == len(segments)-1: | |
| new_segments.append(s) | |
| break | |
| next_s = segments[i+1] | |
| next_text = next_s['text'].strip() | |
| if not next_text or (text[-1] in ['.', '?', '!']): | |
| new_segments.append(s) | |
| i += 1 | |
| continue | |
| m = SENTENCE_SPLIT.match(next_s['text'].strip()) | |
| if not m: | |
| LOG.warning("Bad pattern matching on segment [%s], " | |
| "this should not be possible", next_s['text']) | |
| s['end'] = next_s['end'] | |
| s['text'] = '{} {}'.format(s['text'].strip(), next_s['text'].strip()) | |
| new_segments.append(s) | |
| i += 2 | |
| else: | |
| before = m.group(1) | |
| after = m.group(2) | |
| next_segment_duration = next_s['end'] - next_s['start'] | |
| ratio = len(before) / len(next_text) | |
| add_time = int(next_segment_duration * ratio) | |
| s['end'] = s['end'] + add_time | |
| s['text'] = '{} {}'.format(text, before) | |
| next_s['start'] = next_s['start'] + add_time | |
| next_s['text'] = after.strip() | |
| new_segments.append(s) | |
| i += 1 | |
| return new_segments | |
| def translation(transcriptions, translate): | |
| translations_d = [] | |
| if translate: | |
| LOG.info("Performing translation") | |
| start = time.time() | |
| translations = translator([t['text'] for t in transcriptions]) | |
| for i, t in enumerate(transcriptions): | |
| tsl = t.copy() | |
| tsl['text'] = translations[i]['translation_text'].strip() | |
| translations_d.append(tsl) | |
| elapsed = time.time() - start | |
| LOG.info("Translation done, elapsed %.2f seconds", elapsed) | |
| LOG.info('Translations %s', translations_d) | |
| else: | |
| translations_d = transcriptions | |
| return translations_d | |
| def summarize(transcriptions, translate): | |
| LOG.info("Generating video summary") | |
| whole_text = ' '.join([t['text'].strip() for t in transcriptions]) | |
| # word_count = len(whole_text.split()) | |
| summary = summarizer(whole_text) | |
| # min_length=word_count // 4 + 1, | |
| # max_length=word_count // 2 + 1) | |
| summary = translation([{'text': summary[0]['summary_text']}], translate)[0] | |
| return summary['text'] | |
| def segment_slices(subtitles: list[str]): | |
| LOG.info("Building srt segments slices") | |
| slices = [] | |
| for sub in subtitles: | |
| chunks = np.array_split(sub['text'].split(' '), SLICES) | |
| start = sub['start'] | |
| duration = sub['end'] - start | |
| for i in range(0, SLICES): | |
| s = { | |
| 'text': ' '.join(chunks[i]), | |
| 'start': start + i * duration / SLICES, | |
| 'end': start + (i+1) * duration / SLICES | |
| } | |
| slices.append(s) | |
| return slices | |
| def build_srt_clips(segments, basedir): | |
| LOG.info("Generating subtitles") | |
| segments = segment_slices(segments) | |
| LOG.info("Building srt clips") | |
| max_text_len = 45 | |
| subtitles = pysrt.SubRipFile() | |
| for segment in segments: | |
| start = segment['start'] | |
| end = segment['end'] | |
| text = segment['text'] | |
| text = text.strip() | |
| if len(text) < max_text_len: | |
| o = pysrt.SubRipItem() | |
| o.start = pysrt.SubRipTime(0, 0, 0, start) | |
| o.end = pysrt.SubRipTime(0, 0, 0, end) | |
| o.text = text | |
| subtitles.append(o) | |
| else: | |
| # Just split in two, should be ok in most cases | |
| words = text.split() | |
| o = pysrt.SubRipItem() | |
| o.text = ' '.join(words[0:len(words)//2]) | |
| o.start = pysrt.SubRipTime(0, 0, 0, start) | |
| chkpt = (start + end) / 2 | |
| o.end = pysrt.SubRipTime(0, 0, 0, chkpt) | |
| subtitles.append(o) | |
| o = pysrt.SubRipItem() | |
| o.text = ' '.join(words[len(words)//2:]) | |
| o.start = pysrt.SubRipTime(0, 0, 0, chkpt) | |
| o.end = pysrt.SubRipTime(0, 0, 0, end) | |
| subtitles.append(o) | |
| srt_path = os.path.join(basedir, 'video.srt') | |
| subtitles.save(srt_path, encoding='utf-8') | |
| LOG.info("Subtitles saved in srt file %s", srt_path) | |
| return srt_path | |
| iface = gr.Interface( | |
| fn=demo, | |
| inputs=[ | |
| gr.Text(value="https://youtu.be/tiZFewofSLM", label="English video url"), | |
| gr.Checkbox(value=True, label='Translate to French')], | |
| outputs=[ | |
| gr.Text(label="Video summary"), | |
| gr.File(label="SRT file"), | |
| gr.Video(label="Video with subtitles"), | |
| gr.Text(label="Full transcription") | |
| ]) | |
| # iface.launch(server_name="0.0.0.0", server_port=6443) | |
| iface.launch() | |