|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
from pathlib import Path |
|
|
import io |
|
|
import yaml |
|
|
|
|
|
from PIL import Image, ImageCms |
|
|
import torch |
|
|
import numpy as np |
|
|
from transformers import T5Tokenizer, T5EncoderModel |
|
|
from safetensors.torch import load_file |
|
|
import diffusers |
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler |
|
|
from diffusers.utils import check_min_version, export_to_video |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from controlnet_pipeline import ControlnetCogVideoXPipeline |
|
|
from cogvideo_transformer import CogVideoXTransformer3DModel |
|
|
|
|
|
from training.utils import save_frames_as_pngs |
|
|
from training.helpers import get_conditioning |
|
|
|
|
|
|
|
|
check_min_version("0.31.0.dev0") |
|
|
|
|
|
|
|
|
def convert_to_srgb(img: Image): |
|
|
if 'icc_profile' in img.info: |
|
|
icc = img.info['icc_profile'] |
|
|
src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc)) |
|
|
dst_profile = ImageCms.createProfile("sRGB") |
|
|
img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB') |
|
|
else: |
|
|
img = img.convert("RGB") |
|
|
return img |
|
|
|
|
|
|
|
|
INTERVALS = { |
|
|
"present": { |
|
|
"in_start": 0, |
|
|
"in_end": 16, |
|
|
"out_start": 0, |
|
|
"out_end": 16, |
|
|
"center": 8, |
|
|
"window_size": 16, |
|
|
"mode": "1x", |
|
|
"fps": 240 |
|
|
}, |
|
|
"past, present and future": { |
|
|
"in_start": 4, |
|
|
"in_end": 12, |
|
|
"out_start": 0, |
|
|
"out_end": 16, |
|
|
"center": 8, |
|
|
"window_size": 16, |
|
|
"mode": "2x", |
|
|
"fps": 240, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def convert_to_batch( |
|
|
image, |
|
|
interval_key="present", |
|
|
image_size=(720, 1280), |
|
|
): |
|
|
interval = INTERVALS[interval_key] |
|
|
|
|
|
inp_int, out_int, num_frames = get_conditioning( |
|
|
in_start=interval['in_start'], |
|
|
in_end=interval['in_end'], |
|
|
out_start=interval['out_start'], |
|
|
out_end=interval['out_end'], |
|
|
mode=interval['mode'], |
|
|
fps=interval['fps'], |
|
|
) |
|
|
|
|
|
blur_img_original = convert_to_srgb(image) |
|
|
H, W = blur_img_original.size |
|
|
|
|
|
blur_img = blur_img_original.resize((image_size[1], image_size[0])) |
|
|
blur_img = torch.from_numpy(np.array(blur_img)[None]).permute(0, 3, 1, 2).contiguous().float() |
|
|
blur_img = blur_img / 127.5 - 1.0 |
|
|
|
|
|
data = { |
|
|
"original_size": (H, W), |
|
|
'blur_img': blur_img, |
|
|
'caption': "", |
|
|
'input_interval': inp_int, |
|
|
'output_interval': out_int, |
|
|
'height': image_size[0], |
|
|
'width': image_size[1], |
|
|
'num_frames': num_frames, |
|
|
} |
|
|
return data |
|
|
|
|
|
|
|
|
def load_model(args): |
|
|
with open(args.model_config_path) as f: |
|
|
model_config = yaml.safe_load(f) |
|
|
|
|
|
load_dtype = torch.float16 |
|
|
transformer = CogVideoXTransformer3DModel.from_pretrained( |
|
|
args.pretrained_model_path, |
|
|
subfolder="transformer", |
|
|
torch_dtype=load_dtype, |
|
|
revision=model_config["revision"], |
|
|
variant=model_config["variant"], |
|
|
low_cpu_mem_usage=False, |
|
|
attn_implementation="flash_attention_2", |
|
|
) |
|
|
weight_path = hf_hub_download( |
|
|
repo_id=args.blur2vid_hf_repo_path, |
|
|
filename="cogvideox-outsidephotos/checkpoint/model.safetensors" |
|
|
) |
|
|
transformer.load_state_dict(load_file(weight_path)) |
|
|
|
|
|
text_encoder = T5EncoderModel.from_pretrained( |
|
|
args.pretrained_model_path, |
|
|
subfolder="text_encoder", |
|
|
revision=model_config["revision"], |
|
|
) |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained( |
|
|
args.pretrained_model_path, |
|
|
subfolder="tokenizer", |
|
|
revision=model_config["revision"], |
|
|
) |
|
|
|
|
|
vae = AutoencoderKLCogVideoX.from_pretrained( |
|
|
args.pretrained_model_path, |
|
|
subfolder="vae", |
|
|
revision=model_config["revision"], |
|
|
variant=model_config["variant"], |
|
|
) |
|
|
|
|
|
scheduler = CogVideoXDPMScheduler.from_pretrained( |
|
|
args.pretrained_model_path, |
|
|
subfolder="scheduler" |
|
|
) |
|
|
|
|
|
|
|
|
vae.enable_slicing() |
|
|
vae.enable_tiling() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.bfloat16 |
|
|
|
|
|
text_encoder.to(dtype=weight_dtype) |
|
|
transformer.to(dtype=weight_dtype) |
|
|
vae.to(dtype=weight_dtype) |
|
|
|
|
|
pipe = ControlnetCogVideoXPipeline.from_pretrained( |
|
|
args.pretrained_model_path, |
|
|
tokenizer=tokenizer, |
|
|
transformer=transformer, |
|
|
text_encoder=text_encoder, |
|
|
vae=vae, |
|
|
scheduler=scheduler, |
|
|
torch_dtype=weight_dtype, |
|
|
) |
|
|
|
|
|
scheduler_args = {} |
|
|
|
|
|
if "variance_type" in pipe.scheduler.config: |
|
|
variance_type = pipe.scheduler.config.variance_type |
|
|
|
|
|
if variance_type in ["learned", "learned_range"]: |
|
|
variance_type = "fixed_small" |
|
|
|
|
|
scheduler_args["variance_type"] = variance_type |
|
|
|
|
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) |
|
|
|
|
|
return pipe, model_config |
|
|
|
|
|
|
|
|
def inference_on_image(pipe, image, interval_key, model_config, args): |
|
|
|
|
|
if args.seed is not None: |
|
|
np.random.seed(args.seed) |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
|
|
|
generator = torch.Generator(device=args.device).manual_seed(args.seed) if args.seed else None |
|
|
|
|
|
with torch.autocast(device_type=args.device, dtype=torch.bfloat16, enabled=True): |
|
|
batch = convert_to_batch(image, interval_key, (args.video_height, args.video_width)) |
|
|
|
|
|
frame = batch["blur_img"].permute(0, 2, 3, 1).cpu().numpy() |
|
|
frame = (frame + 1.0) * 127.5 |
|
|
frame = frame.astype(np.uint8) |
|
|
pipeline_args = { |
|
|
"prompt": "", |
|
|
"negative_prompt": "", |
|
|
"image": frame, |
|
|
"input_intervals": torch.stack([batch["input_interval"]]), |
|
|
"output_intervals": torch.stack([batch["output_interval"]]), |
|
|
"guidance_scale": model_config["guidance_scale"], |
|
|
"use_dynamic_cfg": model_config["use_dynamic_cfg"], |
|
|
"height": batch["height"], |
|
|
"width": batch["width"], |
|
|
"num_frames": torch.tensor([[model_config["max_num_frames"]]]), |
|
|
"num_inference_steps": args.num_inference_steps, |
|
|
} |
|
|
|
|
|
input_image = frame |
|
|
|
|
|
num_frames = batch["num_frames"] |
|
|
|
|
|
print(f"Running inference for interval {interval_key}...") |
|
|
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] |
|
|
|
|
|
video = video[0:num_frames] |
|
|
|
|
|
return input_image, video |
|
|
|
|
|
|
|
|
def main(args): |
|
|
output_path = Path(args.output_path) |
|
|
output_path.mkdir(exist_ok=True) |
|
|
|
|
|
image_path = Path(args.image_path) |
|
|
|
|
|
is_dir = image_path.is_dir() |
|
|
|
|
|
if is_dir: |
|
|
image_paths = sorted(list(image_path.glob("*.*"))) |
|
|
else: |
|
|
image_paths = [image_path] |
|
|
|
|
|
pipe, model_config = load_model(args) |
|
|
|
|
|
pipe = pipe.to(args.device) |
|
|
|
|
|
for image_path in image_paths: |
|
|
image = Image.open(image_path) |
|
|
|
|
|
processed_image, video = inference_on_image(pipe, image, "past, present and future", model_config, args) |
|
|
|
|
|
vid_output_path = output_path / f"{image_path.stem}.mp4" |
|
|
export_to_video(video, vid_output_path, fps=20) |
|
|
|
|
|
|
|
|
inpug_image_output_path = output_path / f"{image_path.stem}_input.png" |
|
|
Image.fromarray(processed_image[0]).save(inpug_image_output_path) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--image_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to image input or directory containing input images", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--blur2vid_hf_repo_path", |
|
|
type=str, |
|
|
default="tedlasai/blur2vid", |
|
|
help="hf repo containing the weight files", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pretrained_model_path", |
|
|
type=str, |
|
|
default="THUDM/CogVideoX-2b", |
|
|
help="repo id or path for pretrained CogVideoX model", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_config_path", |
|
|
type=str, |
|
|
default="training/configs/outsidephotos.yaml", |
|
|
help="path to model config yaml", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_path", |
|
|
type=str, |
|
|
default="output/", |
|
|
help="path to output", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_width", |
|
|
type=int, |
|
|
default=1280, |
|
|
help="video resolution width", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_height", |
|
|
type=int, |
|
|
default=720, |
|
|
help="video resolution height", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_inference_steps", |
|
|
type=int, |
|
|
default=50, |
|
|
help="number of DDIM steps", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=None, |
|
|
help="random generator seed", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cuda", |
|
|
help="inference device", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|
|
|
|
|
|
|