import os import json import torch import imageio import numpy as np import gradio as gr import matplotlib.pyplot as plt import torchvision.transforms.v2 as T from PIL import Image from pathlib import Path from gradio.themes.soft import Soft from huggingface_hub import hf_hub_download from data_loaders.get_data import get_dataset_loader from data_loaders.humanml.data.dataset import KINEMATIC_CHAIN from model.mdm_controlnet import MDMControlNet from utils import dist_util from utils.model_util import ( sample_from_model, create_model_and_diffusion, load_saved_model, ) PIL_TRANSFORM = T.Compose( [ T.Resize((224, 224), antialias=True), T.ToTensor(), # scales to [0,1] T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] ) EXAMPLES_DIR = Path("examples") EXAMPLE_IMAGES = [ Image.open(str(p)).convert("RGB") for p in sorted(EXAMPLES_DIR.glob("*.jpg")) ] CHECKPOINT = hf_hub_download( repo_id="hassanjbara/mdm", filename="model000050000.pt", cache_dir="save/injector_100k", ) with open("./save/injector_100k/args.json") as f: args_dict = json.load(f) class Args: pass args = Args() for k, v in args_dict.items(): setattr(args, k, v) dist_util.setup_dist(device="cpu") # Build data loader & model once device = dist_util.dev() data = get_dataset_loader( name=args.dataset, batch_size=1, num_frames=args.num_frames, fixed_len=args.pred_len + args.context_len, pred_len=args.pred_len, device=device, num_samples=None, train_sample_indices=[8], fast_mode=True, ) model, diffusion = create_model_and_diffusion(args, data) model = MDMControlNet(model, args) load_saved_model(model, CHECKPOINT, use_avg=False) model.to(device) model.eval() def add_example_to_gallery(example_idx, image_list): if example_idx is None: raise gr.Error( "Please select an example image before adding it to the gallery." ) # Add the selected example image to the gallery image_list.append(EXAMPLE_IMAGES[example_idx]) return image_list, image_list def write_video(motion: np.ndarray, fps: int = 20, out_path: str = "out.mp4") -> str: """ motion: np.ndarray of shape [n_frames, n_joints, 3] Projects onto the XY plane, draws chains with matplotlib, and writes out_path via ffmpeg. """ # remove old file if os.path.exists(out_path): os.remove(out_path) writer = imageio.get_writer(out_path, fps=fps, codec="libx264") for frame in motion: # frame shape [n_joints, 3] fig, ax = plt.subplots(figsize=(4, 4), dpi=80) for chain in KINEMATIC_CHAIN: xs = frame[chain, 0] ys = frame[chain, 1] ax.plot(xs, ys, linewidth=4, color="black") ax.set_xlim(-2, 2) ax.set_ylim(-2, 2) ax.axis("off") # render to numpy array fig.canvas.draw() img = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8") w, h = fig.canvas.get_width_height() img = img.reshape((h, w, 3)) writer.append_data(img) plt.close(fig) writer.close() return out_path def generate_video( text_prompt: str, guidance: float, image_files: list[Image.Image], frames_csv: str ): # 1) load & preprocess cond images cond_images = [] frame_indices = [] if image_files: imgs = torch.stack([PIL_TRANSFORM(img) for img in image_files]).to(device) frame_indices = [ int(x.strip()) for x in frames_csv.split(",") if x.strip().isdigit() ] # Check for mismatch between images and frame indices if len(frame_indices) != len(image_files): raise gr.Error( f"Number of images ({len(image_files)}) does not match number of frame indices ({len(frame_indices)})." ) cond_images.append(imgs) frame_indices = [torch.tensor(frame_indices)] else: cond_images = None frame_indices = None # 2) sample results = sample_from_model( model=model, diffusion=diffusion, data=data, text_prompts=[text_prompt], num_samples=1, motion_length=10.0, guidance_param=guidance, device=device, cond_images=cond_images, frame_indices=frame_indices, ) motion = results["motions"][0].transpose(2, 0, 1) # [n_frames, njoints, 3] # 3) render to MP4 & return path video_path = write_video(motion, fps=20, out_path="out.mp4") return video_path def save_image(drawing_data, image_list): # The output from gr.ImageEditor is a dictionary with a "composite" key # which holds the final image as a NumPy array. if drawing_data is not None and drawing_data["composite"] is not None: # Convert the NumPy array of the composite image to a PIL Image pil_image = Image.fromarray(drawing_data["composite"]).convert("RGB") image_list.append(pil_image) # Update the gallery with the new image list, and reset the editor return image_list, image_list, None def store_selected_index(evt: gr.SelectData): # gr.SelectData contains the index of the selected item return evt.index, gr.update(interactive=True) # ——————————————— # 4) Build Gradio UI with gr.Blocks( theme=Soft(), css=""" /* target the internal container of your gallery */ #example-gallery .gallery-container { max-height: 250px; /* match the height you set above */ overflow-y: auto; /* enable vertical scrolling */ } """, ) as demo: gr.Markdown("## Image Motion Diffusion Model (IMDM)") saved_images = gr.State([]) selected_example = gr.State(None) with gr.Row(): with gr.Column(): editor = gr.Sketchpad( height=512, width=512, brush=gr.Brush(default_size=5, default_color="black"), show_label=False, ) example_gallery = gr.Gallery( value=EXAMPLE_IMAGES, label="Examples", columns=6, # More columns → smaller images height=250, elem_id="example-gallery", ) with gr.Row(): save_btn = gr.Button("Save Drawing") clear_btn = gr.Button("Clear saved images", variant="huggingface") add_example_btn = gr.Button("Add Example to Gallery", interactive=False) with gr.Column(): gallery = gr.Gallery(label="Saved Images", height=250, columns=3) with gr.Row(): prompt = gr.Textbox( label="Text prompt", placeholder="e.g. “happy dancing”" ) frames = gr.Textbox( placeholder="20,40,60,80", label="Frame indices (comma-separated)" ) guidance = gr.Slider(0.0, 2.0, value=1.0, label="Image Guidance scale") generate_btn = gr.Button("Generate video", variant="primary") output_vid = gr.Video(label="Generated video", format="mp4") example_gallery.select( fn=store_selected_index, inputs=None, # Event data is passed automatically outputs=[selected_example, add_example_btn], ) save_btn.click( fn=save_image, inputs=[editor, saved_images], outputs=[saved_images, gallery, editor], ) clear_btn.click( fn=lambda: ([], [], None), # Reset saved images and gallery outputs=[saved_images, gallery, editor], ) add_example_btn.click( fn=add_example_to_gallery, inputs=[selected_example, saved_images], outputs=[saved_images, gallery], ) generate_btn.click( fn=generate_video, inputs=[prompt, guidance, saved_images, frames], outputs=output_vid, ) if __name__ == "__main__": demo.launch()