| import gradio as gr |
| from attacks.mist import update_args_with_config, main |
|
|
| ''' |
| TODO: |
| - SDXL |
| - model changing |
| ''' |
|
|
|
|
| def process_image(eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \ |
| class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \ |
| rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight): |
|
|
| config = (eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \ |
| class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \ |
| rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight) |
| args = None |
| args = update_args_with_config(args, config) |
| main(args) |
|
|
| if __name__ == "__main__": |
| with gr.Blocks() as demo: |
| with gr.Column(): |
| gr.Image("MIST_logo.png", show_label=False) |
| with gr.Row(): |
| with gr.Column(): |
| eps = gr.Slider(0, 32, step=1, value=10, label='Strength', |
| info="Larger strength results in stronger but more visible defense.") |
| device = gr.Radio(["cpu", "gpu"], value="gpu", label="Device", |
| info="If you do not have good GPUs on your PC, choose 'CPU'.") |
| |
| |
| resize = gr.Checkbox(value=True, label="Resizing the output image to the original resolution") |
| mode = gr.Radio(["Mode 1", "Mode 2", "Mode 3"], value="Mode 1", label="Mode", |
| info="Two modes both work with different visualization.") |
| |
| |
| data_path = gr.Textbox(label="Data Path", lines=1, placeholder="Path to your images") |
| output_path = gr.Textbox(label="Output Path", lines=1, placeholder="Path to store the outputs") |
| model_path = gr.Textbox(label="Target Model Path", lines=1, placeholder="Path to the target model") |
| class_path = gr.Textbox(label="Path to place contrast images ", lines=1, placeholder="Path to the target model") |
| prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Describe your images") |
|
|
| with gr.Accordion("Professional Setups", open=False): |
| class_prompt = gr.Textbox(label="Class prompt", lines=1, placeholder="Prompt for contrast images.") |
| max_train_steps = gr.Slider(1, 20, step=1, value=5, label='Epochs', |
| info="Training epochs of Mist-V2") |
| max_f_train_steps = gr.Slider(0, 30, step=1, value=10, label='LoRA Steps', |
| info="Training steps of LoRA in one epoch") |
| max_adv_train_steps = gr.Slider(0, 100, step=5, value=30, label='Attacking Steps', |
| info="Training steps of attacking in one epoch") |
| lora_lr = gr.Number(minimum=0.0, maximum=1.0, label="The learning rate of LoRA", value=0.0001) |
| pgd_lr = gr.Number(minimum=0.0, maximum=1.0, label="The learning rate of PGD", value=0.005) |
| rank = gr.Slider(4, 32, step=4, value=4, label='LoRA Ranks', |
| info="Ranks of LoRA (Bigger ranks need better GPUs)") |
| prior_loss_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of prior loss", value=0.1) |
| fused_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of vae loss", value=0.00001) |
| constraint_mode = gr.Radio(["Epsilon", "LPIPS"], value="Epsilon", label="Constraint Mode", |
| info="The mode to constraint the watermark") |
| lpips_bound = gr.Number(minimum=0.0, maximum=0.2, label="The LPIPS bound", value=0.1) |
| lpips_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of LPIPI constraint", value=0.5) |
|
|
| |
| |
| |
| inputs = [eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \ |
| class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \ |
| rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight] |
| |
|
|
| image_button = gr.Button("Mist") |
|
|
| |
| image_button.click(process_image, inputs=inputs) |
|
|
| demo.queue().launch(share=True) |
|
|