Test / seed_server.py
Carlos s
Upload seed_server.py
1300afb verified
raw
history blame
7.45 kB
#!/usr/bin/env python3
"""
SeedVR Server (CLI torchrun)
- Garante repositório SeedVR e checkpoints baixados via snapshot_download.
- Cria symlink SeedVR/ckpts/SeedVR2-3B -> CKPTS_ROOT.
- Executa projects/inference_seedvr2_3b.py com torchrun e NUM_GPUS.
- API: run_inference(file_path, seed, res_h, res_w, sp_size) -> (video_out, image_out, out_dir).
"""
import os
import shutil
import subprocess
from pathlib import Path
from typing import Optional, Tuple, List
import time
import mimetypes
from huggingface_hub import snapshot_download # requerido no container
class SeedVRServer:
def __init__(
self,
*,
seedvr_root: Optional[str] = None,
ckpts_root: Optional[str] = None,
output_root: Optional[str] = None,
input_root: Optional[str] = None,
repo_url: Optional[str] = None,
repo_id: Optional[str] = None,
num_gpus: Optional[int] = None,
):
# Paths e envs
self.SEEDVR_ROOT = Path(seedvr_root or os.getenv("SEEDVR_ROOT", "/app/SeedVR"))
self.CKPTS_ROOT = Path(ckpts_root or os.getenv("CKPTS_ROOT", "/app/ckpts/SeedVR2-3B"))
self.OUTPUT_ROOT = Path(output_root or os.getenv("OUTPUT_ROOT", "/app/outputs"))
self.INPUT_ROOT = Path(input_root or os.getenv("INPUT_ROOT", "/app/inputs"))
self.REPO_URL = repo_url or os.getenv("SEEDVR_GIT_URL", "https://github.com/ByteDance-Seed/SeedVR.git")
self.REPO_ID = repo_id or os.getenv("SEEDVR_REPO_ID", "ByteDance-Seed/SeedVR2-3B")
self.NUM_GPUS = int(num_gpus or os.getenv("NUM_GPUS", "8"))
self.HF_HOME = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
self.HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or None
# Diretórios necessários
for p in [self.SEEDVR_ROOT, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME]:
p.mkdir(parents=True, exist_ok=True)
# Bootstrap direto
self._ensure_repo()
self._ensure_model()
self._ensure_ckpt_symlink()
# ---------- Preparação ----------
def _ensure_repo(self) -> None:
if not (self.SEEDVR_ROOT / ".git").exists():
print(f"[seed_server] cloning repo into {self.SEEDVR_ROOT}")
subprocess.run(["git", "clone", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
else:
print(f"[seed_server] repo present at {self.SEEDVR_ROOT}")
def _ensure_model(self) -> None:
print(f"[seed_server] downloading model {self.REPO_ID} into {self.CKPTS_ROOT} (snapshot_download)")
self.CKPTS_ROOT.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id=self.REPO_ID,
cache_dir=str(self.HF_HOME),
local_dir=str(self.CKPTS_ROOT),
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=["*.json", "*.safetensors", "*.pth", "*.bin", "*.py", "*.md", "*.txt"],
token=self.HF_TOKEN,
)
print("[seed_server] model ready")
def _ensure_ckpt_symlink(self) -> None:
ckpts_repo_dir = self.SEEDVR_ROOT / "ckpts"
ckpts_repo_dir.mkdir(parents=True, exist_ok=True)
link = ckpts_repo_dir / "SeedVR2-3B"
try:
if link.is_symlink():
try:
if link.resolve() != self.CKPTS_ROOT:
link.unlink()
except Exception:
link.unlink(missing_ok=True)
if not link.exists():
link.symlink_to(self.CKPTS_ROOT, target_is_directory=True)
print(f"[seed_server] symlink ok: {link} -> {self.CKPTS_ROOT}")
except Exception as e:
print("[seed_server] warn: ckpt symlink failed:", e)
# ---------- Util ----------
@staticmethod
def _is_video(path: str) -> bool:
mime, _ = mimetypes.guess_type(path)
return (mime or "").startswith("video") or str(path).lower().endswith(".mp4")
@staticmethod
def _is_image(path: str) -> bool:
mime, _ = mimetypes.guess_type(path)
if mime and mime.startswith("image"):
return True
return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".webp"))
def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
ts = int(time.time())
job_dir = self.INPUT_ROOT / f"job_{ts}"
out_dir = self.OUTPUT_ROOT / f"run_{ts}"
job_dir.mkdir(parents=True, exist_ok=True)
out_dir.mkdir(parents=True, exist_ok=True)
#####
shutil.copy2(input_file, job_dir / Path(input_file).name)
return out_dir, out_dir
# ---------- Execução ----------
def run_inference(
self,
file_path: str,
*,
seed: int = 42,
res_h: int = 720,
res_w: int = 1280,
sp_size: int = 4,
extra_args: Optional[List[str]] = None,
) -> Tuple[Optional[str], Optional[str], Path]:
"""
Executa inferência via torchrun com NUM_GPUS:
- file_path: vídeo .mp4 ou imagem .png/.jpg/.jpeg/.webp
- Retorna (video_out, image_out, out_dir). Um dos dois primeiros será não-nulo.
"""
if not Path(file_path).exists():
raise FileNotFoundError(f"input not found: {file_path}")
script = self.SEEDVR_ROOT / "projects" / "inference_seedvr2_3b.py"
if not script.exists():
raise FileNotFoundError(f"inference script not found: {script}")
job_dir, out_dir = self._prepare_job(file_path)
self._ensure_ckpt_symlink()
out_dir.mkdir(parents=True, exist_ok=True)
os.chmod(out_dir, 777)
job_dir.mkdir(parents=True, exist_ok=True)
os.chmod(job_dir, 777)
cmd = [
"torchrun",
f"--nproc-per-node={self.NUM_GPUS}",
str(script),
"--video_path", str(job_dir),
"--output_dir", str(out_dir),
"--seed", str(seed),
"--res_h", str(res_h),
"--res_w", str(res_w),
"--sp_size", str(sp_size),
]
if extra_args:
cmd.extend(extra_args)
env = os.environ.copy()
env.setdefault("HF_HOME", str(self.HF_HOME))
env.setdefault("NCCL_P2P_LEVEL", os.getenv("NCCL_P2P_LEVEL", "NVL"))
#env.setdefault("NCCL_ASYNC_ERROR_HANDLING", os.getenv("NCCL_ASYNC_ERROR_HANDLING", "1"))
env.setdefault("OMP_NUM_THREADS", os.getenv("OMP_NUM_THREADS", "8"))
print("[seed_server] running:", " ".join(cmd))
try:
subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=env)
except subprocess.CalledProcessError as e:
print("[seed_server] torchrun error:", e)
return None, None, out_dir
# Buscar artefatos
videos = sorted(out_dir.rglob("*.mp4"), key=lambda p: p.stat().st_mtime)
# Cobrir formatos comuns caso upstream mude
if not videos:
videos = sorted([*out_dir.rglob("*.mov"), *out_dir.rglob("*.avi")], key=lambda p: p.stat().st_mtime)
images = sorted(
[*out_dir.rglob("*.png"), *out_dir.rglob("*.jpg"), *out_dir.rglob("*.jpeg"), *out_dir.rglob("*.webp")],
key=lambda p: p.stat().st_mtime
)
video_out = str(videos[-1]) if videos else None
image_out = str(images[-1]) if images else None
return video_out, image_out, out_dir