|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import subprocess |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
EulerDiscreteScheduler, |
|
|
UNet2DConditionModel, |
|
|
) |
|
|
from kolors.models.modeling_chatglm import ChatGLMModel |
|
|
from kolors.models.tokenization_chatglm import ChatGLMTokenizer |
|
|
from kolors.models.unet_2d_condition import ( |
|
|
UNet2DConditionModel as UNet2DConditionModelIP, |
|
|
) |
|
|
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( |
|
|
StableDiffusionXLPipeline, |
|
|
) |
|
|
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( |
|
|
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, |
|
|
) |
|
|
from PIL import Image |
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"build_text2img_ip_pipeline", |
|
|
"build_text2img_pipeline", |
|
|
"text2img_gen", |
|
|
"download_kolors_weights", |
|
|
] |
|
|
|
|
|
PROMPT_APPEND = ( |
|
|
"Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, " |
|
|
"no surroundings, high-quality appearance, vivid colors, on a plain clean surface, 3D style revealing multiple surfaces" |
|
|
) |
|
|
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality" |
|
|
|
|
|
|
|
|
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None: |
|
|
"""Downloads Kolors model weights from HuggingFace. |
|
|
|
|
|
Args: |
|
|
local_dir (str, optional): Local directory to store weights. |
|
|
""" |
|
|
logger.info(f"Download kolors weights from huggingface...") |
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
subprocess.run( |
|
|
[ |
|
|
"huggingface-cli", |
|
|
"download", |
|
|
"--resume-download", |
|
|
"Kwai-Kolors/Kolors", |
|
|
"--local-dir", |
|
|
local_dir, |
|
|
], |
|
|
check=True, |
|
|
) |
|
|
|
|
|
ip_adapter_path = f"{local_dir}/../Kolors-IP-Adapter-Plus" |
|
|
subprocess.run( |
|
|
[ |
|
|
"huggingface-cli", |
|
|
"download", |
|
|
"--resume-download", |
|
|
"Kwai-Kolors/Kolors-IP-Adapter-Plus", |
|
|
"--local-dir", |
|
|
ip_adapter_path, |
|
|
], |
|
|
check=True, |
|
|
) |
|
|
|
|
|
|
|
|
def build_text2img_ip_pipeline( |
|
|
ckpt_dir: str, |
|
|
ref_scale: float, |
|
|
device: str = "cuda", |
|
|
) -> StableDiffusionXLPipelineIP: |
|
|
"""Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation. |
|
|
|
|
|
Args: |
|
|
ckpt_dir (str): Directory containing model checkpoints. |
|
|
ref_scale (float): Reference scale for IP-Adapter. |
|
|
device (str, optional): Device for inference. |
|
|
|
|
|
Returns: |
|
|
StableDiffusionXLPipelineIP: Configured pipeline. |
|
|
|
|
|
Example: |
|
|
```py |
|
|
from embodied_gen.models.text_model import build_text2img_ip_pipeline |
|
|
pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3) |
|
|
``` |
|
|
""" |
|
|
download_kolors_weights(ckpt_dir) |
|
|
|
|
|
text_encoder = ChatGLMModel.from_pretrained( |
|
|
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 |
|
|
).half() |
|
|
tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
f"{ckpt_dir}/vae", revision=None |
|
|
).half() |
|
|
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") |
|
|
unet = UNet2DConditionModelIP.from_pretrained( |
|
|
f"{ckpt_dir}/unet", revision=None |
|
|
).half() |
|
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
|
|
f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder", |
|
|
ignore_mismatched_sizes=True, |
|
|
).to(dtype=torch.float16) |
|
|
clip_image_processor = CLIPImageProcessor(size=336, crop_size=336) |
|
|
|
|
|
pipe = StableDiffusionXLPipelineIP( |
|
|
vae=vae, |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
unet=unet, |
|
|
scheduler=scheduler, |
|
|
image_encoder=image_encoder, |
|
|
feature_extractor=clip_image_processor, |
|
|
force_zeros_for_empty_prompt=False, |
|
|
) |
|
|
|
|
|
if hasattr(pipe.unet, "encoder_hid_proj"): |
|
|
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj |
|
|
|
|
|
pipe.load_ip_adapter( |
|
|
f"{ckpt_dir}/../Kolors-IP-Adapter-Plus", |
|
|
subfolder="", |
|
|
weight_name=["ip_adapter_plus_general.bin"], |
|
|
) |
|
|
pipe.set_ip_adapter_scale([ref_scale]) |
|
|
|
|
|
pipe = pipe.to(device) |
|
|
pipe.image_encoder = pipe.image_encoder.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return pipe |
|
|
|
|
|
|
|
|
def build_text2img_pipeline( |
|
|
ckpt_dir: str, |
|
|
device: str = "cuda", |
|
|
) -> StableDiffusionXLPipeline: |
|
|
"""Builds a Stable Diffusion XL pipeline for text-to-image generation. |
|
|
|
|
|
Args: |
|
|
ckpt_dir (str): Directory containing model checkpoints. |
|
|
device (str, optional): Device for inference. |
|
|
|
|
|
Returns: |
|
|
StableDiffusionXLPipeline: Configured pipeline. |
|
|
|
|
|
Example: |
|
|
```py |
|
|
from embodied_gen.models.text_model import build_text2img_pipeline |
|
|
pipe = build_text2img_pipeline("weights/Kolors") |
|
|
``` |
|
|
""" |
|
|
download_kolors_weights(ckpt_dir) |
|
|
|
|
|
text_encoder = ChatGLMModel.from_pretrained( |
|
|
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 |
|
|
).half() |
|
|
tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
f"{ckpt_dir}/vae", revision=None |
|
|
).half() |
|
|
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") |
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
f"{ckpt_dir}/unet", revision=None |
|
|
).half() |
|
|
pipe = StableDiffusionXLPipeline( |
|
|
vae=vae, |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
unet=unet, |
|
|
scheduler=scheduler, |
|
|
force_zeros_for_empty_prompt=False, |
|
|
) |
|
|
pipe = pipe.to(device) |
|
|
|
|
|
|
|
|
|
|
|
return pipe |
|
|
|
|
|
|
|
|
def text2img_gen( |
|
|
prompt: str, |
|
|
n_sample: int, |
|
|
guidance_scale: float, |
|
|
pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP, |
|
|
ip_image: Image.Image | str = None, |
|
|
image_wh: tuple[int, int] = [1024, 1024], |
|
|
infer_step: int = 50, |
|
|
ip_image_size: int = 512, |
|
|
seed: int = None, |
|
|
) -> list[Image.Image]: |
|
|
"""Generates images from text prompts using a Stable Diffusion XL pipeline. |
|
|
|
|
|
Args: |
|
|
prompt (str): Text prompt for image generation. |
|
|
n_sample (int): Number of images to generate. |
|
|
guidance_scale (float): Guidance scale for diffusion. |
|
|
pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance. |
|
|
ip_image (Image.Image | str, optional): Reference image for IP-Adapter. |
|
|
image_wh (tuple[int, int], optional): Output image size (width, height). |
|
|
infer_step (int, optional): Number of inference steps. |
|
|
ip_image_size (int, optional): Size for IP-Adapter image. |
|
|
seed (int, optional): Random seed. |
|
|
|
|
|
Returns: |
|
|
list[Image.Image]: List of generated images. |
|
|
|
|
|
Example: |
|
|
```py |
|
|
from embodied_gen.models.text_model import text2img_gen |
|
|
images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5) |
|
|
images[0].save("banana.png") |
|
|
``` |
|
|
""" |
|
|
prompt = PROMPT_KAPPEND.format(object=prompt.strip()) |
|
|
logger.info(f"Processing prompt: {prompt}") |
|
|
|
|
|
generator = None |
|
|
if seed is not None: |
|
|
generator = torch.Generator(pipeline.device).manual_seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
kwargs = dict( |
|
|
prompt=prompt, |
|
|
height=image_wh[1], |
|
|
width=image_wh[0], |
|
|
num_inference_steps=infer_step, |
|
|
guidance_scale=guidance_scale, |
|
|
num_images_per_prompt=n_sample, |
|
|
generator=generator, |
|
|
) |
|
|
if ip_image is not None: |
|
|
if isinstance(ip_image, str): |
|
|
ip_image = Image.open(ip_image) |
|
|
ip_image = ip_image.resize((ip_image_size, ip_image_size)) |
|
|
kwargs.update(ip_adapter_image=[ip_image]) |
|
|
|
|
|
return pipeline(**kwargs).images |
|
|
|