# download_models.py (v4.0 - Versão Definitiva Completa) import os import yaml import logging from huggingface_hub import snapshot_download # Configuração do log para ser claro e informativo logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s') logger = logging.getLogger("MODEL_LOGISTICS") def download_repo_snapshot(repo_id, local_dir, desc, allow_patterns=None): """ Baixa um snapshot de um repositório, verificando se ele já existe para evitar downloads repetidos. É a forma mais robusta de baixar modelos. """ os.makedirs(local_dir, exist_ok=True) # Um bom indicador de que o download foi concluído é a presença de um arquivo de configuração. # Isso evita downloads parciais em caso de reinicialização. completion_marker = os.path.join(local_dir, '.download_completed') if os.path.exists(completion_marker): logger.info(f"Modelos para '{desc}' parecem já existir e estão completos em: {local_dir}") return logger.info(f"Baixando modelos para '{desc}' de '{repo_id}' para '{local_dir}'...") try: snapshot_download( repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False, resume_download=True, allow_patterns=allow_patterns, ignore_patterns=["*.md", "*.txt", "*.gitattributes", "*onnx*", "*fp32*"], # Ignora arquivos desnecessários ) # Cria o marcador de conclusão with open(completion_marker, 'w') as f: f.write('done') logger.info(f"Download para '{desc}' concluído.") except Exception as e: logger.error(f"Falha CRÍTICA ao baixar o snapshot '{desc}'. Erro: {e}") raise def main(): """ Ponto de entrada para baixar todos os modelos de IA necessários, lendo as configurações do arquivo config.yaml. """ logger.info("--- Iniciando verificação e download de todos os modelos ---") try: with open("config.yaml", 'r', encoding='utf-8') as f: # Passamos o 'f' (o stream do arquivo) para a função safe_load config = yaml.safe_load(f).get('specialists', {}) if not config: logger.warning("Seção 'specialists' não encontrada no config.yaml. Nenhum modelo será baixado.") return except FileNotFoundError: logger.error("Arquivo config.yaml não encontrado! Não é possível determinar quais modelos baixar.") raise except Exception as e: logger.error(f"Erro ao ler ou parsear o config.yaml: {e}") raise # --- 1. Modelos para LTX-Video --- if config.get('ltx', {}).get('gpus_required', 0) > 0: ltx_models_dir = "/app/LTX-Video/models_downloaded" download_repo_snapshot("Lightricks/LTX-Video", ltx_models_dir, "LTX Models", allow_patterns=["*.safetensors", "*.json"]) # --- 2. Modelos para Wan2.2 e LoRA Lightning (Logística "Mix-and-Match") --- if config.get('wan', {}).get('gpus_required', 0) > 0: wan_config = config['wan'] main_model_path = f"/app/models/{wan_config['model_id']}" opt_transformer_path = f"/app/models/{wan_config['optimized_transformer_id']}" lora_dir = "/app/models/loras" # Baixa os repositórios completos para garantir a estrutura correta dos arquivos download_repo_snapshot(wan_config['model_id'], main_model_path, "Wan2.2 Base Components") download_repo_snapshot(wan_config['optimized_transformer_id'], opt_transformer_path, "Wan2.2 Optimized Transformers") lora_filename_only = os.path.basename(wan_config['lora_filename']) download_repo_snapshot(wan_config['lora_repo'], lora_dir, "Wan2.2 LoRA Lightning", allow_patterns=f"*{lora_filename_only}") # --- 3. Modelos para SeedVR --- if config.get('seedvr', {}).get('gpus_required', 0) > 0: seedvr_models_dir = "/app/ckpts" download_repo_snapshot("batuhanince/seedvr_3b_fp16", seedvr_models_dir, "SeedVR Models (FP16)", allow_patterns=["*.safetensors", "*.pt"]) download_repo_snapshot("ByteDance-Seed/SeedVR2-3B", seedvr_models_dir, "SeedVR Embeddings", allow_patterns=["*.pt"]) # --- 4. Modelos para MMAudio --- if config.get('mmaudio', {}).get('gpus_required', 0) > 0: mmaudio_models_dir = "/app/MMAudio/ckpts" download_repo_snapshot("hkchengrex/MMAudio-checkpoints", mmaudio_models_dir, "MMAudio Checkpoints") logger.info("--- Verificação de modelos concluída com sucesso ---") if __name__ == "__main__": main()