| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import glob |
| | import time |
| | import random |
| | import os |
| | import tempfile |
| | from collections import defaultdict |
| | from io import BytesIO |
| | from typing import Any, Dict, List, Optional, Union |
| | import io |
| | import cv2 |
| | import kaldiio |
| | import librosa |
| | import soundfile as sf |
| | import torch |
| | import numpy as np |
| | import PIL |
| | import PIL.Image |
| | import requests |
| | import tarfile |
| | import whisper |
| | import decord |
| | from decord import AudioReader, cpu |
| |
|
| | from transformers import PretrainedConfig |
| |
|
| | MEDIA_TOKENS = { |
| | "image": "<image>", |
| | "video": "<vila/video>", |
| | "speech": "<speech>", |
| | "sound": "<sound>", |
| | } |
| |
|
| |
|
| | class Media: |
| | """Base class for media objects.""" |
| | pass |
| |
|
| |
|
| | class File(Media): |
| | """File-based media object.""" |
| | def __init__(self, path: str) -> None: |
| | self.path = path |
| |
|
| |
|
| | class Image(File): |
| | """Image media object.""" |
| | pass |
| |
|
| |
|
| | class Video(File): |
| | """Video media object.""" |
| | pass |
| |
|
| |
|
| | class Speech(File): |
| | """Speech audio media object.""" |
| | def __init__(self, path, extension: str = None) -> None: |
| | self.path = path |
| | self.extension = extension |
| |
|
| |
|
| | class Sound(File): |
| | """Sound/music audio media object.""" |
| | def __init__(self, path, extension: str = None) -> None: |
| | self.path = path |
| | self.extension = extension |
| |
|
| |
|
| | def make_list(obj: Any) -> List: |
| | """Convert object to list if not already a list.""" |
| | return obj if isinstance(obj, list) else [obj] |
| |
|
| |
|
| | def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: |
| | """Extract PIL Image from Image object or return PIL Image as-is.""" |
| | if isinstance(image, Image): |
| | if image.path.startswith("http://") or image.path.startswith("https://"): |
| | image = PIL.Image.open(requests.get(image.path, stream=True).raw) |
| | else: |
| | image = PIL.Image.open(image.path) |
| | return image |
| |
|
| |
|
| | def _load_video_bytesio( |
| | video_bytesio: BytesIO, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False |
| | ) -> List[PIL.Image.Image]: |
| | """Load video from BytesIO object by writing to temporary file.""" |
| | with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: |
| | temp_video.write(video_bytesio.read()) |
| | temp_video_name = temp_video.name |
| | return _load_video(temp_video_name, num_frames=num_frames, load_aud=load_aud, config=config) |
| |
|
| | def get_overlap(inp1, inp2): |
| | """ |
| | Calculates the overlapping time frame between a video clip and an audio segment. |
| | |
| | Args: |
| | inp1 (list): [start_sec, end_sec] |
| | inp2 (list): [start_sec, end_sec] |
| | |
| | Returns: |
| | tuple or None: (overlap_start, overlap_end) if overlap exists, else None. |
| | """ |
| | |
| | overlap_start = max(inp1[0], inp2[0]) |
| | overlap_end = min(inp1[1], inp2[1]) |
| |
|
| | |
| | if overlap_start < overlap_end: |
| | return (overlap_start, overlap_end) |
| | else: |
| | return None |
| |
|
| |
|
| | def _load_video( |
| | video_path: str, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False |
| | ) -> List[PIL.Image.Image]: |
| | |
| | if os.path.isdir(video_path): |
| | frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) |
| | indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) |
| | return [PIL.Image.open(frame_paths[index]) for index in indices] |
| |
|
| | |
| | vidcap = cv2.VideoCapture(video_path) |
| |
|
| | |
| | audio_info = None |
| | if load_aud: |
| | try: |
| | aud_feature, audio_info = _load_speech(video_path, config) |
| | except Exception as e: |
| | aud_feature = None |
| | else: |
| | aud_feature = None |
| |
|
| | |
| | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | while frame_count > 0: |
| | vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) |
| | if vidcap.grab(): |
| | break |
| | frame_count -= 1 |
| | else: |
| | raise ValueError(f"Video '{video_path}' has no frames.") |
| |
|
| | |
| | indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) |
| |
|
| | fps = vidcap.get(cv2.CAP_PROP_FPS) |
| | video_duration = frame_count / fps |
| |
|
| | |
| | if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None: |
| | segment_duration = config.interleaved_video_segment_duration |
| | if segment_duration == -1: |
| | raise ValueError("video_segment_duration is not set") |
| |
|
| | segment_vis_indices_list = [] |
| | segment_aud_indices_list = [] |
| | segment_counts = np.ceil(video_duration / segment_duration).astype(int) |
| |
|
| | if type(aud_feature) == dict: |
| | aud_feas = aud_feature["input_features"] |
| | else: |
| | aud_feas = aud_feature |
| | audio_start_sec = audio_info['audio_start_sec'] |
| | audio_end_sec = audio_info['audio_end_sample_sec'] |
| |
|
| | stft_frames_per_second = config.audio_sampling_rate // config.audio_hop_length |
| |
|
| | _idx = 0 |
| | aud_sample_start_idx = 0 |
| | for i in range(segment_counts): |
| | end_frame = min((i+1) * segment_duration * fps, frame_count) |
| |
|
| | _indices = [] |
| | while _idx < len(indices) and indices[_idx] < end_frame and _idx < len(indices): |
| | _indices.append(indices[_idx]) |
| | _idx += 1 |
| | segment_vis_indices_list.append(_indices) |
| | clip_start_sec = i * segment_duration |
| | clip_end_sec = min(clip_start_sec + segment_duration, video_duration) |
| |
|
| | |
| | overlap = get_overlap([clip_start_sec, clip_end_sec], [audio_start_sec, audio_end_sec]) |
| | if overlap is not None: |
| | aud_sample_end_idx = round((overlap[1] - audio_start_sec) * stft_frames_per_second) |
| | segment_aud_indices_list.append([aud_sample_start_idx, aud_sample_end_idx]) |
| | aud_sample_start_idx = aud_sample_end_idx |
| | else: |
| | segment_aud_indices_list.append([]) |
| | frames = {} |
| | frame_times = {} |
| | for index in indices: |
| | if index in frames: |
| | continue |
| | vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) |
| | success, frame = vidcap.read() |
| | if not success: |
| | print(f"Failed to read frame {index} from video '{video_path}'. Skipped.") |
| | continue |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames[index] = PIL.Image.fromarray(frame) |
| | frame_times[index] = index / fps |
| |
|
| | output_frames = [frames[index] for index in indices if index in frames] |
| | output_frame_times = [frame_times[index] for index in indices if index in frame_times] |
| |
|
| | video_info = {} |
| | if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None: |
| | new_segment_vis_indices_list = [] |
| | processed_frame_index = 0 |
| | for i, segment_indices in enumerate(segment_vis_indices_list): |
| | new_segment_vis_indices_list.append([]) |
| | for index in segment_indices: |
| | if index in frames: |
| | new_segment_vis_indices_list[-1].append(processed_frame_index) |
| | processed_frame_index += 1 |
| | segment_vis_indices_list = new_segment_vis_indices_list |
| |
|
| | video_info["segment_vis_indices_list"] = segment_vis_indices_list |
| | video_info["segment_aud_indices_list"] = segment_aud_indices_list |
| | video_info['expected_frame_count'] = len(indices) |
| | video_info['video_path'] = video_path |
| | if audio_info is not None: |
| | audio_info['video_path'] = video_path |
| | video_info['has_audio'] = aud_feature is not None |
| | video_info['video_duration'] = video_duration |
| | video_info['audio_info'] = audio_info |
| |
|
| | |
| | video_info['video_frame_times'] = output_frame_times |
| |
|
| | return output_frames, aud_feature, video_info |
| |
|
| |
|
| | def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: |
| | num_frames = config.num_video_frames |
| | aud_fea = None |
| |
|
| | if getattr(config, "fps") != 0: |
| | print("Extracting frames from video with specified FPS is not supported yet. Ignored.") |
| |
|
| | if isinstance(video.path, BytesIO): |
| | frames, aud_fea, video_info = _load_video_bytesio( |
| | video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video |
| | ) |
| | else: |
| | frames, aud_fea, video_info = _load_video( |
| | video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video |
| | ) |
| |
|
| | if config.load_audio_in_video: |
| | return frames, aud_fea, video_info |
| | else: |
| | return frames, video_info |
| |
|
| |
|
| | def soundFile_read_audio(audio_file, offset=None, duration=None, dtype='float32'): |
| | if dtype not in ['int32', 'float32']: |
| | print("audio dtype must be int32 or float32. Default to float32") |
| | dtype = 'float32' |
| | |
| | if isinstance(audio_file, bytes): |
| | audio_file = io.BytesIO(audio_file) |
| | with sf.SoundFile(audio_file, 'r') as f: |
| | sample_rate = f.samplerate |
| | if offset is not None and offset > 0: |
| | f.seek(int(offset * sample_rate)) |
| | if duration is not None and duration > 0: |
| | samples = f.read(int(duration * sample_rate), dtype=dtype) |
| | else: |
| | samples = f.read(dtype=dtype) |
| | return samples, sample_rate |
| |
|
| | def load_audio_from_tar(tar_file, audio_file): |
| | with tarfile.open(tar_file, 'r') as tar: |
| | audio_member = tar.getmember(audio_file) |
| | audio_file = tar.extractfile(audio_member) |
| | return librosa.load(audio_file) |
| |
|
| | def _load_audio_file(audio_path: str, config: PretrainedConfig): |
| | |
| | if audio_path is None: |
| | return None |
| |
|
| | dirname = os.path.dirname(audio_path) |
| | filename = os.path.basename(audio_path) |
| |
|
| | if dirname.endswith(".tar"): |
| | speech, sample_rate = load_audio_from_tar(dirname, filename) |
| | else: |
| | sample_rate = config.audio_sampling_rate |
| | speech = whisper.load_audio(audio_path, sr=sample_rate) |
| |
|
| | return speech, sample_rate |
| |
|
| |
|
| | def _load_audio(audio: Union[str, dict], config: PretrainedConfig): |
| | if isinstance(audio, str): |
| | return _load_audio_file(audio, config) |
| | elif isinstance(audio, dict): |
| | audio_sample = audio['sample'] |
| | if isinstance(audio_sample, (bytes, io.BytesIO)): |
| | offset = audio.get('offset', None) |
| | duration = audio.get('duration', None) |
| | dtype = audio.get('dtype', 'float32') |
| | return soundFile_read_audio( |
| | audio_sample, offset=offset, duration=duration, dtype=dtype |
| | ) |
| | elif isinstance(audio_sample, np.ndarray): |
| | return audio_sample, audio.get('sample_rate') |
| | else: |
| | raise ValueError(f"Expect the loaded audio to be a processed numpy array or raw bytes. Got {type(audio_sample)}") |
| | else: |
| | raise ValueError(f"Expect input to be a path string or dict. Got {type(audio)}") |
| |
|
| | def _whisper_process(audio, sample_rate, audio_chunk_length, max_chunks_per_file): |
| | outputs = [] |
| | num_audio_chunks = 0 |
| |
|
| | chunk_length = audio_chunk_length * sample_rate |
| | for i in range(0, len(audio), chunk_length): |
| | chunk = audio[i : i + chunk_length] |
| | chunk = whisper.pad_or_trim(chunk) |
| | if chunk.dtype != np.float32: |
| | chunk = chunk.astype(np.float32) |
| | mel = whisper.log_mel_spectrogram(chunk, n_mels=128) |
| | num_audio_chunks+=1 |
| | outputs.append(mel) |
| | if num_audio_chunks == max_chunks_per_file: |
| | break |
| |
|
| | frames = torch.stack(outputs, dim=0) |
| | return frames.numpy().tolist() |
| |
|
| | def _load_speech(speech, config: PretrainedConfig): |
| | if type(speech) == str: |
| | speech_path = speech |
| | else: |
| | speech_path = speech.path |
| |
|
| | |
| | if speech_path is None: |
| | return None |
| | speech_outputs = [] |
| |
|
| | if config.audio_chunk_length and not (type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length): |
| | try: |
| | config.audio_chunk_length = int(config.audio_chunk_length) |
| | except Exception as e: |
| | print(f"Error setting audio_chunk_length: {e}") |
| | raise e |
| |
|
| | audio_n_samples_limit = config.audio_chunk_length * config.audio_sampling_rate |
| |
|
| | def load_wav(speech_path): |
| | speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate) |
| | cur_max_length = speech.shape[0] |
| | ori_audio_duration = cur_max_length / sr |
| | return speech, ori_audio_duration |
| |
|
| | def get_audio(speech, audio_n_samples): |
| |
|
| | if type(speech) == decord.audio_reader.AudioReader: |
| | ori_n_samples = speech.shape[1] |
| | else: |
| | ori_n_samples = speech.shape[0] |
| |
|
| | |
| | audio_start_sample_id = 0 |
| | audio_end_sample_id = ori_n_samples |
| |
|
| |
|
| | load_max_audio = type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length |
| | if hasattr(config, 'random_audio_sample') and not load_max_audio: |
| | if ori_n_samples > audio_n_samples: |
| | audio_start_sample_id = random.randint(0, ori_n_samples - audio_n_samples) |
| | audio_end_sample_id = audio_start_sample_id + audio_n_samples |
| | else: |
| | if load_max_audio: |
| | if "_" in config.audio_chunk_length: |
| | max_audio_chunk_length = int(config.audio_chunk_length.split("_")[1]) |
| | max_audio_n_samples = max_audio_chunk_length * config.audio_sampling_rate |
| | audio_n_samples = min(ori_n_samples, max_audio_n_samples) |
| | audio_end_sample_id = audio_n_samples |
| | else: |
| | audio_n_samples = ori_n_samples |
| | audio_end_sample_id = audio_n_samples |
| | else: |
| | audio_end_sample_id = min(audio_n_samples, ori_n_samples) |
| |
|
| | if type(speech) == decord.audio_reader.AudioReader: |
| | speech = speech[audio_start_sample_id:audio_end_sample_id].asnumpy()[0] |
| | else: |
| | speech = speech[audio_start_sample_id:audio_end_sample_id] |
| |
|
| |
|
| | return speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id |
| |
|
| | if isinstance(speech_path, dict): |
| | if "offset" in speech_path: |
| | speech, ori_sample_rate = _load_audio(speech_path, config) |
| |
|
| | else: |
| | speech = speech_path["sample"] |
| | ori_sample_rate = speech_path["sample_rate"] |
| |
|
| | |
| | speech = librosa.resample(speech, orig_sr=ori_sample_rate, target_sr=config.audio_sampling_rate) |
| | |
| | ori_audio_duration = speech.shape[0] / config.audio_sampling_rate |
| | speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
| |
|
| | elif isinstance(speech_path, BytesIO): |
| | if speech.extension == ".wav": |
| | |
| | |
| | speech, ori_audio_duration = load_wav(speech_path) |
| | speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
| | else: |
| | raise ValueError(f"Unsupported audio extension: {speech.extension}") |
| |
|
| | elif ".mat" in speech_path or ".ark" in speech_path: |
| | rate, speech = kaldiio.load_mat(speech_path) |
| | speech = librosa.resample(speech, orig_sr=rate, target_sr=config.audio_sampling_rate) |
| | speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
| | ori_audio_duration = speech.shape[0] / config.audio_sampling_rate |
| | elif ".mp4" in speech_path: |
| | |
| | ar = AudioReader(speech_path, ctx=cpu(0), sample_rate=config.audio_sampling_rate, mono=True) |
| | cur_max_length = ar.shape[1] |
| | ori_audio_duration = cur_max_length / config.audio_sampling_rate |
| | speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(ar, audio_n_samples_limit) |
| | else: |
| | assert os.path.exists(speech_path), f"File {speech_path} does not exist" |
| | speech, ori_audio_duration = load_wav(speech_path) |
| | speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
| |
|
| | |
| | speech = speech.astype(np.float32) |
| | audio_n_samples = int(np.ceil(speech.shape[0] / (config.audio_sampling_rate * 30)) * (config.audio_sampling_rate * 30)) |
| |
|
| | speech = whisper.pad_or_trim(speech, length=audio_n_samples) |
| |
|
| | new_audio_chunk_length = int(audio_n_samples // config.audio_sampling_rate) |
| | audio_start_sec = audio_start_sample_id / config.audio_sampling_rate |
| | audio_end_sample_sec = audio_end_sample_id / config.audio_sampling_rate |
| |
|
| | audio_info = {} |
| | audio_info['new_audio_chunk_length'] = new_audio_chunk_length |
| | audio_info['new_audio_n_samples'] = audio_n_samples |
| | audio_info['ori_audio_duration'] = ori_audio_duration |
| | audio_info['audio_start_sec'] = audio_start_sec |
| | audio_info['audio_end_sample_sec'] = audio_end_sample_sec |
| |
|
| | return speech, audio_info |
| |
|
| | def _extract_speech(speech: Speech, config: PretrainedConfig): |
| | frames, audio_info = _load_speech(speech, config) |
| | return frames, audio_info |
| |
|
| | _extract_sound = _extract_speech |
| | def extract_media( |
| | messages: List[Dict[str, Any]], |
| | config: Optional[PretrainedConfig] = None, |
| | draft: bool = False, |
| | ) -> Dict[str, List[Any]]: |
| | media = defaultdict(list) |
| |
|
| | if not hasattr(config, "load_audio_in_video"): |
| | print(f"Warning: load_audio_in_video not in config, set to False") |
| | config.load_audio_in_video = False |
| |
|
| | for message in messages: |
| | text = "" |
| | for part in make_list(message["value"]): |
| | if isinstance(part, str): |
| | for token in MEDIA_TOKENS.values(): |
| | if token in part: |
| | print(f"Media token '{token}' found in text: '{part}'. Removed.") |
| | part = part.replace(token, "").strip() |
| | text += part |
| | elif isinstance(part, (Image, PIL.Image.Image)): |
| | if draft: |
| | media["image"].append(part) |
| | else: |
| | media["image"].append(_extract_image(part)) |
| | text += MEDIA_TOKENS["image"] |
| | elif isinstance(part, Video): |
| | if draft: |
| | media["video"].append(part) |
| | else: |
| | if config.load_audio_in_video: |
| | output, aud_fea, video_info = _extract_video(part, config) |
| | media["video"].append(output) |
| | media["video_info"].append(video_info) |
| | if aud_fea is not None: |
| | media["sound"].append(aud_fea) |
| | media["audio_info"].append(video_info['audio_info']) |
| | text += MEDIA_TOKENS["sound"] |
| | else: |
| | output, video_info = _extract_video(part, config) |
| | media["video"].append(output) |
| | media["video_info"].append(video_info) |
| | text += MEDIA_TOKENS["video"] |
| | elif isinstance(part, Speech): |
| | if draft: |
| | if config.unified_audio_encoder: |
| | media["sound"].append(part) |
| | text += MEDIA_TOKENS["sound"] |
| | else: |
| | media["speech"].append(part) |
| | text += MEDIA_TOKENS["speech"] |
| | else: |
| | output, audio_info = _extract_speech(part, config) |
| | if output is not None: |
| | if config.unified_audio_encoder: |
| | media["sound"].append(output) |
| | text += MEDIA_TOKENS["sound"] |
| | else: |
| | media["speech"].append(output) |
| | text += MEDIA_TOKENS["speech"] |
| | media["audio_info"].append(audio_info) |
| | elif isinstance(part, Sound): |
| | if draft: |
| | media["sound"].append(part) |
| | text += MEDIA_TOKENS["sound"] |
| | else: |
| | output, audio_info = _extract_sound(part, config) |
| | if output is not None: |
| | media["sound"].append(output) |
| | media["audio_info"].append(audio_info) |
| | text += MEDIA_TOKENS["sound"] |
| | else: |
| | print(f"part: {part}") |
| | raise ValueError(f"Unsupported prompt part type: {type(part)}") |
| | message["value"] = text |
| | return media |
| |
|