carlex3321 commited on
Commit
6e9236e
·
verified ·
1 Parent(s): 1a79202

Upload vincie.py

Browse files
Files changed (1) hide show
  1. vincie.py +330 -0
vincie.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VincieService
4
+ - Ensures the upstream VINCIE repository is present.
5
+ - Fetches the minimal checkpoint files (dit.pth, vae.pth) via hf_hub_download into /app/ckpt/VINCIE-3B.
6
+ - Creates a compatibility symlink /app/VINCIE/ckpt/VINCIE-3B -> /app/ckpt/VINCIE-3B for repo-relative paths.
7
+ - Runs the official VINCIE main.py with Hydra/YACS overrides for both multi-turn and multi-concept generation.
8
+ - Optionally injects a minimal 'apex.normalization' shim when NVIDIA Apex is not available (to avoid import errors).
9
+ Upstream reference: https://github.com/ByteDance-Seed/VINCIE
10
+
11
+ Developed by [email protected]
12
+ https://github.com/carlex22
13
+
14
+ Version 1.0.0
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import json
20
+ import subprocess
21
+ from pathlib import Path
22
+ from typing import List, Optional
23
+
24
+ from huggingface_hub import hf_hub_download
25
+
26
+
27
+ class VincieService:
28
+ """
29
+ High-level service for preparing VINCIE runtime assets and invoking generation.
30
+
31
+ Responsibilities:
32
+ - Repository management: clone the official VINCIE repository when missing.
33
+ - Checkpoint management: download dit.pth and vae.pth from the VINCIE-3B checkpoint on the Hub.
34
+ - Path compatibility: ensure /app/VINCIE/ckpt/VINCIE-3B points to /app/ckpt/VINCIE-3B.
35
+ - Runners: execute main.py with generate.yaml overrides for multi-turn edits and multi-concept composition.
36
+ - Apex shim: provide a minimal fallback for apex.normalization if Apex isn’t installed.
37
+
38
+ Defaults assume the Docker/container layout used by the Space:
39
+ - Repository directory: /app/VINCIE
40
+ - Checkpoint directory: /app/ckpt/VINCIE-3B
41
+ - Output root: /app/outputs
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ repo_dir: str = "/app/VINCIE",
47
+ ckpt_dir: str = "/app/ckpt/VINCIE-3B",
48
+ python_bin: str = "python",
49
+ repo_id: str = "ByteDance-Seed/VINCIE-3B",
50
+ ):
51
+ """
52
+ Initialize the service with paths and runtime settings.
53
+
54
+ Args:
55
+ repo_dir: Filesystem location of the upstream VINCIE repository clone.
56
+ ckpt_dir: Filesystem location where dit.pth and vae.pth are stored.
57
+ python_bin: Python executable to invoke for main.py (e.g., 'python' or a full path).
58
+ repo_id: Hugging Face Hub repo id for the VINCIE-3B checkpoint.
59
+
60
+ Side-effects:
61
+ - Ensures the output root directory exists.
62
+ - Ensures the repo ckpt/ directory exists (for symlink placement).
63
+ """
64
+ self.repo_dir = Path(repo_dir)
65
+ self.ckpt_dir = Path(ckpt_dir)
66
+ self.python = python_bin
67
+ self.repo_id = repo_id
68
+
69
+ # Canonical config and paths within the upstream repo
70
+ self.generate_yaml = self.repo_dir / "configs" / "generate.yaml"
71
+ self.assets_dir = self.repo_dir / "assets"
72
+
73
+ # Output root for generated media
74
+ self.output_root = Path("/app/outputs")
75
+ self.output_root.mkdir(parents=True, exist_ok=True)
76
+
77
+ # Ensure ckpt/ exists in the repo (symlink target lives here)
78
+ (self.repo_dir / "ckpt").mkdir(parents=True, exist_ok=True)
79
+
80
+ # ---------- Setup ----------
81
+
82
+ def ensure_repo(self, git_url: str = "https://github.com/ByteDance-Seed/VINCIE") -> None:
83
+ """
84
+ Clone the official VINCIE repository when missing.
85
+
86
+ Args:
87
+ git_url: Source URL of the official VINCIE repo.
88
+
89
+ Raises:
90
+ subprocess.CalledProcessError on git clone failure.
91
+ """
92
+ if not self.repo_dir.exists():
93
+ subprocess.run(["git", "clone", git_url, str(self.repo_dir)], check=True)
94
+
95
+ def ensure_model(self, hf_token: Optional[str] = None) -> None:
96
+ """
97
+ Download the minimal VINCIE-3B checkpoint files if missing and create a repo-compatible symlink.
98
+
99
+ Files fetched from the Hub (repo_id):
100
+ - dit.pth
101
+ - vae.pth
102
+
103
+ The files are placed under self.ckpt_dir (default /app/ckpt/VINCIE-3B) and a symlink
104
+ /app/VINCIE/ckpt/VINCIE-3B -> /app/ckpt/VINCIE-3B is created to match upstream relative paths.
105
+
106
+ Args:
107
+ hf_token: Optional Hugging Face token; defaults to env HF_TOKEN or HUGGINGFACE_TOKEN.
108
+
109
+ Notes:
110
+ - Uses hf_hub_download with local_dir, so files are placed directly in the target directory.
111
+ - A basic size check (> 1MB) is used to decide whether to refetch a file.
112
+ """
113
+ self.ckpt_dir.mkdir(parents=True, exist_ok=True)
114
+ token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
115
+
116
+ def _need(p: Path) -> bool:
117
+ try:
118
+ return not (p.exists() and p.stat().st_size > 1_000_000)
119
+ except FileNotFoundError:
120
+ return True
121
+
122
+ for fname in ["dit.pth", "vae.pth"]:
123
+ dst = self.ckpt_dir / fname
124
+ if _need(dst):
125
+ print(f"Downloading {fname} from {self.repo_id} ...")
126
+ hf_hub_download(
127
+ repo_id=self.repo_id,
128
+ filename=fname,
129
+ local_dir=str(self.ckpt_dir),
130
+ local_dir_use_symlinks=False,
131
+ token=token,
132
+ force_download=False,
133
+ local_files_only=False,
134
+ )
135
+
136
+ # Compatibility symlink for repo-relative ckpt paths
137
+ link = self.repo_dir / "ckpt" / "VINCIE-3B"
138
+ try:
139
+ if link.is_symlink() or link.exists():
140
+ try:
141
+ link.unlink()
142
+ except IsADirectoryError:
143
+ # If a directory sits at that path, we leave it as-is or replace as needed
144
+ pass
145
+ if not link.exists():
146
+ link.symlink_to(self.ckpt_dir, target_is_directory=True)
147
+ except Exception as e:
148
+ print("Warning: failed to create checkpoint symlink:", e)
149
+
150
+ def ensure_apex(self, enable_shim: bool = True) -> None:
151
+ """
152
+ Ensure apex.normalization importability.
153
+
154
+ If NVIDIA Apex is not installed, and enable_shim=True, inject a minimal shim implementing:
155
+ - FusedRMSNorm via torch.nn.RMSNorm
156
+ - FusedLayerNorm via torch.nn.LayerNorm
157
+
158
+ This prevents import-time failures in code that references apex.normalization while
159
+ sacrificing any Apex-specific kernel benefits.
160
+
161
+ Args:
162
+ enable_shim: Whether to install a local shim when 'apex.normalization' is missing.
163
+ """
164
+ try:
165
+ import importlib
166
+ importlib.import_module("apex.normalization")
167
+ return
168
+ except Exception:
169
+ if not enable_shim:
170
+ return
171
+
172
+ shim_root = Path("/app/shims")
173
+ apex_pkg = shim_root / "apex"
174
+ apex_pkg.mkdir(parents=True, exist_ok=True)
175
+
176
+ (apex_pkg / "__init__.py").write_text("from .normalization import *\n")
177
+ (apex_pkg / "normalization.py").write_text(
178
+ "import torch\n"
179
+ "import torch.nn as nn\n"
180
+ "\n"
181
+ "class FusedRMSNorm(nn.Module):\n"
182
+ " def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):\n"
183
+ " super().__init__()\n"
184
+ " self.mod = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n"
185
+ " def forward(self, x):\n"
186
+ " return self.mod(x)\n"
187
+ "\n"
188
+ "class FusedLayerNorm(nn.Module):\n"
189
+ " def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n"
190
+ " super().__init__()\n"
191
+ " self.mod = nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n"
192
+ " def forward(self, x):\n"
193
+ " return self.mod(x)\n"
194
+ )
195
+
196
+ # Make shim importable in this process and child processes
197
+ sys.path.insert(0, str(shim_root))
198
+ os.environ["PYTHONPATH"] = f"{str(shim_root)}:{os.environ.get('PYTHONPATH','')}"
199
+
200
+ def ready(self) -> bool:
201
+ """
202
+ Quick readiness probe for UI:
203
+ - The repository and generate.yaml exist.
204
+ - Minimal checkpoint files (dit.pth, vae.pth) exist.
205
+
206
+ Returns:
207
+ True if the environment is ready to run generation tasks; otherwise False.
208
+ """
209
+ have_repo = self.repo_dir.exists() and self.generate_yaml.exists()
210
+ dit_ok = (self.ckpt_dir / "dit.pth").exists()
211
+ vae_ok = (self.ckpt_dir / "vae.pth").exists()
212
+ return bool(have_repo and dit_ok and vae_ok)
213
+
214
+ # ---------- Core runner ----------
215
+
216
+ def _run_vincie(self, overrides: List[str], work_output: Path) -> None:
217
+ """
218
+ Invoke VINCIE's main.py with Hydra/YACS overrides inside the upstream repo directory.
219
+
220
+ Args:
221
+ overrides: A list of CLI overrides (e.g., generation.positive_prompt.*).
222
+ work_output: Output directory path for generated assets.
223
+
224
+ Raises:
225
+ subprocess.CalledProcessError if the underlying process fails.
226
+ """
227
+ work_output.mkdir(parents=True, exist_ok=True)
228
+ cmd = [
229
+ self.python,
230
+ "main.py",
231
+ str(self.generate_yaml),
232
+ *overrides,
233
+ f"generation.output.dir={str(work_output)}",
234
+ ]
235
+ env = os.environ.copy()
236
+ subprocess.run(cmd, cwd=self.repo_dir, check=True, env=env)
237
+
238
+ # ---------- Multi-turn editing ----------
239
+
240
+ def multi_turn_edit(
241
+ self,
242
+ input_image: str,
243
+ turns: List[str],
244
+ out_dir_name: Optional[str] = None,
245
+ ) -> Path:
246
+ """
247
+ Run the official 'multi-turn' generation equivalent.
248
+
249
+ This wraps generate.yaml using overrides:
250
+ - generation.positive_prompt.image_path = [ "<input-image-path>" ]
251
+ - generation.positive_prompt.prompts = [ "<turn1>", "<turn2>", ... ]
252
+
253
+ Args:
254
+ input_image: Path to the single input image on disk.
255
+ turns: A list of editing instructions, in the order they should be applied.
256
+ out_dir_name: Optional name for the output subdirectory; auto-generated if omitted.
257
+
258
+ Returns:
259
+ Path to the output directory containing images and, if produced, a video.
260
+ """
261
+ out_dir = self.output_root / (out_dir_name or f"multi_turn_{self._slug(input_image)}")
262
+ image_json = json.dumps([str(input_image)])
263
+ prompts_json = json.dumps(turns)
264
+
265
+ overrides = [
266
+ f"generation.positive_prompt.image_path={image_json}",
267
+ f"generation.positive_prompt.prompts={prompts_json}",
268
+ f"ckpt.path={str(self.ckpt_dir)}",
269
+ ]
270
+ self._run_vincie(overrides, out_dir)
271
+ return out_dir
272
+
273
+ # ---------- Multi-concept composition ----------
274
+
275
+ def multi_concept_compose(
276
+ self,
277
+ concept_images: List[str],
278
+ concept_prompts: List[str],
279
+ final_prompt: str,
280
+ out_dir_name: Optional[str] = None,
281
+ ) -> Path:
282
+ """
283
+ Run the 'multi-concept' composition pipeline.
284
+
285
+ The service forms:
286
+ - generation.positive_prompt.image_path = [ <concept-img-1>, ..., <concept-img-N> ]
287
+ - generation.positive_prompt.prompts = [ <desc-1>, ..., <desc-N>, <final-prompt> ]
288
+ - generation.pad_img_placehoder = False (preserves input shapes)
289
+ - ckpt.path = /app/ckpt/VINCIE-3B (by default)
290
+
291
+ Args:
292
+ concept_images: Paths to concept images on disk.
293
+ concept_prompts: Per-image descriptions in the same order as concept_images.
294
+ final_prompt: Composition prompt appended after all per-image descriptions.
295
+ out_dir_name: Optional name for the output subdirectory; defaults to 'multi_concept'.
296
+
297
+ Returns:
298
+ Path to the output directory containing images and, if produced, a video.
299
+ """
300
+ out_dir = self.output_root / (out_dir_name or "multi_concept")
301
+ imgs_json = json.dumps([str(p) for p in concept_images])
302
+ prompts_all = concept_prompts + [final_prompt]
303
+ prompts_json = json.dumps(prompts_all)
304
+
305
+ overrides = [
306
+ f"generation.positive_prompt.image_path={imgs_json}",
307
+ f"generation.positive_prompt.prompts={prompts_json}",
308
+ "generation.pad_img_placehoder=False",
309
+ f"ckpt.path={str(self.ckpt_dir)}",
310
+ ]
311
+ self._run_vincie(overrides, out_dir)
312
+ return out_dir
313
+
314
+ # ---------- Helpers ----------
315
+
316
+ @staticmethod
317
+ def _slug(path_or_text: str) -> str:
318
+ """
319
+ Produce a filesystem-friendly short name (max 64 chars) from a path or text.
320
+
321
+ Args:
322
+ path_or_text: An input path or arbitrary string.
323
+
324
+ Returns:
325
+ A sanitized string consisting of [A-Za-z0-9._-] with non-matching chars converted to underscores.
326
+ """
327
+ p = Path(path_or_text)
328
+ base = p.stem if p.exists() else str(path_or_text)
329
+ keep = "".join(c if c.isalnum() or c in "-_." else "_" for c in str(base))
330
+ return keep[:64]