Aduc-sdr-2_5 / vincie.py
carlex3321's picture
Upload vincie.py
6e9236e verified
raw
history blame
13 kB
#!/usr/bin/env python3
"""
VincieService
- Ensures the upstream VINCIE repository is present.
- Fetches the minimal checkpoint files (dit.pth, vae.pth) via hf_hub_download into /app/ckpt/VINCIE-3B.
- Creates a compatibility symlink /app/VINCIE/ckpt/VINCIE-3B -> /app/ckpt/VINCIE-3B for repo-relative paths.
- Runs the official VINCIE main.py with Hydra/YACS overrides for both multi-turn and multi-concept generation.
- Optionally injects a minimal 'apex.normalization' shim when NVIDIA Apex is not available (to avoid import errors).
Upstream reference: https://github.com/ByteDance-Seed/VINCIE
Developed by [email protected]
https://github.com/carlex22
Version 1.0.0
"""
import os
import sys
import json
import subprocess
from pathlib import Path
from typing import List, Optional
from huggingface_hub import hf_hub_download
class VincieService:
"""
High-level service for preparing VINCIE runtime assets and invoking generation.
Responsibilities:
- Repository management: clone the official VINCIE repository when missing.
- Checkpoint management: download dit.pth and vae.pth from the VINCIE-3B checkpoint on the Hub.
- Path compatibility: ensure /app/VINCIE/ckpt/VINCIE-3B points to /app/ckpt/VINCIE-3B.
- Runners: execute main.py with generate.yaml overrides for multi-turn edits and multi-concept composition.
- Apex shim: provide a minimal fallback for apex.normalization if Apex isn’t installed.
Defaults assume the Docker/container layout used by the Space:
- Repository directory: /app/VINCIE
- Checkpoint directory: /app/ckpt/VINCIE-3B
- Output root: /app/outputs
"""
def __init__(
self,
repo_dir: str = "/app/VINCIE",
ckpt_dir: str = "/app/ckpt/VINCIE-3B",
python_bin: str = "python",
repo_id: str = "ByteDance-Seed/VINCIE-3B",
):
"""
Initialize the service with paths and runtime settings.
Args:
repo_dir: Filesystem location of the upstream VINCIE repository clone.
ckpt_dir: Filesystem location where dit.pth and vae.pth are stored.
python_bin: Python executable to invoke for main.py (e.g., 'python' or a full path).
repo_id: Hugging Face Hub repo id for the VINCIE-3B checkpoint.
Side-effects:
- Ensures the output root directory exists.
- Ensures the repo ckpt/ directory exists (for symlink placement).
"""
self.repo_dir = Path(repo_dir)
self.ckpt_dir = Path(ckpt_dir)
self.python = python_bin
self.repo_id = repo_id
# Canonical config and paths within the upstream repo
self.generate_yaml = self.repo_dir / "configs" / "generate.yaml"
self.assets_dir = self.repo_dir / "assets"
# Output root for generated media
self.output_root = Path("/app/outputs")
self.output_root.mkdir(parents=True, exist_ok=True)
# Ensure ckpt/ exists in the repo (symlink target lives here)
(self.repo_dir / "ckpt").mkdir(parents=True, exist_ok=True)
# ---------- Setup ----------
def ensure_repo(self, git_url: str = "https://github.com/ByteDance-Seed/VINCIE") -> None:
"""
Clone the official VINCIE repository when missing.
Args:
git_url: Source URL of the official VINCIE repo.
Raises:
subprocess.CalledProcessError on git clone failure.
"""
if not self.repo_dir.exists():
subprocess.run(["git", "clone", git_url, str(self.repo_dir)], check=True)
def ensure_model(self, hf_token: Optional[str] = None) -> None:
"""
Download the minimal VINCIE-3B checkpoint files if missing and create a repo-compatible symlink.
Files fetched from the Hub (repo_id):
- dit.pth
- vae.pth
The files are placed under self.ckpt_dir (default /app/ckpt/VINCIE-3B) and a symlink
/app/VINCIE/ckpt/VINCIE-3B -> /app/ckpt/VINCIE-3B is created to match upstream relative paths.
Args:
hf_token: Optional Hugging Face token; defaults to env HF_TOKEN or HUGGINGFACE_TOKEN.
Notes:
- Uses hf_hub_download with local_dir, so files are placed directly in the target directory.
- A basic size check (> 1MB) is used to decide whether to refetch a file.
"""
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
def _need(p: Path) -> bool:
try:
return not (p.exists() and p.stat().st_size > 1_000_000)
except FileNotFoundError:
return True
for fname in ["dit.pth", "vae.pth"]:
dst = self.ckpt_dir / fname
if _need(dst):
print(f"Downloading {fname} from {self.repo_id} ...")
hf_hub_download(
repo_id=self.repo_id,
filename=fname,
local_dir=str(self.ckpt_dir),
local_dir_use_symlinks=False,
token=token,
force_download=False,
local_files_only=False,
)
# Compatibility symlink for repo-relative ckpt paths
link = self.repo_dir / "ckpt" / "VINCIE-3B"
try:
if link.is_symlink() or link.exists():
try:
link.unlink()
except IsADirectoryError:
# If a directory sits at that path, we leave it as-is or replace as needed
pass
if not link.exists():
link.symlink_to(self.ckpt_dir, target_is_directory=True)
except Exception as e:
print("Warning: failed to create checkpoint symlink:", e)
def ensure_apex(self, enable_shim: bool = True) -> None:
"""
Ensure apex.normalization importability.
If NVIDIA Apex is not installed, and enable_shim=True, inject a minimal shim implementing:
- FusedRMSNorm via torch.nn.RMSNorm
- FusedLayerNorm via torch.nn.LayerNorm
This prevents import-time failures in code that references apex.normalization while
sacrificing any Apex-specific kernel benefits.
Args:
enable_shim: Whether to install a local shim when 'apex.normalization' is missing.
"""
try:
import importlib
importlib.import_module("apex.normalization")
return
except Exception:
if not enable_shim:
return
shim_root = Path("/app/shims")
apex_pkg = shim_root / "apex"
apex_pkg.mkdir(parents=True, exist_ok=True)
(apex_pkg / "__init__.py").write_text("from .normalization import *\n")
(apex_pkg / "normalization.py").write_text(
"import torch\n"
"import torch.nn as nn\n"
"\n"
"class FusedRMSNorm(nn.Module):\n"
" def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):\n"
" super().__init__()\n"
" self.mod = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n"
" def forward(self, x):\n"
" return self.mod(x)\n"
"\n"
"class FusedLayerNorm(nn.Module):\n"
" def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n"
" super().__init__()\n"
" self.mod = nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n"
" def forward(self, x):\n"
" return self.mod(x)\n"
)
# Make shim importable in this process and child processes
sys.path.insert(0, str(shim_root))
os.environ["PYTHONPATH"] = f"{str(shim_root)}:{os.environ.get('PYTHONPATH','')}"
def ready(self) -> bool:
"""
Quick readiness probe for UI:
- The repository and generate.yaml exist.
- Minimal checkpoint files (dit.pth, vae.pth) exist.
Returns:
True if the environment is ready to run generation tasks; otherwise False.
"""
have_repo = self.repo_dir.exists() and self.generate_yaml.exists()
dit_ok = (self.ckpt_dir / "dit.pth").exists()
vae_ok = (self.ckpt_dir / "vae.pth").exists()
return bool(have_repo and dit_ok and vae_ok)
# ---------- Core runner ----------
def _run_vincie(self, overrides: List[str], work_output: Path) -> None:
"""
Invoke VINCIE's main.py with Hydra/YACS overrides inside the upstream repo directory.
Args:
overrides: A list of CLI overrides (e.g., generation.positive_prompt.*).
work_output: Output directory path for generated assets.
Raises:
subprocess.CalledProcessError if the underlying process fails.
"""
work_output.mkdir(parents=True, exist_ok=True)
cmd = [
self.python,
"main.py",
str(self.generate_yaml),
*overrides,
f"generation.output.dir={str(work_output)}",
]
env = os.environ.copy()
subprocess.run(cmd, cwd=self.repo_dir, check=True, env=env)
# ---------- Multi-turn editing ----------
def multi_turn_edit(
self,
input_image: str,
turns: List[str],
out_dir_name: Optional[str] = None,
) -> Path:
"""
Run the official 'multi-turn' generation equivalent.
This wraps generate.yaml using overrides:
- generation.positive_prompt.image_path = [ "<input-image-path>" ]
- generation.positive_prompt.prompts = [ "<turn1>", "<turn2>", ... ]
Args:
input_image: Path to the single input image on disk.
turns: A list of editing instructions, in the order they should be applied.
out_dir_name: Optional name for the output subdirectory; auto-generated if omitted.
Returns:
Path to the output directory containing images and, if produced, a video.
"""
out_dir = self.output_root / (out_dir_name or f"multi_turn_{self._slug(input_image)}")
image_json = json.dumps([str(input_image)])
prompts_json = json.dumps(turns)
overrides = [
f"generation.positive_prompt.image_path={image_json}",
f"generation.positive_prompt.prompts={prompts_json}",
f"ckpt.path={str(self.ckpt_dir)}",
]
self._run_vincie(overrides, out_dir)
return out_dir
# ---------- Multi-concept composition ----------
def multi_concept_compose(
self,
concept_images: List[str],
concept_prompts: List[str],
final_prompt: str,
out_dir_name: Optional[str] = None,
) -> Path:
"""
Run the 'multi-concept' composition pipeline.
The service forms:
- generation.positive_prompt.image_path = [ <concept-img-1>, ..., <concept-img-N> ]
- generation.positive_prompt.prompts = [ <desc-1>, ..., <desc-N>, <final-prompt> ]
- generation.pad_img_placehoder = False (preserves input shapes)
- ckpt.path = /app/ckpt/VINCIE-3B (by default)
Args:
concept_images: Paths to concept images on disk.
concept_prompts: Per-image descriptions in the same order as concept_images.
final_prompt: Composition prompt appended after all per-image descriptions.
out_dir_name: Optional name for the output subdirectory; defaults to 'multi_concept'.
Returns:
Path to the output directory containing images and, if produced, a video.
"""
out_dir = self.output_root / (out_dir_name or "multi_concept")
imgs_json = json.dumps([str(p) for p in concept_images])
prompts_all = concept_prompts + [final_prompt]
prompts_json = json.dumps(prompts_all)
overrides = [
f"generation.positive_prompt.image_path={imgs_json}",
f"generation.positive_prompt.prompts={prompts_json}",
"generation.pad_img_placehoder=False",
f"ckpt.path={str(self.ckpt_dir)}",
]
self._run_vincie(overrides, out_dir)
return out_dir
# ---------- Helpers ----------
@staticmethod
def _slug(path_or_text: str) -> str:
"""
Produce a filesystem-friendly short name (max 64 chars) from a path or text.
Args:
path_or_text: An input path or arbitrary string.
Returns:
A sanitized string consisting of [A-Za-z0-9._-] with non-matching chars converted to underscores.
"""
p = Path(path_or_text)
base = p.stem if p.exists() else str(path_or_text)
keep = "".join(c if c.isalnum() or c in "-_." else "_" for c in str(base))
return keep[:64]