Spaces:
Paused
Paused
Update aduc_framework/managers/wan_manager.py
Browse files
aduc_framework/managers/wan_manager.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
# aduc_framework/managers/wan_manager.py
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import tempfile
|
|
@@ -17,20 +18,23 @@ from diffusers.utils.export_utils import export_to_video
|
|
| 17 |
|
| 18 |
class WanManager:
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
-
-
|
| 22 |
-
-
|
| 23 |
-
-
|
| 24 |
-
-
|
|
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 28 |
|
|
|
|
| 29 |
MAX_DIMENSION = 832
|
| 30 |
MIN_DIMENSION = 480
|
| 31 |
DIMENSION_MULTIPLE = 16
|
| 32 |
SQUARE_SIZE = 480
|
| 33 |
|
|
|
|
| 34 |
FIXED_FPS = 16
|
| 35 |
MIN_FRAMES_MODEL = 8
|
| 36 |
MAX_FRAMES_MODEL = 81
|
|
@@ -44,6 +48,7 @@ class WanManager:
|
|
| 44 |
def __init__(self) -> None:
|
| 45 |
print("Loading models into memory. This may take a few minutes...")
|
| 46 |
|
|
|
|
| 47 |
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
| 48 |
self.MODEL_ID,
|
| 49 |
transformer=WanTransformer3DModel.from_pretrained(
|
|
@@ -60,10 +65,13 @@ class WanManager:
|
|
| 60 |
),
|
| 61 |
torch_dtype=torch.bfloat16,
|
| 62 |
)
|
|
|
|
|
|
|
| 63 |
self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
| 64 |
self.pipe.scheduler.config, shift=32.0
|
| 65 |
)
|
| 66 |
|
|
|
|
| 67 |
print("Applying 8-step Lightning LoRA...")
|
| 68 |
try:
|
| 69 |
self.pipe.load_lora_weights(
|
|
@@ -83,6 +91,7 @@ class WanManager:
|
|
| 83 |
print("Fusing LoRA weights into the main model...")
|
| 84 |
self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
|
| 85 |
self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
|
|
|
|
| 86 |
self.pipe.unload_lora_weights()
|
| 87 |
print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
|
| 88 |
except Exception as e:
|
|
@@ -90,6 +99,8 @@ class WanManager:
|
|
| 90 |
|
| 91 |
print("All models loaded. Service is ready.")
|
| 92 |
|
|
|
|
|
|
|
| 93 |
def process_image_for_video(self, image: Image.Image) -> Image.Image:
|
| 94 |
width, height = image.size
|
| 95 |
if width == height:
|
|
@@ -125,9 +136,11 @@ class WanManager:
|
|
| 125 |
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
|
| 126 |
return resized.crop((left, top, left + ref_width, top + ref_height))
|
| 127 |
|
|
|
|
|
|
|
| 128 |
def generate_video_from_conditions(
|
| 129 |
self,
|
| 130 |
-
images_condition_items: List[List[Any]],
|
| 131 |
prompt: str,
|
| 132 |
negative_prompt: Optional[str],
|
| 133 |
duration_seconds: float,
|
|
@@ -139,25 +152,37 @@ class WanManager:
|
|
| 139 |
output_type: str = "np",
|
| 140 |
) -> Tuple[str, int]:
|
| 141 |
"""
|
| 142 |
-
Primeiro item
|
|
|
|
|
|
|
| 143 |
"""
|
| 144 |
if not images_condition_items or len(images_condition_items) < 2:
|
| 145 |
raise ValueError("Forneça ao menos dois itens (início e fim).")
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
if start_image is None or end_image is None:
|
| 155 |
raise ValueError("As imagens inicial e final não podem ser vazias.")
|
| 156 |
if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image):
|
| 157 |
raise TypeError("Os 'patches' devem ser PIL.Image.")
|
|
|
|
|
|
|
| 158 |
|
| 159 |
processed_start = self.process_image_for_video(start_image)
|
| 160 |
processed_end = self.resize_and_crop_to_match(end_image, processed_start)
|
|
|
|
|
|
|
| 161 |
target_height, target_width = processed_start.height, processed_start.width
|
| 162 |
|
| 163 |
num_frames = int(round(duration_seconds * self.FIXED_FPS))
|
|
@@ -167,8 +192,8 @@ class WanManager:
|
|
| 167 |
generator = torch.Generator().manual_seed(current_seed)
|
| 168 |
|
| 169 |
call_kwargs = dict(
|
| 170 |
-
image=processed_start,
|
| 171 |
-
last_image=processed_end,
|
| 172 |
prompt=prompt,
|
| 173 |
negative_prompt=negative_prompt if negative_prompt is not None else self.default_negative_prompt,
|
| 174 |
height=target_height,
|
|
@@ -181,22 +206,28 @@ class WanManager:
|
|
| 181 |
output_type=output_type,
|
| 182 |
)
|
| 183 |
|
| 184 |
-
result = None
|
| 185 |
try:
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
except TypeError:
|
| 193 |
-
print("[WanManager]
|
| 194 |
result = self.pipe(**call_kwargs)
|
| 195 |
|
| 196 |
frames = result.frames[0]
|
| 197 |
-
|
| 198 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 199 |
video_path = tmp.name
|
| 200 |
export_to_video(frames, video_path, fps=self.FIXED_FPS)
|
| 201 |
-
|
| 202 |
return video_path, current_seed
|
|
|
|
| 1 |
# aduc_framework/managers/wan_manager.py
|
| 2 |
+
# WanManager v0.0.1 (beta)
|
| 3 |
|
| 4 |
import os
|
| 5 |
import tempfile
|
|
|
|
| 18 |
|
| 19 |
class WanManager:
|
| 20 |
"""
|
| 21 |
+
WanManager v0.0.1 (beta)
|
| 22 |
+
- image: primeiro item (peso fixo 1.0) -> latente 0
|
| 23 |
+
- handle: segundo item (se presente) -> latente 4, com handle_weight da lista
|
| 24 |
+
- last: último item -> último latente, com anchor_weight_last da lista
|
| 25 |
+
- Mantém LoRA Lightning fundida, FlowMatch Euler, device_map='auto' e contrato i2v.
|
| 26 |
+
- Fallback: se a pipeline não suportar os novos args, chama a API original sem handle/pesos.
|
| 27 |
"""
|
| 28 |
|
| 29 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 30 |
|
| 31 |
+
# Dimensões
|
| 32 |
MAX_DIMENSION = 832
|
| 33 |
MIN_DIMENSION = 480
|
| 34 |
DIMENSION_MULTIPLE = 16
|
| 35 |
SQUARE_SIZE = 480
|
| 36 |
|
| 37 |
+
# Vídeo
|
| 38 |
FIXED_FPS = 16
|
| 39 |
MIN_FRAMES_MODEL = 8
|
| 40 |
MAX_FRAMES_MODEL = 81
|
|
|
|
| 48 |
def __init__(self) -> None:
|
| 49 |
print("Loading models into memory. This may take a few minutes...")
|
| 50 |
|
| 51 |
+
# Pipeline i2v com dois transformadores (alto/baixo ruído)
|
| 52 |
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
| 53 |
self.MODEL_ID,
|
| 54 |
transformer=WanTransformer3DModel.from_pretrained(
|
|
|
|
| 65 |
),
|
| 66 |
torch_dtype=torch.bfloat16,
|
| 67 |
)
|
| 68 |
+
|
| 69 |
+
# Scheduler
|
| 70 |
self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
| 71 |
self.pipe.scheduler.config, shift=32.0
|
| 72 |
)
|
| 73 |
|
| 74 |
+
# LoRA Lightning (fused)
|
| 75 |
print("Applying 8-step Lightning LoRA...")
|
| 76 |
try:
|
| 77 |
self.pipe.load_lora_weights(
|
|
|
|
| 91 |
print("Fusing LoRA weights into the main model...")
|
| 92 |
self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
|
| 93 |
self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
|
| 94 |
+
|
| 95 |
self.pipe.unload_lora_weights()
|
| 96 |
print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
|
| 97 |
except Exception as e:
|
|
|
|
| 99 |
|
| 100 |
print("All models loaded. Service is ready.")
|
| 101 |
|
| 102 |
+
# ===== Utils =====
|
| 103 |
+
|
| 104 |
def process_image_for_video(self, image: Image.Image) -> Image.Image:
|
| 105 |
width, height = image.size
|
| 106 |
if width == height:
|
|
|
|
| 136 |
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
|
| 137 |
return resized.crop((left, top, left + ref_width, top + ref_height))
|
| 138 |
|
| 139 |
+
# ===== API =====
|
| 140 |
+
|
| 141 |
def generate_video_from_conditions(
|
| 142 |
self,
|
| 143 |
+
images_condition_items: List[List[Any]], # [[patch(Image), frame(int|str), peso(float)], ...]
|
| 144 |
prompt: str,
|
| 145 |
negative_prompt: Optional[str],
|
| 146 |
duration_seconds: float,
|
|
|
|
| 152 |
output_type: str = "np",
|
| 153 |
) -> Tuple[str, int]:
|
| 154 |
"""
|
| 155 |
+
- Primeiro item: image (peso fixo 1.0) no latente 0.
|
| 156 |
+
- Segundo item (opcional): handle em latente 4 com peso da lista.
|
| 157 |
+
- Último item: last no último latente com peso da lista.
|
| 158 |
"""
|
| 159 |
if not images_condition_items or len(images_condition_items) < 2:
|
| 160 |
raise ValueError("Forneça ao menos dois itens (início e fim).")
|
| 161 |
|
| 162 |
+
items = images_condition_items
|
| 163 |
+
|
| 164 |
+
# image (peso fixo 1.0)
|
| 165 |
+
start_image = items[0][0]
|
| 166 |
|
| 167 |
+
# handle (segundo item se houver)
|
| 168 |
+
handle_image = items[1][0] if len(items) >= 3 else None
|
| 169 |
+
handle_weight = float(items[1][2]) if len(items) >= 3 and items[1][2] is not None else 1.0
|
| 170 |
+
|
| 171 |
+
# last (sempre o último item)
|
| 172 |
+
end_image = items[-1][0]
|
| 173 |
+
end_weight = float(items[-1][2]) if len(items[-1]) >= 3 and items[-1][2] is not None else 1.0
|
| 174 |
|
| 175 |
if start_image is None or end_image is None:
|
| 176 |
raise ValueError("As imagens inicial e final não podem ser vazias.")
|
| 177 |
if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image):
|
| 178 |
raise TypeError("Os 'patches' devem ser PIL.Image.")
|
| 179 |
+
if handle_image is not None and not isinstance(handle_image, Image.Image):
|
| 180 |
+
raise TypeError("O 'patch' do handle deve ser PIL.Image.")
|
| 181 |
|
| 182 |
processed_start = self.process_image_for_video(start_image)
|
| 183 |
processed_end = self.resize_and_crop_to_match(end_image, processed_start)
|
| 184 |
+
processed_handle = self.resize_and_crop_to_match(handle_image, processed_start) if handle_image is not None else None
|
| 185 |
+
|
| 186 |
target_height, target_width = processed_start.height, processed_start.width
|
| 187 |
|
| 188 |
num_frames = int(round(duration_seconds * self.FIXED_FPS))
|
|
|
|
| 192 |
generator = torch.Generator().manual_seed(current_seed)
|
| 193 |
|
| 194 |
call_kwargs = dict(
|
| 195 |
+
image=processed_start, # latente 0 (peso 1.0 implícito)
|
| 196 |
+
last_image=processed_end, # último latente (peso ajustável)
|
| 197 |
prompt=prompt,
|
| 198 |
negative_prompt=negative_prompt if negative_prompt is not None else self.default_negative_prompt,
|
| 199 |
height=target_height,
|
|
|
|
| 206 |
output_type=output_type,
|
| 207 |
)
|
| 208 |
|
|
|
|
| 209 |
try:
|
| 210 |
+
if processed_handle is not None:
|
| 211 |
+
# handle no latente 4 com peso da lista; last no último com end_weight
|
| 212 |
+
result = self.pipe(
|
| 213 |
+
**call_kwargs,
|
| 214 |
+
handle_image=processed_handle,
|
| 215 |
+
handle_weight=float(handle_weight),
|
| 216 |
+
handle_latent_index=4,
|
| 217 |
+
anchor_weight_last=float(end_weight),
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
# sem handle; apenas peso do last
|
| 221 |
+
result = self.pipe(
|
| 222 |
+
**call_kwargs,
|
| 223 |
+
anchor_weight_last=float(end_weight),
|
| 224 |
+
)
|
| 225 |
except TypeError:
|
| 226 |
+
print("[WanManager] handle/anchor args não suportados; usando chamada padrão.")
|
| 227 |
result = self.pipe(**call_kwargs)
|
| 228 |
|
| 229 |
frames = result.frames[0]
|
|
|
|
| 230 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 231 |
video_path = tmp.name
|
| 232 |
export_to_video(frames, video_path, fps=self.FIXED_FPS)
|
|
|
|
| 233 |
return video_path, current_seed
|