import io import os from functools import cache, lru_cache from pathlib import Path from typing import Any import gradio as gr import spaces import torch from finegrain import CutoutResultWithImage, EditorAPIContext, ErrorResult from finegrain_toolbox.flux import Model from finegrain_toolbox.flux.prompt import prompt_with_embeds from finegrain_toolbox.processors import product_placement from gradio_image_annotation import image_annotator from huggingface_hub import hf_hub_download from PIL import Image from safetensors.torch import load_file from typing_extensions import TypeIs # initialize on CPU then move to GPU (Zero GPU) DEVICE_CPU = torch.device("cpu") DTYPE = torch.bfloat16 FG_API_KEY = os.getenv("FG_API_KEY") model = Model.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", device=DEVICE_CPU, dtype=DTYPE) lora_path = Path( hf_hub_download( repo_id="finegrain/finegrain-product-placement-lora", filename="finegrain-placement-v1-rank8.safetensors", ) ) prompt_path = Path( hf_hub_download( repo_id="finegrain/finegrain-product-placement-lora", filename="addinbox-prompt.safetensors", ) ) prompt_st = load_file(prompt_path, device="cpu") prompt = prompt_with_embeds( text="Add this in the box", clip_prompt_embeds=prompt_st["clip"], t5_prompt_embeds=prompt_st["t5"], ) model.transformer.load_lora_adapter(lora_path, adapter_name="placement") model.transformer.fuse_lora() model.transformer.unload_lora() DEVICE = torch.device("cuda") model = model.to(device=DEVICE, dtype=DTYPE) prompt = prompt.to(device=DEVICE, dtype=DTYPE) @cache def _ctx() -> EditorAPIContext: assert FG_API_KEY is not None return EditorAPIContext( api_key=FG_API_KEY, user_agent="fg-hf-placement", priority="low", ) def on_change(scene: dict[str, Any] | None, reference: Image.Image | None) -> tuple[dict[str, Any], str]: bbox_str = "" if scene is not None and isinstance(scene.get("boxes"), list) and len(scene.get("boxes", [])) == 1: assert scene is not None box = scene["boxes"][0] bbox_str = f"({box['xmin']}, {box['ymin']}, {box['xmax']}, {box['ymax']})" return (gr.update(interactive=reference is not None and bbox_str != ""), bbox_str) @spaces.GPU(duration=120) def _process( scene: dict[str, Any], reference: Image.Image, seed: int = 1234, ) -> tuple[tuple[Image.Image, Image.Image], Image.Image, Image.Image]: assert isinstance(scene_image := scene["image"], Image.Image) assert isinstance(boxes := scene["boxes"], list) assert len(boxes) == 1 assert isinstance(box := boxes[0], dict) bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"]) result = product_placement.process( model=model, scene=scene_image, reference=reference, bbox=bbox, prompt=prompt, seed=seed, max_short_size=1024, max_long_size=2048, ) output = result.output before_after = (scene_image.resize(output.size), output) return (before_after, result.reference, result.scene) def _is_error(result: Any) -> TypeIs[ErrorResult]: if isinstance(result, ErrorResult): raise RuntimeError(result.error) return False @lru_cache(maxsize=32) def _cutout_reference(image_bytes: bytes) -> Image.Image: async def _process(ctx: EditorAPIContext, image_bytes: bytes) -> Image.Image: st_input = await ctx.call_async.upload_image(image_bytes) name_r = await ctx.call_async.infer_product_name(st_input) assert not _is_error(name_r) bbox_r = await ctx.call_async.infer_bbox(st_input, product_name=name_r.is_product) assert not _is_error(bbox_r) mask_r = await ctx.call_async.segment(st_input, bbox=bbox_r.bbox) assert not _is_error(mask_r) cutout_r = await ctx.call_async.cutout(st_input, mask_r.state_id, with_image=True) assert not _is_error(cutout_r) assert isinstance(cutout_r, CutoutResultWithImage) return Image.open(io.BytesIO(cutout_r.image)) api_ctx = _ctx() try: cutout = api_ctx.run_one_sync(_process, image_bytes) except AssertionError: api_ctx.reset() cutout = api_ctx.run_one_sync(_process, image_bytes) return cutout def cutout_reference(reference: Image.Image) -> Image.Image: buf = io.BytesIO() reference.save(buf, format="PNG") return _cutout_reference(buf.getvalue()) def process( scene: dict[str, Any], reference: Image.Image, seed: int = 1234, cut_out_reference: bool = False, ) -> tuple[tuple[Image.Image, Image.Image], Image.Image, Image.Image]: if cut_out_reference: reference = cutout_reference(reference) return _process(scene, reference, seed) TITLE = """

Finegrain Product Placement LoRA

🧪 An experiment to extend Flux Kontext with product placement capabilities. The LoRA was trained using EditNet, our before / after image editing dataset.

Just draw a box to set where the subject should be blended, and at what size.

Model Card | Blog Post | EditNet

""" with gr.Blocks() as demo: gr.HTML(TITLE) with gr.Row(): with gr.Column(): scene = image_annotator( label="Scene", image_type="pil", disable_edit_boxes=True, show_download_button=False, show_share_button=False, single_box=True, image_mode="RGB", ) reference = gr.Image( label="Product Reference", visible=True, interactive=True, type="pil", image_mode="RGBA", ) with gr.Accordion("Options", open=False): seed = gr.Slider( minimum=0, maximum=10_000, value=1234, step=1, label="Seed", ) cut_out_reference = gr.Checkbox( label="Cut out reference", value=bool(FG_API_KEY), interactive=bool(FG_API_KEY), ) with gr.Row(): run_btn = gr.ClearButton(value="Blend", interactive=False) with gr.Column(): output_image = gr.ImageSlider(label="Output Image", show_fullscreen_button=False) with gr.Accordion("Debug", open=False): output_textbox = gr.Textbox(label="Bounding Box", interactive=False) output_reference = gr.Image( label="Reference", visible=True, interactive=False, type="pil", image_mode="RGB", ) output_scene = gr.Image( label="Scene", visible=True, interactive=False, type="pil", image_mode="RGB", ) run_btn.add(output_image) # Watch for changes (scene and reference) # i.e. the user must select a box in the scene and upload a reference image scene.change(fn=on_change, inputs=[scene, reference], outputs=[run_btn, output_textbox]) reference.change(fn=on_change, inputs=[scene, reference], outputs=[run_btn, output_textbox]) run_btn.click( fn=process, inputs=[scene, reference, seed, cut_out_reference], outputs=[output_image, output_reference, output_scene], ) examples = [ [ { "image": "examples/sunglasses/scene.jpg", "boxes": [{"xmin": 164, "ymin": 89, "xmax": 379, "ymax": 204}], }, "examples/sunglasses/reference.webp", ], [ { "image": "examples/kitchen/scene.webp", "boxes": [{"xmin": 165, "ymin": 765, "xmax": 332, "ymax": 883}], }, "examples/kitchen/reference.webp", ], [ { "image": "examples/glass/scene.webp", "boxes": [{"xmin": 389, "ymin": 509, "xmax": 611, "ymax": 1088}], }, "examples/glass/reference.webp", ], [ { "image": "examples/chair/scene.webp", "boxes": [{"xmin": 366, "ymin": 389, "xmax": 623, "ymax": 728}], }, "examples/chair/reference.webp", ], [ { "image": "examples/lantern/scene.webp", "boxes": [{"xmin": 497, "ymin": 690, "xmax": 618, "ymax": 873}], }, "examples/lantern/reference.webp", ], ] ex = gr.Examples( examples=examples, inputs=[scene, reference], outputs=[output_image, output_reference, output_scene], fn=process, cache_examples=True, cache_mode="eager", ) demo.launch(show_api=False, ssr_mode=False)