Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| import os, sys, json, subprocess | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from time import time, sleep | |
| from huggingface_hub import hf_hub_download | |
| class VincieService: | |
| def __init__( | |
| self, | |
| repo_dir: str = "/app/VINCIE", | |
| ckpt_dir: str = "/app/ckpt/VINCIE-3B", | |
| python_bin: str = "python3", | |
| repo_id: str = "ByteDance-Seed/VINCIE-3B", | |
| ): | |
| self.repo_dir = Path(repo_dir) | |
| self.ckpt_dir = Path(ckpt_dir) | |
| self.python = python_bin | |
| self.repo_id = repo_id | |
| self.generate_yaml = self.repo_dir / "configs" / "generate.yaml" | |
| self.output_root = Path("/app/outputs") | |
| self.output_root.mkdir(parents=True, exist_ok=True) | |
| (self.repo_dir / "ckpt").mkdir(parents=True, exist_ok=True) | |
| def ensure_repo(self, git_url: str = "https://github.com/ByteDance-Seed/VINCIE") -> None: | |
| if not self.repo_dir.exists(): | |
| subprocess.run(["git", "clone", "--depth", "1", git_url, str(self.repo_dir)], check=True) | |
| def ensure_model(self, hf_token: Optional[str] = None) -> None: | |
| self.ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
| def _need(p: Path) -> bool: | |
| try: | |
| return not (p.exists() and p.stat().st_size > 1_000_000) | |
| except FileNotFoundError: | |
| return True | |
| for fname in ["dit.pth", "vae.pth"]: | |
| dst = self.ckpt_dir / fname | |
| if _need(dst): | |
| print(f"[vince] downloading {fname} from {self.repo_id} ...") | |
| hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=fname, | |
| local_dir=str(self.ckpt_dir), | |
| token=token, | |
| force_download=False, | |
| local_files_only=False, | |
| ) | |
| link = self.repo_dir / "ckpt" / "VINCIE-3B" | |
| try: | |
| if link.is_symlink() or link.exists(): | |
| try: | |
| link.unlink() | |
| except IsADirectoryError: | |
| pass | |
| if not link.exists(): | |
| link.symlink_to(self.ckpt_dir, target_is_directory=True) | |
| except Exception as e: | |
| print("[vince] symlink warning:", e) | |
| def ready(self) -> bool: | |
| have_repo = self.repo_dir.exists() and self.generate_yaml.exists() | |
| dit_ok = (self.ckpt_dir / "dit.pth").exists() | |
| vae_ok = (self.ckpt_dir / "vae.pth").exists() | |
| return bool(have_repo and dit_ok and vae_ok) | |
| def _wait_until_outputs(self, out_dir: Path, timeout_s: int = 300) -> None: | |
| exts = (".png", ".jpg", ".jpeg", ".gif", ".mp4") | |
| deadline = time() + timeout_s | |
| while time() < deadline: | |
| if any(p.is_file() and p.suffix.lower() in exts for p in out_dir.rglob("*")): | |
| print(f"[vince] outputs detected in {out_dir}") | |
| return | |
| sleep(1) | |
| print(f"[vince] warning: no outputs detected in {out_dir} within {timeout_s}s") | |
| def _run_vincie(self, overrides: List[str], work_output: Path, wait_outputs: bool = True) -> None: | |
| work_output.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| self.python, | |
| "main.py", | |
| str(self.generate_yaml), | |
| *overrides, | |
| f"generation.output.dir={str(work_output)}", | |
| ] | |
| print("[vince] CWD=", self.repo_dir) | |
| print("[vince] CMD=", " ".join(cmd)) | |
| subprocess.run(cmd, cwd=self.repo_dir, check=True, env=os.environ.copy()) | |
| if wait_outputs: | |
| self._wait_until_outputs(work_output, timeout_s=int(os.getenv("VINCIE_WAIT_OUTPUTS_SEC", "300"))) | |
| def multi_turn_edit(self, input_image: str, turns: List[str], **kwargs) -> Path: | |
| out_dir = self.output_root / f"multi_turn_{Path(input_image).stem}" | |
| overrides = [ | |
| f'generation.positive_prompt.image_path="{str(input_image)}"', | |
| f"generation.positive_prompt.prompts={json.dumps(turns)}", | |
| f"generation.seed={int(kwargs.get('seed', 1))}", | |
| f"diffusion.timesteps.sampling.steps={int(kwargs.get('steps', 50))}", | |
| f"diffusion.cfg.scale={float(kwargs.get('cfg_scale', 7.5))}", | |
| f'generation.negative_prompt="{kwargs.get("negative_prompt","")}"', | |
| f"generation.resolution={int(kwargs.get('resolution', 512))}", | |
| f"generation.batch_size={int(kwargs.get('batch_size', 1))}", | |
| ] | |
| self._run_vincie(overrides, out_dir, wait_outputs=True) | |
| return out_dir | |
| def multi_concept_compose(self, files: List[str], descs: List[str], final_prompt: str, **kwargs) -> Path: | |
| out_dir = self.output_root / f"multi_concept_{len(files)}" | |
| overrides = [ | |
| f"generation.concepts.files={json.dumps(files)}", | |
| f"generation.concepts.descs={json.dumps(descs)}", | |
| f'generation.final_prompt="{final_prompt}"', | |
| f"generation.seed={int(kwargs.get('seed', 1))}", | |
| f"diffusion.timesteps.sampling.steps={int(kwargs.get('steps', 50))}", | |
| f"diffusion.cfg.scale={float(kwargs.get('cfg_scale', 7.5))}", | |
| f"generation.resolution={int(kwargs.get('resolution', 512))}", | |
| f"generation.batch_size={int(kwargs.get('batch_size', 1))}", | |
| ] | |
| self._run_vincie(overrides, out_dir, wait_outputs=True) | |
| return out_dir | |