Aduc-sdr-2_5 / download_models.py
x2XcarleX2x's picture
Update download_models.py
14db299 verified
raw
history blame
4.62 kB
# download_models.py (v4.0 - Versão Definitiva Completa)
import os
import yaml
import logging
from huggingface_hub import snapshot_download
# Configuração do log para ser claro e informativo
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s')
logger = logging.getLogger("MODEL_LOGISTICS")
def download_repo_snapshot(repo_id, local_dir, desc, allow_patterns=None):
"""
Baixa um snapshot de um repositório, verificando se ele já existe para evitar
downloads repetidos. É a forma mais robusta de baixar modelos.
"""
os.makedirs(local_dir, exist_ok=True)
# Um bom indicador de que o download foi concluído é a presença de um arquivo de configuração.
# Isso evita downloads parciais em caso de reinicialização.
completion_marker = os.path.join(local_dir, '.download_completed')
if os.path.exists(completion_marker):
logger.info(f"Modelos para '{desc}' parecem já existir e estão completos em: {local_dir}")
return
logger.info(f"Baixando modelos para '{desc}' de '{repo_id}' para '{local_dir}'...")
try:
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=allow_patterns,
ignore_patterns=["*.md", "*.txt", "*.gitattributes", "*onnx*", "*fp32*"], # Ignora arquivos desnecessários
)
# Cria o marcador de conclusão
with open(completion_marker, 'w') as f:
f.write('done')
logger.info(f"Download para '{desc}' concluído.")
except Exception as e:
logger.error(f"Falha CRÍTICA ao baixar o snapshot '{desc}'. Erro: {e}")
raise
def main():
"""
Ponto de entrada para baixar todos os modelos de IA necessários, lendo as
configurações do arquivo config.yaml.
"""
logger.info("--- Iniciando verificação e download de todos os modelos ---")
try:
with open("config.yaml", 'r', encoding='utf-8') as f:
# Passamos o 'f' (o stream do arquivo) para a função safe_load
config = yaml.safe_load(f).get('specialists', {})
if not config:
logger.warning("Seção 'specialists' não encontrada no config.yaml. Nenhum modelo será baixado.")
return
except FileNotFoundError:
logger.error("Arquivo config.yaml não encontrado! Não é possível determinar quais modelos baixar.")
raise
except Exception as e:
logger.error(f"Erro ao ler ou parsear o config.yaml: {e}")
raise
# --- 1. Modelos para LTX-Video ---
if config.get('ltx', {}).get('gpus_required', 0) > 0:
ltx_models_dir = "/app/LTX-Video/models_downloaded"
download_repo_snapshot("Lightricks/LTX-Video", ltx_models_dir, "LTX Models", allow_patterns=["*.safetensors", "*.json"])
# --- 2. Modelos para Wan2.2 e LoRA Lightning (Logística "Mix-and-Match") ---
if config.get('wan', {}).get('gpus_required', 0) > 0:
wan_config = config['wan']
main_model_path = f"/app/models/{wan_config['model_id']}"
opt_transformer_path = f"/app/models/{wan_config['optimized_transformer_id']}"
lora_dir = "/app/models/loras"
# Baixa os repositórios completos para garantir a estrutura correta dos arquivos
download_repo_snapshot(wan_config['model_id'], main_model_path, "Wan2.2 Base Components")
download_repo_snapshot(wan_config['optimized_transformer_id'], opt_transformer_path, "Wan2.2 Optimized Transformers")
lora_filename_only = os.path.basename(wan_config['lora_filename'])
download_repo_snapshot(wan_config['lora_repo'], lora_dir, "Wan2.2 LoRA Lightning", allow_patterns=f"*{lora_filename_only}")
# --- 3. Modelos para SeedVR ---
if config.get('seedvr', {}).get('gpus_required', 0) > 0:
seedvr_models_dir = "/app/ckpts"
download_repo_snapshot("batuhanince/seedvr_3b_fp16", seedvr_models_dir, "SeedVR Models (FP16)", allow_patterns=["*.safetensors", "*.pt"])
download_repo_snapshot("ByteDance-Seed/SeedVR2-3B", seedvr_models_dir, "SeedVR Embeddings", allow_patterns=["*.pt"])
# --- 4. Modelos para MMAudio ---
if config.get('mmaudio', {}).get('gpus_required', 0) > 0:
mmaudio_models_dir = "/app/MMAudio/ckpts"
download_repo_snapshot("hkchengrex/MMAudio-checkpoints", mmaudio_models_dir, "MMAudio Checkpoints")
logger.info("--- Verificação de modelos concluída com sucesso ---")
if __name__ == "__main__":
main()