| 
							 | 
						import spaces | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import gradio as gr | 
					
					
						
						| 
							 | 
						import gc | 
					
					
						
						| 
							 | 
						import soundfile as sf | 
					
					
						
						| 
							 | 
						import shutil | 
					
					
						
						| 
							 | 
						import argparse | 
					
					
						
						| 
							 | 
						from moviepy.tools import verbose_print | 
					
					
						
						| 
							 | 
						from omegaconf import OmegaConf | 
					
					
						
						| 
							 | 
						import random | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import json  | 
					
					
						
						| 
							 | 
						import librosa | 
					
					
						
						| 
							 | 
						import emage.mertic | 
					
					
						
						| 
							 | 
						from datetime import datetime | 
					
					
						
						| 
							 | 
						from decord import VideoReader | 
					
					
						
						| 
							 | 
						from PIL import Image | 
					
					
						
						| 
							 | 
						import copy | 
					
					
						
						| 
							 | 
						import cv2 | 
					
					
						
						| 
							 | 
						import subprocess | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import importlib | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						from torch.optim import AdamW | 
					
					
						
						| 
							 | 
						from torch.utils.data import DataLoader | 
					
					
						
						| 
							 | 
						from torch.nn.parallel import DistributedDataParallel as DDP | 
					
					
						
						| 
							 | 
						from tqdm import tqdm | 
					
					
						
						| 
							 | 
						import smplx | 
					
					
						
						| 
							 | 
						from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip | 
					
					
						
						| 
							 | 
						import igraph | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import utils.rotation_conversions as rc | 
					
					
						
						| 
							 | 
						from utils.video_io import save_videos_from_pil | 
					
					
						
						| 
							 | 
						from utils.genextend_inference_utils import adjust_statistics_to_match_reference | 
					
					
						
						| 
							 | 
						from create_graph import path_visualization, graph_pruning, get_motion_reps_tensor, path_visualization_v2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def search_path_dp(graph, audio_low_np, audio_high_np, loop_penalty=0.1, top_k=1, search_mode="both", continue_penalty=0.1): | 
					
					
						
						| 
							 | 
						    T = audio_low_np.shape[0]   | 
					
					
						
						| 
							 | 
						    N = len(graph.vs)           | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    min_cost = [{} for _ in range(T)]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1] | 
					
					
						
						| 
							 | 
						    for node in start_nodes: | 
					
					
						
						| 
							 | 
						        node_index = node.index | 
					
					
						
						| 
							 | 
						        motion_low = node['motion_low']       | 
					
					
						
						| 
							 | 
						        motion_high = node['motion_high']     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if search_mode == "both": | 
					
					
						
						| 
							 | 
						            cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T)) | 
					
					
						
						| 
							 | 
						        elif search_mode == "high_level": | 
					
					
						
						| 
							 | 
						            cost = 1 - np.dot(audio_high_np[0], motion_high.T) | 
					
					
						
						| 
							 | 
						        elif search_mode == "low_level": | 
					
					
						
						| 
							 | 
						            cost = 1 - np.dot(audio_low_np[0], motion_low.T) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        visited_nodes = {node_index: 1}   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        min_cost[0][node_index] = [ (cost, None, None, 0, visited_nodes) ]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for t in range(1, T): | 
					
					
						
						| 
							 | 
						        for node in graph.vs: | 
					
					
						
						| 
							 | 
						            node_index = node.index | 
					
					
						
						| 
							 | 
						            candidates = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            incoming_edges = graph.es.select(_to=node_index) | 
					
					
						
						| 
							 | 
						            for edge in incoming_edges: | 
					
					
						
						| 
							 | 
						                prev_node_index = edge.source | 
					
					
						
						| 
							 | 
						                edge_id = edge.index | 
					
					
						
						| 
							 | 
						                is_continue_edge = graph.es[edge_id]['is_continue'] | 
					
					
						
						| 
							 | 
						                prev_node = graph.vs[prev_node_index] | 
					
					
						
						| 
							 | 
						                if prev_node_index in min_cost[t-1]: | 
					
					
						
						| 
							 | 
						                    for tuple_index, (prev_cost, _, _, prev_non_continue_count, prev_visited) in enumerate(min_cost[t-1][prev_node_index]): | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        if node_index in prev_visited: | 
					
					
						
						| 
							 | 
						                            loop_time = prev_visited[node_index]   | 
					
					
						
						| 
							 | 
						                            loop_cost = prev_cost + loop_penalty * np.exp(loop_time)   | 
					
					
						
						| 
							 | 
						                            new_visited = prev_visited.copy() | 
					
					
						
						| 
							 | 
						                            new_visited[node_index] = loop_time + 1   | 
					
					
						
						| 
							 | 
						                        else: | 
					
					
						
						| 
							 | 
						                            loop_cost = prev_cost | 
					
					
						
						| 
							 | 
						                            new_visited = prev_visited.copy() | 
					
					
						
						| 
							 | 
						                            new_visited[node_index] = 1   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        motion_low = node['motion_low']       | 
					
					
						
						| 
							 | 
						                        motion_high = node['motion_high']     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        if search_mode == "both": | 
					
					
						
						| 
							 | 
						                            cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T)) | 
					
					
						
						| 
							 | 
						                        elif search_mode == "high_level": | 
					
					
						
						| 
							 | 
						                            cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T) | 
					
					
						
						| 
							 | 
						                        elif search_mode == "low_level": | 
					
					
						
						| 
							 | 
						                            cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        if not is_continue_edge: | 
					
					
						
						| 
							 | 
						                            non_continue_count = prev_non_continue_count + 1   | 
					
					
						
						| 
							 | 
						                        else: | 
					
					
						
						| 
							 | 
						                            non_continue_count = prev_non_continue_count | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        continue_penalty_cost = continue_penalty * non_continue_count | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        total_cost = loop_cost + cost_increment + continue_penalty_cost | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        candidates.append( (total_cost, prev_node_index, tuple_index, non_continue_count, new_visited) ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if candidates: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                candidates.sort(key=lambda x: x[0]) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                min_cost[t][node_index] = candidates[:top_k] | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    end_candidates = [] | 
					
					
						
						| 
							 | 
						    for node_index, tuples in min_cost[T-1].items(): | 
					
					
						
						| 
							 | 
						        for tuple_index, (cost, _, _, _, _) in enumerate(tuples): | 
					
					
						
						| 
							 | 
						            end_candidates.append( (cost, node_index, tuple_index) ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not end_candidates: | 
					
					
						
						| 
							 | 
						        print("No valid path found.") | 
					
					
						
						| 
							 | 
						        return [], [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    end_candidates.sort(key=lambda x: x[0]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    top_k_paths_info = end_candidates[:top_k] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    optimal_paths = [] | 
					
					
						
						| 
							 | 
						    is_continue_lists = [] | 
					
					
						
						| 
							 | 
						    for final_cost, node_index, tuple_index in top_k_paths_info: | 
					
					
						
						| 
							 | 
						        optimal_path_indices = [] | 
					
					
						
						| 
							 | 
						        current_node_index = node_index | 
					
					
						
						| 
							 | 
						        current_tuple_index = tuple_index | 
					
					
						
						| 
							 | 
						        for t in range(T-1, -1, -1): | 
					
					
						
						| 
							 | 
						            optimal_path_indices.append(current_node_index) | 
					
					
						
						| 
							 | 
						            tuple_data = min_cost[t][current_node_index][current_tuple_index] | 
					
					
						
						| 
							 | 
						            _, prev_node_index, prev_tuple_index, _, _ = tuple_data | 
					
					
						
						| 
							 | 
						            current_node_index = prev_node_index | 
					
					
						
						| 
							 | 
						            current_tuple_index = prev_tuple_index | 
					
					
						
						| 
							 | 
						            if current_node_index is None: | 
					
					
						
						| 
							 | 
						                break   | 
					
					
						
						| 
							 | 
						        optimal_path_indices = optimal_path_indices[::-1]   | 
					
					
						
						| 
							 | 
						        optimal_path = [graph.vs[idx] for idx in optimal_path_indices] | 
					
					
						
						| 
							 | 
						        optimal_paths.append(optimal_path) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        is_continue = [] | 
					
					
						
						| 
							 | 
						        for i in range(len(optimal_path) - 1): | 
					
					
						
						| 
							 | 
						            edge_id = graph.get_eid(optimal_path[i].index, optimal_path[i + 1].index) | 
					
					
						
						| 
							 | 
						            is_cont = graph.es[edge_id]['is_continue'] | 
					
					
						
						| 
							 | 
						            is_continue.append(is_cont) | 
					
					
						
						| 
							 | 
						        is_continue_lists.append(is_continue) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("Top {} Paths:".format(len(optimal_paths))) | 
					
					
						
						| 
							 | 
						    for i, path in enumerate(optimal_paths): | 
					
					
						
						| 
							 | 
						        path_indices = [node.index for node in path] | 
					
					
						
						| 
							 | 
						        print("Path {}: Cost: {}, Nodes: {}".format(i+1, top_k_paths_info[i][0], path_indices)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return optimal_paths, is_continue_lists | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def test_fn(model, device, iteration, candidate_json_path, test_path, cfg, audio_path, **kwargs): | 
					
					
						
						| 
							 | 
						    create_graph = kwargs["create_graph"] | 
					
					
						
						| 
							 | 
						    torch.set_grad_enabled(False) | 
					
					
						
						| 
							 | 
						    pool_path = candidate_json_path.replace("data_json", "cached_graph").replace(".json", ".pkl") | 
					
					
						
						| 
							 | 
						    graph = igraph.Graph.Read_Pickle(fname=pool_path) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    save_dir = os.path.join(test_path, f"retrieved_motions_{iteration}") | 
					
					
						
						| 
							 | 
						    os.makedirs(save_dir, exist_ok=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model | 
					
					
						
						| 
							 | 
						    actual_model.eval() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    all_motions = {} | 
					
					
						
						| 
							 | 
						    for i, node in enumerate(graph.vs): | 
					
					
						
						| 
							 | 
						        if all_motions.get(node["name"]) is None: | 
					
					
						
						| 
							 | 
						            all_motions[node["name"]] = [node["axis_angle"].reshape(-1)] | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            all_motions[node["name"]].append(node["axis_angle"].reshape(-1)) | 
					
					
						
						| 
							 | 
						    for k, v in all_motions.items(): | 
					
					
						
						| 
							 | 
						        all_motions[k] = np.stack(v)  | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    window_size = cfg.data.pose_length | 
					
					
						
						| 
							 | 
						    motion_high_all = [] | 
					
					
						
						| 
							 | 
						    motion_low_all = [] | 
					
					
						
						| 
							 | 
						    for k, v in all_motions.items(): | 
					
					
						
						| 
							 | 
						        motion_tensor = torch.from_numpy(v).float().to(device).unsqueeze(0) | 
					
					
						
						| 
							 | 
						        _, t, _ = motion_tensor.shape | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if t >= window_size: | 
					
					
						
						| 
							 | 
						            num_chunks = t // window_size | 
					
					
						
						| 
							 | 
						            motion_high_list = [] | 
					
					
						
						| 
							 | 
						            motion_low_list = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            for i in range(num_chunks): | 
					
					
						
						| 
							 | 
						                start_idx = i * window_size | 
					
					
						
						| 
							 | 
						                end_idx = start_idx + window_size | 
					
					
						
						| 
							 | 
						                motion_slice = motion_tensor[:, start_idx:end_idx, :] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                motion_features = actual_model.get_motion_features(motion_slice) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                motion_low = motion_features["motion_low"].cpu().numpy() | 
					
					
						
						| 
							 | 
						                motion_high = motion_features["motion_cls"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                motion_high_list.append(motion_high[0]) | 
					
					
						
						| 
							 | 
						                motion_low_list.append(motion_low[0]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            remain_length = t % window_size | 
					
					
						
						| 
							 | 
						            if remain_length > 0: | 
					
					
						
						| 
							 | 
						                start_idx = t - window_size | 
					
					
						
						| 
							 | 
						                motion_slice = motion_tensor[:, start_idx:, :] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                motion_features = actual_model.get_motion_features(motion_slice) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                motion_low = motion_features["motion_low"].cpu().numpy() | 
					
					
						
						| 
							 | 
						                motion_high = motion_features["motion_cls"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                motion_high_list.append(motion_high[0][-remain_length:]) | 
					
					
						
						| 
							 | 
						                motion_low_list.append(motion_low[0][-remain_length:]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            motion_high_all.append(np.concatenate(motion_high_list, axis=0)) | 
					
					
						
						| 
							 | 
						            motion_low_all.append(np.concatenate(motion_low_list, axis=0)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        else:  | 
					
					
						
						| 
							 | 
						            gap = window_size - t | 
					
					
						
						| 
							 | 
						            motion_slice = torch.cat([motion_tensor, torch.zeros((motion_tensor.shape[0], gap, motion_tensor.shape[2])).to(motion_tensor.device)], 1) | 
					
					
						
						| 
							 | 
						            motion_features = actual_model.get_motion_features(motion_slice) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            motion_low = motion_features["motion_low"].cpu().numpy() | 
					
					
						
						| 
							 | 
						            motion_high = motion_features["motion_cls"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            motion_high_all.append(motion_high[0][:t]) | 
					
					
						
						| 
							 | 
						            motion_low_all.append(motion_low[0][:t]) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						    motion_high_all = np.concatenate(motion_high_all, axis=0) | 
					
					
						
						| 
							 | 
						    motion_low_all = np.concatenate(motion_low_all, axis=0) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    motion_low_all = motion_low_all / np.linalg.norm(motion_low_all, axis=1, keepdims=True) | 
					
					
						
						| 
							 | 
						    motion_high_all = motion_high_all / np.linalg.norm(motion_high_all, axis=1, keepdims=True) | 
					
					
						
						| 
							 | 
						    assert motion_high_all.shape[0] == len(graph.vs) | 
					
					
						
						| 
							 | 
						    assert motion_low_all.shape[0] == len(graph.vs) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for i, node in enumerate(graph.vs): | 
					
					
						
						| 
							 | 
						        node["motion_high"] = motion_high_all[i] | 
					
					
						
						| 
							 | 
						        node["motion_low"] = motion_low_all[i] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    graph = graph_pruning(graph) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if len(graph.vs) > 1800: | 
					
					
						
						| 
							 | 
						        gap = len(graph.vs) - 1800 | 
					
					
						
						| 
							 | 
						        start_d = random.randint(0, 1800) | 
					
					
						
						| 
							 | 
						        graph.delete_vertices(range(start_d, start_d + gap)) | 
					
					
						
						| 
							 | 
						    ascc_2 = graph.clusters(mode="STRONG") | 
					
					
						
						| 
							 | 
						    graph = ascc_2.giant() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    idx = 0 | 
					
					
						
						| 
							 | 
						    audio_waveform, sr = librosa.load(audio_path) | 
					
					
						
						| 
							 | 
						    audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr) | 
					
					
						
						| 
							 | 
						    audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    target_length = audio_tensor.shape[1] // cfg.data.audio_sr * 30 | 
					
					
						
						| 
							 | 
						    window_size = int(cfg.data.audio_sr * (cfg.data.pose_length / 30)) | 
					
					
						
						| 
							 | 
						    _, t = audio_tensor.shape | 
					
					
						
						| 
							 | 
						    audio_low_list = [] | 
					
					
						
						| 
							 | 
						    audio_high_list = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if t >= window_size: | 
					
					
						
						| 
							 | 
						        num_chunks = t // window_size | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        for i in range(num_chunks): | 
					
					
						
						| 
							 | 
						            start_idx = i * window_size | 
					
					
						
						| 
							 | 
						            end_idx = start_idx + window_size | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            audio_slice = audio_tensor[:, start_idx:end_idx] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            model_out_candidates = actual_model.get_audio_features(audio_slice) | 
					
					
						
						| 
							 | 
						            audio_low = model_out_candidates["audio_low"] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            audio_high = model_out_candidates["audio_cls"].unsqueeze(0).repeat(1, audio_low.shape[1], 1) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            audio_low = F.normalize(audio_low, dim=2)[0].cpu().numpy() | 
					
					
						
						| 
							 | 
						            audio_high = F.normalize(audio_high, dim=2)[0].cpu().numpy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            audio_low_list.append(audio_low) | 
					
					
						
						| 
							 | 
						            audio_high_list.append(audio_high) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        remain_length = t % window_size | 
					
					
						
						| 
							 | 
						        if remain_length > 1: | 
					
					
						
						| 
							 | 
						            start_idx = t - window_size | 
					
					
						
						| 
							 | 
						            audio_slice = audio_tensor[:, start_idx:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            model_out_candidates = actual_model.get_audio_features(audio_slice) | 
					
					
						
						| 
							 | 
						            audio_low = model_out_candidates["audio_low"] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            audio_high = model_out_candidates["audio_cls"].unsqueeze(0).repeat(1, audio_low.shape[1], 1) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1] | 
					
					
						
						| 
							 | 
						            audio_low = F.normalize(audio_low, dim=2)[0][-gap:].cpu().numpy() | 
					
					
						
						| 
							 | 
						            audio_high = F.normalize(audio_high, dim=2)[0][-gap:].cpu().numpy() | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            audio_low_list.append(audio_low) | 
					
					
						
						| 
							 | 
						            audio_high_list.append(audio_high) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        gap = window_size - t | 
					
					
						
						| 
							 | 
						        audio_slice = audio_tensor  | 
					
					
						
						| 
							 | 
						        model_out_candidates = actual_model.get_audio_features(audio_slice) | 
					
					
						
						| 
							 | 
						        audio_low = model_out_candidates["audio_low"] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        audio_high = model_out_candidates["audio_cls"].unsqueeze(0).repeat(1, audio_low.shape[1], 1) | 
					
					
						
						| 
							 | 
						        audio_low = F.normalize(audio_low, dim=2)[0].cpu().numpy() | 
					
					
						
						| 
							 | 
						        audio_high = F.normalize(audio_high, dim=2)[0].cpu().numpy() | 
					
					
						
						| 
							 | 
						        audio_low_list.append(audio_low) | 
					
					
						
						| 
							 | 
						        audio_high_list.append(audio_high) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    audio_low_all = np.concatenate(audio_low_list, axis=0) | 
					
					
						
						| 
							 | 
						    audio_high_all = np.concatenate(audio_high_list, axis=0) | 
					
					
						
						| 
							 | 
						    path_list, is_continue_list = search_path_dp(graph, audio_low_all, audio_high_all, top_k=1, search_mode="both") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    res_motion = [] | 
					
					
						
						| 
							 | 
						    counter = 0 | 
					
					
						
						| 
							 | 
						    for path, is_continue in zip(path_list, is_continue_list): | 
					
					
						
						| 
							 | 
						        if False: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            res_motion_current = path_visualization( | 
					
					
						
						| 
							 | 
						              graph, path, is_continue, os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), audio_path=audio_path, return_motion=True, verbose_continue=True | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            video_temp_path = os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            res_motion_current = path_visualization_v2( | 
					
					
						
						| 
							 | 
						              graph, path, is_continue, os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), audio_path=None, return_motion=True, verbose_continue=True | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            video_temp_path = os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4") | 
					
					
						
						| 
							 | 
						            video_reader = VideoReader(video_temp_path) | 
					
					
						
						| 
							 | 
						            video_np = [] | 
					
					
						
						| 
							 | 
						            for i in range(len(video_reader)): | 
					
					
						
						| 
							 | 
						                if i == 0: continue | 
					
					
						
						| 
							 | 
						                video_frame = video_reader[i].asnumpy() | 
					
					
						
						| 
							 | 
						                video_np.append(Image.fromarray(video_frame)) | 
					
					
						
						| 
							 | 
						            adjusted_video_pil = adjust_statistics_to_match_reference([video_np]) | 
					
					
						
						| 
							 | 
						            save_videos_from_pil(adjusted_video_pil[0], os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), fps=graph.vs[0]['fps'], bitrate=2000000) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        audio_temp_path = audio_path | 
					
					
						
						| 
							 | 
						        lipsync_output_path = os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4") | 
					
					
						
						| 
							 | 
						        checkpoint_path = './Wav2Lip/checkpoints/wav2lip_gan.pth'   | 
					
					
						
						| 
							 | 
						        os.system(f'python ./Wav2Lip/inference.py --checkpoint_path {checkpoint_path} --face {video_temp_path} --audio {audio_temp_path} --outfile {lipsync_output_path} --nosmooth') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        res_motion.append(res_motion_current) | 
					
					
						
						| 
							 | 
						        np.savez(os.path.join(save_dir, f"audio_{idx}_retri_{counter}.npz"), motion=res_motion_current) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						        start_node = path[1].index | 
					
					
						
						| 
							 | 
						        end_node = start_node + 100 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if create_graph: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        result = [ | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_0.mp4"), | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_0.mp4"), | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_0.npz"), | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_0.npz") | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						        return result | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print(f"delete gt-nodes {start_node}, {end_node}") | 
					
					
						
						| 
							 | 
						    nodes_to_delete = list(range(start_node, end_node)) | 
					
					
						
						| 
							 | 
						    graph.delete_vertices(nodes_to_delete) | 
					
					
						
						| 
							 | 
						    graph = graph_pruning(graph) | 
					
					
						
						| 
							 | 
						    path_list, is_continue_list = search_path_dp(graph, audio_low_all, audio_high_all, top_k=1, search_mode="both") | 
					
					
						
						| 
							 | 
						    res_motion = [] | 
					
					
						
						| 
							 | 
						    counter = 1 | 
					
					
						
						| 
							 | 
						    for path, is_continue in zip(path_list, is_continue_list): | 
					
					
						
						| 
							 | 
						        res_motion_current = path_visualization_v2( | 
					
					
						
						| 
							 | 
						              graph, path, is_continue, os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), audio_path=None, return_motion=True, verbose_continue=True | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        video_temp_path = os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        video_reader = VideoReader(video_temp_path) | 
					
					
						
						| 
							 | 
						        video_np = [] | 
					
					
						
						| 
							 | 
						        for i in range(len(video_reader)): | 
					
					
						
						| 
							 | 
						            if i == 0: continue | 
					
					
						
						| 
							 | 
						            video_frame = video_reader[i].asnumpy() | 
					
					
						
						| 
							 | 
						            video_np.append(Image.fromarray(video_frame)) | 
					
					
						
						| 
							 | 
						        adjusted_video_pil = adjust_statistics_to_match_reference([video_np]) | 
					
					
						
						| 
							 | 
						        save_videos_from_pil(adjusted_video_pil[0], os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), fps=graph.vs[0]['fps'], bitrate=2000000) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        audio_temp_path = audio_path | 
					
					
						
						| 
							 | 
						        lipsync_output_path = os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4") | 
					
					
						
						| 
							 | 
						        checkpoint_path = './Wav2Lip/checkpoints/wav2lip_gan.pth'   | 
					
					
						
						| 
							 | 
						        os.system(f'python ./Wav2Lip/inference.py --checkpoint_path {checkpoint_path} --face {video_temp_path} --audio {audio_temp_path} --outfile {lipsync_output_path} --nosmooth') | 
					
					
						
						| 
							 | 
						        res_motion.append(res_motion_current) | 
					
					
						
						| 
							 | 
						        np.savez(os.path.join(save_dir, f"audio_{idx}_retri_{counter}.npz"), motion=res_motion_current) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    result = [ | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_0.mp4"), | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_1.mp4"), | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_0.npz"), | 
					
					
						
						| 
							 | 
						        os.path.join(save_dir, f"audio_{idx}_retri_1.npz") | 
					
					
						
						| 
							 | 
						    ] | 
					
					
						
						| 
							 | 
						    return result | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def init_class(module_name, class_name, config, **kwargs): | 
					
					
						
						| 
							 | 
						    module = importlib.import_module(module_name) | 
					
					
						
						| 
							 | 
						    model_class = getattr(module, class_name) | 
					
					
						
						| 
							 | 
						    instance = model_class(config, **kwargs) | 
					
					
						
						| 
							 | 
						    return instance | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def seed_everything(seed): | 
					
					
						
						| 
							 | 
						    random.seed(seed) | 
					
					
						
						| 
							 | 
						    np.random.seed(seed) | 
					
					
						
						| 
							 | 
						    torch.manual_seed(seed) | 
					
					
						
						| 
							 | 
						    torch.cuda.manual_seed_all(seed) | 
					
					
						
						| 
							 | 
						        | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def prepare_all(yaml_name): | 
					
					
						
						| 
							 | 
						    parser = argparse.ArgumentParser() | 
					
					
						
						| 
							 | 
						    parser.add_argument("--config", type=str, default=yaml_name) | 
					
					
						
						| 
							 | 
						    parser.add_argument("--debug", action="store_true", help="Enable debugging mode") | 
					
					
						
						| 
							 | 
						    parser.add_argument('overrides', nargs=argparse.REMAINDER) | 
					
					
						
						| 
							 | 
						    args = parser.parse_args() | 
					
					
						
						| 
							 | 
						    if args.config.endswith(".yaml"): | 
					
					
						
						| 
							 | 
						        config = OmegaConf.load(args.config) | 
					
					
						
						| 
							 | 
						        config.exp_name = args.config.split("/")[-1][:-5] | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError("Unsupported config file format. Only .yaml files are allowed.") | 
					
					
						
						| 
							 | 
						    save_dir = os.path.join(config.output_dir, config.exp_name) | 
					
					
						
						| 
							 | 
						    os.makedirs(save_dir, exist_ok=True) | 
					
					
						
						| 
							 | 
						    return config | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def save_first_10_seconds(video_path, output_path="./save_video.mp4", max_length=512): | 
					
					
						
						| 
							 | 
						    if os.path.exists(output_path): | 
					
					
						
						| 
							 | 
						        os.remove(output_path) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    cap = cv2.VideoCapture(video_path) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not cap.isOpened(): | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    fps = int(cap.get(cv2.CAP_PROP_FPS)) | 
					
					
						
						| 
							 | 
						    original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | 
					
					
						
						| 
							 | 
						    original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if original_width >= original_height: | 
					
					
						
						| 
							 | 
						        new_width = max_length | 
					
					
						
						| 
							 | 
						        new_height = int(original_height * (max_length / original_width)) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        new_height = max_length | 
					
					
						
						| 
							 | 
						        new_width = int(original_width * (max_length / original_height)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    fourcc = cv2.VideoWriter_fourcc(*'mp4v') | 
					
					
						
						| 
							 | 
						    out = cv2.VideoWriter(output_path.replace(".mp4", "_fps.mp4"), fourcc, fps, (new_width, new_height)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    frames_to_save = fps * 20 | 
					
					
						
						| 
							 | 
						    frame_count = 0 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    while cap.isOpened() and frame_count < frames_to_save: | 
					
					
						
						| 
							 | 
						        ret, frame = cap.read() | 
					
					
						
						| 
							 | 
						        if not ret: | 
					
					
						
						| 
							 | 
						            break | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        resized_frame = cv2.resize(frame, (new_width, new_height)) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        out.write(resized_frame) | 
					
					
						
						| 
							 | 
						        frame_count += 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    cap.release() | 
					
					
						
						| 
							 | 
						    out.release() | 
					
					
						
						| 
							 | 
						    command = [ | 
					
					
						
						| 
							 | 
						        'ffmpeg', | 
					
					
						
						| 
							 | 
						        '-i', output_path.replace(".mp4", "_fps.mp4"), | 
					
					
						
						| 
							 | 
						        '-vf', 'minterpolate=fps=30:mi_mode=mci:mc_mode=aobmc:vsbmc=1', | 
					
					
						
						| 
							 | 
						        output_path | 
					
					
						
						| 
							 | 
						    ] | 
					
					
						
						| 
							 | 
						    subprocess.run(command) | 
					
					
						
						| 
							 | 
						    os.remove(output_path.replace(".mp4", "_fps.mp4")) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						character_name_to_yaml = { | 
					
					
						
						| 
							 | 
						  "speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4": "./datasets/data_json/youtube_test/speaker8.json", | 
					
					
						
						| 
							 | 
						  "speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4": "./datasets/data_json/youtube_test/speaker7.json", | 
					
					
						
						| 
							 | 
						  "speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4": "./datasets/data_json/youtube_test/speaker9.json", | 
					
					
						
						| 
							 | 
						  "1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4": "./datasets/data_json/youtube_test/speaker1.json", | 
					
					
						
						| 
							 | 
						  "101099-00_18_09-00_18_19.mp4": "./datasets/data_json/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json", | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@spaces.GPU(duration=200)  | 
					
					
						
						| 
							 | 
						def tango(audio_path, character_name, seed, threshold=1.0, create_graph=False, video_folder_path=None): | 
					
					
						
						| 
							 | 
						    os.system("rm -r ./outputs/") | 
					
					
						
						| 
							 | 
						    cfg = prepare_all("./configs/gradio.yaml") | 
					
					
						
						| 
							 | 
						    cfg.seed = seed | 
					
					
						
						| 
							 | 
						    seed_everything(cfg.seed) | 
					
					
						
						| 
							 | 
						    experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name) | 
					
					
						
						| 
							 | 
						    saved_audio_path = "./saved_audio.wav" | 
					
					
						
						| 
							 | 
						    sample_rate, audio_waveform = audio_path  | 
					
					
						
						| 
							 | 
						    sf.write(saved_audio_path, audio_waveform, sample_rate) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    audio_waveform, sample_rate = librosa.load(saved_audio_path) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    resampled_audio = librosa.resample(audio_waveform, orig_sr=sample_rate, target_sr=16000) | 
					
					
						
						| 
							 | 
						    required_length = int(16000 * (128 / 30)) * 2 | 
					
					
						
						| 
							 | 
						    resampled_audio = resampled_audio[:required_length] | 
					
					
						
						| 
							 | 
						    sf.write(saved_audio_path, resampled_audio, 16000) | 
					
					
						
						| 
							 | 
						    audio_path = saved_audio_path | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    yaml_name = character_name_to_yaml.get(character_name.split("/")[-1], "./datasets/data_json/youtube_test/speaker1.json") | 
					
					
						
						| 
							 | 
						    cfg.data.test_meta_paths = yaml_name | 
					
					
						
						| 
							 | 
						    print(yaml_name, character_name.split("/")[-1]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if character_name.split("/")[-1] not in character_name_to_yaml.keys(): | 
					
					
						
						| 
							 | 
						        create_graph=True | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        os.makedirs("./outputs/tmpvideo/", exist_ok=True) | 
					
					
						
						| 
							 | 
						        save_first_10_seconds(character_name, "./outputs/tmpvideo/save_video.mp4") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if create_graph: | 
					
					
						
						| 
							 | 
						        video_folder_path = "./outputs/tmpvideo/" | 
					
					
						
						| 
							 | 
						        data_save_path = "./outputs/tmpdata/" | 
					
					
						
						| 
							 | 
						        json_save_path = "./outputs/save_video.json" | 
					
					
						
						| 
							 | 
						        graph_save_path = "./outputs/save_video.pkl" | 
					
					
						
						| 
							 | 
						        os.system(f"cd ./SMPLer-X/ && python app.py --video_folder_path .{video_folder_path} --data_save_path .{data_save_path} --json_save_path .{json_save_path} && cd ..") | 
					
					
						
						| 
							 | 
						        print(f"cd ./SMPLer-X/ && python app.py --video_folder_path .{video_folder_path} --data_save_path .{data_save_path} --json_save_path .{json_save_path} && cd ..") | 
					
					
						
						| 
							 | 
						        os.system(f"python ./create_graph.py --json_save_path {json_save_path} --graph_save_path {graph_save_path} --threshold {threshold}")  | 
					
					
						
						| 
							 | 
						        cfg.data.test_meta_paths = json_save_path | 
					
					
						
						| 
							 | 
						        gc.collect() | 
					
					
						
						| 
							 | 
						        torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    smplx_model = smplx.create( | 
					
					
						
						| 
							 | 
						        "./emage/smplx_models/",  | 
					
					
						
						| 
							 | 
						        model_type='smplx', | 
					
					
						
						| 
							 | 
						        gender='NEUTRAL_2020',  | 
					
					
						
						| 
							 | 
						        use_face_contour=False, | 
					
					
						
						| 
							 | 
						        num_betas=300, | 
					
					
						
						| 
							 | 
						        num_expression_coeffs=100,  | 
					
					
						
						| 
							 | 
						        ext='npz', | 
					
					
						
						| 
							 | 
						        use_pca=False, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg) | 
					
					
						
						| 
							 | 
						    for param in model.parameters(): | 
					
					
						
						| 
							 | 
						        param.requires_grad = False | 
					
					
						
						| 
							 | 
						    model.smplx_model = smplx_model | 
					
					
						
						| 
							 | 
						    model.get_motion_reps = get_motion_reps_tensor | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    local_rank = 0   | 
					
					
						
						| 
							 | 
						    torch.cuda.set_device(local_rank) | 
					
					
						
						| 
							 | 
						    device = torch.device("cuda", local_rank) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    smplx_model = smplx_model.to(device).eval() | 
					
					
						
						| 
							 | 
						    model = model.to(device) | 
					
					
						
						| 
							 | 
						    model.smplx_model = model.smplx_model.to(device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    checkpoint_path = "./datasets/cached_ckpts/ckpt.pth" | 
					
					
						
						| 
							 | 
						    checkpoint = torch.load(checkpoint_path) | 
					
					
						
						| 
							 | 
						    state_dict = checkpoint['model_state_dict'] | 
					
					
						
						| 
							 | 
						    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | 
					
					
						
						| 
							 | 
						    model.load_state_dict(new_state_dict, strict=False) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_path = os.path.join(experiment_ckpt_dir, f"test_{0}") | 
					
					
						
						| 
							 | 
						    os.makedirs(test_path, exist_ok=True) | 
					
					
						
						| 
							 | 
						    result = test_fn(model, device, 0, cfg.data.test_meta_paths, test_path, cfg, audio_path, create_graph=create_graph) | 
					
					
						
						| 
							 | 
						    gc.collect() | 
					
					
						
						| 
							 | 
						    torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						    return result | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						examples_audio = [ | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/example_male_voice_9_seconds.wav"], | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/example_female_voice_9_seconds.wav"], | 
					
					
						
						| 
							 | 
						] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						examples_video = [ | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4"], | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4"], | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4"], | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4"], | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/101099-00_18_09-00_18_19.mp4"], | 
					
					
						
						| 
							 | 
						] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						combined_examples = [ | 
					
					
						
						| 
							 | 
						    ["./datasets/cached_audio/example_female_voice_9_seconds.wav", "./datasets/cached_audio/female_test_V1.mp4", 2024, 1.0], | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def make_demo(): | 
					
					
						
						| 
							 | 
						    with gr.Blocks(analytics_enabled=False) as Interface: | 
					
					
						
						| 
							 | 
						        gr.Markdown( | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | 
					
					
						
						| 
							 | 
						          <div> | 
					
					
						
						| 
							 | 
						            <h1>TANGO</h1> | 
					
					
						
						| 
							 | 
						            <span>Generating full-body talking videos from audio and reference video</span> | 
					
					
						
						| 
							 | 
						            <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\ | 
					
					
						
						| 
							 | 
						              <a href='https://h-liu1997.github.io/'>Haiyang Liu</a>, \ | 
					
					
						
						| 
							 | 
						              <a href='https://yangxingchao.github.io/'>Xingchao Yang</a>, \ | 
					
					
						
						| 
							 | 
						              <a href=''>Tomoya Akiyama</a>, \ | 
					
					
						
						| 
							 | 
						              <a href='https://sky24h.github.io/'> Yuantian Huang</a>, \ | 
					
					
						
						| 
							 | 
						              <a href=''>Qiaoge Li</a>, \ | 
					
					
						
						| 
							 | 
						              <a href='https://www.tut.ac.jp/english/university/faculty/cs/164.html'>Shigeru Kuriyama</a>, \ | 
					
					
						
						| 
							 | 
						              <a href='https://taketomitakafumi.sakura.ne.jp/web/en/'>Takafumi Taketomi</a>\ | 
					
					
						
						| 
							 | 
						            </h2> | 
					
					
						
						| 
							 | 
						            <br> | 
					
					
						
						| 
							 | 
						            <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | 
					
					
						
						| 
							 | 
						              <a href="https://arxiv.org/abs/2410.04221"><img src="https://img.shields.io/badge/arXiv-2410.04221-blue"></a> | 
					
					
						
						| 
							 | 
						                | 
					
					
						
						| 
							 | 
						              <a href="https://pantomatrix.github.io/TANGO/"><img src="https://img.shields.io/badge/Project_Page-TANGO-orange" alt="Project Page"></a> | 
					
					
						
						| 
							 | 
						                | 
					
					
						
						| 
							 | 
						              <a href="https://github.com/CyberAgentAILab/TANGO"><img src="https://img.shields.io/badge/Github-Code-green"></a> | 
					
					
						
						| 
							 | 
						                | 
					
					
						
						| 
							 | 
						              <a href="https://github.com/CyberAgentAILab/TANGO"><img src="https://img.shields.io/github/stars/CyberAgentAILab/TANGO | 
					
					
						
						| 
							 | 
						              "></a> | 
					
					
						
						| 
							 | 
						            </div> | 
					
					
						
						| 
							 | 
						          </div> | 
					
					
						
						| 
							 | 
						        </div> | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            video1 = gr.Video(value="./datasets/cached_audio/demo1.mp4", label="Demo 0", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video2 = gr.Video(value="./datasets/cached_audio/demo2.mp4", label="Demo 1", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video3 = gr.Video(value="./datasets/cached_audio/demo3.mp4", label="Demo 2", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video4 = gr.Video(value="./datasets/cached_audio/demo4.mp4", label="Demo 3", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video5 = gr.Video(value="./datasets/cached_audio/demo5.mp4", label="Demo 4", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            video1 = gr.Video(value="./datasets/cached_audio/demo6.mp4", label="Demo 5", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video2 = gr.Video(value="./datasets/cached_audio/demo0.mp4", label="Demo 6", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video3 = gr.Video(value="./datasets/cached_audio/demo7.mp4", label="Demo 7", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video4 = gr.Video(value="./datasets/cached_audio/demo8.mp4", label="Demo 8", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            video5 = gr.Video(value="./datasets/cached_audio/demo9.mp4", label="Demo 9", watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            gr.Markdown( | 
					
					
						
						| 
							 | 
						              """ | 
					
					
						
						| 
							 | 
						              <div style="display: flex; justify-content: center; align-items: center; text-align: left;"> | 
					
					
						
						| 
							 | 
						              News: | 
					
					
						
						| 
							 | 
						              <br> | 
					
					
						
						| 
							 | 
						              This is an open-source project supported by Hugging Face's free L40S GPU. Runtime is limited, so it operates in low-quality mode. Some generated results from high-quality mode are shown above. | 
					
					
						
						| 
							 | 
						              <br> | 
					
					
						
						| 
							 | 
						              [10/19]: Add a parameter to control "smooth" or "jumpy" for custom uploaded video. lower threshold is more smooth but more repeated gestures. | 
					
					
						
						| 
							 | 
						              <br> | 
					
					
						
						| 
							 | 
						              [10/15]: Add watermark, fix bugs on custom character by downgrades to py3.9, fix bugs to support audio less than 4s. | 
					
					
						
						| 
							 | 
						              <br> | 
					
					
						
						| 
							 | 
						              [10/14]: Hugging face supports free L40S GPU for this project now! | 
					
					
						
						| 
							 | 
						              </div> | 
					
					
						
						| 
							 | 
						              """ | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=4): | 
					
					
						
						| 
							 | 
						                video_output_1 = gr.Video(label="Generated video - 1", | 
					
					
						
						| 
							 | 
						                            interactive=False, | 
					
					
						
						| 
							 | 
						                            autoplay=False, | 
					
					
						
						| 
							 | 
						                            loop=False, | 
					
					
						
						| 
							 | 
						                            show_share_button=True, | 
					
					
						
						| 
							 | 
						                            watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=4): | 
					
					
						
						| 
							 | 
						                video_output_2 = gr.Video(label="Generated video - 2", | 
					
					
						
						| 
							 | 
						                            interactive=False, | 
					
					
						
						| 
							 | 
						                            autoplay=False, | 
					
					
						
						| 
							 | 
						                            loop=False, | 
					
					
						
						| 
							 | 
						                            show_share_button=True, | 
					
					
						
						| 
							 | 
						                            watermark="./datasets/watermark.png") | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=1): | 
					
					
						
						| 
							 | 
						                file_output_1 = gr.File(label="Download 3D Motion and Visualize in Blender") | 
					
					
						
						| 
							 | 
						                file_output_2 = gr.File(label="Download 3D Motion and Visualize in Blender") | 
					
					
						
						| 
							 | 
						                gr.Markdown(""" | 
					
					
						
						| 
							 | 
						                <div style="display: flex; justify-content: center; align-items: center; text-align: left;"> | 
					
					
						
						| 
							 | 
						                Details of the low-quality mode: | 
					
					
						
						| 
							 | 
						                <br> | 
					
					
						
						| 
							 | 
						                1. lower resolution, video resized as long-side 512 and keep aspect ratio. | 
					
					
						
						| 
							 | 
						                <br> | 
					
					
						
						| 
							 | 
						                2. subgraph instead of full-graph, causing noticeable "frame jumps".  | 
					
					
						
						| 
							 | 
						                <br> | 
					
					
						
						| 
							 | 
						                3. only use the first 8s of your input audio. | 
					
					
						
						| 
							 | 
						                <br> | 
					
					
						
						| 
							 | 
						                4. only use the first 20s of your input video for custom character. if you custom character, it will only generate one video result without "smoothing" for saving time. | 
					
					
						
						| 
							 | 
						                <br> | 
					
					
						
						| 
							 | 
						                5. use open-source tools like SMPLerX-s-model, Wav2Lip, and FiLM for faster processing.  | 
					
					
						
						| 
							 | 
						                <br> | 
					
					
						
						| 
							 | 
						                <br> | 
					
					
						
						| 
							 | 
						                Feel free to open an issue on GitHub or contact the authors if this does not meet your needs. | 
					
					
						
						| 
							 | 
						                </div> | 
					
					
						
						| 
							 | 
						                """) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=1): | 
					
					
						
						| 
							 | 
						                audio_input = gr.Audio(label="Upload your audio") | 
					
					
						
						| 
							 | 
						                seed_input = gr.Number(label="Random Seed", value=2024, interactive=True) | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=2): | 
					
					
						
						| 
							 | 
						                gr.Examples( | 
					
					
						
						| 
							 | 
						                    examples=examples_audio, | 
					
					
						
						| 
							 | 
						                    inputs=[audio_input], | 
					
					
						
						| 
							 | 
						                    outputs=[video_output_1, video_output_2, file_output_1, file_output_2], | 
					
					
						
						| 
							 | 
						                    label="Select existing Audio examples", | 
					
					
						
						| 
							 | 
						                    cache_examples=False | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=1): | 
					
					
						
						| 
							 | 
						                video_input = gr.Video(label="Your Character", elem_classes="video") | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=2): | 
					
					
						
						| 
							 | 
						                gr.Examples( | 
					
					
						
						| 
							 | 
						                    examples=examples_video, | 
					
					
						
						| 
							 | 
						                    inputs=[video_input],   | 
					
					
						
						| 
							 | 
						                    outputs=[video_output_1, video_output_2, file_output_1, file_output_2], | 
					
					
						
						| 
							 | 
						                    label="Character Examples", | 
					
					
						
						| 
							 | 
						                    cache_examples=False | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                threshold_input = gr.Slider(label="Result's Smoothness, lower is more smooth, only effective for custom uploaded videos.", minimum=0.7, maximum=1.0, value=1.0, step=0.01, interactive=True) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            run_button = gr.Button("Generate Video") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        run_button.click( | 
					
					
						
						| 
							 | 
						            fn=tango, | 
					
					
						
						| 
							 | 
						            inputs=[audio_input, video_input, seed_input, threshold_input], | 
					
					
						
						| 
							 | 
						            outputs=[video_output_1, video_output_2, file_output_1, file_output_2] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=4): | 
					
					
						
						| 
							 | 
						                gr.Examples( | 
					
					
						
						| 
							 | 
						                    examples=combined_examples, | 
					
					
						
						| 
							 | 
						                    inputs=[audio_input, video_input, seed_input, threshold_input],   | 
					
					
						
						| 
							 | 
						                    outputs=[video_output_1, video_output_2, file_output_1, file_output_2], | 
					
					
						
						| 
							 | 
						                    fn=tango,   | 
					
					
						
						| 
							 | 
						                    label="Select Combined Audio and Video Examples (Cached)", | 
					
					
						
						| 
							 | 
						                    cache_examples=True | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return Interface | 
					
					
						
						| 
							 | 
						       | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    os.environ["MASTER_ADDR"]='127.0.0.1' | 
					
					
						
						| 
							 | 
						    os.environ["MASTER_PORT"]='8675' | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    demo = make_demo() | 
					
					
						
						| 
							 | 
						    demo.launch(share=True) |