HAL1993's picture
Update app.py
4f8f533 verified
raw
history blame
23.5 kB
# =============================================================
# 0️⃣ FORCE ALL CACHES TO EPHEMERAL /tmp (DO NOT COUNT TO 150 GB)
# =============================================================
import os, shutil, pathlib
# -----------------------------------------------------------------
# Clean any leftover cache that may already be on the persistent volume.
# This runs **once** at container start, before any import that touches HF.
# -----------------------------------------------------------------
for p in [
pathlib.Path.home() / ".cache",
pathlib.Path("/workspace") / ".cache",
pathlib.Path("/tmp") / "hf_cache",
pathlib.Path("/tmp") / "torch_home",
]:
if p.exists():
shutil.rmtree(p, ignore_errors=True)
# -----------------------------------------------------------------
# Point every HF / torch cache to /tmp (which is a RAM‑disk and is
# NOT counted against the Space’s disk quota).
# -----------------------------------------------------------------
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
os.environ["TORCH_HOME"] = "/tmp/torch_home"
# =============================================================
# 1️⃣ IMPORTS
# =============================================================
import spaces
import torch
import numpy as np
import random
import gc
import tempfile
import requests
import logging
from PIL import Image
import gradio as gr
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
from torchao.quantization import quantize_, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig
import aoti
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================
# 2️⃣ CONFIG
# =============================================================
MAX_DIM = 832
MIN_DIM = 480
SQUARE_DIM = 640
MULTIPLE_OF = 16
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 80
default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
default_negative_prompt = (
"colorful tones, overexposed, static, unclear details, subtitles, style, artwork, painting, screen, "
"still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, deformed, "
"extra fingers, poorly drawn hands, poorly drawn face, deformed, mutated, deformed limbs, "
"fused fingers, still screen, messy background, three legs, many people in background, walking backwards"
)
# ------------------------------------------------------------
# 3️⃣ TRANSLATOR (Albanian → English) – unchanged
# ------------------------------------------------------------
def translate_albanian_to_english(text: str) -> str:
if not text.strip():
return text
for attempt in range(2):
try:
response = requests.post(
"https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate",
json={"from_language": "sq", "to_language": "en", "input_text": text},
headers={"accept": "application/json", "Content-Type": "application/json"},
timeout=8,
)
response.raise_for_status()
translated = response.json().get("translate", text)
logger.info(f"Translated: {text[:50]}... → {translated[:50]}...")
return translated.strip() or text
except Exception as e:
logger.warning(f"Translation failed (attempt {attempt + 1}): {e}")
if attempt == 1:
return text
return text
# ------------------------------------------------------------
# 4️⃣ MODEL LOADING (all caches forced to /tmp)
# ------------------------------------------------------------
pipe = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16,
cache_dir="/tmp/hf_cache",
).to("cuda")
pipe.transformer = WanTransformer3DModel.from_pretrained(
"cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers",
subfolder="transformer",
torch_dtype=torch.bfloat16,
device_map="cuda",
cache_dir="/tmp/hf_cache",
)
pipe.transformer_2 = WanTransformer3DModel.from_pretrained(
"cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers",
subfolder="transformer_2",
torch_dtype=torch.bfloat16,
device_map="cuda",
cache_dir="/tmp/hf_cache",
)
# ---- LoRA -------------------------------------------------
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v",
)
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2",
load_into_transformer_2=True,
)
pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
pipe.unload_lora_weights()
# ---- Quantisation & AoT ------------------------------------
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
# ------------------------------------------------------------
# 5️⃣ HELPER FUNCTIONS (resize, frame count, GPU‑time estimate)
# ------------------------------------------------------------
def resize_image(image: Image.Image) -> Image.Image:
"""Resize / crop the input image so the model receives a valid size."""
w, h = image.size
if w == h:
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
aspect = w / h
MAX_ASPECT = MAX_DIM / MIN_DIM
MIN_ASPECT = MIN_DIM / MAX_DIM
img = image
if aspect > MAX_ASPECT: # very wide → crop width
crop_w = int(round(h * MAX_ASPECT))
left = (w - crop_w) // 2
img = image.crop((left, 0, left + crop_w, h))
elif aspect < MIN_ASPECT: # very tall → crop height
crop_h = int(round(w / MIN_ASPECT))
top = (h - crop_h) // 2
img = image.crop((0, top, w, top + crop_h))
else:
if w > h: # landscape
target_w = MAX_DIM
target_h = int(round(target_w / aspect))
else: # portrait
target_h = MAX_DIM
target_w = int(round(target_h * aspect))
img = image
final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
final_w = max(MIN_DIM, min(MAX_DIM, final_w))
final_h = max(MIN_DIM, min(MAX_DIM, final_h))
return img.resize((final_w, final_h), Image.LANCZOS)
def get_num_frames(duration_seconds: float) -> int:
"""Number of frames for the requested duration."""
return 1 + int(
np.clip(
int(round(duration_seconds * FIXED_FPS)),
MIN_FRAMES_MODEL,
MAX_FRAMES_MODEL,
)
)
def get_duration(
input_image,
prompt,
steps,
negative_prompt,
duration_seconds,
guidance_scale,
guidance_scale_2,
seed,
randomize_seed,
progress, # <- required by @spaces.GPU
):
"""
Rough estimate of the GPU run‑time.
The @spaces.GPU decorator will cut the job at 30 s.
"""
BASE = 81 * 832 * 624 # reference size used by the original demo
BASE_STEP = 15
w, h = resize_image(input_image).size
frames = get_num_frames(duration_seconds)
factor = frames * w * h / BASE
step_time = BASE_STEP * factor ** 1.5
est = 10 + int(steps) * step_time
return min(est, 30) # never exceed the 30‑second safety cap
# ------------------------------------------------------------
# 6️⃣ GENERATION FUNCTION
# ------------------------------------------------------------
@spaces.GPU(duration=get_duration)
def generate_video(
input_image,
prompt_input,
steps=6,
negative_prompt=default_negative_prompt,
duration_seconds=3.2,
guidance_scale=1.5,
guidance_scale_2=1.5,
seed=42,
randomize_seed=False,
progress=gr.Progress(track_tqdm=True),
):
"""Run the model → return a temporary MP4 path and the seed used."""
if input_image is None:
raise gr.Error("Please upload an input image.")
# ---- translate prompt (Albanian → English) -----------------
prompt = translate_albanian_to_english(prompt_input)
# ---- prepare inputs ----------------------------------------
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
resized = resize_image(input_image)
num_frames = get_num_frames(duration_seconds)
# ---- model inference ----------------------------------------
out = pipe(
image=resized,
prompt=prompt,
negative_prompt=negative_prompt,
height=resized.height,
width=resized.width,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
)
frames = out.frames[0]
# ---- write a temporary MP4 (still inside /tmp) -------------
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp:
video_path = tmp.name
export_to_video(frames, video_path, fps=FIXED_FPS)
# ---- unload AoT blocks (they occupy a few GB on disk) -----
aoti.aoti_blocks_unload(pipe.transformer)
aoti.aoti_blocks_unload(pipe.transformer_2)
# ---- GPU cleanup -------------------------------------------
gc.collect()
torch.cuda.empty_cache()
return video_path, current_seed
# ------------------------------------------------------------
# 7️⃣ UI – 100 % identical visual appearance to the original demo
# ------------------------------------------------------------
with gr.Blocks(
css="""
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;600;700&display=swap');
@keyframes glow {0%{box-shadow:0 0 14px rgba(0,255,128,0.5);}50%{box-shadow:0 0 14px rgba(0,255,128,0.7);}100%{box-shadow:0 0 14px rgba(0,255,128,0.5);}}
@keyframes glow-hover {0%{box-shadow:0 0 20px rgba(0,255,128,0.7);}50%{box-shadow:0 0 20px rgba(0,255,128,0.9);}100%{box-shadow:0 0 20px rgba(0,255,128,0.7);}}
@keyframes slide {0%{background-position:0% 50%;}50%{background-position:100% 50%;}100%{background-position:0% 50%;}}
@keyframes pulse {0%,100%{opacity:0.7;}50%{opacity:1;}}
body{
background:#000 !important;
color:#FFF !important;
font-family:'Orbitron',sans-serif;
min-height:100vh;
margin:0 !important;
padding:0 !important;
overflow-x:hidden !important;
display:flex !important;
justify-content:center;
align-items:center;
flex-direction:column;
}
body::before{
content:"";
display:block;
height:600px; /* <-- top gap you asked for */
background:#000 !important;
}
.gr-blocks,.container{
width:100% !important;
max-width:100vw !important;
margin:0 !important;
padding:0 !important;
box-sizing:border-box !important;
overflow-x:hidden !important;
background:#000 !important;
color:#FFF !important;
}
#general_items{
width:100% !important;
max-width:100vw !important;
margin:2rem 0 !important;
display:flex !important;
flex-direction:column;
align-items:center;
justify-content:center;
background:#000 !important;
color:#FFF !important;
}
#input_column{
background:#000 !important;
border:none !important;
border-radius:8px;
padding:1rem !important;
box-shadow:0 0 10px rgba(255,255,255,0.3) !important;
width:100% !important;
max-width:100vw !important;
box-sizing:border-box !important;
color:#FFF !important;
}
h1{
font-size:5rem;
font-weight:700;
text-align:center;
color:#FFF !important;
text-shadow:0 0 8px rgba(255,255,255,0.3) !important;
margin:0 auto .5rem auto;
display:block;
max-width:100%;
}
#subtitle{
font-size:1rem;
text-align:center;
color:#FFF !important;
opacity:0.8;
margin-bottom:1rem;
display:block;
max-width:100%;
}
.gradio-component{
background:#000 !important;
border:none;
margin:.75rem 0;
width:100% !important;
max-width:100vw !important;
color:#FFF !important;
}
.image-container{
aspect-ratio:1/1;
width:100% !important;
max-width:100vw !important;
min-height:500px;
height:auto;
border:0.5px solid #FFF !important;
border-radius:4px;
box-sizing:border-box !important;
background:#000 !important;
box-shadow:0 0 10px rgba(255,255,255,0.3) !important;
position:relative;
color:#FFF !important;
overflow:hidden !important;
}
.image-container img,.image-container video{
width:100% !important;
height:auto;
box-sizing:border-box !important;
display:block !important;
}
/* Hide all Gradio progress UI */
.image-container[aria-label="Generated Video"] .progress-text,
.image-container[aria-label="Generated Video"] .gr-progress,
.image-container[aria-label="Generated Video"] .gr-progress-bar,
.image-container[aria-label="Generated Video"] .progress-bar,
.image-container[aria-label="Generated Video"] [data-testid="progress"],
.image-container[aria-label="Generated Video"] .status,
.image-container[aria-label="Generated Video"] .loading,
.image-container[aria-label="Generated Video"] .spinner,
.image-container[aria-label="Generated Video"] .gr-spinner,
.image-container[aria-label="Generated Video"] .gr-loading,
.image-container[aria-label="Generated Video"] .gr-status,
.image-container[aria-label="Generated Video"] .gpu-init,
.image-container[aria-label="Generated Video"] .initializing,
.image-container[aria-label="Generated Video"] .queue,
.image-container[aria-label="Generated Video"] .queued,
.image-container[aria-label="Generated Video"] .waiting,
.image-container[aria-label="Generated Video"] .processing,
.image-container[aria-label="Generated Video"] .gradio-progress,
.image-container[aria-label="Generated Video"] .gradio-status,
.image-container[aria-label="Generated Video"] div[class*="progress"],
.image-container[aria-label="Generated Video"] div[class*="loading"],
.image-container[aria-label="Generated Video"] div[class*="status"],
.image-container[aria-label="Generated Video"] div[class*="spinner"],
.image-container[aria-label="Generated Video"] *[class*="progress"],
.image-container[aria-label="Generated Video"] *[class*="loading"],
.image-container[aria-label="Generated Video"] *[class*="status"],
.image-container[aria-label="Generated Video"] *[class*="spinner"],
.progress-text,.gr-progress,.gr-progress-bar,.progress-bar,
[data-testid="progress"],.status,.loading,.spinner,.gr-spinner,
.gr-loading,.gr-status,.gpu-init,.initializing,.queue,
.queued,.waiting,.processing,.gradio-progress,.gradio-status,
div[class*="progress"],div[class*="loading"],div[class*="status"],
div[class*="spinner"],*[class*="progress"],*[class*="loading"],
*[class*="status"],*[class*="spinner"]{
display:none!important;
visibility:hidden!important;
opacity:0!important;
height:0!important;
width:0!important;
font-size:0!important;
line-height:0!important;
padding:0!important;
margin:0!important;
position:absolute!important;
left:-9999px!important;
top:-9999px!important;
z-index:-9999!important;
pointer-events:none!important;
overflow:hidden!important;
}
/* Toolbar hiding */
.image-container[aria-label="Input Image"] .file-upload,
.image-container[aria-label="Input Image"] .file-preview,
.image-container[aria-label="Input Image"] .image-actions,
.image-container[aria-label="Generated Video"] .file-upload,
.image-container[aria-label="Generated Video"] .file-preview,
.image-container[aria-label="Generated Video"] .image-actions{
display:none!important;
}
.image-container[aria-label="Generated Video"].processing{
background:#000!important;
position:relative;
}
.image-container[aria-label="Generated Video"].processing::before{
content:"PROCESSING...";
position:absolute!important;
top:50%!important;
left:50%!important;
transform:translate(-50%,-50%)!important;
color:#FFF;
font-family:'Orbitron',sans-serif;
font-size:1.8rem!important;
font-weight:700!important;
text-align:center;
text-shadow:0 0 10px rgba(0,255,128,0.8)!important;
animation:pulse 1.5s ease-in-out infinite,glow 2s ease-in-out infinite!important;
z-index:9999!important;
width:100%!important;
height:100%!important;
display:flex!important;
align-items:center!important;
justify-content:center!important;
pointer-events:none!important;
background:#000!important;
border-radius:4px!important;
box-sizing:border-box!important;
}
.image-container[aria-label="Generated Video"].processing *{
display:none!important;
}
input,textarea,.gr-dropdown,.gr-dropdown select{
background:#000!important;
color:#FFF!important;
border:1px solid #FFF!important;
border-radius:4px;
padding:.5rem;
width:100%!important;
max-width:100vw!important;
box-sizing:border-box!important;
}
.gr-button-primary{
background:linear-gradient(90deg,rgba(0,255,128,0.3),rgba(0,200,100,0.3),rgba(0,255,128,0.3))!important;
background-size:200% 100%;
animation:slide 4s ease-in-out infinite,glow 3s ease-in-out infinite;
color:#FFF!important;
border:1px solid #FFF!important;
border-radius:6px;
padding:.75rem 1.5rem;
font-size:1.1rem;
font-weight:600;
box-shadow:0 0 14px rgba(0,255,128,0.7)!important;
transition:box-shadow .3s,transform .3s;
width:100%!important;
max-width:100vw!important;
min-height:48px;
cursor:pointer;
}
.gr-button-primary:hover{
box-shadow:0 0 20px rgba(0,255,128,0.9)!important;
animation:slide 4s ease-in-out infinite,glow-hover 3s ease-in-out infinite;
transform:scale(1.05);
}
button[aria-label="Fullscreen"],button[aria-label="Share"]{
display:none!important;
}
button[aria-label="Download"]{
transform:scale(3);
transform-origin:top right;
background:#000!important;
color:#FFF!important;
border:1px solid #FFF!important;
border-radius:4px;
padding:.4rem!important;
margin:.5rem!important;
box-shadow:0 0 8px rgba(255,255,255,0.3)!important;
transition:box-shadow .3s;
}
button[aria-label="Download"]:hover{
box-shadow:0 0 12px rgba(255,255,255,0.5)!important;
}
footer,.gr-button-secondary{
display:none!important;
}
.gr-group{
background:#000!important;
border:none!important;
width:100% !important;
max-width:100vw !important;
}
@media (max-width:768px){
h1{font-size:4rem;}
#subtitle{font-size:.9rem;}
.gr-button-primary{
padding:.6rem 1rem;
font-size:1rem;
box-shadow:0 0 10px rgba(0,255,128,0.7)!important;
}
.gr-button-primary:hover{
box-shadow:0 0 12px rgba(0,255,128,0.9)!important;
}
.image-container{min-height:300px;}
.image-container[aria-label="Generated Video"].processing::before{
font-size:1.2rem!important;
}
}
""",
title="Fast Image to Video"
) as demo:
# -------------------------------------------------
# 500‑ERROR GUARD – exact same unique path string
# -------------------------------------------------
gr.HTML("""
<script>
if (!window.location.pathname.includes('b9v0c1x2z3a4s5d6f7g8h9j0k1l2m3n4b5v6c7x8z9a0s1d2f3g4h5j6k7l8m9n0')) {
document.body.innerHTML = '<h1 style="color:#ef4444;font-family:Orbitron,sans-serif;text-align:center;margin-top:300px;">500 Internal Server Error</h1>';
throw new Error('Access denied');
}
</script>
""")
# -------------------------------------------------
# UI layout – identical visual hierarchy
# -------------------------------------------------
with gr.Row(elem_id="general_items"):
gr.Markdown("# ")
gr.Markdown(
"Convert an image into an animated video with prompt description.",
elem_id="subtitle",
)
with gr.Column(elem_id="input_column"):
input_image = gr.Image(
type="pil",
label="Input Image",
sources=["upload"],
show_download_button=False,
show_share_button=False,
interactive=True,
elem_classes=["gradio-component", "image-container"],
)
prompt = gr.Textbox(
label="Prompt",
value=default_prompt_i2v,
lines=3,
placeholder="Describe the desired animation or motion",
elem_classes=["gradio-component"],
)
generate_btn = gr.Button(
"Generate Video",
variant="primary",
elem_classes=["gradio-component", "gr-button-primary"],
)
output_video = gr.Video(
label="Generated Video",
autoplay=True,
interactive=False,
show_download_button=True,
show_share_button=False,
elem_classes=["gradio-component", "image-container"],
)
# -------------------------------------------------
# Wiring – order must match generate_video signature
# -------------------------------------------------
def _postprocess(video_path, seed):
"""Delete the temporary file *after* Gradio has streamed it."""
try:
os.remove(video_path)
except OSError:
pass
return video_path, seed
generate_btn.click(
fn=generate_video,
inputs=[
input_image,
prompt,
gr.State(value=6), # steps
gr.State(value=default_negative_prompt), # negative_prompt
gr.State(value=3.2), # duration_seconds
gr.State(value=1.5), # guidance_scale
gr.State(value=1.5), # guidance_scale_2
gr.State(value=42), # seed
gr.State(value=True), # randomize_seed
# progress is injected by @spaces.GPU – do NOT pass it here
],
outputs=[output_video, gr.State(value=42)],
postprocess=_postprocess, # <-- guarantees the MP4 is removed
)
# ------------------------------------------------------------
# 8️⃣ MAIN
# ------------------------------------------------------------
if __name__ == "__main__":
demo.queue().launch(share=True)