x2XcarleX2x commited on
Commit
c23fa4c
·
verified ·
1 Parent(s): bcea26b

Update aduc_framework/managers/wan_manager.py

Browse files
aduc_framework/managers/wan_manager.py CHANGED
@@ -1,4 +1,5 @@
1
  # aduc_framework/managers/wan_manager.py
 
2
 
3
  import os
4
  import tempfile
@@ -17,20 +18,23 @@ from diffusers.utils.export_utils import export_to_video
17
 
18
  class WanManager:
19
  """
20
- Encapsula:
21
- - Pipeline Wan I2V com dois transformadores (alto/baixo ruído).
22
- - Fusão da LoRA Lightning (8 passos rápidos).
23
- - Pré-processamento e geração a partir de images_condition_items.
24
- - Ancoragem da last_image no índice latente 4 com peso ajustável (se suportado).
 
25
  """
26
 
27
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
28
 
 
29
  MAX_DIMENSION = 832
30
  MIN_DIMENSION = 480
31
  DIMENSION_MULTIPLE = 16
32
  SQUARE_SIZE = 480
33
 
 
34
  FIXED_FPS = 16
35
  MIN_FRAMES_MODEL = 8
36
  MAX_FRAMES_MODEL = 81
@@ -44,6 +48,7 @@ class WanManager:
44
  def __init__(self) -> None:
45
  print("Loading models into memory. This may take a few minutes...")
46
 
 
47
  self.pipe = WanImageToVideoPipeline.from_pretrained(
48
  self.MODEL_ID,
49
  transformer=WanTransformer3DModel.from_pretrained(
@@ -60,10 +65,13 @@ class WanManager:
60
  ),
61
  torch_dtype=torch.bfloat16,
62
  )
 
 
63
  self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
64
  self.pipe.scheduler.config, shift=32.0
65
  )
66
 
 
67
  print("Applying 8-step Lightning LoRA...")
68
  try:
69
  self.pipe.load_lora_weights(
@@ -83,6 +91,7 @@ class WanManager:
83
  print("Fusing LoRA weights into the main model...")
84
  self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
85
  self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
 
86
  self.pipe.unload_lora_weights()
87
  print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
88
  except Exception as e:
@@ -90,6 +99,8 @@ class WanManager:
90
 
91
  print("All models loaded. Service is ready.")
92
 
 
 
93
  def process_image_for_video(self, image: Image.Image) -> Image.Image:
94
  width, height = image.size
95
  if width == height:
@@ -125,9 +136,11 @@ class WanManager:
125
  left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
126
  return resized.crop((left, top, left + ref_width, top + ref_height))
127
 
 
 
128
  def generate_video_from_conditions(
129
  self,
130
- images_condition_items: List[List[Any]],
131
  prompt: str,
132
  negative_prompt: Optional[str],
133
  duration_seconds: float,
@@ -139,25 +152,37 @@ class WanManager:
139
  output_type: str = "np",
140
  ) -> Tuple[str, int]:
141
  """
142
- Primeiro item = image; último item = last_image; last_image é ancorada no índice latente 4 com peso end_peso [0,1] (se suportado) .
 
 
143
  """
144
  if not images_condition_items or len(images_condition_items) < 2:
145
  raise ValueError("Forneça ao menos dois itens (início e fim).")
146
 
147
- first_item = images_condition_items[0]
148
- last_item = images_condition_items[-1]
 
 
149
 
150
- start_image = first_item[0]
151
- end_image = last_item[0]
152
- end_weight = float(last_item[2]) if len(last_item) >= 3 and last_item[2] is not None else 1.0
 
 
 
 
153
 
154
  if start_image is None or end_image is None:
155
  raise ValueError("As imagens inicial e final não podem ser vazias.")
156
  if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image):
157
  raise TypeError("Os 'patches' devem ser PIL.Image.")
 
 
158
 
159
  processed_start = self.process_image_for_video(start_image)
160
  processed_end = self.resize_and_crop_to_match(end_image, processed_start)
 
 
161
  target_height, target_width = processed_start.height, processed_start.width
162
 
163
  num_frames = int(round(duration_seconds * self.FIXED_FPS))
@@ -167,8 +192,8 @@ class WanManager:
167
  generator = torch.Generator().manual_seed(current_seed)
168
 
169
  call_kwargs = dict(
170
- image=processed_start,
171
- last_image=processed_end,
172
  prompt=prompt,
173
  negative_prompt=negative_prompt if negative_prompt is not None else self.default_negative_prompt,
174
  height=target_height,
@@ -181,22 +206,28 @@ class WanManager:
181
  output_type=output_type,
182
  )
183
 
184
- result = None
185
  try:
186
- # Ancorar no índice latente 4 com peso end_weight
187
- result = self.pipe(
188
- **call_kwargs,
189
- anchor_weight_last=float(end_weight),
190
- anchor_latent_index=4,
191
- )
 
 
 
 
 
 
 
 
 
192
  except TypeError:
193
- print("[WanManager] anchor_latent_index/anchor_weight_last não suportados; usando chamada padrão.")
194
  result = self.pipe(**call_kwargs)
195
 
196
  frames = result.frames[0]
197
-
198
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
199
  video_path = tmp.name
200
  export_to_video(frames, video_path, fps=self.FIXED_FPS)
201
-
202
  return video_path, current_seed
 
1
  # aduc_framework/managers/wan_manager.py
2
+ # WanManager v0.0.1 (beta)
3
 
4
  import os
5
  import tempfile
 
18
 
19
  class WanManager:
20
  """
21
+ WanManager v0.0.1 (beta)
22
+ - image: primeiro item (peso fixo 1.0) -> latente 0
23
+ - handle: segundo item (se presente) -> latente 4, com handle_weight da lista
24
+ - last: último item -> último latente, com anchor_weight_last da lista
25
+ - Mantém LoRA Lightning fundida, FlowMatch Euler, device_map='auto' e contrato i2v.
26
+ - Fallback: se a pipeline não suportar os novos args, chama a API original sem handle/pesos.
27
  """
28
 
29
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
30
 
31
+ # Dimensões
32
  MAX_DIMENSION = 832
33
  MIN_DIMENSION = 480
34
  DIMENSION_MULTIPLE = 16
35
  SQUARE_SIZE = 480
36
 
37
+ # Vídeo
38
  FIXED_FPS = 16
39
  MIN_FRAMES_MODEL = 8
40
  MAX_FRAMES_MODEL = 81
 
48
  def __init__(self) -> None:
49
  print("Loading models into memory. This may take a few minutes...")
50
 
51
+ # Pipeline i2v com dois transformadores (alto/baixo ruído)
52
  self.pipe = WanImageToVideoPipeline.from_pretrained(
53
  self.MODEL_ID,
54
  transformer=WanTransformer3DModel.from_pretrained(
 
65
  ),
66
  torch_dtype=torch.bfloat16,
67
  )
68
+
69
+ # Scheduler
70
  self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
71
  self.pipe.scheduler.config, shift=32.0
72
  )
73
 
74
+ # LoRA Lightning (fused)
75
  print("Applying 8-step Lightning LoRA...")
76
  try:
77
  self.pipe.load_lora_weights(
 
91
  print("Fusing LoRA weights into the main model...")
92
  self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
93
  self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
94
+
95
  self.pipe.unload_lora_weights()
96
  print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
97
  except Exception as e:
 
99
 
100
  print("All models loaded. Service is ready.")
101
 
102
+ # ===== Utils =====
103
+
104
  def process_image_for_video(self, image: Image.Image) -> Image.Image:
105
  width, height = image.size
106
  if width == height:
 
136
  left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
137
  return resized.crop((left, top, left + ref_width, top + ref_height))
138
 
139
+ # ===== API =====
140
+
141
  def generate_video_from_conditions(
142
  self,
143
+ images_condition_items: List[List[Any]], # [[patch(Image), frame(int|str), peso(float)], ...]
144
  prompt: str,
145
  negative_prompt: Optional[str],
146
  duration_seconds: float,
 
152
  output_type: str = "np",
153
  ) -> Tuple[str, int]:
154
  """
155
+ - Primeiro item: image (peso fixo 1.0) no latente 0.
156
+ - Segundo item (opcional): handle em latente 4 com peso da lista.
157
+ - Último item: last no último latente com peso da lista.
158
  """
159
  if not images_condition_items or len(images_condition_items) < 2:
160
  raise ValueError("Forneça ao menos dois itens (início e fim).")
161
 
162
+ items = images_condition_items
163
+
164
+ # image (peso fixo 1.0)
165
+ start_image = items[0][0]
166
 
167
+ # handle (segundo item se houver)
168
+ handle_image = items[1][0] if len(items) >= 3 else None
169
+ handle_weight = float(items[1][2]) if len(items) >= 3 and items[1][2] is not None else 1.0
170
+
171
+ # last (sempre o último item)
172
+ end_image = items[-1][0]
173
+ end_weight = float(items[-1][2]) if len(items[-1]) >= 3 and items[-1][2] is not None else 1.0
174
 
175
  if start_image is None or end_image is None:
176
  raise ValueError("As imagens inicial e final não podem ser vazias.")
177
  if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image):
178
  raise TypeError("Os 'patches' devem ser PIL.Image.")
179
+ if handle_image is not None and not isinstance(handle_image, Image.Image):
180
+ raise TypeError("O 'patch' do handle deve ser PIL.Image.")
181
 
182
  processed_start = self.process_image_for_video(start_image)
183
  processed_end = self.resize_and_crop_to_match(end_image, processed_start)
184
+ processed_handle = self.resize_and_crop_to_match(handle_image, processed_start) if handle_image is not None else None
185
+
186
  target_height, target_width = processed_start.height, processed_start.width
187
 
188
  num_frames = int(round(duration_seconds * self.FIXED_FPS))
 
192
  generator = torch.Generator().manual_seed(current_seed)
193
 
194
  call_kwargs = dict(
195
+ image=processed_start, # latente 0 (peso 1.0 implícito)
196
+ last_image=processed_end, # último latente (peso ajustável)
197
  prompt=prompt,
198
  negative_prompt=negative_prompt if negative_prompt is not None else self.default_negative_prompt,
199
  height=target_height,
 
206
  output_type=output_type,
207
  )
208
 
 
209
  try:
210
+ if processed_handle is not None:
211
+ # handle no latente 4 com peso da lista; last no último com end_weight
212
+ result = self.pipe(
213
+ **call_kwargs,
214
+ handle_image=processed_handle,
215
+ handle_weight=float(handle_weight),
216
+ handle_latent_index=4,
217
+ anchor_weight_last=float(end_weight),
218
+ )
219
+ else:
220
+ # sem handle; apenas peso do last
221
+ result = self.pipe(
222
+ **call_kwargs,
223
+ anchor_weight_last=float(end_weight),
224
+ )
225
  except TypeError:
226
+ print("[WanManager] handle/anchor args não suportados; usando chamada padrão.")
227
  result = self.pipe(**call_kwargs)
228
 
229
  frames = result.frames[0]
 
230
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
231
  video_path = tmp.name
232
  export_to_video(frames, video_path, fps=self.FIXED_FPS)
 
233
  return video_path, current_seed