Gorluxor commited on
Commit
5822221
·
1 Parent(s): 16eb15e

fixed cache path

Browse files
Files changed (5) hide show
  1. .gitignore +55 -0
  2. README.md +2 -0
  3. app.py +19 -13
  4. requirements.txt +2 -1
  5. stable_diffusion_xl_partedit.py +2 -5
.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Python ###
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+ test/*
7
+ # C extensions
8
+ *.so
9
+ output
10
+ *labels.json
11
+ thirdparty
12
+ Inversion_SDXL
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+
38
+ # Jupyter Notebook
39
+ .ipynb_checkpoints
40
+
41
+ # IPython
42
+ profile_default/
43
+ ipython_config.py
44
+
45
+ # Environments
46
+ .env
47
+ .venv
48
+ env/
49
+ venv/
50
+ ENV/
51
+ env.bak/
52
+ venv.bak/
53
+
54
+ # LSP config files
55
+ pyrightconfig.json
README.md CHANGED
@@ -31,6 +31,8 @@ tags:
31
  - research
32
  preload_from_hub:
33
  - stabilityai/stable-diffusion-xl-base-1.0
 
 
34
  ---
35
 
36
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
31
  - research
32
  preload_from_hub:
33
  - stabilityai/stable-diffusion-xl-base-1.0
34
+ - Aleksandar/PartEdit-Bench
35
+ - Aleksandar/PartEdit-extra
36
  ---
37
 
38
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -18,13 +18,23 @@ from io import BytesIO
18
  import tempfile
19
  import uuid
20
 
 
 
 
 
 
 
 
 
 
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  CACHE_EXAMPLES = os.environ.get("CACHE_EXAMPLES") == "1"
23
  AVAILABLE_TOKENS = list(PART_TOKENS.keys())
24
 
25
  # Download examples directly from the huggingface PartEdit-Bench
26
  # Login using e.g. `huggingface-cli login` or `hf login` if needed.
27
- bench = load_dataset("Aleksandar/PartEdit-Bench", revision="v1.1", split="synth")
28
 
29
  use_examples = None # all with None
30
  logo = "assets/partedit.png"
@@ -37,18 +47,14 @@ with open(logo, "rb") as f:
37
 
38
 
39
  def _save_image_for_download(edited: Union[PIL.Image.Image, np.ndarray, str, List]) -> str:
40
- """Save the first edited image to a temp file and return its filepath."""
41
- # clone to be sure we don't modify the input
42
- edited = edited.copy()
43
- img = edited[0] if isinstance(edited, list) else edited
44
- if isinstance(img, str):
45
- # path on disk already
46
- return img
47
- if isinstance(img, np.ndarray):
48
- img = PIL.Image.fromarray(img)
49
- assert isinstance(img, PIL.Image.Image), "Edited output must be PIL, ndarray, str path, or list of these."
50
  out_path = os.path.join(tempfile.gettempdir(), f"partedit_{uuid.uuid4().hex}.png")
51
- img.save(out_path)
52
  return out_path
53
 
54
 
@@ -288,4 +294,4 @@ if __name__ == "__main__":
288
  with gr.Tab(label="PartEdit", id="edit"):
289
  edit_demo(model)
290
 
291
- demo.queue(max_size=20).launch()
 
18
  import tempfile
19
  import uuid
20
 
21
+ import pathlib
22
+ HF_BASE = "/data/.huggingface" if os.getenv("SYSTEM") == "spaces" else "./.hf_cache" # /data/hf_cache
23
+ HF_BASE = str(pathlib.Path(HF_BASE).absolute())
24
+ os.environ.setdefault("HF_HOME", HF_BASE)
25
+ os.environ.setdefault("HF_HUB_CACHE", os.path.join(HF_BASE, "hub"))
26
+ os.environ.setdefault("HF_DATASETS_CACHE", os.path.join(HF_BASE, "datasets"))
27
+ os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(HF_BASE, "transformers"))
28
+ for k in ("HF_HUB_CACHE", "HF_DATASETS_CACHE", "TRANSFORMERS_CACHE"):
29
+ os.makedirs(os.environ[k], exist_ok=True)
30
+
31
  MAX_SEED = np.iinfo(np.int32).max
32
  CACHE_EXAMPLES = os.environ.get("CACHE_EXAMPLES") == "1"
33
  AVAILABLE_TOKENS = list(PART_TOKENS.keys())
34
 
35
  # Download examples directly from the huggingface PartEdit-Bench
36
  # Login using e.g. `huggingface-cli login` or `hf login` if needed.
37
+ bench = load_dataset("Aleksandar/PartEdit-Bench", revision="v1.1", split="synth", cache_dir=os.environ["HF_DATASETS_CACHE"])
38
 
39
  use_examples = None # all with None
40
  logo = "assets/partedit.png"
 
47
 
48
 
49
  def _save_image_for_download(edited: Union[PIL.Image.Image, np.ndarray, str, List]) -> str:
50
+ item = edited[0] if isinstance(edited, list) else edited # pick first
51
+ if isinstance(item, str):
52
+ return item # already a path
53
+ if isinstance(item, np.ndarray):
54
+ item = PIL.Image.fromarray(item)
55
+ assert isinstance(item, PIL.Image.Image), "Edited output must be PIL, ndarray, str path, or list of these."
 
 
 
 
56
  out_path = os.path.join(tempfile.gettempdir(), f"partedit_{uuid.uuid4().hex}.png")
57
+ item.save(out_path)
58
  return out_path
59
 
60
 
 
294
  with gr.Tab(label="PartEdit", id="edit"):
295
  edit_demo(model)
296
 
297
+ demo.queue(concurrency_count=1, max_size=20).launch(server_name="0.0.0.0")
requirements.txt CHANGED
@@ -4,6 +4,7 @@
4
 
5
  setuptools>=61.0
6
  numpy<1.24
 
7
  # ipywidgets
8
  # black[jupyter]
9
  # jupyterlab
@@ -16,7 +17,7 @@ tqdm
16
  # Core ML stack (PyTorch 2.1.0 + CUDA 11.8)
17
  torch==2.1.0
18
  torchvision==0.16.0
19
- # torchaudio==2.1.0
20
 
21
  # UI / HF stack
22
  gradio<5.0 # tested on 4.29; should work on 4.44.1 with pydantic fix
 
4
 
5
  setuptools>=61.0
6
  numpy<1.24
7
+ # Not needed for demo environment
8
  # ipywidgets
9
  # black[jupyter]
10
  # jupyterlab
 
17
  # Core ML stack (PyTorch 2.1.0 + CUDA 11.8)
18
  torch==2.1.0
19
  torchvision==0.16.0
20
+ # torchaudio==2.1.0
21
 
22
  # UI / HF stack
23
  gradio<5.0 # tested on 4.29; should work on 4.44.1 with pydantic fix
stable_diffusion_xl_partedit.py CHANGED
@@ -1132,9 +1132,6 @@ class DotDictExtra(dict):
1132
  if self.grounding.max() > 1.0:
1133
  self.grounding = self.grounding / self.grounding.max()
1134
 
1135
- # self.edit_mask = ToTensor()(self.edit_mask.convert('')).unsqueeze(0)
1136
- # # TODO(Alex): Fix this
1137
- # self.edit_mask = torch.load(self.edit_mask)
1138
  assert isinstance(self.th_strategy, Binarization), "th_strategy should be of type Binarization"
1139
  assert isinstance(self.pad_strategy, PaddingStrategy), "pad_strategy should be of type PaddingStrategy"
1140
 
@@ -1660,7 +1657,7 @@ class AttentionStore(AttentionControl):
1660
  class AttentionControlEdit(AttentionStore, abc.ABC):
1661
  def step_callback(self, x_t):
1662
  if self.local_blend is not None:
1663
- # x_t = self.local_blend(x_t, self.attention_store) # TODO:return back
1664
  x_t = self.local_blend(x_t, self)
1665
  return x_t
1666
 
@@ -1688,7 +1685,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC):
1688
  if is_cross:
1689
  alpha_words = self.cross_replace_alpha[self.cur_step].to(attn_base.device)
1690
  attn_replace_new = self.replace_cross_attention(attn_base, attn_replace) * alpha_words + (1 - alpha_words) * attn_replace
1691
- # TODO(Alex): We want to control the actual
1692
 
1693
  attn[1:] = attn_replace_new
1694
  if self.has_maps() and self.extra_kwargs.get("force_cross_attn", False): # and self.cur_step <= 51:
 
1132
  if self.grounding.max() > 1.0:
1133
  self.grounding = self.grounding / self.grounding.max()
1134
 
 
 
 
1135
  assert isinstance(self.th_strategy, Binarization), "th_strategy should be of type Binarization"
1136
  assert isinstance(self.pad_strategy, PaddingStrategy), "pad_strategy should be of type PaddingStrategy"
1137
 
 
1657
  class AttentionControlEdit(AttentionStore, abc.ABC):
1658
  def step_callback(self, x_t):
1659
  if self.local_blend is not None:
1660
+ # x_t = self.local_blend(x_t, self.attention_store) # TODO: Check if there is more memory efficient way
1661
  x_t = self.local_blend(x_t, self)
1662
  return x_t
1663
 
 
1685
  if is_cross:
1686
  alpha_words = self.cross_replace_alpha[self.cur_step].to(attn_base.device)
1687
  attn_replace_new = self.replace_cross_attention(attn_base, attn_replace) * alpha_words + (1 - alpha_words) * attn_replace
1688
+
1689
 
1690
  attn[1:] = attn_replace_new
1691
  if self.has_maps() and self.extra_kwargs.get("force_cross_attn", False): # and self.cur_step <= 51: