Spaces:
Runtime error
Runtime error
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| IC-Custom Gradio Application | |
| This module defines the UI and glue logic to run the IC-Custom pipeline | |
| via Gradio. The code aims to keep UI text user-friendly while keeping the | |
| implementation readable and maintainable. | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| import time | |
| # Add current directory to path for imports | |
| sys.path.append(os.getcwd() + '/app') | |
| # Import modular components | |
| from config import parse_args, load_config, setup_environment | |
| from ui_components import ( | |
| create_theme, create_css, create_header_section, create_customization_section, | |
| create_image_input_section, create_prompt_section, create_advanced_options_section, | |
| create_mask_operation_section, create_output_section, create_examples_section, | |
| create_citation_section | |
| ) | |
| from event_handlers import setup_event_handlers | |
| from business_logic import ( | |
| init_image_target_1, init_image_target_2, init_image_reference, | |
| undo_seg_points, segmentation, get_point, get_brush, | |
| dilate_mask, erode_mask, bounding_box, | |
| change_input_mask_mode, change_custmization_mode, change_seg_ref_mode, | |
| vlm_auto_generate, vlm_auto_polish, save_results, set_mobile_predictor, | |
| set_ben2_model, set_vlm_processor, set_vlm_model, | |
| ) | |
| # Import other dependencies | |
| from utils import ( | |
| get_sam_predictor, get_vlm, get_ben2_model, | |
| prepare_input_images, get_mask_type_ids | |
| ) | |
| from examples import GRADIO_EXAMPLES, MASK_TGT, IMG_GEN | |
| from ic_custom.pipelines.ic_custom_pipeline import ICCustomPipeline | |
| # Global variables for pipeline and assets cache directory | |
| PIPELINE = None | |
| ASSETS_CACHE_DIR = None | |
| # Force Hugging Face to re-download models and clear cache | |
| os.environ["HF_HUB_FORCE_DOWNLOAD"] = "1" | |
| os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" # Use temp directory for Spaces | |
| os.environ["HF_HOME"] = "/tmp/hf_home" # Use temp directory for Spaces | |
| os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")) | |
| def set_pipeline(pipeline): | |
| """Inject pipeline into this module without changing function signatures.""" | |
| global PIPELINE | |
| PIPELINE = pipeline | |
| def set_assets_cache_dir(assets_cache_dir): | |
| """Inject assets cache dir into this module without changing function signatures.""" | |
| global ASSETS_CACHE_DIR | |
| ASSETS_CACHE_DIR = assets_cache_dir | |
| def initialize_models(args, cfg, device, weight_dtype): | |
| """Initialize all required models.""" | |
| # Load IC-Custom pipeline | |
| pipeline = ICCustomPipeline( | |
| clip_path=cfg.checkpoint_config.clip_path if os.path.exists(cfg.checkpoint_config.clip_path) else "clip-vit-large-patch14", | |
| t5_path=cfg.checkpoint_config.t5_path if os.path.exists(cfg.checkpoint_config.t5_path) else "t5-v1_1-xxl", | |
| siglip_path=cfg.checkpoint_config.siglip_path if os.path.exists(cfg.checkpoint_config.siglip_path) else "siglip-so400m-patch14-384", | |
| ae_path=cfg.checkpoint_config.ae_path if os.path.exists(cfg.checkpoint_config.ae_path) else "flux-fill-dev-ae", | |
| dit_path=cfg.checkpoint_config.dit_path if os.path.exists(cfg.checkpoint_config.dit_path) else "flux-fill-dev-dit", | |
| redux_path=cfg.checkpoint_config.redux_path if os.path.exists(cfg.checkpoint_config.redux_path) else "flux1-redux-dev", | |
| lora_path=cfg.checkpoint_config.lora_path if os.path.exists(cfg.checkpoint_config.lora_path) else "dit_lora_0x1561", | |
| img_txt_in_path=cfg.checkpoint_config.img_txt_in_path if os.path.exists(cfg.checkpoint_config.img_txt_in_path) else "dit_txt_img_in_0x1561", | |
| boundary_embeddings_path=cfg.checkpoint_config.boundary_embeddings_path if os.path.exists(cfg.checkpoint_config.boundary_embeddings_path) else "dit_boundary_embeddings_0x1561", | |
| task_register_embeddings_path=cfg.checkpoint_config.task_register_embeddings_path if os.path.exists(cfg.checkpoint_config.task_register_embeddings_path) else "dit_task_register_embeddings_0x1561", | |
| network_alpha=cfg.model_config.network_alpha, | |
| double_blocks_idx=cfg.model_config.double_blocks, | |
| single_blocks_idx=cfg.model_config.single_blocks, | |
| device=device, | |
| weight_dtype=weight_dtype, | |
| offload=True, | |
| ) | |
| pipeline.set_pipeline_offload(True) | |
| # pipeline.set_show_progress(True) | |
| # Load SAM predictor | |
| mobile_predictor = get_sam_predictor(cfg.checkpoint_config.sam_path, device) | |
| # Load VLM if enabled | |
| vlm_processor, vlm_model = None, None | |
| if args.enable_vlm_for_prompt: | |
| vlm_processor, vlm_model = get_vlm( | |
| cfg.checkpoint_config.vlm_path, | |
| device=device, | |
| torch_dtype=weight_dtype, | |
| ) | |
| # Load BEN2 model if enabled | |
| ben2_model = None | |
| if args.enable_ben2_for_mask_ref: | |
| ben2_model = get_ben2_model(cfg.checkpoint_config.ben2_path, device) | |
| return pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model | |
| def run_model( | |
| image_target_state, mask_target_state, image_reference_ori_state, | |
| image_reference_rmbg_state, prompt, seed, guidance, true_gs, num_steps, | |
| num_images_per_prompt, use_background_preservation, background_blend_threshold, | |
| aspect_ratio, custmization_mode, seg_ref_mode, input_mask_mode, | |
| progress=gr.Progress() | |
| ): | |
| """Run IC-Custom pipeline with current UI state and return images.""" | |
| start_ts = time.time() | |
| progress(0, desc="Starting generation...") | |
| # Select reference image and check inputs | |
| if seg_ref_mode == "Masked Ref": | |
| image_reference_state = image_reference_rmbg_state | |
| else: | |
| image_reference_state = image_reference_ori_state | |
| if image_reference_state is None: | |
| gr.Warning('Please upload the reference image') | |
| return None, seed, gr.update(placeholder="Last Input: " + prompt, value="") | |
| if image_target_state is None and custmization_mode != "Position-free": | |
| gr.Warning('Please upload the target image and mask it') | |
| return None, seed, gr.update(placeholder="Last Input: " + prompt, value="") | |
| if custmization_mode == "Position-aware" and mask_target_state is None: | |
| gr.Warning('Please select/draw the target mask') | |
| return None, seed, gr.update(placeholder=prompt, value="") | |
| mask_type_ids = get_mask_type_ids(custmization_mode, input_mask_mode) | |
| from constants import ASPECT_RATIO_TEMPLATE | |
| output_w, output_h = ASPECT_RATIO_TEMPLATE[aspect_ratio] | |
| image_reference, image_target, mask_target = prepare_input_images( | |
| image_reference_state, custmization_mode, image_target_state, mask_target_state, | |
| width=output_w, height=output_h, | |
| force_resize_long_edge="long edge" in aspect_ratio, | |
| return_type="pil" | |
| ) | |
| gr.Info(f"Output WH resolution: {image_target.size[0]}px x {image_target.size[1]}px") | |
| # Run the model | |
| if seed == -1: | |
| seed = torch.randint(0, 2147483647, (1,)).item() | |
| width, height = image_target.size[0] + image_reference.size[0], image_target.size[1] | |
| with torch.no_grad(): | |
| output_img = PIPELINE( | |
| prompt=prompt, width=width, height=height, guidance=guidance, | |
| num_steps=num_steps, seed=seed, img_ref=image_reference, | |
| img_target=image_target, mask_target=mask_target, img_ip=image_reference, | |
| cond_w_regions=[image_reference.size[0]], mask_type_ids=mask_type_ids, | |
| use_background_preservation=use_background_preservation, | |
| background_blend_threshold=background_blend_threshold, true_gs=true_gs, | |
| neg_prompt="worst quality, normal quality, low quality, low res, blurry,", | |
| num_images_per_prompt=num_images_per_prompt, | |
| gradio_progress=progress, | |
| ) | |
| elapsed = time.time() - start_ts | |
| progress(1.0, desc=f"Completed in {elapsed:.2f}s!") | |
| gr.Info(f"Finished in {elapsed:.2f}s") | |
| return output_img, -1, gr.update(placeholder=f"Last Input ({elapsed:.2f}s): " + prompt, value="") | |
| def example_pipeline( | |
| image_reference, image_target_1, image_target_2, custmization_mode, | |
| input_mask_mode, seg_ref_mode, prompt, seed, true_gs, eg_idx, | |
| num_steps, guidance | |
| ): | |
| """Handle example loading in the UI.""" | |
| if seg_ref_mode == "Full Ref": | |
| image_reference_ori_state = np.array(image_reference.convert("RGB")) | |
| image_reference_rmbg_state = None | |
| image_reference_state = image_reference_ori_state | |
| else: | |
| image_reference_rmbg_state = np.array(image_reference.convert("RGB")) | |
| image_reference_ori_state = None | |
| image_reference_state = image_reference_rmbg_state | |
| if custmization_mode == "Position-aware": | |
| if input_mask_mode == "Precise mask": | |
| image_target_state = np.array(image_target_1.convert("RGB")) | |
| else: | |
| image_target_state = np.array(image_target_2['composite'].convert("RGB")) | |
| mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)])) | |
| else: # Position-free mode | |
| # For Position-free, use the target image from IMG_TGT1 and corresponding mask | |
| image_target_state = np.array(image_target_1.convert("RGB")) | |
| mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)])) | |
| mask_target_binary = mask_target_state / 255 | |
| masked_img = image_target_state * mask_target_binary | |
| masked_img_pil = Image.fromarray(masked_img.astype("uint8")) | |
| output_mask_pil = Image.fromarray(mask_target_state.astype("uint8")) | |
| if custmization_mode == "Position-aware": | |
| mask_gallery = [masked_img_pil, output_mask_pil] | |
| else: | |
| mask_gallery = gr.skip() | |
| result_gallery = [Image.open(IMG_GEN[int(eg_idx)]).convert("RGB")] | |
| if custmization_mode == "Position-free": | |
| return (image_reference_ori_state, image_reference_rmbg_state, image_target_state, | |
| mask_target_state, mask_gallery, result_gallery, | |
| gr.update(visible=False), gr.update(visible=False)) | |
| if input_mask_mode == "Precise mask": | |
| return (image_reference_ori_state, image_reference_rmbg_state, image_target_state, | |
| mask_target_state, mask_gallery, result_gallery, | |
| gr.update(visible=True), gr.update(visible=False)) | |
| else: | |
| # Ensure ImageEditor has a proper background so brush + undo work | |
| try: | |
| bg_img = image_target_2.get('background') or image_target_2.get('composite') | |
| except Exception: | |
| bg_img = image_target_2 | |
| return ( | |
| image_reference_ori_state, image_reference_rmbg_state, image_target_state, | |
| mask_target_state, mask_gallery, result_gallery, | |
| gr.update(visible=False), | |
| gr.update(visible=True, value={"background": bg_img, "layers": [], "composite": bg_img}), | |
| ) | |
| def create_application(): | |
| """Create the main Gradio application.""" | |
| # Create theme and CSS | |
| theme = create_theme() | |
| css = create_css() | |
| with gr.Blocks(theme=theme, css=css) as demo: | |
| with gr.Column(elem_id="global_glass_container"): | |
| # Create UI sections | |
| create_header_section() | |
| # Hidden components | |
| eg_idx = gr.Textbox(label="eg_idx", visible=False, value="-1") | |
| # State variables | |
| image_target_state = gr.State(value=None) | |
| mask_target_state = gr.State(value=None) | |
| image_reference_ori_state = gr.State(value=None) | |
| image_reference_rmbg_state = gr.State(value=None) | |
| selected_points = gr.State(value=[]) | |
| # Main UI content with optimized left-right layout | |
| with gr.Column(elem_id="glass_card"): | |
| # Top section - Mode selection (full width) | |
| custmization_mode, md_custmization_mode = create_customization_section() | |
| # Main layout: Left for inputs, Right for outputs | |
| with gr.Row(equal_height=False): | |
| # LEFT COLUMN - ALL INPUTS | |
| with gr.Column(scale=3, min_width=400): | |
| # Image input section | |
| (image_reference, input_mask_mode, image_target_1, image_target_2, | |
| undo_target_seg_button, md_image_reference, md_input_mask_mode, | |
| md_target_image) = create_image_input_section() | |
| # Text prompt section | |
| prompt, vlm_generate_btn, vlm_polish_btn, md_prompt = create_prompt_section() | |
| # Advanced options (collapsible) | |
| (aspect_ratio, seg_ref_mode, move_to_center, use_background_preservation, | |
| background_blend_threshold, seed, num_images_per_prompt, guidance, | |
| num_steps, true_gs) = create_advanced_options_section() | |
| # RIGHT COLUMN - ALL OUTPUTS | |
| with gr.Column(scale=2, min_width=350): | |
| # Mask preview and operations | |
| (mask_gallery, dilate_button, erode_button, bounding_box_button, | |
| md_mask_operation) = create_mask_operation_section() | |
| # Generation controls and results | |
| result_gallery, submit_button, clear_btn, md_submit = create_output_section() | |
| with gr.Row(elem_id="glass_card"): | |
| # Examples section | |
| examples = create_examples_section( | |
| GRADIO_EXAMPLES, | |
| inputs=[ | |
| image_reference, | |
| image_target_1, | |
| image_target_2, | |
| custmization_mode, | |
| input_mask_mode, | |
| seg_ref_mode, | |
| prompt, | |
| seed, | |
| true_gs, | |
| eg_idx, | |
| num_steps, | |
| guidance | |
| ], | |
| outputs=[ | |
| image_reference_ori_state, | |
| image_reference_rmbg_state, | |
| image_target_state, | |
| mask_target_state, | |
| mask_gallery, | |
| result_gallery, | |
| image_target_1, | |
| image_target_2, | |
| ], | |
| fn=example_pipeline, | |
| ) | |
| with gr.Row(elem_id="glass_card"): | |
| # Citation section | |
| create_citation_section() | |
| # Setup event handlers | |
| setup_event_handlers( | |
| ## UI components | |
| input_mask_mode, image_target_1, image_target_2, undo_target_seg_button, | |
| custmization_mode, dilate_button, erode_button, bounding_box_button, | |
| mask_gallery, md_input_mask_mode, md_target_image, md_mask_operation, | |
| md_prompt, md_submit, result_gallery, image_target_state, mask_target_state, | |
| seg_ref_mode, image_reference_ori_state, move_to_center, | |
| image_reference, image_reference_rmbg_state, | |
| ## Functions | |
| change_input_mask_mode, change_custmization_mode, | |
| change_seg_ref_mode, | |
| init_image_target_1, init_image_target_2, init_image_reference, | |
| get_point, undo_seg_points, | |
| get_brush, | |
| # VLM buttons | |
| vlm_generate_btn, vlm_polish_btn, | |
| # VLM functions | |
| vlm_auto_generate, | |
| vlm_auto_polish, | |
| dilate_mask, erode_mask, bounding_box, | |
| run_model, | |
| ## Other components | |
| selected_points, prompt, | |
| use_background_preservation, background_blend_threshold, seed, | |
| num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio, | |
| submit_button, | |
| eg_idx, | |
| ) | |
| # Setup clear button | |
| clear_btn.add( | |
| [image_reference, image_target_1,image_target_2, mask_gallery, result_gallery, | |
| selected_points, image_target_state, mask_target_state, prompt, | |
| image_reference_ori_state, image_reference_rmbg_state] | |
| ) | |
| return demo | |
| def main(): | |
| """Main entry point for the application.""" | |
| # Parse arguments and load config | |
| args = parse_args() | |
| cfg = load_config(args.config) | |
| setup_environment(args) | |
| # Initialize device and models | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| weight_dtype = torch.bfloat16 | |
| pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model = initialize_models( | |
| args, cfg, device, weight_dtype | |
| ) | |
| set_pipeline(pipeline) | |
| set_assets_cache_dir(args.assets_cache_dir) | |
| # Inject mobile predictor into business logic module so get_point can access it without lambdas | |
| set_mobile_predictor(mobile_predictor) | |
| set_ben2_model(ben2_model) | |
| set_vlm_processor(vlm_processor) | |
| set_vlm_model(vlm_model) | |
| # Create and launch the application | |
| demo = create_application() | |
| # Launch the demo | |
| demo.launch(server_port=7860, server_name="0.0.0.0", | |
| allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")), | |
| os.path.abspath(os.path.join(os.path.dirname(__file__), "results"))]) | |
| if __name__ == "__main__": | |
| main() |