Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	fix md and pipeline
Browse files- app/business_logic.py +12 -12
- app/ui_components.py +7 -7
- ic_custom/pipelines/ic_custom_pipeline.py +14 -1
- ic_custom/utils/model_utils.py +26 -4
    	
        app/business_logic.py
    CHANGED
    
    | @@ -412,8 +412,8 @@ def change_custmization_mode(custmization_mode, input_mask_mode): | |
| 412 | 
             
                            gr.update(value="<s>Select a input mask mode</s>", visible=False),
         | 
| 413 | 
             
                            gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
         | 
| 414 | 
             
                            gr.update(value="<s>View or modify the target mask</s>", visible=False),
         | 
| 415 | 
            -
                            gr.update(value="3 | 
| 416 | 
            -
                            gr.update(value="4 | 
| 417 | 
             
                            gr.update(visible=False),
         | 
| 418 | 
             
                            gr.update(visible=False),
         | 
| 419 |  | 
| @@ -426,11 +426,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode): | |
| 426 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 427 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 428 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 429 | 
            -
                                gr.update(value="3 | 
| 430 | 
            -
                                gr.update(value="4 | 
| 431 | 
            -
                                gr.update(value="6 | 
| 432 | 
            -
                                gr.update(value="5 | 
| 433 | 
            -
                                gr.update(value="7 | 
| 434 | 
             
                                gr.update(visible=True, value="Precise mask"),
         | 
| 435 | 
             
                                gr.update(visible=True),
         | 
| 436 | 
             
                                )
         | 
| @@ -441,11 +441,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode): | |
| 441 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 442 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 443 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 444 | 
            -
                                gr.update(value="3 | 
| 445 | 
            -
                                gr.update(value="4 | 
| 446 | 
            -
                                gr.update(value="6 | 
| 447 | 
            -
                                gr.update(value="5 | 
| 448 | 
            -
                                gr.update(value="7 | 
| 449 | 
             
                                gr.update(visible=True, value="User-drawn mask"),
         | 
| 450 | 
             
                                gr.update(visible=True),
         | 
| 451 | 
             
                                )
         | 
|  | |
| 412 | 
             
                            gr.update(value="<s>Select a input mask mode</s>", visible=False),
         | 
| 413 | 
             
                            gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
         | 
| 414 | 
             
                            gr.update(value="<s>View or modify the target mask</s>", visible=False),
         | 
| 415 | 
            +
                            gr.update(value="3\. Input text prompt (necessary)"),
         | 
| 416 | 
            +
                            gr.update(value="4\. Submit and view the output"),
         | 
| 417 | 
             
                            gr.update(visible=False),
         | 
| 418 | 
             
                            gr.update(visible=False),
         | 
| 419 |  | 
|  | |
| 426 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 427 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 428 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 429 | 
            +
                                gr.update(value="3\. Select a input mask mode", visible=True),
         | 
| 430 | 
            +
                                gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
         | 
| 431 | 
            +
                                gr.update(value="6\. View or modify the target mask", visible=True),
         | 
| 432 | 
            +
                                gr.update(value="5\. Input text prompt (optional)", visible=True),
         | 
| 433 | 
            +
                                gr.update(value="7\. Submit and view the output", visible=True),
         | 
| 434 | 
             
                                gr.update(visible=True, value="Precise mask"),
         | 
| 435 | 
             
                                gr.update(visible=True),
         | 
| 436 | 
             
                                )
         | 
|  | |
| 441 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 442 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 443 | 
             
                                gr.update(interactive=True, visible=True),
         | 
| 444 | 
            +
                                gr.update(value="3\. Select a input mask mode", visible=True),
         | 
| 445 | 
            +
                                gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
         | 
| 446 | 
            +
                                gr.update(value="6\. View or modify the target mask", visible=True),
         | 
| 447 | 
            +
                                gr.update(value="5\. Input text prompt (optional)", visible=True),
         | 
| 448 | 
            +
                                gr.update(value="7\. Submit and view the output", visible=True),
         | 
| 449 | 
             
                                gr.update(visible=True, value="User-drawn mask"),
         | 
| 450 | 
             
                                gr.update(visible=True),
         | 
| 451 | 
             
                                )
         | 
    	
        app/ui_components.py
    CHANGED
    
    | @@ -44,7 +44,7 @@ def create_customization_section(): | |
| 44 | 
             
                with gr.Row():
         | 
| 45 | 
             
                    # Add a note to remind users to click Clear before starting
         | 
| 46 | 
             
                    md_custmization_mode = gr.Markdown(
         | 
| 47 | 
            -
                        "1 | 
| 48 | 
             
                    )
         | 
| 49 | 
             
                with gr.Row():
         | 
| 50 | 
             
                    custmization_mode = gr.Radio(
         | 
| @@ -61,7 +61,7 @@ def create_customization_section(): | |
| 61 | 
             
            def create_image_input_section():
         | 
| 62 | 
             
                """Create image input section optimized for left column layout."""
         | 
| 63 | 
             
                # Reference image section
         | 
| 64 | 
            -
                md_image_reference = gr.Markdown("2 | 
| 65 | 
             
                with gr.Group():
         | 
| 66 | 
             
                    image_reference = gr.Image(
         | 
| 67 | 
             
                        label="Reference Image", 
         | 
| @@ -73,7 +73,7 @@ def create_image_input_section(): | |
| 73 | 
             
                    )
         | 
| 74 |  | 
| 75 | 
             
                # Input mask mode selection
         | 
| 76 | 
            -
                md_input_mask_mode = gr.Markdown("3 | 
| 77 | 
             
                with gr.Group():
         | 
| 78 | 
             
                    input_mask_mode = gr.Radio(
         | 
| 79 | 
             
                        ["Precise mask", "User-drawn mask"],
         | 
| @@ -84,7 +84,7 @@ def create_image_input_section(): | |
| 84 | 
             
                    )
         | 
| 85 |  | 
| 86 | 
             
                # Target image section
         | 
| 87 | 
            -
                md_target_image = gr.Markdown("4 | 
| 88 |  | 
| 89 | 
             
                # Precise mask mode
         | 
| 90 | 
             
                with gr.Group():
         | 
| @@ -129,7 +129,7 @@ def create_image_input_section(): | |
| 129 |  | 
| 130 | 
             
            def create_prompt_section():
         | 
| 131 | 
             
                """Create the text prompt input section with improved layout."""
         | 
| 132 | 
            -
                md_prompt = gr.Markdown("5 | 
| 133 | 
             
                with gr.Group():
         | 
| 134 | 
             
                    prompt = gr.Textbox(
         | 
| 135 | 
             
                        placeholder="Please input the description for the target scene.", 
         | 
| @@ -243,7 +243,7 @@ def create_advanced_options_section(): | |
| 243 |  | 
| 244 | 
             
            def create_mask_operation_section():
         | 
| 245 | 
             
                """Create mask operation section optimized for right column (outputs)."""
         | 
| 246 | 
            -
                md_mask_operation = gr.Markdown("6 | 
| 247 |  | 
| 248 | 
             
                with gr.Group():
         | 
| 249 | 
             
                    # Mask gallery with responsive layout
         | 
| @@ -293,7 +293,7 @@ def create_mask_operation_section(): | |
| 293 |  | 
| 294 | 
             
            def create_output_section():
         | 
| 295 | 
             
                """Create the output section optimized for right column."""
         | 
| 296 | 
            -
                md_submit = gr.Markdown("7 | 
| 297 |  | 
| 298 | 
             
                # Generation controls at top for better workflow
         | 
| 299 | 
             
                with gr.Group():
         | 
|  | |
| 44 | 
             
                with gr.Row():
         | 
| 45 | 
             
                    # Add a note to remind users to click Clear before starting
         | 
| 46 | 
             
                    md_custmization_mode = gr.Markdown(
         | 
| 47 | 
            +
                        "1\. Select a Customization Mode\n\n*Tip: Please click the Clear button first to reset all states before starting a new task.*"
         | 
| 48 | 
             
                    )
         | 
| 49 | 
             
                with gr.Row():
         | 
| 50 | 
             
                    custmization_mode = gr.Radio(
         | 
|  | |
| 61 | 
             
            def create_image_input_section():
         | 
| 62 | 
             
                """Create image input section optimized for left column layout."""
         | 
| 63 | 
             
                # Reference image section
         | 
| 64 | 
            +
                md_image_reference = gr.Markdown("2\. Input reference image")
         | 
| 65 | 
             
                with gr.Group():
         | 
| 66 | 
             
                    image_reference = gr.Image(
         | 
| 67 | 
             
                        label="Reference Image", 
         | 
|  | |
| 73 | 
             
                    )
         | 
| 74 |  | 
| 75 | 
             
                # Input mask mode selection
         | 
| 76 | 
            +
                md_input_mask_mode = gr.Markdown("3\. Select input mask mode")
         | 
| 77 | 
             
                with gr.Group():
         | 
| 78 | 
             
                    input_mask_mode = gr.Radio(
         | 
| 79 | 
             
                        ["Precise mask", "User-drawn mask"],
         | 
|  | |
| 84 | 
             
                    )
         | 
| 85 |  | 
| 86 | 
             
                # Target image section
         | 
| 87 | 
            +
                md_target_image = gr.Markdown("4\. Input target image & mask (Iterate clicking or brushing until the target is covered)")
         | 
| 88 |  | 
| 89 | 
             
                # Precise mask mode
         | 
| 90 | 
             
                with gr.Group():
         | 
|  | |
| 129 |  | 
| 130 | 
             
            def create_prompt_section():
         | 
| 131 | 
             
                """Create the text prompt input section with improved layout."""
         | 
| 132 | 
            +
                md_prompt = gr.Markdown("5\. Input text prompt (optional)")
         | 
| 133 | 
             
                with gr.Group():
         | 
| 134 | 
             
                    prompt = gr.Textbox(
         | 
| 135 | 
             
                        placeholder="Please input the description for the target scene.", 
         | 
|  | |
| 243 |  | 
| 244 | 
             
            def create_mask_operation_section():
         | 
| 245 | 
             
                """Create mask operation section optimized for right column (outputs)."""
         | 
| 246 | 
            +
                md_mask_operation = gr.Markdown("6\. View or modify the target mask")
         | 
| 247 |  | 
| 248 | 
             
                with gr.Group():
         | 
| 249 | 
             
                    # Mask gallery with responsive layout
         | 
|  | |
| 293 |  | 
| 294 | 
             
            def create_output_section():
         | 
| 295 | 
             
                """Create the output section optimized for right column."""
         | 
| 296 | 
            +
                md_submit = gr.Markdown("7\. Submit and view the output")
         | 
| 297 |  | 
| 298 | 
             
                # Generation controls at top for better workflow
         | 
| 299 | 
             
                with gr.Group():
         | 
    	
        ic_custom/pipelines/ic_custom_pipeline.py
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
             | 
| 2 | 
             
            import re
         | 
| 3 | 
             
            from typing import List, Optional, Union
         | 
| 4 |  | 
| @@ -128,6 +128,10 @@ class ICCustomPipeline: | |
| 128 | 
             
                    double_blocks_idx: str = None, 
         | 
| 129 | 
             
                    single_blocks_idx: str = None,
         | 
| 130 | 
             
                    ):
         | 
|  | |
|  | |
|  | |
|  | |
| 131 | 
             
                    lora_path = resolve_model_path(
         | 
| 132 | 
             
                        name=lora_path,
         | 
| 133 | 
             
                        repo_id_field="repo_id",
         | 
| @@ -181,6 +185,9 @@ class ICCustomPipeline: | |
| 181 | 
             
                    self.load_model_weights(weights, strict=False)
         | 
| 182 |  | 
| 183 | 
             
                def set_img_txt_in(self, img_txt_in_path: str):
         | 
|  | |
|  | |
|  | |
| 184 | 
             
                    img_txt_in_path = resolve_model_path(
         | 
| 185 | 
             
                        name=img_txt_in_path,
         | 
| 186 | 
             
                        repo_id_field="repo_id",
         | 
| @@ -192,6 +199,9 @@ class ICCustomPipeline: | |
| 192 | 
             
                    self.load_model_weights(weights, strict=False)
         | 
| 193 |  | 
| 194 | 
             
                def set_boundary_embeddings(self, boundary_embeddings_path: str):
         | 
|  | |
|  | |
|  | |
| 195 | 
             
                    boundary_embeddings_path = resolve_model_path(
         | 
| 196 | 
             
                        name=boundary_embeddings_path,
         | 
| 197 | 
             
                        repo_id_field="repo_id",
         | 
| @@ -203,6 +213,9 @@ class ICCustomPipeline: | |
| 203 | 
             
                    self.load_model_weights(weights, strict=False)
         | 
| 204 |  | 
| 205 | 
             
                def set_task_register_embeddings(self, task_register_embeddings_path: str):
         | 
|  | |
|  | |
|  | |
| 206 | 
             
                    task_register_embeddings_path = resolve_model_path(
         | 
| 207 | 
             
                        name=task_register_embeddings_path,
         | 
| 208 | 
             
                        repo_id_field="repo_id",
         | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
             
            import re
         | 
| 3 | 
             
            from typing import List, Optional, Union
         | 
| 4 |  | 
|  | |
| 128 | 
             
                    double_blocks_idx: str = None, 
         | 
| 129 | 
             
                    single_blocks_idx: str = None,
         | 
| 130 | 
             
                    ):
         | 
| 131 | 
            +
                    if not os.path.exists(lora_path):
         | 
| 132 | 
            +
                        lora_path = "dit_lora_0x1561"
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        
         | 
| 135 | 
             
                    lora_path = resolve_model_path(
         | 
| 136 | 
             
                        name=lora_path,
         | 
| 137 | 
             
                        repo_id_field="repo_id",
         | 
|  | |
| 185 | 
             
                    self.load_model_weights(weights, strict=False)
         | 
| 186 |  | 
| 187 | 
             
                def set_img_txt_in(self, img_txt_in_path: str):
         | 
| 188 | 
            +
                    if not os.path.exists(img_txt_in_path):
         | 
| 189 | 
            +
                        img_txt_in_path = "dit_txt_img_in_0x1561"
         | 
| 190 | 
            +
                        
         | 
| 191 | 
             
                    img_txt_in_path = resolve_model_path(
         | 
| 192 | 
             
                        name=img_txt_in_path,
         | 
| 193 | 
             
                        repo_id_field="repo_id",
         | 
|  | |
| 199 | 
             
                    self.load_model_weights(weights, strict=False)
         | 
| 200 |  | 
| 201 | 
             
                def set_boundary_embeddings(self, boundary_embeddings_path: str):
         | 
| 202 | 
            +
                    if not os.path.exists(boundary_embeddings_path):
         | 
| 203 | 
            +
                        boundary_embeddings_path = "dit_boundary_embeddings_0x1561"
         | 
| 204 | 
            +
                        
         | 
| 205 | 
             
                    boundary_embeddings_path = resolve_model_path(
         | 
| 206 | 
             
                        name=boundary_embeddings_path,
         | 
| 207 | 
             
                        repo_id_field="repo_id",
         | 
|  | |
| 213 | 
             
                    self.load_model_weights(weights, strict=False)
         | 
| 214 |  | 
| 215 | 
             
                def set_task_register_embeddings(self, task_register_embeddings_path: str):
         | 
| 216 | 
            +
                    if not os.path.exists(task_register_embeddings_path):
         | 
| 217 | 
            +
                        task_register_embeddings_path = "dit_task_register_embeddings_0x1561"
         | 
| 218 | 
            +
                        
         | 
| 219 | 
             
                    task_register_embeddings_path = resolve_model_path(
         | 
| 220 | 
             
                        name=task_register_embeddings_path,
         | 
| 221 | 
             
                        repo_id_field="repo_id",
         | 
    	
        ic_custom/utils/model_utils.py
    CHANGED
    
    | @@ -206,6 +206,9 @@ def load_dit( | |
| 206 | 
             
                    model: Loaded Flux model
         | 
| 207 | 
             
                """
         | 
| 208 | 
             
                # Loading Flux
         | 
|  | |
|  | |
|  | |
| 209 | 
             
                logger.info("Initializing Flux model")
         | 
| 210 |  | 
| 211 | 
             
                # Resolve checkpoint path
         | 
| @@ -249,9 +252,11 @@ def load_ic_custom( | |
| 249 | 
             
                    model: Loaded IC_Custom model
         | 
| 250 | 
             
                """
         | 
| 251 | 
             
                logger.info("Initializing IC-Custom model")
         | 
| 252 | 
            -
             | 
| 253 | 
             
                # Resolve checkpoint path
         | 
| 254 | 
            -
             | 
|  | |
|  | |
| 255 | 
             
                ckpt_path = resolve_model_path(
         | 
| 256 | 
             
                    name=name,
         | 
| 257 | 
             
                    repo_id_field="repo_id",
         | 
| @@ -312,8 +317,7 @@ def load_embedder( | |
| 312 | 
             
                    path, 
         | 
| 313 | 
             
                    max_length=max_length, 
         | 
| 314 | 
             
                    is_clip=is_clip, 
         | 
| 315 | 
            -
             | 
| 316 | 
            -
                ).to(device)
         | 
| 317 |  | 
| 318 | 
             
                return model
         | 
| 319 |  | 
| @@ -336,7 +340,11 @@ def load_t5( | |
| 336 | 
             
                Returns:
         | 
| 337 | 
             
                    model: Loaded T5 model
         | 
| 338 | 
             
                """
         | 
|  | |
|  | |
|  | |
| 339 | 
             
                logger.info(f"Loading T5 model: {name}")
         | 
|  | |
| 340 | 
             
                return load_embedder(
         | 
| 341 | 
             
                    name=name,
         | 
| 342 | 
             
                    is_clip=False,
         | 
| @@ -362,7 +370,11 @@ def load_clip( | |
| 362 | 
             
                Returns:
         | 
| 363 | 
             
                    model: Loaded CLIP model
         | 
| 364 | 
             
                """
         | 
|  | |
|  | |
|  | |
| 365 | 
             
                logger.info(f"Loading CLIP model: {name}")
         | 
|  | |
| 366 | 
             
                return load_embedder(
         | 
| 367 | 
             
                    name=name,
         | 
| 368 | 
             
                    is_clip=True,
         | 
| @@ -387,6 +399,10 @@ def load_ae( | |
| 387 | 
             
                Returns:
         | 
| 388 | 
             
                    model: Loaded AutoEncoder model
         | 
| 389 | 
             
                """
         | 
|  | |
|  | |
|  | |
|  | |
| 390 | 
             
                logger.info(f"Loading AutoEncoder model: {name}")
         | 
| 391 |  | 
| 392 | 
             
                # Convert device string to torch.device if needed
         | 
| @@ -429,6 +445,12 @@ def load_redux( | |
| 429 | 
             
                Returns:
         | 
| 430 | 
             
                    model: Loaded Redux Image Encoder model
         | 
| 431 | 
             
                """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 432 | 
             
                logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
         | 
| 433 |  | 
| 434 | 
             
                # Convert device string to torch.device if needed
         | 
|  | |
| 206 | 
             
                    model: Loaded Flux model
         | 
| 207 | 
             
                """
         | 
| 208 | 
             
                # Loading Flux
         | 
| 209 | 
            +
                if not os.path.exists(name):
         | 
| 210 | 
            +
                    name = "flux-fill-dev-dit"
         | 
| 211 | 
            +
                
         | 
| 212 | 
             
                logger.info("Initializing Flux model")
         | 
| 213 |  | 
| 214 | 
             
                # Resolve checkpoint path
         | 
|  | |
| 252 | 
             
                    model: Loaded IC_Custom model
         | 
| 253 | 
             
                """
         | 
| 254 | 
             
                logger.info("Initializing IC-Custom model")
         | 
| 255 | 
            +
             | 
| 256 | 
             
                # Resolve checkpoint path
         | 
| 257 | 
            +
                if not os.path.exists(name):
         | 
| 258 | 
            +
                    name = "flux-fill-dev-dit"
         | 
| 259 | 
            +
                
         | 
| 260 | 
             
                ckpt_path = resolve_model_path(
         | 
| 261 | 
             
                    name=name,
         | 
| 262 | 
             
                    repo_id_field="repo_id",
         | 
|  | |
| 317 | 
             
                    path, 
         | 
| 318 | 
             
                    max_length=max_length, 
         | 
| 319 | 
             
                    is_clip=is_clip, 
         | 
| 320 | 
            +
                ).to(device).to(dtype)
         | 
|  | |
| 321 |  | 
| 322 | 
             
                return model
         | 
| 323 |  | 
|  | |
| 340 | 
             
                Returns:
         | 
| 341 | 
             
                    model: Loaded T5 model
         | 
| 342 | 
             
                """
         | 
| 343 | 
            +
                if not os.path.exists(name):
         | 
| 344 | 
            +
                    name = "t5-v1_1-xxl"
         | 
| 345 | 
            +
             | 
| 346 | 
             
                logger.info(f"Loading T5 model: {name}")
         | 
| 347 | 
            +
                
         | 
| 348 | 
             
                return load_embedder(
         | 
| 349 | 
             
                    name=name,
         | 
| 350 | 
             
                    is_clip=False,
         | 
|  | |
| 370 | 
             
                Returns:
         | 
| 371 | 
             
                    model: Loaded CLIP model
         | 
| 372 | 
             
                """
         | 
| 373 | 
            +
                if not os.path.exists(name):
         | 
| 374 | 
            +
                    name = "clip-vit-large-patch14"
         | 
| 375 | 
            +
             | 
| 376 | 
             
                logger.info(f"Loading CLIP model: {name}")
         | 
| 377 | 
            +
                
         | 
| 378 | 
             
                return load_embedder(
         | 
| 379 | 
             
                    name=name,
         | 
| 380 | 
             
                    is_clip=True,
         | 
|  | |
| 399 | 
             
                Returns:
         | 
| 400 | 
             
                    model: Loaded AutoEncoder model
         | 
| 401 | 
             
                """
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                if not os.path.exists(name):
         | 
| 404 | 
            +
                    name = "flux-fill-dev-ae"
         | 
| 405 | 
            +
                
         | 
| 406 | 
             
                logger.info(f"Loading AutoEncoder model: {name}")
         | 
| 407 |  | 
| 408 | 
             
                # Convert device string to torch.device if needed
         | 
|  | |
| 445 | 
             
                Returns:
         | 
| 446 | 
             
                    model: Loaded Redux Image Encoder model
         | 
| 447 | 
             
                """
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                if not os.path.exists(redux_name):
         | 
| 450 | 
            +
                    redux_name = "flux1-redux-dev"
         | 
| 451 | 
            +
                if not os.path.exists(siglip_name):
         | 
| 452 | 
            +
                    siglip_name = "siglip-so400m-patch14-384"
         | 
| 453 | 
            +
                
         | 
| 454 | 
             
                logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
         | 
| 455 |  | 
| 456 | 
             
                # Convert device string to torch.device if needed
         | 
