Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| from typing import List, Optional, Union | |
| import PIL | |
| from PIL import Image | |
| from einops import rearrange | |
| from torch import Tensor | |
| import numpy as np | |
| import torch | |
| from safetensors.torch import load_file as load_sft | |
| from diffusers.image_processor import VaeImageProcessor | |
| from ..modules.layers import ( | |
| SingleStreamBlockLoraProcessor, | |
| DoubleStreamBlockLoraProcessor, | |
| ) | |
| from ..pipelines.sampling import denoise, prepare_image_cond, get_noise, get_schedule, prepare, prepare_with_redux, unpack | |
| from ..utils.model_utils import ( | |
| load_ae, | |
| load_clip, | |
| load_ic_custom, | |
| load_t5, | |
| load_redux, | |
| resolve_model_path | |
| ) | |
| PipelineImageInput = Union[ | |
| PIL.Image.Image, | |
| np.ndarray, | |
| torch.Tensor, | |
| List[PIL.Image.Image], | |
| List[np.ndarray], | |
| List[torch.Tensor], | |
| ] | |
| class ICCustomPipeline: | |
| def __init__( | |
| self, | |
| clip_path: str = "clip-vit-large-patch14", | |
| t5_path: str = "t5-v1_1-xxl", | |
| siglip_path: str = "siglip-so400m-patch14-384", | |
| ae_path: str = "flux-fill-dev-ae", | |
| dit_path: str = "flux-fill-dev-dit", | |
| redux_path: str = "flux1-redux-dev", | |
| lora_path: str = "dit_lora_0x1561", | |
| img_txt_in_path: str = "dit_txt_img_in_0x1561", | |
| boundary_embeddings_path: str = "dit_boundary_embeddings_0x1561", | |
| task_register_embeddings_path: str = "dit_task_register_embeddings_0x1561", | |
| network_alpha: int = None, | |
| double_blocks_idx: str = None, | |
| single_blocks_idx: str = None, | |
| device: torch.device = torch.device("cuda"), | |
| offload: bool = False, | |
| weight_dtype: torch.dtype = torch.bfloat16, | |
| show_progress: bool = False, | |
| use_flash_attention: bool = False, | |
| ): | |
| self.device = device | |
| self.offload = offload | |
| self.weight_dtype = weight_dtype | |
| self.clip = load_clip(clip_path, self.device if not offload else "cpu", dtype=self.weight_dtype).eval() | |
| self.t5 = load_t5(t5_path, self.device if not offload else "cpu", max_length=512, dtype=self.weight_dtype).eval() | |
| self.ae = load_ae(ae_path, device="cpu" if offload else self.device).eval() | |
| self.model = load_ic_custom(dit_path, device="cpu" if offload else self.device, dtype=self.weight_dtype).eval() | |
| self.image_encoder = load_redux(redux_path, siglip_path, device="cpu" if offload else self.device, dtype=self.weight_dtype).eval() | |
| self.vae_scale_factor = 8 | |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) | |
| self.mask_processor = VaeImageProcessor(resample="nearest", do_normalize=False) | |
| self.set_lora(lora_path, network_alpha, double_blocks_idx, single_blocks_idx) | |
| self.set_img_txt_in(img_txt_in_path) | |
| self.set_boundary_embeddings(boundary_embeddings_path) | |
| self.set_task_register_embeddings(task_register_embeddings_path) | |
| self.show_progress = show_progress | |
| self.use_flash_attention = use_flash_attention | |
| def set_show_progress(self, show_progress: bool): | |
| self.show_progress = show_progress | |
| def set_use_flash_attention(self, use_flash_attention: bool): | |
| self.use_flash_attention = use_flash_attention | |
| def set_pipeline_offload(self, offload: bool): | |
| self.ae = self.ae.to("cpu" if offload else self.device) | |
| self.model = self.model.to("cpu" if offload else self.device) | |
| self.image_encoder = self.image_encoder.to("cpu" if offload else self.device) | |
| self.clip = self.clip.to("cpu" if offload else self.device) | |
| self.t5 = self.t5.to("cpu" if offload else self.device) | |
| self.offload = offload | |
| def set_pipeline_gradient_checkpointing(self, enable: bool): | |
| def _recursive_set_gradient_checkpointing(module): | |
| self.model._set_gradient_checkpointing(module, enable) | |
| for child in module.children(): | |
| _recursive_set_gradient_checkpointing(child) | |
| _recursive_set_gradient_checkpointing(self.model) | |
| def get_lora_rank(self, weights): | |
| for k in weights.keys(): | |
| if k.endswith(".down.weight"): | |
| return weights[k].shape[0] | |
| def load_model_weights(self, weights: dict, strict: bool = False): | |
| model_state_dict = self.model.state_dict() | |
| update_dict = {k: v for k, v in weights.items() if k in model_state_dict} | |
| missing_keys = [k for k in weights if k not in model_state_dict] | |
| assert len(missing_keys) == 0, f"Some keys in the file are not found in the model: {missing_keys}" | |
| self.model.load_state_dict(update_dict, strict=strict) | |
| def set_lora( | |
| self, | |
| lora_path: str = None, | |
| network_alpha: int = None, | |
| double_blocks_idx: str = None, | |
| single_blocks_idx: str = None, | |
| ): | |
| if not os.path.exists(lora_path): | |
| lora_path = "dit_lora_0x1561" | |
| lora_path = resolve_model_path( | |
| name=lora_path, | |
| repo_id_field="repo_id", | |
| filename_field="filename", | |
| ckpt_path_field="ckpt_path", | |
| hf_download=True, | |
| ) | |
| weights = load_sft(lora_path) | |
| self.update_model_with_lora(weights, network_alpha, double_blocks_idx, single_blocks_idx) | |
| def update_model_with_lora( | |
| self, | |
| weights, | |
| network_alpha, | |
| double_blocks_idx, | |
| single_blocks_idx, | |
| ): | |
| rank = self.get_lora_rank(weights) | |
| network_alpha = network_alpha if network_alpha is not None else rank | |
| lora_attn_procs = {} | |
| if double_blocks_idx is None: | |
| double_blocks_idx = [] | |
| else: | |
| double_blocks_idx = [int(idx) for idx in double_blocks_idx.split(",")] | |
| if single_blocks_idx is None: | |
| single_blocks_idx = list(range(38)) | |
| else: | |
| single_blocks_idx = [int(idx) for idx in single_blocks_idx.split(",")] | |
| for name, attn_processor in self.model.attn_processors.items(): | |
| match = re.search(r'\.(\d+)\.', name) | |
| if match: | |
| layer_index = int(match.group(1)) | |
| if name.startswith("double_blocks") and layer_index in double_blocks_idx: | |
| lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( | |
| dim=3072, rank=rank, network_alpha=network_alpha | |
| ) | |
| elif name.startswith("single_blocks") and layer_index in single_blocks_idx: | |
| lora_attn_procs[name] = SingleStreamBlockLoraProcessor( | |
| dim=3072, rank=rank, network_alpha=network_alpha | |
| ) | |
| else: | |
| lora_attn_procs[name] = attn_processor | |
| self.model.set_attn_processor(lora_attn_procs) | |
| self.load_model_weights(weights, strict=False) | |
| def set_img_txt_in(self, img_txt_in_path: str): | |
| if not os.path.exists(img_txt_in_path): | |
| img_txt_in_path = "dit_txt_img_in_0x1561" | |
| img_txt_in_path = resolve_model_path( | |
| name=img_txt_in_path, | |
| repo_id_field="repo_id", | |
| filename_field="filename", | |
| ckpt_path_field="ckpt_path", | |
| hf_download=True, | |
| ) | |
| weights = load_sft(img_txt_in_path) | |
| self.load_model_weights(weights, strict=False) | |
| def set_boundary_embeddings(self, boundary_embeddings_path: str): | |
| if not os.path.exists(boundary_embeddings_path): | |
| boundary_embeddings_path = "dit_boundary_embeddings_0x1561" | |
| boundary_embeddings_path = resolve_model_path( | |
| name=boundary_embeddings_path, | |
| repo_id_field="repo_id", | |
| filename_field="filename", | |
| ckpt_path_field="ckpt_path", | |
| hf_download=True, | |
| ) | |
| weights = load_sft(boundary_embeddings_path) | |
| self.load_model_weights(weights, strict=False) | |
| def set_task_register_embeddings(self, task_register_embeddings_path: str): | |
| if not os.path.exists(task_register_embeddings_path): | |
| task_register_embeddings_path = "dit_task_register_embeddings_0x1561" | |
| task_register_embeddings_path = resolve_model_path( | |
| name=task_register_embeddings_path, | |
| repo_id_field="repo_id", | |
| filename_field="filename", | |
| ckpt_path_field="ckpt_path", | |
| hf_download=True, | |
| ) | |
| weights = load_sft(task_register_embeddings_path) | |
| self.load_model_weights(weights, strict=False) | |
| def offload_model_to_cpu(self, *models): | |
| for model in models: | |
| if model is not None: | |
| model.to("cpu") | |
| def prepare_image( | |
| self, | |
| image, | |
| device, | |
| dtype, | |
| width=None, | |
| height=None, | |
| ): | |
| if isinstance(image, torch.Tensor): | |
| pass | |
| else: | |
| image = self.image_processor.preprocess(image, height=height, width=width) | |
| image = image.to(device=device, dtype=dtype) | |
| return image | |
| def prepare_mask( | |
| self, | |
| mask, | |
| device, | |
| dtype, | |
| width: int = None, | |
| height: int = None, | |
| ): | |
| if isinstance(mask, torch.Tensor): | |
| pass | |
| else: | |
| mask = self.mask_processor.preprocess(mask, height=height, width=width) | |
| mask = mask.to(device=device, dtype=dtype) | |
| return mask | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str], None], | |
| width: int = 512, | |
| height: int = 512, | |
| guidance: float = 4, | |
| num_steps: int = 50, | |
| seed: int = 123456789, | |
| true_gs: float = 1, | |
| neg_prompt: Optional[Union[str, List[str], None]] = None, | |
| timestep_to_start_cfg: int = 0, | |
| img_ref: Optional[PipelineImageInput] = None, | |
| img_target: Optional[PipelineImageInput] = None, | |
| mask_target: Optional[PipelineImageInput] = None, | |
| img_ip: Optional[PipelineImageInput] = None, | |
| cond_w_regions: Optional[Union[List[int], int]] = None, | |
| mask_type_ids: Optional[Union[Tensor, int]] = None, | |
| use_background_preservation: bool = False, | |
| use_progressive_background_preservation: bool = True, | |
| background_blend_threshold: float = 0.8, | |
| num_images_per_prompt: int = 1, | |
| gradio_progress=None, | |
| ): | |
| width = 16 * (width // 16) | |
| height = 16 * (height // 16) | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = 1 | |
| img_ref = self.prepare_image( | |
| img_ref, | |
| self.device, | |
| self.weight_dtype, | |
| ) | |
| img_target = self.prepare_image( | |
| img_target, | |
| self.device, | |
| self.weight_dtype, | |
| ) | |
| mask_target = self.prepare_mask( | |
| mask_target, | |
| self.device, | |
| self.weight_dtype, | |
| ) | |
| if num_images_per_prompt > 1: | |
| mask_type_ids = mask_type_ids.repeat_interleave(num_images_per_prompt, dim=0) | |
| return self.forward( | |
| batch_size, | |
| num_images_per_prompt, | |
| prompt, | |
| width, | |
| height, | |
| guidance, | |
| num_steps, | |
| seed, | |
| timestep_to_start_cfg=timestep_to_start_cfg, | |
| true_gs=true_gs, | |
| neg_prompt=neg_prompt, | |
| img_ref=img_ref, | |
| img_target=img_target, | |
| mask_target=mask_target, | |
| img_ip=img_ip, | |
| cond_w_regions=cond_w_regions, | |
| mask_type_ids=mask_type_ids, | |
| use_background_preservation=use_background_preservation, | |
| use_progressive_background_preservation=use_progressive_background_preservation, | |
| background_blend_threshold=background_blend_threshold, | |
| gradio_progress=gradio_progress, | |
| ) | |
| def forward( | |
| self, | |
| batch_size, | |
| num_images_per_prompt, | |
| prompt, | |
| width, | |
| height, | |
| guidance, | |
| num_steps, | |
| seed, | |
| timestep_to_start_cfg, | |
| true_gs, | |
| neg_prompt, | |
| img_ref, | |
| img_target, | |
| mask_target, | |
| img_ip, | |
| cond_w_regions, | |
| mask_type_ids, | |
| use_background_preservation, | |
| use_progressive_background_preservation, | |
| background_blend_threshold, | |
| gradio_progress=None, | |
| ): | |
| has_neg_prompt = neg_prompt is not None | |
| do_true_cfg = true_gs > 1 and has_neg_prompt | |
| x = get_noise( | |
| batch_size * num_images_per_prompt, height, width, device=self.device, | |
| dtype=self.weight_dtype, seed=seed | |
| ) | |
| image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) | |
| timesteps = get_schedule( | |
| num_steps, | |
| image_seq_len, | |
| shift=True, | |
| ) | |
| with torch.no_grad(): | |
| self.t5, self.clip, self.image_encoder = self.t5.to(self.device), self.clip.to(self.device), self.image_encoder.to(self.device) | |
| if self.image_encoder is not None: | |
| inp_cond = prepare_with_redux(t5=self.t5, clip=self.clip, image_encoder=self.image_encoder, img=x, img_ip=img_ip, prompt=prompt, num_images_per_prompt=num_images_per_prompt) | |
| else: | |
| inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt, num_images_per_prompt=num_images_per_prompt) | |
| neg_inp_cond = None | |
| if do_true_cfg: | |
| if self.image_encoder is not None: | |
| neg_inp_cond = prepare_with_redux(t5=self.t5, clip=self.clip, image_encoder=self.image_encoder, img=x, img_ip=img_ip, prompt=neg_prompt, num_images_per_prompt=num_images_per_prompt) | |
| else: | |
| neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt, num_images_per_prompt=num_images_per_prompt) | |
| if self.offload: | |
| self.offload_model_to_cpu(self.t5, self.clip, self.image_encoder) | |
| self.model = self.model.to(self.device) | |
| self.ae.encoder = self.ae.encoder.to(self.device) | |
| inp_img_cond = prepare_image_cond( | |
| ae=self.ae, | |
| img_ref=img_ref, | |
| img_target=img_target, | |
| mask_target=mask_target, | |
| dtype=self.weight_dtype, | |
| device=self.device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| ) | |
| x = denoise( | |
| self.model, | |
| img=inp_cond['img'], | |
| img_ids=inp_cond['img_ids'], | |
| txt=inp_cond['txt'], | |
| txt_ids=inp_cond['txt_ids'], | |
| txt_vec=inp_cond['txt_vec'], | |
| timesteps=timesteps, | |
| guidance=guidance, | |
| img_cond=inp_img_cond['img_cond'], | |
| mask_cond=inp_img_cond['mask_cond'], | |
| img_latent=inp_img_cond['img_latent'], | |
| cond_w_regions=cond_w_regions, | |
| mask_type_ids=mask_type_ids, | |
| height=height, | |
| width=width, | |
| use_background_preservation=use_background_preservation, | |
| use_progressive_background_preservation=use_progressive_background_preservation, | |
| background_blend_threshold=background_blend_threshold, | |
| true_gs=true_gs, | |
| timestep_to_start_cfg=timestep_to_start_cfg, | |
| neg_txt=neg_inp_cond['txt'] if neg_inp_cond is not None else None, | |
| neg_txt_ids=neg_inp_cond['txt_ids'] if neg_inp_cond is not None else None, | |
| neg_txt_vec=neg_inp_cond['txt_vec'] if neg_inp_cond is not None else None, | |
| show_progress=self.show_progress, | |
| use_flash_attention=self.use_flash_attention, | |
| gradio_progress=gradio_progress, | |
| ) | |
| if self.offload: | |
| self.offload_model_to_cpu(self.model, self.ae.encoder) | |
| x = unpack(x.float(), height, width) | |
| self.ae.decoder = self.ae.decoder.to(x.device) | |
| x = self.ae.decode(x) | |
| if self.offload: | |
| self.offload_model_to_cpu(self.ae.decoder) | |
| x1 = x.clamp(-1, 1) | |
| x1 = rearrange(x1, "b c h w -> b h w c") | |
| output_imgs_target = [] | |
| for i in range(x1.shape[0]): | |
| output_img = Image.fromarray((127.5 * (x1[i] + 1.0)).cpu().byte().numpy()) | |
| img_target_height, img_target_width = img_target.shape[2], img_target.shape[3] | |
| output_img_target = output_img.crop(( | |
| output_img.width - img_target_width, | |
| output_img.height - img_target_height, | |
| output_img.width, | |
| output_img.height | |
| )) | |
| output_imgs_target.append(output_img_target) | |
| return output_imgs_target |