eeuuia commited on
Commit
df3a6b5
·
verified ·
1 Parent(s): f735b5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -62
app.py CHANGED
@@ -4,8 +4,6 @@ import numpy as np
4
  import tempfile
5
  import os
6
  from torchvision import transforms
7
- from safetensors import safe_open
8
- from diffusers.models.autoencoders import AutoencoderKLLTXVideo
9
 
10
  from diffusers import LTXLatentUpsamplePipeline
11
  #from pipeline_ltx_condition_control import LTXConditionPipeline, LTXVideoCondition
@@ -19,13 +17,13 @@ import cv2
19
  import shutil
20
  import glob
21
  from pathlib import Path
22
- import json
23
  import warnings
24
  import logging
25
  warnings.filterwarnings("ignore", category=UserWarning)
26
  warnings.filterwarnings("ignore", category=FutureWarning)
27
  warnings.filterwarnings("ignore", message=".*")
28
- from huggingface_hub import hf_hub_download, logging as ll
29
  ll.set_verbosity_error()
30
  ll.set_verbosity_warning()
31
  ll.set_verbosity_info()
@@ -39,66 +37,14 @@ dtype = torch.bfloat16
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
 
41
  # Carregamento das pipelines
42
- #pipeline = LTXConditionPipeline.from_pretrained(
43
- # "Lightricks/LTX-Video-0.9.8-13B-distilled",
44
- # offload_state_dict=False,
45
- # torch_dtype=torch.bfloat16,
46
- # cache_dir=os.getenv("HF_HOME_CACHE"),
47
- # token=os.getenv("HF_TOKEN"),
48
- #)
49
-
50
- from huggingface_hub import hf_hub_download
51
- from safetensors.torch import load_file as safe_load
52
-
53
- # Baixa exatamente a variante desejada do repo oficial:
54
- weight_path = hf_hub_download(
55
- repo_id="Lightricks/LTX-Video",
56
- filename="ltxv-13b-0.9.8-distilled-fp8.safetensors",
57
- revision=os.getenv("LTXV_REVISION", "8984fa25007f376c1a299016d0957a37a2f797bb")
58
  )
59
 
60
-
61
-
62
-
63
-
64
- if True:
65
- if True:
66
- with safe_open(weight_path, framework="pt") as f:
67
- metadata = f.metadata() or {}
68
- config_str = metadata.get("config", "{}")
69
- configs = json.loads(config_str)
70
- allowed_inference_steps = configs.get("allowed_inference_steps")
71
-
72
- # 2. Carrega os Componentes Individuais (todos na CPU)
73
- # O `.from_pretrained(ckpt_path)` é inteligente e carrega os pesos corretos do arquivo .safetensors.
74
- logging.info("Carregando VAE...")
75
- #vae = AutoencoderKLLTXVideo.from_pretrained(weight_path).to("cpu")
76
-
77
- logging.info("Carregando Transformer...")
78
- #transformer = Transformer3DModel.from_pretrained(weight_path).to("cpu")
79
-
80
- logging.info("Carregando Scheduler...")
81
- #scheduler = RectifiedFlowScheduler.from_pretrained(weight_path)
82
-
83
- logging.info("Carregando Text Encoder e Tokenizer...")
84
- #text_encoder_path = "PixArt-alpha/PixArt-XL-2-1024-MS" #self.config["text_encoder_model_name_or_path"]
85
- #text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
86
- #tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
87
-
88
- #patchifier = SymmetricPatchifier(patch_size=1)
89
-
90
- # 3. Define a precisão dos modelos (ainda na CPU, será aplicado na GPU depois)
91
-
92
- # 4. Monta o objeto do Pipeline com os componentes carregados
93
-
94
- pipeline = LTXConditionPipeline.from_pretrained(weight_path, cache_dir=os.getenv("HF_HOME_CACHE"), torch_dtype=dtype)
95
-
96
-
97
- # Carrega o state_dict e aplica no transformer já criado pelo model_index:
98
- state = safe_load(weight_path)
99
- pipeline.transformer.load_state_dict(state, strict=True)
100
-
101
-
102
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
103
  "Lightricks/ltxv-spatial-upscaler-0.9.7",
104
  cache_dir=os.getenv("HF_HOME_CACHE"),
 
4
  import tempfile
5
  import os
6
  from torchvision import transforms
 
 
7
 
8
  from diffusers import LTXLatentUpsamplePipeline
9
  #from pipeline_ltx_condition_control import LTXConditionPipeline, LTXVideoCondition
 
17
  import shutil
18
  import glob
19
  from pathlib import Path
20
+
21
  import warnings
22
  import logging
23
  warnings.filterwarnings("ignore", category=UserWarning)
24
  warnings.filterwarnings("ignore", category=FutureWarning)
25
  warnings.filterwarnings("ignore", message=".*")
26
+ from huggingface_hub import logging as ll
27
  ll.set_verbosity_error()
28
  ll.set_verbosity_warning()
29
  ll.set_verbosity_info()
 
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
 
39
  # Carregamento das pipelines
40
+ pipeline = LTXConditionPipeline.from_pretrained(
41
+ "Lightricks/LTX-Video-0.9.8-13B-distilled",
42
+ offload_state_dict=False,
43
+ torch_dtype=torch.bfloat16,
44
+ cache_dir=os.getenv("HF_HOME_CACHE"),
45
+ token=os.getenv("HF_TOKEN"),
 
 
 
 
 
 
 
 
 
 
46
  )
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
49
  "Lightricks/ltxv-spatial-upscaler-0.9.7",
50
  cache_dir=os.getenv("HF_HOME_CACHE"),