Spaces:
Paused
Paused
| # download_models.py (v4.0 - Versão Definitiva Completa) | |
| import os | |
| import yaml | |
| import logging | |
| from huggingface_hub import snapshot_download | |
| # Configuração do log para ser claro e informativo | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s') | |
| logger = logging.getLogger("MODEL_LOGISTICS") | |
| def download_repo_snapshot(repo_id, local_dir, desc, allow_patterns=None): | |
| """ | |
| Baixa um snapshot de um repositório, verificando se ele já existe para evitar | |
| downloads repetidos. É a forma mais robusta de baixar modelos. | |
| """ | |
| os.makedirs(local_dir, exist_ok=True) | |
| # Um bom indicador de que o download foi concluído é a presença de um arquivo de configuração. | |
| # Isso evita downloads parciais em caso de reinicialização. | |
| completion_marker = os.path.join(local_dir, '.download_completed') | |
| if os.path.exists(completion_marker): | |
| logger.info(f"Modelos para '{desc}' parecem já existir e estão completos em: {local_dir}") | |
| return | |
| logger.info(f"Baixando modelos para '{desc}' de '{repo_id}' para '{local_dir}'...") | |
| try: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| allow_patterns=allow_patterns, | |
| ignore_patterns=["*.md", "*.txt", "*.gitattributes", "*onnx*", "*fp32*"], # Ignora arquivos desnecessários | |
| ) | |
| # Cria o marcador de conclusão | |
| with open(completion_marker, 'w') as f: | |
| f.write('done') | |
| logger.info(f"Download para '{desc}' concluído.") | |
| except Exception as e: | |
| logger.error(f"Falha CRÍTICA ao baixar o snapshot '{desc}'. Erro: {e}") | |
| raise | |
| def main(): | |
| """ | |
| Ponto de entrada para baixar todos os modelos de IA necessários, lendo as | |
| configurações do arquivo config.yaml. | |
| """ | |
| logger.info("--- Iniciando verificação e download de todos os modelos ---") | |
| try: | |
| with open("config.yaml", 'r', encoding='utf-8') as f: | |
| # Passamos o 'f' (o stream do arquivo) para a função safe_load | |
| config = yaml.safe_load(f).get('specialists', {}) | |
| if not config: | |
| logger.warning("Seção 'specialists' não encontrada no config.yaml. Nenhum modelo será baixado.") | |
| return | |
| except FileNotFoundError: | |
| logger.error("Arquivo config.yaml não encontrado! Não é possível determinar quais modelos baixar.") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Erro ao ler ou parsear o config.yaml: {e}") | |
| raise | |
| # --- 1. Modelos para LTX-Video --- | |
| if config.get('ltx', {}).get('gpus_required', 0) > 0: | |
| ltx_models_dir = "/app/LTX-Video/models_downloaded" | |
| download_repo_snapshot("Lightricks/LTX-Video", ltx_models_dir, "LTX Models", allow_patterns=["*.safetensors", "*.json"]) | |
| # --- 2. Modelos para Wan2.2 e LoRA Lightning (Logística "Mix-and-Match") --- | |
| if config.get('wan', {}).get('gpus_required', 0) > 0: | |
| wan_config = config['wan'] | |
| main_model_path = f"/app/models/{wan_config['model_id']}" | |
| opt_transformer_path = f"/app/models/{wan_config['optimized_transformer_id']}" | |
| lora_dir = "/app/models/loras" | |
| # Baixa os repositórios completos para garantir a estrutura correta dos arquivos | |
| download_repo_snapshot(wan_config['model_id'], main_model_path, "Wan2.2 Base Components") | |
| download_repo_snapshot(wan_config['optimized_transformer_id'], opt_transformer_path, "Wan2.2 Optimized Transformers") | |
| lora_filename_only = os.path.basename(wan_config['lora_filename']) | |
| download_repo_snapshot(wan_config['lora_repo'], lora_dir, "Wan2.2 LoRA Lightning", allow_patterns=f"*{lora_filename_only}") | |
| # --- 3. Modelos para SeedVR --- | |
| if config.get('seedvr', {}).get('gpus_required', 0) > 0: | |
| seedvr_models_dir = "/app/ckpts" | |
| download_repo_snapshot("batuhanince/seedvr_3b_fp16", seedvr_models_dir, "SeedVR Models (FP16)", allow_patterns=["*.safetensors", "*.pt"]) | |
| download_repo_snapshot("ByteDance-Seed/SeedVR2-3B", seedvr_models_dir, "SeedVR Embeddings", allow_patterns=["*.pt"]) | |
| # --- 4. Modelos para MMAudio --- | |
| if config.get('mmaudio', {}).get('gpus_required', 0) > 0: | |
| mmaudio_models_dir = "/app/MMAudio/ckpts" | |
| download_repo_snapshot("hkchengrex/MMAudio-checkpoints", mmaudio_models_dir, "MMAudio Checkpoints") | |
| logger.info("--- Verificação de modelos concluída com sucesso ---") | |
| if __name__ == "__main__": | |
| main() |