Spaces:
Runtime error
Runtime error
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Business logic functions for IC-Custom application. | |
| """ | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import gradio as gr | |
| from PIL import Image | |
| from datetime import datetime | |
| import json | |
| import os | |
| from scipy.ndimage import binary_dilation, binary_erosion | |
| from constants import ( | |
| DEFAULT_BACKGROUND_BLEND_THRESHOLD, DEFAULT_SEED, DEFAULT_NUM_IMAGES, | |
| DEFAULT_GUIDANCE, DEFAULT_TRUE_GS, DEFAULT_NUM_STEPS, DEFAULT_ASPECT_RATIO, | |
| DEFAULT_DILATION_KERNEL_SIZE, DEFAULT_MARKER_SIZE, DEFAULT_MARKER_THICKNESS, | |
| DEFAULT_MASK_ALPHA, DEFAULT_COLOR_ALPHA, TIMESTAMP_FORMAT, SEGMENTATION_COLORS, SEGMENTATION_MARKERS | |
| ) | |
| from utils import run_vlm, construct_vlm_gen_prompt, construct_vlm_polish_prompt | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Global holder for SAM mobile predictor injected from the app layer | |
| MOBILE_PREDICTOR = None | |
| BEN2_MODEL = None # ben2 model injected from the app layer | |
| def set_mobile_predictor(predictor): | |
| """Inject SAM mobile predictor into this module without changing function signatures.""" | |
| global MOBILE_PREDICTOR | |
| MOBILE_PREDICTOR = predictor | |
| def set_ben2_model(ben2_model): | |
| """Inject ben2 model into this module without changing function signatures.""" | |
| global BEN2_MODEL | |
| BEN2_MODEL = ben2_model | |
| def set_vlm_processor(vlm_processor): | |
| """Inject vlm processor into this module without changing function signatures.""" | |
| global VLM_PROCESSOR | |
| VLM_PROCESSOR = vlm_processor | |
| def set_vlm_model(vlm_model): | |
| """Inject vlm model into this module without changing function signatures.""" | |
| global VLM_MODEL | |
| VLM_MODEL = vlm_model | |
| def init_image_target_1(target_image): | |
| """Initialize UI state when a target image is uploaded.""" | |
| # Handle both PIL Image (image_target_1) and ImageEditor dict (image_target_2) | |
| try: | |
| if isinstance(target_image, dict) and 'composite' in target_image: | |
| # ImageEditor format (user-drawn mask) | |
| image_target_state = np.array(target_image['composite'].convert("RGB")) | |
| else: | |
| # PIL Image format (precise mask) | |
| image_target_state = np.array(target_image.convert("RGB")) | |
| except Exception as e: | |
| # If there's an error processing the image, skip initialization | |
| return ( | |
| gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), | |
| gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), | |
| gr.skip(), gr.skip(), gr.update(value="-1") | |
| ) | |
| selected_points = [] | |
| mask_target_state = None | |
| prompt = None | |
| mask_gallery = [] | |
| result_gallery = [] | |
| use_background_preservation = False | |
| background_blend_threshold = DEFAULT_BACKGROUND_BLEND_THRESHOLD | |
| seed = DEFAULT_SEED | |
| num_images_per_prompt = DEFAULT_NUM_IMAGES | |
| guidance = DEFAULT_GUIDANCE | |
| true_gs = DEFAULT_TRUE_GS | |
| num_steps = DEFAULT_NUM_STEPS | |
| aspect_ratio_val = gr.update(value=DEFAULT_ASPECT_RATIO) | |
| return (image_target_state, selected_points, mask_target_state, prompt, | |
| mask_gallery, result_gallery, use_background_preservation, | |
| background_blend_threshold, seed, num_images_per_prompt, guidance, | |
| true_gs, num_steps, aspect_ratio_val) | |
| def init_image_target_2(target_image): | |
| """Initialize UI state when a target image is uploaded.""" | |
| # Handle both PIL Image (image_target_1) and ImageEditor dict (image_target_2) | |
| try: | |
| if isinstance(target_image, dict) and 'composite' in target_image: | |
| # ImageEditor format (user-drawn mask) | |
| image_target_state = np.array(target_image['composite'].convert("RGB")) | |
| else: | |
| # PIL Image format (precise mask) | |
| image_target_state = np.array(target_image.convert("RGB")) | |
| except Exception as e: | |
| # If there's an error processing the image, skip initialization | |
| return ( | |
| gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), | |
| gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), | |
| gr.skip(), gr.skip(), gr.update(value="-1") | |
| ) | |
| selected_points = gr.skip() | |
| mask_target_state = gr.skip() | |
| prompt = gr.skip() | |
| mask_gallery = gr.skip() | |
| result_gallery = gr.skip() | |
| use_background_preservation = gr.skip() | |
| background_blend_threshold = gr.skip() | |
| seed = gr.skip() | |
| num_images_per_prompt = gr.skip() | |
| guidance = gr.skip() | |
| true_gs = gr.skip() | |
| num_steps = gr.skip() | |
| aspect_ratio_val = gr.skip() | |
| return (image_target_state, selected_points, mask_target_state, prompt, | |
| mask_gallery, result_gallery, use_background_preservation, | |
| background_blend_threshold, seed, num_images_per_prompt, guidance, | |
| true_gs, num_steps, aspect_ratio_val) | |
| def init_image_reference(image_reference): | |
| """Initialize all UI states when a reference image is uploaded.""" | |
| image_reference_state = np.array(image_reference.convert("RGB")) | |
| image_reference_ori_state = image_reference_state | |
| image_reference_rmbg_state = None | |
| image_target_state = None | |
| mask_target_state = None | |
| prompt = None | |
| mask_gallery = [] | |
| result_gallery = [] | |
| image_target_1_val = None | |
| image_target_2_val = None | |
| selected_points = [] | |
| input_mask_mode_val = gr.update(value="Precise mask") | |
| seg_ref_mode_val = gr.update(value="Full Ref") | |
| move_to_center = False | |
| use_background_preservation = False | |
| background_blend_threshold = DEFAULT_BACKGROUND_BLEND_THRESHOLD | |
| seed = DEFAULT_SEED | |
| num_images_per_prompt = DEFAULT_NUM_IMAGES | |
| guidance = DEFAULT_GUIDANCE | |
| true_gs = DEFAULT_TRUE_GS | |
| num_steps = DEFAULT_NUM_STEPS | |
| aspect_ratio_val = gr.update(value=DEFAULT_ASPECT_RATIO) | |
| return ( | |
| image_reference_ori_state, image_reference_rmbg_state, image_target_state, | |
| mask_target_state, prompt, mask_gallery, result_gallery, image_target_1_val, | |
| image_target_2_val, selected_points, input_mask_mode_val, seg_ref_mode_val, | |
| move_to_center, use_background_preservation, background_blend_threshold, | |
| seed, num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio_val, | |
| ) | |
| def undo_seg_points(orig_img, sel_pix): | |
| """Remove the latest segmentation point and recompute the preview mask.""" | |
| if len(sel_pix) != 0: | |
| temp = orig_img.copy() | |
| sel_pix.pop() | |
| # Online show seg mask | |
| if len(sel_pix) != 0: | |
| temp, output_mask = segmentation(temp, sel_pix, MOBILE_PREDICTOR, SEGMENTATION_COLORS, SEGMENTATION_MARKERS) | |
| output_mask_pil = Image.fromarray(output_mask.astype("uint8")) | |
| masked_img_pil = Image.fromarray(np.where(output_mask > 0, orig_img, 0).astype("uint8")) | |
| mask_gallery = [masked_img_pil, output_mask_pil] | |
| else: | |
| output_mask = None | |
| mask_gallery = [] | |
| return temp.astype(np.uint8), output_mask, mask_gallery | |
| else: | |
| gr.Warning("Nothing to Undo") | |
| return orig_img, None, [] | |
| def segmentation(img, sel_pix, mobile_predictor, colors, markers): | |
| """Run SAM-based segmentation given selected points and return previews.""" | |
| points = [] | |
| labels = [] | |
| for p, l in sel_pix: | |
| points.append(p) | |
| labels.append(l) | |
| mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img)) | |
| with torch.no_grad(): | |
| masks, _, _ = mobile_predictor.predict( | |
| point_coords=np.array(points), | |
| point_labels=np.array(labels), | |
| multimask_output=False | |
| ) | |
| output_mask = np.ones((masks.shape[1], masks.shape[2], 3)) * 255 | |
| for i in range(3): | |
| output_mask[masks[0] == True, i] = 0.0 | |
| mask_all = np.ones((masks.shape[1], masks.shape[2], 3)) | |
| color_mask = np.random.random((1, 3)).tolist()[0] | |
| for i in range(3): | |
| mask_all[masks[0] == True, i] = color_mask[i] | |
| masked_img = img / 255 * DEFAULT_MASK_ALPHA + mask_all * DEFAULT_COLOR_ALPHA | |
| masked_img = masked_img * 255 | |
| # Draw points | |
| for point, label in sel_pix: | |
| cv2.drawMarker( | |
| masked_img, point, colors[label], | |
| markerType=markers[label], | |
| markerSize=DEFAULT_MARKER_SIZE, | |
| thickness=DEFAULT_MARKER_THICKNESS | |
| ) | |
| return masked_img, output_mask | |
| def get_point(img, sel_pix, evt: gr.SelectData): | |
| """Handle a user click on the target image to add a foreground point.""" | |
| if evt is None or not hasattr(evt, 'index'): | |
| gr.Warning(f"Event object missing index attribute. Event type: {type(evt)}") | |
| return img, None, [] | |
| sel_pix.append((evt.index, 1)) # append the foreground_point | |
| # Online show seg mask | |
| global MOBILE_PREDICTOR | |
| masked_img_seg, output_mask = segmentation(img, sel_pix, MOBILE_PREDICTOR, SEGMENTATION_COLORS, SEGMENTATION_MARKERS) | |
| # Apply dilation to output_mask | |
| output_mask = 1 - output_mask | |
| kernel = np.ones((DEFAULT_DILATION_KERNEL_SIZE, DEFAULT_DILATION_KERNEL_SIZE), np.uint8) | |
| output_mask = cv2.dilate(output_mask, kernel, iterations=1) | |
| output_mask = 1 - output_mask | |
| output_mask_binary = output_mask / 255 | |
| masked_img_seg = masked_img_seg.astype("uint8") | |
| output_mask = output_mask.astype("uint8") | |
| masked_img = img * output_mask_binary | |
| masked_img_pil = Image.fromarray(masked_img.astype("uint8")) | |
| output_mask_pil = Image.fromarray(output_mask.astype("uint8")) | |
| outputs_gallery = [masked_img_pil, output_mask_pil] | |
| return masked_img_seg, output_mask, outputs_gallery | |
| def get_brush(img): | |
| """Extract a mask from ImageEditor brush layers or composite/background diff.""" | |
| if img is None or not isinstance(img, dict): | |
| return gr.skip(), gr.skip() | |
| layers = img.get("layers", []) | |
| background = img.get('background', None) | |
| composite = img.get('composite', None) | |
| output_mask = None | |
| if layers and layers[0] is not None and background is not None: | |
| output_mask = 255 - np.array(layers[0].convert("RGB")).astype(np.uint8) | |
| elif composite is not None and background is not None: | |
| comp_rgb = np.array(composite.convert("RGB")).astype(np.int16) | |
| bg_rgb = np.array(background.convert("RGB")).astype(np.int16) | |
| diff = np.abs(comp_rgb - bg_rgb) | |
| painted = (diff.sum(axis=2) > 0).astype(np.uint8) | |
| output_mask = (1 - painted) * 255 | |
| output_mask = np.repeat(output_mask[:, :, None], 3, axis=2).astype(np.uint8) | |
| else: | |
| return gr.skip(), gr.skip() | |
| if len(np.unique(output_mask)) == 1: | |
| return gr.skip(), gr.skip() | |
| img = np.array(background.convert("RGB")).astype(np.uint8) | |
| output_mask_binary = output_mask / 255 | |
| masked_img = img * output_mask_binary | |
| masked_img_pil = Image.fromarray(masked_img.astype("uint8")) | |
| output_mask_pil = Image.fromarray(output_mask.astype("uint8")) | |
| mask_gallery = [masked_img_pil, output_mask_pil] | |
| return output_mask, mask_gallery | |
| def random_mask_func(mask, dilation_type='square', dilation_size=20): | |
| """Utility to dilate/erode/box/ellipse expand a binary mask.""" | |
| binary_mask = mask[:,:,0] < 128 | |
| if dilation_type == 'square_dilation': | |
| structure = np.ones((dilation_size, dilation_size), dtype=bool) | |
| dilated_mask = binary_dilation(binary_mask, structure=structure) | |
| elif dilation_type == 'square_erosion': | |
| structure = np.ones((dilation_size, dilation_size), dtype=bool) | |
| dilated_mask = binary_erosion(binary_mask, structure=structure) | |
| elif dilation_type == 'bounding_box': | |
| # Find the most left top and left bottom point | |
| rows, cols = np.where(binary_mask) | |
| if len(rows) == 0 or len(cols) == 0: | |
| return mask # return original mask if no valid points | |
| min_row, max_row = np.min(rows), np.max(rows) | |
| min_col, max_col = np.min(cols), np.max(cols) | |
| # Create a bounding box | |
| dilated_mask = np.zeros_like(binary_mask, dtype=bool) | |
| dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True | |
| elif dilation_type == 'bounding_ellipse': | |
| # Find the most left top and left bottom point | |
| rows, cols = np.where(binary_mask) | |
| if len(rows) == 0 or len(cols) == 0: | |
| return mask # return original mask if no valid points | |
| min_row, max_row = np.min(rows), np.max(rows) | |
| min_col, max_col = np.min(cols), np.max(cols) | |
| # Calculate the center and axis length of the ellipse | |
| center = ((min_col + max_col) // 2, (min_row + max_row) // 2) | |
| a = (max_col - min_col) // 2 # half long axis | |
| b = (max_row - min_row) // 2 # half short axis | |
| # Create a bounding ellipse | |
| y, x = np.ogrid[:mask.shape[0], :mask.shape[1]] | |
| ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1 | |
| dilated_mask = np.zeros_like(binary_mask, dtype=bool) | |
| dilated_mask[ellipse_mask] = True | |
| else: | |
| raise ValueError("dilation_type must be 'square', 'ellipse', 'bounding_box', or 'bounding_ellipse'") | |
| # Use binary dilation | |
| dilated_mask = 1 - dilated_mask | |
| dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255 | |
| dilated_mask = np.concatenate([dilated_mask, dilated_mask, dilated_mask], axis=2) | |
| return dilated_mask | |
| def dilate_mask(mask, image): | |
| """Dilate the target mask for robustness and preview the result.""" | |
| if mask is None: | |
| gr.Warning("Please input the target mask first") | |
| return None, None | |
| mask = random_mask_func(mask, dilation_type='square_dilation', dilation_size=DEFAULT_DILATION_KERNEL_SIZE) | |
| masked_img = image * (mask > 0) | |
| return mask, [masked_img, mask] | |
| def erode_mask(mask, image): | |
| """Erode the target mask and preview the result.""" | |
| if mask is None: | |
| gr.Warning("Please input the target mask first") | |
| return None, None | |
| mask = random_mask_func(mask, dilation_type='square_erosion', dilation_size=DEFAULT_DILATION_KERNEL_SIZE) | |
| masked_img = image * (mask > 0) | |
| return mask, [masked_img, mask] | |
| def bounding_box(mask, image): | |
| """Create bounding box mask and preview the result.""" | |
| if mask is None: | |
| gr.Warning("Please input the target mask first") | |
| return None, None | |
| mask = random_mask_func(mask, dilation_type='bounding_box', dilation_size=DEFAULT_DILATION_KERNEL_SIZE) | |
| masked_img = image * (mask > 0) | |
| return mask, [masked_img, mask] | |
| def change_input_mask_mode(input_mask_mode, custmization_mode): | |
| """Change visibility of input mask mode components.""" | |
| if custmization_mode == "Position-free": | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| elif input_mask_mode.lower() == "precise mask": | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| ) | |
| elif input_mask_mode.lower() == "user-drawn mask": | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| ) | |
| else: | |
| gr.Warning("Invalid input mask mode") | |
| return ( | |
| gr.skip(), gr.skip(), gr.skip() | |
| ) | |
| def change_custmization_mode(custmization_mode, input_mask_mode): | |
| """Change visibility and interactivity based on customization mode.""" | |
| if custmization_mode.lower() == "position-free": | |
| return (gr.update(interactive=False, visible=False), | |
| gr.update(interactive=False, visible=False), | |
| gr.update(interactive=False, visible=False), | |
| gr.update(interactive=False, visible=False), | |
| gr.update(interactive=False, visible=False), | |
| gr.update(interactive=False, visible=False), | |
| gr.update(value="<s>Select a input mask mode</s>", visible=False), | |
| gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False), | |
| gr.update(value="<s>View or modify the target mask</s>", visible=False), | |
| gr.update(value="3\. Input text prompt (necessary)"), | |
| gr.update(value="4\. Submit and view the output"), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| else: | |
| if input_mask_mode.lower() == "precise mask": | |
| return (gr.update(interactive=True, visible=True), | |
| gr.update(interactive=True, visible=False), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(value="3\. Select a input mask mode", visible=True), | |
| gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True), | |
| gr.update(value="6\. View or modify the target mask", visible=True), | |
| gr.update(value="5\. Input text prompt (optional)", visible=True), | |
| gr.update(value="7\. Submit and view the output", visible=True), | |
| gr.update(visible=True, value="Precise mask"), | |
| gr.update(visible=True), | |
| ) | |
| elif input_mask_mode.lower() == "user-drawn mask": | |
| return (gr.update(interactive=True, visible=False), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(interactive=False, visible=False), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(interactive=True, visible=True), | |
| gr.update(value="3\. Select a input mask mode", visible=True), | |
| gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True), | |
| gr.update(value="6\. View or modify the target mask", visible=True), | |
| gr.update(value="5\. Input text prompt (optional)", visible=True), | |
| gr.update(value="7\. Submit and view the output", visible=True), | |
| gr.update(visible=True, value="User-drawn mask"), | |
| gr.update(visible=True), | |
| ) | |
| def change_seg_ref_mode(seg_ref_mode, image_reference_state, move_to_center): | |
| """Change segmentation reference mode and handle background removal.""" | |
| if image_reference_state is None: | |
| gr.Warning("Please upload the reference image first") | |
| return None, None | |
| global BEN2_MODEL | |
| if seg_ref_mode == "Full Ref": | |
| return image_reference_state, None | |
| else: | |
| if BEN2_MODEL is None: | |
| gr.Warning("Please enable ben2 for mask reference first") | |
| return gr.skip(), gr.skip() | |
| image_reference_pil = Image.fromarray(image_reference_state) | |
| image_reference_pil_rmbg = BEN2_MODEL.inference(image_reference_pil, move_to_center=move_to_center) | |
| image_reference_rmbg = np.array(image_reference_pil_rmbg) | |
| return image_reference_rmbg, image_reference_rmbg | |
| def vlm_auto_generate(image_target_state, image_reference_state, mask_target_state, | |
| custmization_mode): | |
| """Auto-generate prompt using VLM.""" | |
| global VLM_PROCESSOR, VLM_MODEL | |
| if custmization_mode == "Position-aware": | |
| if image_target_state is None or mask_target_state is None: | |
| gr.Warning("Please upload the target image and get mask first") | |
| return None | |
| if image_reference_state is None: | |
| gr.Warning("Please upload the reference image first") | |
| return None | |
| if VLM_PROCESSOR is None or VLM_MODEL is None: | |
| gr.Warning("Please enable vlm for prompt first") | |
| return None | |
| messages = construct_vlm_gen_prompt(image_target_state, image_reference_state, mask_target_state, custmization_mode) | |
| output_text = run_vlm(VLM_PROCESSOR, VLM_MODEL, messages, device=device) | |
| return output_text | |
| def vlm_auto_polish(prompt, custmization_mode): | |
| """Auto-polish prompt using VLM.""" | |
| global VLM_PROCESSOR, VLM_MODEL | |
| if prompt is None: | |
| gr.Warning("Please input the text prompt first") | |
| return None | |
| if custmization_mode == "Position-aware": | |
| gr.Warning("Polishing only works in position-free mode") | |
| return prompt | |
| if VLM_PROCESSOR is None or VLM_MODEL is None: | |
| gr.Warning("Please enable vlm for prompt first") | |
| return prompt | |
| messages = construct_vlm_polish_prompt(prompt) | |
| output_text = run_vlm(VLM_PROCESSOR, VLM_MODEL, messages, device=device) | |
| return output_text | |
| def save_results(output_img, image_reference, image_target, mask_target, prompt, | |
| custmization_mode, input_mask_mode, seg_ref_mode, seed, guidance, | |
| num_steps, num_images_per_prompt, use_background_preservation, | |
| background_blend_threshold, true_gs, assets_cache_dir): | |
| """Save generated results and metadata.""" | |
| save_name = datetime.now().strftime(TIMESTAMP_FORMAT) | |
| results = [] | |
| for i in range(num_images_per_prompt): | |
| save_dir = os.path.join(assets_cache_dir, save_name) | |
| os.makedirs(save_dir, exist_ok=True) | |
| output_img[i].save(os.path.join(save_dir, f"img_gen_{i}.png")) | |
| image_reference.save(os.path.join(save_dir, f"img_ref_{i}.png")) | |
| image_target.save(os.path.join(save_dir, f"img_target_{i}.png")) | |
| mask_target.save(os.path.join(save_dir, f"mask_target_{i}.png")) | |
| with open(os.path.join(save_dir, f"hyper_params_{i}.json"), "w") as f: | |
| json.dump({ | |
| "prompt": prompt, | |
| "custmization_mode": custmization_mode, | |
| "input_mask_mode": input_mask_mode, | |
| "seg_ref_mode": seg_ref_mode, | |
| "seed": seed, | |
| "guidance": guidance, | |
| "num_steps": num_steps, | |
| "num_images_per_prompt": num_images_per_prompt, | |
| "use_background_preservation": use_background_preservation, | |
| "background_blend_threshold": background_blend_threshold, | |
| "true_gs": true_gs, | |
| }, f) | |
| results.append(output_img[i]) | |
| return results |