vaibhavpandeyvpz's picture
Import files for official repo
bef42b6
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import binascii
import os
import os.path as osp
import imageio
import torch
import torchvision
__all__ = ["cache_video", "cache_image", "str2bool"]
def rand_name(length=8, suffix=""):
name = binascii.b2a_hex(os.urandom(length)).decode("utf-8")
if suffix:
if not suffix.startswith("."):
suffix = "." + suffix
name += suffix
return name
def cache_video(
tensor,
save_file=None,
fps=30,
suffix=".mp4",
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5,
):
# cache file
cache_file = (
osp.join("/tmp", rand_name(suffix=suffix)) if save_file is None else save_file
)
# save to cache
error = None
for _ in range(retry):
try:
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack(
[
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range
)
for u in tensor.unbind(2)
],
dim=1,
).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8)
for frame in tensor.numpy():
writer.append_data(frame)
writer.close()
return cache_file
except Exception as e:
error = e
continue
else:
print(f"cache_video failed, error: {error}", flush=True)
return None
def cache_image(
tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5
):
# cache file
suffix = osp.splitext(save_file)[1]
if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]:
suffix = ".png"
# save to cache
error = None
for _ in range(retry):
try:
tensor = tensor.clamp(min(value_range), max(value_range))
torchvision.utils.save_image(
tensor,
save_file,
nrow=nrow,
normalize=normalize,
value_range=value_range,
)
return save_file
except Exception as e:
error = e
continue
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ("yes", "true", "t", "y", "1"):
return True
elif v_lower in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected (True/False)")