#!/usr/bin/env python3 """ SeedVR Server (CLI torchrun) - Garante repositório SeedVR e checkpoints baixados via snapshot_download. - Cria symlink SeedVR/ckpts/SeedVR2-3B -> CKPTS_ROOT. - Executa projects/inference_seedvr2_3b.py com torchrun e NUM_GPUS. - API: run_inference(file_path, seed, res_h, res_w, sp_size) -> (video_out, image_out, out_dir). """ import os import shutil import subprocess from pathlib import Path from typing import Optional, Tuple, List import time import mimetypes from huggingface_hub import snapshot_download # requerido no container class SeedVRServer: def __init__( self, *, seedvr_root: Optional[str] = None, ckpts_root: Optional[str] = None, output_root: Optional[str] = None, input_root: Optional[str] = None, repo_url: Optional[str] = None, repo_id: Optional[str] = None, num_gpus: Optional[int] = None, ): # Paths e envs self.SEEDVR_ROOT = Path(seedvr_root or os.getenv("SEEDVR_ROOT", "/app/SeedVR")) self.CKPTS_ROOT = Path(ckpts_root or os.getenv("CKPTS_ROOT", "/app/ckpts/SeedVR2-3B")) self.OUTPUT_ROOT = Path(output_root or os.getenv("OUTPUT_ROOT", "/app/outputs")) self.INPUT_ROOT = Path(input_root or os.getenv("INPUT_ROOT", "/app/inputs")) self.REPO_URL = repo_url or os.getenv("SEEDVR_GIT_URL", "https://github.com/ByteDance-Seed/SeedVR.git") self.REPO_ID = repo_id or os.getenv("SEEDVR_REPO_ID", "ByteDance-Seed/SeedVR2-3B") self.NUM_GPUS = int(num_gpus or os.getenv("NUM_GPUS", "8")) self.HF_HOME = Path(os.getenv("HF_HOME", "/data/.cache/huggingface")) self.HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or None # Diretórios necessários for p in [self.SEEDVR_ROOT, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME]: p.mkdir(parents=True, exist_ok=True) # Bootstrap direto self._ensure_repo() self._ensure_model() self._ensure_ckpt_symlink() # ---------- Preparação ---------- def _ensure_repo(self) -> None: if not (self.SEEDVR_ROOT / ".git").exists(): print(f"[seed_server] cloning repo into {self.SEEDVR_ROOT}") subprocess.run(["git", "clone", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True) else: print(f"[seed_server] repo present at {self.SEEDVR_ROOT}") def _ensure_model(self) -> None: print(f"[seed_server] downloading model {self.REPO_ID} into {self.CKPTS_ROOT} (snapshot_download)") self.CKPTS_ROOT.mkdir(parents=True, exist_ok=True) snapshot_download( repo_id=self.REPO_ID, cache_dir=str(self.HF_HOME), local_dir=str(self.CKPTS_ROOT), local_dir_use_symlinks=False, resume_download=True, allow_patterns=["*.json", "*.safetensors", "*.pth", "*.bin", "*.py", "*.md", "*.txt"], token=self.HF_TOKEN, ) print("[seed_server] model ready") def _ensure_ckpt_symlink(self) -> None: ckpts_repo_dir = self.SEEDVR_ROOT / "ckpts" ckpts_repo_dir.mkdir(parents=True, exist_ok=True) link = ckpts_repo_dir / "SeedVR2-3B" try: if link.is_symlink(): try: if link.resolve() != self.CKPTS_ROOT: link.unlink() except Exception: link.unlink(missing_ok=True) if not link.exists(): link.symlink_to(self.CKPTS_ROOT, target_is_directory=True) print(f"[seed_server] symlink ok: {link} -> {self.CKPTS_ROOT}") except Exception as e: print("[seed_server] warn: ckpt symlink failed:", e) # ---------- Util ---------- @staticmethod def _is_video(path: str) -> bool: mime, _ = mimetypes.guess_type(path) return (mime or "").startswith("video") or str(path).lower().endswith(".mp4") @staticmethod def _is_image(path: str) -> bool: mime, _ = mimetypes.guess_type(path) if mime and mime.startswith("image"): return True return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".webp")) def _prepare_job(self, input_file: str) -> Tuple[Path, Path]: ts = int(time.time()) job_dir = self.INPUT_ROOT / f"job_{ts}" out_dir = self.OUTPUT_ROOT / f"run_{ts}" job_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True) ##### shutil.copy2(input_file, job_dir / Path(input_file).name) return out_dir, out_dir # ---------- Execução ---------- def run_inference( self, file_path: str, *, seed: int = 42, res_h: int = 720, res_w: int = 1280, sp_size: int = 4, extra_args: Optional[List[str]] = None, ) -> Tuple[Optional[str], Optional[str], Path]: """ Executa inferência via torchrun com NUM_GPUS: - file_path: vídeo .mp4 ou imagem .png/.jpg/.jpeg/.webp - Retorna (video_out, image_out, out_dir). Um dos dois primeiros será não-nulo. """ if not Path(file_path).exists(): raise FileNotFoundError(f"input not found: {file_path}") script = self.SEEDVR_ROOT / "projects" / "inference_seedvr2_3b.py" if not script.exists(): raise FileNotFoundError(f"inference script not found: {script}") job_dir, out_dir = self._prepare_job(file_path) self._ensure_ckpt_symlink() out_dir.mkdir(parents=True, exist_ok=True) os.chmod(out_dir, 777) job_dir.mkdir(parents=True, exist_ok=True) os.chmod(job_dir, 777) cmd = [ "torchrun", f"--nproc-per-node={self.NUM_GPUS}", str(script), "--video_path", str(job_dir), "--output_dir", str(out_dir), "--seed", str(seed), "--res_h", str(res_h), "--res_w", str(res_w), "--sp_size", str(sp_size), ] if extra_args: cmd.extend(extra_args) env = os.environ.copy() env.setdefault("HF_HOME", str(self.HF_HOME)) env.setdefault("NCCL_P2P_LEVEL", os.getenv("NCCL_P2P_LEVEL", "NVL")) #env.setdefault("NCCL_ASYNC_ERROR_HANDLING", os.getenv("NCCL_ASYNC_ERROR_HANDLING", "1")) env.setdefault("OMP_NUM_THREADS", os.getenv("OMP_NUM_THREADS", "8")) print("[seed_server] running:", " ".join(cmd)) try: subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=env) except subprocess.CalledProcessError as e: print("[seed_server] torchrun error:", e) return None, None, out_dir # Buscar artefatos videos = sorted(out_dir.rglob("*.mp4"), key=lambda p: p.stat().st_mtime) # Cobrir formatos comuns caso upstream mude if not videos: videos = sorted([*out_dir.rglob("*.mov"), *out_dir.rglob("*.avi")], key=lambda p: p.stat().st_mtime) images = sorted( [*out_dir.rglob("*.png"), *out_dir.rglob("*.jpg"), *out_dir.rglob("*.jpeg"), *out_dir.rglob("*.webp")], key=lambda p: p.stat().st_mtime ) video_out = str(videos[-1]) if videos else None image_out = str(images[-1]) if images else None return video_out, image_out, out_dir