Update pipeline.py
Browse files- pipeline.py +62 -6
pipeline.py
CHANGED
|
@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
|
| 12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 13 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 14 |
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
| 15 |
-
from diffusers.utils import deprecate, logging
|
| 16 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 17 |
|
| 18 |
|
|
@@ -40,7 +40,7 @@ re_attention = re.compile(
|
|
| 40 |
|
| 41 |
def parse_prompt_attention(text):
|
| 42 |
"""
|
| 43 |
-
Parses a string with attention tokens and returns a list of pairs: text and its
|
| 44 |
Accepted tokens are:
|
| 45 |
(abc) - increases attention to abc by a multiplier of 1.1
|
| 46 |
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
|
@@ -237,9 +237,9 @@ def get_weighted_text_embeddings(
|
|
| 237 |
r"""
|
| 238 |
Prompts can be assigned with local weights using brackets. For example,
|
| 239 |
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 240 |
-
and the embedding tokens corresponding to the words get
|
| 241 |
|
| 242 |
-
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the
|
| 243 |
|
| 244 |
Args:
|
| 245 |
pipe (`DiffusionPipeline`):
|
|
@@ -431,6 +431,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 431 |
new_config["steps_offset"] = 1
|
| 432 |
scheduler._internal_dict = FrozenDict(new_config)
|
| 433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
if safety_checker is None:
|
| 435 |
logger.warn(
|
| 436 |
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
|
@@ -451,6 +464,24 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 451 |
feature_extractor=feature_extractor,
|
| 452 |
)
|
| 453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
| 455 |
r"""
|
| 456 |
Enable sliced attention computation.
|
|
@@ -478,6 +509,23 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 478 |
# set slice_size = `None` to disable `attention slicing`
|
| 479 |
self.enable_attention_slicing(None)
|
| 480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
@torch.no_grad()
|
| 482 |
def __call__(
|
| 483 |
self,
|
|
@@ -498,6 +546,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 498 |
output_type: Optional[str] = "pil",
|
| 499 |
return_dict: bool = True,
|
| 500 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
|
|
| 501 |
callback_steps: Optional[int] = 1,
|
| 502 |
**kwargs,
|
| 503 |
):
|
|
@@ -560,11 +609,15 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 560 |
callback (`Callable`, *optional*):
|
| 561 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 562 |
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
|
|
|
|
|
|
|
|
|
| 563 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 564 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 565 |
called at every step.
|
| 566 |
|
| 567 |
Returns:
|
|
|
|
| 568 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 569 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 570 |
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
|
@@ -757,8 +810,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 757 |
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 758 |
|
| 759 |
# call the callback, if provided
|
| 760 |
-
if
|
| 761 |
-
callback
|
|
|
|
|
|
|
|
|
|
| 762 |
|
| 763 |
latents = 1 / 0.18215 * latents
|
| 764 |
image = self.vae.decode(latents).sample
|
|
|
|
| 12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 13 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 14 |
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
| 15 |
+
from diffusers.utils import deprecate, is_accelerate_available, logging
|
| 16 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 17 |
|
| 18 |
|
|
|
|
| 40 |
|
| 41 |
def parse_prompt_attention(text):
|
| 42 |
"""
|
| 43 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 44 |
Accepted tokens are:
|
| 45 |
(abc) - increases attention to abc by a multiplier of 1.1
|
| 46 |
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
|
|
|
| 237 |
r"""
|
| 238 |
Prompts can be assigned with local weights using brackets. For example,
|
| 239 |
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 240 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
| 241 |
|
| 242 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
| 243 |
|
| 244 |
Args:
|
| 245 |
pipe (`DiffusionPipeline`):
|
|
|
|
| 431 |
new_config["steps_offset"] = 1
|
| 432 |
scheduler._internal_dict = FrozenDict(new_config)
|
| 433 |
|
| 434 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
| 435 |
+
deprecation_message = (
|
| 436 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 437 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 438 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 439 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 440 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 441 |
+
)
|
| 442 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
| 443 |
+
new_config = dict(scheduler.config)
|
| 444 |
+
new_config["clip_sample"] = False
|
| 445 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 446 |
+
|
| 447 |
if safety_checker is None:
|
| 448 |
logger.warn(
|
| 449 |
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
|
|
|
| 464 |
feature_extractor=feature_extractor,
|
| 465 |
)
|
| 466 |
|
| 467 |
+
def enable_xformers_memory_efficient_attention(self):
|
| 468 |
+
r"""
|
| 469 |
+
Enable memory efficient attention as implemented in xformers.
|
| 470 |
+
|
| 471 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
| 472 |
+
time. Speed up at training time is not guaranteed.
|
| 473 |
+
|
| 474 |
+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
| 475 |
+
is used.
|
| 476 |
+
"""
|
| 477 |
+
self.unet.set_use_memory_efficient_attention_xformers(True)
|
| 478 |
+
|
| 479 |
+
def disable_xformers_memory_efficient_attention(self):
|
| 480 |
+
r"""
|
| 481 |
+
Disable memory efficient attention as implemented in xformers.
|
| 482 |
+
"""
|
| 483 |
+
self.unet.set_use_memory_efficient_attention_xformers(False)
|
| 484 |
+
|
| 485 |
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
| 486 |
r"""
|
| 487 |
Enable sliced attention computation.
|
|
|
|
| 509 |
# set slice_size = `None` to disable `attention slicing`
|
| 510 |
self.enable_attention_slicing(None)
|
| 511 |
|
| 512 |
+
def enable_sequential_cpu_offload(self):
|
| 513 |
+
r"""
|
| 514 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
| 515 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
| 516 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
| 517 |
+
"""
|
| 518 |
+
if is_accelerate_available():
|
| 519 |
+
from accelerate import cpu_offload
|
| 520 |
+
else:
|
| 521 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
| 522 |
+
|
| 523 |
+
device = self.device
|
| 524 |
+
|
| 525 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
| 526 |
+
if cpu_offloaded_model is not None:
|
| 527 |
+
cpu_offload(cpu_offloaded_model, device)
|
| 528 |
+
|
| 529 |
@torch.no_grad()
|
| 530 |
def __call__(
|
| 531 |
self,
|
|
|
|
| 546 |
output_type: Optional[str] = "pil",
|
| 547 |
return_dict: bool = True,
|
| 548 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 549 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 550 |
callback_steps: Optional[int] = 1,
|
| 551 |
**kwargs,
|
| 552 |
):
|
|
|
|
| 609 |
callback (`Callable`, *optional*):
|
| 610 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 611 |
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 612 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 613 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 614 |
+
`True`, the inference will be cancelled.
|
| 615 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 616 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 617 |
called at every step.
|
| 618 |
|
| 619 |
Returns:
|
| 620 |
+
`None` if cancelled by `is_cancelled_callback`,
|
| 621 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 622 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 623 |
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
|
|
|
| 810 |
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 811 |
|
| 812 |
# call the callback, if provided
|
| 813 |
+
if i % callback_steps == 0:
|
| 814 |
+
if callback is not None:
|
| 815 |
+
callback(i, t, latents)
|
| 816 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 817 |
+
return None
|
| 818 |
|
| 819 |
latents = 1 / 0.18215 * latents
|
| 820 |
image = self.vae.decode(latents).sample
|