|  | import os | 
					
						
						|  | import glob | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  |  | 
					
						
						|  | from pipeline_flux_ipa import FluxPipeline | 
					
						
						|  | from transformer_flux import FluxTransformer2DModel | 
					
						
						|  | from attention_processor import IPAFluxAttnProcessor2_0 | 
					
						
						|  | from transformers import AutoProcessor, SiglipVisionModel | 
					
						
						|  |  | 
					
						
						|  | def resize_img(input_image, max_side=1280, min_side=1024, size=None, | 
					
						
						|  | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): | 
					
						
						|  |  | 
					
						
						|  | w, h = input_image.size | 
					
						
						|  | if size is not None: | 
					
						
						|  | w_resize_new, h_resize_new = size | 
					
						
						|  | else: | 
					
						
						|  | ratio = min_side / min(h, w) | 
					
						
						|  | w, h = round(ratio*w), round(ratio*h) | 
					
						
						|  | ratio = max_side / max(h, w) | 
					
						
						|  | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) | 
					
						
						|  | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number | 
					
						
						|  | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number | 
					
						
						|  | input_image = input_image.resize([w_resize_new, h_resize_new], mode) | 
					
						
						|  |  | 
					
						
						|  | if pad_to_max_side: | 
					
						
						|  | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 | 
					
						
						|  | offset_x = (max_side - w_resize_new) // 2 | 
					
						
						|  | offset_y = (max_side - h_resize_new) // 2 | 
					
						
						|  | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) | 
					
						
						|  | input_image = Image.fromarray(res) | 
					
						
						|  | return input_image | 
					
						
						|  |  | 
					
						
						|  | class MLPProjModel(torch.nn.Module): | 
					
						
						|  | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.cross_attention_dim = cross_attention_dim | 
					
						
						|  | self.num_tokens = num_tokens | 
					
						
						|  |  | 
					
						
						|  | self.proj = torch.nn.Sequential( | 
					
						
						|  | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), | 
					
						
						|  | torch.nn.GELU(), | 
					
						
						|  | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), | 
					
						
						|  | ) | 
					
						
						|  | self.norm = torch.nn.LayerNorm(cross_attention_dim) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, id_embeds): | 
					
						
						|  | x = self.proj(id_embeds) | 
					
						
						|  | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) | 
					
						
						|  | x = self.norm(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | class IPAdapter: | 
					
						
						|  | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): | 
					
						
						|  | self.device = device | 
					
						
						|  | self.image_encoder_path = image_encoder_path | 
					
						
						|  | self.ip_ckpt = ip_ckpt | 
					
						
						|  | self.num_tokens = num_tokens | 
					
						
						|  |  | 
					
						
						|  | self.pipe = sd_pipe.to(self.device) | 
					
						
						|  | self.set_ip_adapter() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path).to(self.device, dtype=torch.bfloat16) | 
					
						
						|  | self.clip_image_processor = AutoProcessor.from_pretrained(self.image_encoder_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.image_proj_model = self.init_proj() | 
					
						
						|  |  | 
					
						
						|  | self.load_ip_adapter() | 
					
						
						|  |  | 
					
						
						|  | def init_proj(self): | 
					
						
						|  | image_proj_model = MLPProjModel( | 
					
						
						|  | cross_attention_dim=self.pipe.transformer.config.joint_attention_dim, | 
					
						
						|  | id_embeddings_dim=1152, | 
					
						
						|  | num_tokens=self.num_tokens, | 
					
						
						|  | ).to(self.device, dtype=torch.bfloat16) | 
					
						
						|  |  | 
					
						
						|  | return image_proj_model | 
					
						
						|  |  | 
					
						
						|  | def set_ip_adapter(self): | 
					
						
						|  | transformer = self.pipe.transformer | 
					
						
						|  | ip_attn_procs = {} | 
					
						
						|  | for name in transformer.attn_processors.keys(): | 
					
						
						|  | if name.startswith("transformer_blocks.") or name.startswith("single_transformer_blocks"): | 
					
						
						|  | ip_attn_procs[name] = IPAFluxAttnProcessor2_0( | 
					
						
						|  | hidden_size=transformer.config.num_attention_heads * transformer.config.attention_head_dim, | 
					
						
						|  | cross_attention_dim=transformer.config.joint_attention_dim, | 
					
						
						|  | num_tokens=self.num_tokens, | 
					
						
						|  | ).to(self.device, dtype=torch.bfloat16) | 
					
						
						|  | else: | 
					
						
						|  | ip_attn_procs[name] = transformer.attn_processors[name] | 
					
						
						|  |  | 
					
						
						|  | transformer.set_attn_processor(ip_attn_procs) | 
					
						
						|  |  | 
					
						
						|  | def load_ip_adapter(self): | 
					
						
						|  | state_dict = torch.load(self.ip_ckpt, map_location="cpu") | 
					
						
						|  | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) | 
					
						
						|  | ip_layers = torch.nn.ModuleList(self.pipe.transformer.attn_processors.values()) | 
					
						
						|  | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) | 
					
						
						|  |  | 
					
						
						|  | @torch.inference_mode() | 
					
						
						|  | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): | 
					
						
						|  | if pil_image is not None: | 
					
						
						|  | if isinstance(pil_image, Image.Image): | 
					
						
						|  | pil_image = [pil_image] | 
					
						
						|  | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values | 
					
						
						|  | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output | 
					
						
						|  | clip_image_embeds = clip_image_embeds.to(dtype=torch.bfloat16) | 
					
						
						|  | else: | 
					
						
						|  | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16) | 
					
						
						|  | image_prompt_embeds = self.image_proj_model(clip_image_embeds) | 
					
						
						|  | return image_prompt_embeds | 
					
						
						|  |  | 
					
						
						|  | def set_scale(self, scale): | 
					
						
						|  | for attn_processor in self.pipe.transformer.attn_processors.values(): | 
					
						
						|  | if isinstance(attn_processor, IPAFluxAttnProcessor2_0): | 
					
						
						|  | attn_processor.scale = scale | 
					
						
						|  |  | 
					
						
						|  | def generate( | 
					
						
						|  | self, | 
					
						
						|  | pil_image=None, | 
					
						
						|  | clip_image_embeds=None, | 
					
						
						|  | prompt=None, | 
					
						
						|  | scale=1.0, | 
					
						
						|  | num_samples=1, | 
					
						
						|  | seed=None, | 
					
						
						|  | guidance_scale=3.5, | 
					
						
						|  | num_inference_steps=24, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | self.set_scale(scale) | 
					
						
						|  |  | 
					
						
						|  | image_prompt_embeds = self.get_image_embeds( | 
					
						
						|  | pil_image=pil_image, clip_image_embeds=clip_image_embeds | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if seed is None: | 
					
						
						|  | generator = None | 
					
						
						|  | else: | 
					
						
						|  | generator = torch.Generator(self.device).manual_seed(seed) | 
					
						
						|  |  | 
					
						
						|  | images = self.pipe( | 
					
						
						|  | prompt=prompt, | 
					
						
						|  | image_emb=image_prompt_embeds, | 
					
						
						|  | guidance_scale=guidance_scale, | 
					
						
						|  | num_inference_steps=num_inference_steps, | 
					
						
						|  | generator=generator, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ).images | 
					
						
						|  |  | 
					
						
						|  | return images | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  |  | 
					
						
						|  | model_path = "black-forest-labs/FLUX.1-dev" | 
					
						
						|  | image_encoder_path = "google/siglip-so400m-patch14-384" | 
					
						
						|  | ipadapter_path = "./ip-adapter.bin" | 
					
						
						|  |  | 
					
						
						|  | transformer = FluxTransformer2DModel.from_pretrained( | 
					
						
						|  | model_path, subfolder="transformer", torch_dtype=torch.bfloat16 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | pipe = FluxPipeline.from_pretrained( | 
					
						
						|  | model_path, transformer=transformer, torch_dtype=torch.bfloat16 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | ip_model = IPAdapter(pipe, image_encoder_path, ipadapter_path, device="cuda", num_tokens=128) | 
					
						
						|  |  | 
					
						
						|  | image_dir = "./assets/images/2.jpg" | 
					
						
						|  | image_name = image_dir.split("/")[-1] | 
					
						
						|  | image = Image.open(image_dir).convert("RGB") | 
					
						
						|  | image = resize_img(image) | 
					
						
						|  |  | 
					
						
						|  | prompt = "a young girl" | 
					
						
						|  |  | 
					
						
						|  | images = ip_model.generate( | 
					
						
						|  | pil_image=image, | 
					
						
						|  | prompt=prompt, | 
					
						
						|  | scale=0.7, | 
					
						
						|  | width=960, height=1280, | 
					
						
						|  | seed=42 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | images[0].save(f"results/{image_name}") |