File size: 4,334 Bytes
50e876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# filename: ip_adapter_multi_mode.py

import torch
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionInpaintPipelineLegacy,
    DDIMScheduler,
    AutoencoderKL,
)
from PIL import Image
from ip_adapter import IPAdapter


class IPAdapterRunner:
    def __init__(
        self,
        base_model_path="runwayml/stable-diffusion-v1-5",
        vae_model_path="stabilityai/sd-vae-ft-mse",
        image_encoder_path="models/image_encoder/",
        ip_ckpt="models/ip-adapter_sd15.bin",
        device="cuda"
    ):
        self.base_model_path = base_model_path
        self.vae_model_path = vae_model_path
        self.image_encoder_path = image_encoder_path
        self.ip_ckpt = ip_ckpt
        self.device = device
        self.vae = self._load_vae()
        self.scheduler = self._create_scheduler()
        self.pipe = None
        self.ip_model = None

    def _create_scheduler(self):
        return DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
            steps_offset=1,
        )

    def _load_vae(self):
        return AutoencoderKL.from_pretrained(self.vae_model_path).to(dtype=torch.float16)

    def _clear_previous_pipe(self):
        if self.pipe:
            del self.pipe
            del self.ip_model
            torch.cuda.empty_cache()

    def _load_pipeline(self, mode):
        self._clear_previous_pipe()
        if mode == "text2img":
            self.pipe = StableDiffusionPipeline.from_pretrained(
                self.base_model_path,
                torch_dtype=torch.float16,
                scheduler=self.scheduler,
                vae=self.vae,
                feature_extractor=None,
                safety_checker=None,
            )
        elif mode == "img2img":
            self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
                self.base_model_path,
                torch_dtype=torch.float16,
                scheduler=self.scheduler,
                vae=self.vae,
                feature_extractor=None,
                safety_checker=None,
            )
        elif mode == "inpaint":
            self.pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
                self.base_model_path,
                torch_dtype=torch.float16,
                scheduler=self.scheduler,
                vae=self.vae,
                feature_extractor=None,
                safety_checker=None,
            )
        else:
            raise ValueError(f"Unsupported mode: {mode}")
        self.ip_model = IPAdapter(self.pipe, self.image_encoder_path, self.ip_ckpt, self.device)

    def generate_text2img(self, pil_image, num_samples=4, num_inference_steps=50, seed=42):
        self._load_pipeline("text2img")
        pil_image = pil_image.resize((256, 256))
        return self.ip_model.generate(
            pil_image=pil_image,
            num_samples=num_samples,
            num_inference_steps=num_inference_steps,
            seed=seed,
        )

    def generate_img2img(self, pil_image, reference_image, strength=0.6, num_samples=4, num_inference_steps=50, seed=42):
        self._load_pipeline("img2img")
        return self.ip_model.generate(
            pil_image=pil_image,
            image=reference_image,
            strength=strength,
            num_samples=num_samples,
            num_inference_steps=num_inference_steps,
            seed=seed,
        )

    def generate_inpaint(self, pil_image, image, mask_image, strength=0.7, num_samples=4, num_inference_steps=50, seed=42):
        self._load_pipeline("inpaint")
        return self.ip_model.generate(
            pil_image=pil_image,
            image=image,
            mask_image=mask_image,
            strength=strength,
            num_samples=num_samples,
            num_inference_steps=num_inference_steps,
            seed=seed,
        )

    @staticmethod
    def image_grid(imgs, rows, cols):
        assert len(imgs) == rows * cols
        w, h = imgs[0].size
        grid = Image.new('RGB', size=(cols * w, rows * h))
        for i, img in enumerate(imgs):
            grid.paste(img, box=(i % cols * w, i // cols * h))
        return grid