|
|
|
|
|
""" |
|
|
SeedVR Server (CLI torchrun) |
|
|
|
|
|
- Garante repositório SeedVR e checkpoints baixados via snapshot_download. |
|
|
- Cria symlink SeedVR/ckpts/SeedVR2-3B -> CKPTS_ROOT. |
|
|
- Executa projects/inference_seedvr2_3b.py com torchrun e NUM_GPUS. |
|
|
- API: run_inference(file_path, seed, res_h, res_w, sp_size) -> (video_out, image_out, out_dir). |
|
|
""" |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, List |
|
|
import time |
|
|
import mimetypes |
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
class SeedVRServer: |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
seedvr_root: Optional[str] = None, |
|
|
ckpts_root: Optional[str] = None, |
|
|
output_root: Optional[str] = None, |
|
|
input_root: Optional[str] = None, |
|
|
repo_url: Optional[str] = None, |
|
|
repo_id: Optional[str] = None, |
|
|
num_gpus: Optional[int] = None, |
|
|
): |
|
|
|
|
|
self.SEEDVR_ROOT = Path(seedvr_root or os.getenv("SEEDVR_ROOT", "/app/SeedVR")) |
|
|
self.CKPTS_ROOT = Path(ckpts_root or os.getenv("CKPTS_ROOT", "/app/ckpts/SeedVR2-3B")) |
|
|
self.OUTPUT_ROOT = Path(output_root or os.getenv("OUTPUT_ROOT", "/app/outputs")) |
|
|
self.INPUT_ROOT = Path(input_root or os.getenv("INPUT_ROOT", "/app/inputs")) |
|
|
self.REPO_URL = repo_url or os.getenv("SEEDVR_GIT_URL", "https://github.com/ByteDance-Seed/SeedVR.git") |
|
|
self.REPO_ID = repo_id or os.getenv("SEEDVR_REPO_ID", "ByteDance-Seed/SeedVR2-3B") |
|
|
self.NUM_GPUS = int(num_gpus or os.getenv("NUM_GPUS", "8")) |
|
|
self.HF_HOME = Path(os.getenv("HF_HOME", "/data/.cache/huggingface")) |
|
|
self.HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or None |
|
|
|
|
|
|
|
|
for p in [self.SEEDVR_ROOT, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME]: |
|
|
p.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self._ensure_repo() |
|
|
self._ensure_model() |
|
|
self._ensure_ckpt_symlink() |
|
|
|
|
|
|
|
|
def _ensure_repo(self) -> None: |
|
|
if not (self.SEEDVR_ROOT / ".git").exists(): |
|
|
print(f"[seed_server] cloning repo into {self.SEEDVR_ROOT}") |
|
|
subprocess.run(["git", "clone", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True) |
|
|
else: |
|
|
print(f"[seed_server] repo present at {self.SEEDVR_ROOT}") |
|
|
|
|
|
def _ensure_model(self) -> None: |
|
|
print(f"[seed_server] downloading model {self.REPO_ID} into {self.CKPTS_ROOT} (snapshot_download)") |
|
|
self.CKPTS_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
snapshot_download( |
|
|
repo_id=self.REPO_ID, |
|
|
cache_dir=str(self.HF_HOME), |
|
|
local_dir=str(self.CKPTS_ROOT), |
|
|
local_dir_use_symlinks=False, |
|
|
resume_download=True, |
|
|
allow_patterns=["*.json", "*.safetensors", "*.pth", "*.bin", "*.py", "*.md", "*.txt"], |
|
|
token=self.HF_TOKEN, |
|
|
) |
|
|
print("[seed_server] model ready") |
|
|
|
|
|
def _ensure_ckpt_symlink(self) -> None: |
|
|
ckpts_repo_dir = self.SEEDVR_ROOT / "ckpts" |
|
|
ckpts_repo_dir.mkdir(parents=True, exist_ok=True) |
|
|
link = ckpts_repo_dir / "SeedVR2-3B" |
|
|
try: |
|
|
if link.is_symlink(): |
|
|
try: |
|
|
if link.resolve() != self.CKPTS_ROOT: |
|
|
link.unlink() |
|
|
except Exception: |
|
|
link.unlink(missing_ok=True) |
|
|
if not link.exists(): |
|
|
link.symlink_to(self.CKPTS_ROOT, target_is_directory=True) |
|
|
print(f"[seed_server] symlink ok: {link} -> {self.CKPTS_ROOT}") |
|
|
except Exception as e: |
|
|
print("[seed_server] warn: ckpt symlink failed:", e) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _is_video(path: str) -> bool: |
|
|
mime, _ = mimetypes.guess_type(path) |
|
|
return (mime or "").startswith("video") or str(path).lower().endswith(".mp4") |
|
|
|
|
|
@staticmethod |
|
|
def _is_image(path: str) -> bool: |
|
|
mime, _ = mimetypes.guess_type(path) |
|
|
if mime and mime.startswith("image"): |
|
|
return True |
|
|
return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".webp")) |
|
|
|
|
|
def _prepare_job(self, input_file: str) -> Tuple[Path, Path]: |
|
|
ts = int(time.time()) |
|
|
job_dir = self.INPUT_ROOT / f"job_{ts}" |
|
|
out_dir = self.OUTPUT_ROOT / f"run_{ts}" |
|
|
job_dir.mkdir(parents=True, exist_ok=True) |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
shutil.copy2(input_file, job_dir / Path(input_file).name) |
|
|
return out_dir, out_dir |
|
|
|
|
|
|
|
|
def run_inference( |
|
|
self, |
|
|
file_path: str, |
|
|
*, |
|
|
seed: int = 42, |
|
|
res_h: int = 720, |
|
|
res_w: int = 1280, |
|
|
sp_size: int = 4, |
|
|
extra_args: Optional[List[str]] = None, |
|
|
) -> Tuple[Optional[str], Optional[str], Path]: |
|
|
""" |
|
|
Executa inferência via torchrun com NUM_GPUS: |
|
|
- file_path: vídeo .mp4 ou imagem .png/.jpg/.jpeg/.webp |
|
|
- Retorna (video_out, image_out, out_dir). Um dos dois primeiros será não-nulo. |
|
|
""" |
|
|
if not Path(file_path).exists(): |
|
|
raise FileNotFoundError(f"input not found: {file_path}") |
|
|
|
|
|
script = self.SEEDVR_ROOT / "projects" / "inference_seedvr2_3b.py" |
|
|
if not script.exists(): |
|
|
raise FileNotFoundError(f"inference script not found: {script}") |
|
|
|
|
|
job_dir, out_dir = self._prepare_job(file_path) |
|
|
self._ensure_ckpt_symlink() |
|
|
|
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
os.chmod(out_dir, 777) |
|
|
|
|
|
job_dir.mkdir(parents=True, exist_ok=True) |
|
|
os.chmod(job_dir, 777) |
|
|
|
|
|
cmd = [ |
|
|
"torchrun", |
|
|
f"--nproc-per-node={self.NUM_GPUS}", |
|
|
str(script), |
|
|
"--video_path", str(job_dir), |
|
|
"--output_dir", str(out_dir), |
|
|
"--seed", str(seed), |
|
|
"--res_h", str(res_h), |
|
|
"--res_w", str(res_w), |
|
|
"--sp_size", str(sp_size), |
|
|
] |
|
|
if extra_args: |
|
|
cmd.extend(extra_args) |
|
|
|
|
|
env = os.environ.copy() |
|
|
env.setdefault("HF_HOME", str(self.HF_HOME)) |
|
|
env.setdefault("NCCL_P2P_LEVEL", os.getenv("NCCL_P2P_LEVEL", "NVL")) |
|
|
|
|
|
env.setdefault("OMP_NUM_THREADS", os.getenv("OMP_NUM_THREADS", "8")) |
|
|
|
|
|
print("[seed_server] running:", " ".join(cmd)) |
|
|
try: |
|
|
subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=env) |
|
|
except subprocess.CalledProcessError as e: |
|
|
print("[seed_server] torchrun error:", e) |
|
|
return None, None, out_dir |
|
|
|
|
|
|
|
|
videos = sorted(out_dir.rglob("*.mp4"), key=lambda p: p.stat().st_mtime) |
|
|
|
|
|
if not videos: |
|
|
videos = sorted([*out_dir.rglob("*.mov"), *out_dir.rglob("*.avi")], key=lambda p: p.stat().st_mtime) |
|
|
images = sorted( |
|
|
[*out_dir.rglob("*.png"), *out_dir.rglob("*.jpg"), *out_dir.rglob("*.jpeg"), *out_dir.rglob("*.webp")], |
|
|
key=lambda p: p.stat().st_mtime |
|
|
) |
|
|
|
|
|
video_out = str(videos[-1]) if videos else None |
|
|
image_out = str(images[-1]) if images else None |
|
|
return video_out, image_out, out_dir |
|
|
|