eeuuia commited on
Commit
3a201e7
·
verified ·
1 Parent(s): b8a0748

Update api/ltx_server_refactored_complete.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored_complete.py +76 -15
api/ltx_server_refactored_complete.py CHANGED
@@ -54,9 +54,16 @@ add_deps_to_path()
54
 
55
  # --- PROJECT IMPORTS ---
56
  try:
 
 
 
 
 
 
 
57
  from api.gpu_manager import gpu_manager
58
  from ltx_video.models.autoencoders.vae_encode import (normalize_latents, un_normalize_latents)
59
- from ltx_video.pipelines.pipeline_ltx_video import (ConditioningItem, LTXMultiScalePipeline, adain_filter_latent, create_latent_upsampler, create_ltx_video_pipeline)
60
  from ltx_video.utils.inference_utils import load_image_to_tensor_with_resize_and_crop
61
  from managers.vae_manager import vae_manager_singleton
62
  from tools.video_encode_tool import video_encode_tool_singleton
@@ -158,28 +165,82 @@ class VideoService:
158
  with open(config_path, "r") as file:
159
  return yaml.safe_load(file)
160
 
161
- def _load_models(self) -> Tuple[LTXMultiScalePipeline, Optional[torch.nn.Module]]:
162
- """Loads models from cache to CPU."""
 
 
 
 
163
  t0 = time.perf_counter()
164
- logging.info("Loading LTX models from cache to CPU...")
165
-
166
- pipeline = create_ltx_video_pipeline(
167
- ckpt_path=self.config["checkpoint_path"],
168
- precision=self.config["precision"],
169
- text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
170
- sampler=self.config["sampler"],
171
- device="cpu",
172
- enhance_prompt=False,
173
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  latent_upsampler = None
176
  if self.config.get("spatial_upscaler_model_path"):
 
177
  spatial_path = self.config["spatial_upscaler_model_path"]
178
  latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
 
 
179
 
180
- logging.info(f"Models loaded on CPU in {time.perf_counter()-t0:.2f}s")
181
  return pipeline, latent_upsampler
182
-
 
183
  def move_to_device(self, main_device_str: str, vae_device_str: str):
184
  """Moves pipeline components to their target devices."""
185
  target_main_device = torch.device(main_device_str)
 
54
 
55
  # --- PROJECT IMPORTS ---
56
  try:
57
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, create_latent_upsampler # E outros...
58
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
59
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
60
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
61
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
62
+ from transformers import T5EncoderModel, T5Tokenizer
63
+ from safetensors import safe_open
64
  from api.gpu_manager import gpu_manager
65
  from ltx_video.models.autoencoders.vae_encode import (normalize_latents, un_normalize_latents)
66
+ from ltx_video.pipelines.pipeline_ltx_video import (ConditioningItem, LTXMultiScalePipeline, adain_filter_latent, create_latent_upsampler)
67
  from ltx_video.utils.inference_utils import load_image_to_tensor_with_resize_and_crop
68
  from managers.vae_manager import vae_manager_singleton
69
  from tools.video_encode_tool import video_encode_tool_singleton
 
165
  with open(config_path, "r") as file:
166
  return yaml.safe_load(file)
167
 
168
+ def _load_models(self) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
169
+ """
170
+ Carrega todos os sub-modelos do pipeline na CPU.
171
+ Esta função substitui a necessidade de chamar a `create_ltx_video_pipeline` externa,
172
+ dando-nos controle total sobre o processo.
173
+ """
174
  t0 = time.perf_counter()
175
+ logging.info("Carregando sub-modelos do LTX para a CPU...")
176
+
177
+ ckpt_path = Path(self.config["checkpoint_path"])
178
+ if not ckpt_path.is_file():
179
+ raise FileNotFoundError(f"Arquivo de checkpoint principal não encontrado em: {ckpt_path}")
180
+
181
+ # 1. Carrega Metadados do Checkpoint
182
+ with safe_open(ckpt_path, framework="pt") as f:
183
+ metadata = f.metadata() or {}
184
+ config_str = metadata.get("config", "{}")
185
+ configs = json.loads(config_str)
186
+ allowed_inference_steps = configs.get("allowed_inference_steps")
187
+
188
+ # 2. Carrega os Componentes Individuais (todos na CPU)
189
+ # O `.from_pretrained(ckpt_path)` é inteligente e carrega os pesos corretos do arquivo .safetensors.
190
+ logging.info("Carregando VAE...")
191
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
192
+
193
+ logging.info("Carregando Transformer...")
194
+ transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
195
+
196
+ logging.info("Carregando Scheduler...")
197
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
198
+
199
+ logging.info("Carregando Text Encoder e Tokenizer...")
200
+ text_encoder_path = self.config["text_encoder_model_name_or_path"]
201
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
202
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
203
+
204
+ patchifier = SymmetricPatchifier(patch_size=1)
205
+
206
+ # 3. Define a precisão dos modelos (ainda na CPU, será aplicado na GPU depois)
207
+ precision = self.config.get("precision", "bfloat16")
208
+ if precision == "bfloat16":
209
+ vae.to(torch.bfloat16)
210
+ transformer.to(torch.bfloat16)
211
+ text_encoder.to(torch.bfloat16)
212
 
213
+ # 4. Monta o objeto do Pipeline com os componentes carregados
214
+ logging.info("Montando o objeto LTXVideoPipeline...")
215
+ submodel_dict = {
216
+ "transformer": transformer,
217
+ "patchifier": patchifier,
218
+ "text_encoder": text_encoder,
219
+ "tokenizer": tokenizer,
220
+ "scheduler": scheduler,
221
+ "vae": vae,
222
+ "allowed_inference_steps": allowed_inference_steps,
223
+ # Os prompt enhancers são opcionais e não são carregados por padrão para economizar memória
224
+ "prompt_enhancer_image_caption_model": None,
225
+ "prompt_enhancer_image_caption_processor": None,
226
+ "prompt_enhancer_llm_model": None,
227
+ "prompt_enhancer_llm_tokenizer": None,
228
+ }
229
+ pipeline = LTXVideoPipeline(**submodel_dict)
230
+
231
+ # 5. Carrega o Latent Upsampler (também na CPU)
232
  latent_upsampler = None
233
  if self.config.get("spatial_upscaler_model_path"):
234
+ logging.info("Carregando Latent Upsampler...")
235
  spatial_path = self.config["spatial_upscaler_model_path"]
236
  latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
237
+ if precision == "bfloat16":
238
+ latent_upsampler.to(torch.bfloat16)
239
 
240
+ logging.info(f"Modelos LTX carregados na CPU em {time.perf_counter()-t0:.2f}s")
241
  return pipeline, latent_upsampler
242
+
243
+
244
  def move_to_device(self, main_device_str: str, vae_device_str: str):
245
  """Moves pipeline components to their target devices."""
246
  target_main_device = torch.device(main_device_str)