Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline | |
| from huggingface_hub import snapshot_download | |
| from test_ccsr_tile import load_pipeline | |
| import argparse | |
| from accelerate import Accelerator | |
| # Global variables | |
| class ModelContainer: | |
| def __init__(self): | |
| self.pipeline = None | |
| self.generator = None | |
| self.accelerator = None | |
| self.is_initialized = False | |
| model_container = ModelContainer() | |
| class Args: | |
| def __init__(self, **kwargs): | |
| self.__dict__.update(kwargs) | |
| def initialize_models(): | |
| """Initialize models only if they haven't been initialized yet""" | |
| if model_container.is_initialized: | |
| return True | |
| try: | |
| # Download model repository (only once) | |
| model_path = snapshot_download( | |
| repo_id="NightRaven109/CCSRModels", | |
| token=os.environ['Read2'] | |
| ) | |
| # Set up default arguments | |
| args = Args( | |
| pretrained_model_path=os.path.join(model_path, "stable-diffusion-2-1-base"), | |
| controlnet_model_path=os.path.join(model_path, "Controlnet"), | |
| vae_model_path=os.path.join(model_path, "vae"), | |
| mixed_precision="fp16", | |
| tile_vae=False, | |
| sample_method="ddpm", | |
| vae_encoder_tile_size=1024, | |
| vae_decoder_tile_size=224 | |
| ) | |
| # Initialize accelerator | |
| model_container.accelerator = Accelerator( | |
| mixed_precision=args.mixed_precision, | |
| ) | |
| # Load pipeline | |
| model_container.pipeline = load_pipeline(args, model_container.accelerator, | |
| enable_xformers_memory_efficient_attention=False) | |
| # Set models to eval mode | |
| model_container.pipeline.unet.eval() | |
| model_container.pipeline.controlnet.eval() | |
| model_container.pipeline.vae.eval() | |
| model_container.pipeline.text_encoder.eval() | |
| # Move pipeline to CUDA and set to eval mode once | |
| model_container.pipeline = model_container.pipeline.to("cuda") | |
| # Initialize generator | |
| model_container.generator = torch.Generator("cuda") | |
| # Set initialization flag | |
| model_container.is_initialized = True | |
| return True | |
| except Exception as e: | |
| print(f"Error initializing models: {str(e)}") | |
| return False | |
| # Add no_grad decorator for inference | |
| def process_image( | |
| input_image, | |
| prompt="clean, texture, high-resolution, 8k", | |
| negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", | |
| guidance_scale=2.5, | |
| conditioning_scale=1.0, | |
| num_inference_steps=6, | |
| seed=None, | |
| upscale_factor=4, | |
| color_fix_method="adain" | |
| ): | |
| # Initialize models if not already done | |
| if not model_container.is_initialized: | |
| if not initialize_models(): | |
| return None | |
| try: | |
| # Create args object | |
| args = Args( | |
| added_prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| conditioning_scale=conditioning_scale, | |
| num_inference_steps=num_inference_steps, | |
| seed=seed, | |
| upscale=upscale_factor, | |
| process_size=512, | |
| align_method=color_fix_method, | |
| t_max=0.6666, | |
| t_min=0.0, | |
| tile_diffusion=False, | |
| tile_diffusion_size=None, | |
| tile_diffusion_stride=None, | |
| start_steps=999, | |
| start_point='lr', | |
| use_vae_encode_condition=True, | |
| sample_times=1 | |
| ) | |
| # Set seed if provided | |
| if seed is not None: | |
| model_container.generator.manual_seed(seed) | |
| # Process input image | |
| validation_image = Image.fromarray(input_image) | |
| ori_width, ori_height = validation_image.size | |
| # Resize logic | |
| resize_flag = False | |
| if ori_width < args.process_size//args.upscale or ori_height < args.process_size//args.upscale: | |
| scale = (args.process_size//args.upscale)/min(ori_width, ori_height) | |
| validation_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height))) | |
| resize_flag = True | |
| validation_image = validation_image.resize((validation_image.size[0]*args.upscale, validation_image.size[1]*args.upscale)) | |
| validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8)) | |
| width, height = validation_image.size | |
| # Generate image | |
| inference_time, output = model_container.pipeline( | |
| args.t_max, | |
| args.t_min, | |
| args.tile_diffusion, | |
| args.tile_diffusion_size, | |
| args.tile_diffusion_stride, | |
| args.added_prompt, | |
| validation_image, | |
| num_inference_steps=args.num_inference_steps, | |
| generator=model_container.generator, | |
| height=height, | |
| width=width, | |
| guidance_scale=args.guidance_scale, | |
| negative_prompt=args.negative_prompt, | |
| conditioning_scale=args.conditioning_scale, | |
| start_steps=args.start_steps, | |
| start_point=args.start_point, | |
| use_vae_encode_condition=True, | |
| ) | |
| image = output.images[0] | |
| # Apply color fixing if specified | |
| if args.align_method != "none": | |
| from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix | |
| fix_func = wavelet_color_fix if args.align_method == "wavelet" else adain_color_fix | |
| image = fix_func(image, validation_image) | |
| if resize_flag: | |
| image = image.resize((ori_width*args.upscale, ori_height*args.upscale)) | |
| return image | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| # Define default values | |
| DEFAULT_VALUES = { | |
| "prompt": "clean, texture, high-resolution, 8k", | |
| "negative_prompt": "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", | |
| "guidance_scale": 3, | |
| "conditioning_scale": 1.0, | |
| "num_steps": 6, | |
| "seed": None, | |
| "upscale_factor": 4, | |
| "color_fix_method": "adain" | |
| } | |
| # Define example data | |
| EXAMPLES = [ | |
| [ | |
| "examples/1.png", # Input image path | |
| "clean, texture, high-resolution, 8k", # Prompt | |
| "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", # Negative prompt | |
| 3.0, # Guidance scale | |
| 1.0, # Conditioning scale | |
| 6, # Num steps | |
| 42, # Seed | |
| 4, # Upscale factor | |
| "wavelet" # Color fix method | |
| ], | |
| [ | |
| "examples/22.png", | |
| "clean, texture, high-resolution, 8k", | |
| "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", | |
| 3.0, | |
| 1.0, | |
| 6, | |
| 123, | |
| 4, | |
| "wavelet" | |
| ], | |
| [ | |
| "examples/4.png", | |
| "clean, texture, high-resolution, 8k", | |
| "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", | |
| 3.0, | |
| 1.0, | |
| 6, | |
| 123, | |
| 4, | |
| "wavelet" | |
| ], | |
| [ | |
| "examples/9D03D7F206775949.png", | |
| "clean, texture, high-resolution, 8k", | |
| "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", | |
| 3.0, | |
| 1.0, | |
| 6, | |
| 123, | |
| 4, | |
| "wavelet" | |
| ], | |
| [ | |
| "examples/3.jpeg", | |
| "clean, texture, high-resolution, 8k", | |
| "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", | |
| 2.5, | |
| 1.0, | |
| 6, | |
| 456, | |
| 4, | |
| "wavelet" | |
| ] | |
| ] | |
| # Create interface components | |
| with gr.Blocks(title="Texture Super-Resolution") as demo: | |
| gr.Markdown("## Texture Super-Resolution") | |
| gr.Markdown("Upload a texture to enhance its resolution.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image") | |
| with gr.Accordion("Advanced Options", open=False): | |
| prompt = gr.Textbox(label="Prompt", value=DEFAULT_VALUES["prompt"]) | |
| negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_VALUES["negative_prompt"]) | |
| guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=DEFAULT_VALUES["guidance_scale"], label="Guidance Scale") | |
| conditioning_scale = gr.Slider(minimum=0.1, maximum=2.0, value=DEFAULT_VALUES["conditioning_scale"], label="Conditioning Scale") | |
| num_steps = gr.Slider(minimum=1, maximum=50, value=DEFAULT_VALUES["num_steps"], step=1, label="Number of Steps") | |
| seed = gr.Number(label="Seed", value=DEFAULT_VALUES["seed"]) | |
| upscale_factor = gr.Slider(minimum=1, maximum=8, value=DEFAULT_VALUES["upscale_factor"], step=1, label="Upscale Factor") | |
| color_fix_method = gr.Dropdown( | |
| choices=["none", "wavelet", "adain"], | |
| label="Color Fix Method", | |
| value=DEFAULT_VALUES["color_fix_method"] | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image", type="pil", format="png") | |
| # Add examples | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[ | |
| input_image, prompt, negative_prompt, guidance_scale, | |
| conditioning_scale, num_steps, seed, upscale_factor, | |
| color_fix_method | |
| ], | |
| outputs=output_image, | |
| fn=process_image, | |
| cache_examples=True # Cache the results for faster loading | |
| ) | |
| # Define submit action | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[ | |
| input_image, prompt, negative_prompt, guidance_scale, | |
| conditioning_scale, num_steps, seed, upscale_factor, | |
| color_fix_method | |
| ], | |
| outputs=output_image | |
| ) | |
| # Define clear action that resets to default values | |
| def reset_to_defaults(): | |
| return [ | |
| None, # input_image | |
| DEFAULT_VALUES["prompt"], | |
| DEFAULT_VALUES["negative_prompt"], | |
| DEFAULT_VALUES["guidance_scale"], | |
| DEFAULT_VALUES["conditioning_scale"], | |
| DEFAULT_VALUES["num_steps"], | |
| DEFAULT_VALUES["seed"], | |
| DEFAULT_VALUES["upscale_factor"], | |
| DEFAULT_VALUES["color_fix_method"] | |
| ] | |
| clear_btn.click( | |
| fn=reset_to_defaults, | |
| inputs=None, | |
| outputs=[ | |
| input_image, prompt, negative_prompt, guidance_scale, | |
| conditioning_scale, num_steps, seed, upscale_factor, | |
| color_fix_method | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |