Update app_more_lora.py
Browse files- app_more_lora.py +51 -54
app_more_lora.py
CHANGED
|
@@ -1,70 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
-
|
|
|
|
| 3 |
from diffusers.utils import export_to_video
|
| 4 |
-
from transformers import CLIPVisionModel
|
| 5 |
import gradio as gr
|
| 6 |
import tempfile
|
| 7 |
-
import spaces
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
-
|
| 10 |
-
from
|
| 11 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
| 14 |
-
# Base MODEL_ID (using original Wan model that's compatible with diffusers)
|
| 15 |
-
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
| 16 |
|
| 17 |
-
|
|
|
|
| 18 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
| 19 |
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
| 20 |
|
| 21 |
-
# Additional enhancement LoRAs for FusionX-like quality
|
| 22 |
-
ACCVIDEO_LORA_REPO = "alibaba-pai/Wan2.1-Fun-Reward-LoRAs"
|
| 23 |
-
MPS_LORA_FILENAME = "Wan2.1-MPS-Reward-LoRA.safetensors"
|
| 24 |
-
|
| 25 |
-
# Load enhanced model components
|
| 26 |
-
print("🚀 Loading FusionX Enhanced Wan2.1 I2V Model...")
|
| 27 |
-
image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
|
| 28 |
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
| 29 |
-
pipe =
|
| 30 |
-
MODEL_ID, vae=vae,
|
| 31 |
)
|
| 32 |
-
|
| 33 |
-
# FusionX optimized scheduler settings
|
| 34 |
-
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
|
| 35 |
pipe.to("cuda")
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
try:
|
| 52 |
-
# Load MPS Rewards LoRA (strength 0.7 as per FusionX)
|
| 53 |
-
mps_path = hf_hub_download(repo_id=ACCVIDEO_LORA_REPO, filename=MPS_LORA_FILENAME)
|
| 54 |
-
pipe.load_lora_weights(mps_path, adapter_name="mps_lora")
|
| 55 |
-
lora_adapters.append("mps_lora")
|
| 56 |
-
lora_weights.append(0.7) # FusionX uses 0.7 for MPS
|
| 57 |
-
print("✅ MPS Rewards LoRA loaded (strength: 0.7)")
|
| 58 |
-
except Exception as e:
|
| 59 |
-
print(f"⚠️ MPS LoRA not loaded: {e}")
|
| 60 |
-
|
| 61 |
-
# Apply LoRA adapters if any were loaded
|
| 62 |
-
if lora_adapters:
|
| 63 |
-
pipe.set_adapters(lora_adapters, adapter_weights=lora_weights)
|
| 64 |
-
pipe.fuse_lora()
|
| 65 |
-
print(f"🔥 FusionX Enhancement Applied: {len(lora_adapters)} LoRAs fused")
|
| 66 |
-
else:
|
| 67 |
-
print("📝 No LoRAs loaded - using base Wan model")
|
| 68 |
|
| 69 |
examples = [
|
| 70 |
["A ginger cat passionately plays eletric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights casts dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
|
|
@@ -157,9 +149,14 @@ def generate_video_with_example(
|
|
| 157 |
|
| 158 |
|
| 159 |
with gr.Blocks() as demo:
|
| 160 |
-
gr.Markdown('''# Normalized Attention Guidance (NAG) for fast 4 steps Wan2.1-T2V-14B with
|
|
|
|
|
|
|
| 161 |
Implementation of [Normalized Attention Guidance](https://chendaryen.github.io/NAG.github.io/).
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
| 163 |
''')
|
| 164 |
|
| 165 |
with gr.Row():
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
import random
|
| 3 |
+
import spaces
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
|
| 8 |
from diffusers.utils import export_to_video
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
import tempfile
|
|
|
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
from src.pipeline_wan_nag import NAGWanPipeline
|
| 14 |
+
from src.transformer_wan_nag import NagWanTransformer3DModel
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
MOD_VALUE = 32
|
| 18 |
+
DEFAULT_DURATION_SECONDS = 4
|
| 19 |
+
DEFAULT_STEPS = 4
|
| 20 |
+
DEFAULT_SEED = 2025
|
| 21 |
+
DEFAULT_H_SLIDER_VALUE = 480
|
| 22 |
+
DEFAULT_W_SLIDER_VALUE = 832
|
| 23 |
+
NEW_FORMULA_MAX_AREA = 480.0 * 832.0
|
| 24 |
+
|
| 25 |
+
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
|
| 26 |
+
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
|
| 27 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 28 |
+
|
| 29 |
+
FIXED_FPS = 16
|
| 30 |
+
MIN_FRAMES_MODEL = 8
|
| 31 |
+
MAX_FRAMES_MODEL = 81
|
| 32 |
|
| 33 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
|
| 36 |
+
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
| 37 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
| 38 |
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
| 41 |
+
pipe = NAGWanPipeline.from_pretrained(
|
| 42 |
+
MODEL_ID, vae=vae, torch_dtype=torch.bfloat16
|
| 43 |
)
|
| 44 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
|
|
|
|
|
|
|
| 45 |
pipe.to("cuda")
|
| 46 |
|
| 47 |
+
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
|
| 48 |
+
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
|
| 49 |
+
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
|
| 50 |
+
for name, param in pipe.transformer.named_parameters():
|
| 51 |
+
if "lora_B" in name:
|
| 52 |
+
if "blocks.0" in name:
|
| 53 |
+
param.data = param.data * 0.25
|
| 54 |
+
pipe.fuse_lora()
|
| 55 |
+
pipe.unload_lora_weights()
|
| 56 |
+
|
| 57 |
+
pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
|
| 58 |
+
pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
|
| 59 |
+
pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
examples = [
|
| 62 |
["A ginger cat passionately plays eletric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights casts dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
with gr.Blocks() as demo:
|
| 152 |
+
gr.Markdown('''# Normalized Attention Guidance (NAG) for fast 4 steps Wan2.1-T2V-14B with CausVid LoRA
|
| 153 |
+
NAG demos: [LTX Video Fast](https://huggingface.co/spaces/ChenDY/NAG_ltx-video-distilled), [FLUX.1-dev](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-dev), [FLUX.1-schnell](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-schnell)
|
| 154 |
+
|
| 155 |
Implementation of [Normalized Attention Guidance](https://chendaryen.github.io/NAG.github.io/).
|
| 156 |
+
|
| 157 |
+
[Paper](https://arxiv.org/abs/2505.21179), [GitHub](https://github.com/ChenDarYen/Normalized-Attention-Guidance), [ComfyUI](https://github.com/ChenDarYen/ComfyUI-NAG)
|
| 158 |
+
|
| 159 |
+
[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors).
|
| 160 |
''')
|
| 161 |
|
| 162 |
with gr.Row():
|