Commit 
							
							·
						
						6853258
	
0
								Parent(s):
							
							
Initial commit
Browse files- .gitattributes +34 -0
- README.md +12 -0
- __init__.py +0 -0
- app.py +4 -0
- image_transformation.py +96 -0
- requirements.txt +5 -0
- tool_config.json +3 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: Image Transformation
         | 
| 3 | 
            +
            emoji: ⚡
         | 
| 4 | 
            +
            colorFrom: blue
         | 
| 5 | 
            +
            colorTo: purple
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 3.27.0
         | 
| 8 | 
            +
            app_file: app.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            tags:
         | 
| 11 | 
            +
            - tool
         | 
| 12 | 
            +
            ---
         | 
    	
        __init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers.tools.base import launch_gradio_demo
         | 
| 2 | 
            +
            from image_transformation import ImageTransformationTool
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            launch_gradio_demo(ImageTransformationTool)
         | 
    	
        image_transformation.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from transformers.tools.base import Tool, get_default_device
         | 
| 6 | 
            +
            from transformers.utils import (
         | 
| 7 | 
            +
                is_accelerate_available,
         | 
| 8 | 
            +
                is_diffusers_available,
         | 
| 9 | 
            +
                is_opencv_available,
         | 
| 10 | 
            +
                is_vision_available,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            if is_vision_available():
         | 
| 15 | 
            +
                from PIL import Image
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            if is_diffusers_available():
         | 
| 18 | 
            +
                from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            if is_opencv_available():
         | 
| 21 | 
            +
                import cv2
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            IMAGE_TRANSFORMATION_DESCRIPTION = (
         | 
| 25 | 
            +
                "This is a tool that transforms an image according to a prompt. It takes two inputs: `image`, which should be "
         | 
| 26 | 
            +
                "the image to transform, and `prompt`, which should be the prompt to use to change it. It returns the "
         | 
| 27 | 
            +
                "modified image."
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ImageTransformationTool(Tool):
         | 
| 32 | 
            +
                default_stable_diffusion_checkpoint = "runwayml/stable-diffusion-v1-5"
         | 
| 33 | 
            +
                default_controlnet_checkpoint = "lllyasviel/sd-controlnet-canny"
         | 
| 34 | 
            +
                description = IMAGE_TRANSFORMATION_DESCRIPTION
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __init__(self, device=None, controlnet=None, stable_diffusion=None, **hub_kwargs) -> None:
         | 
| 37 | 
            +
                    if not is_accelerate_available():
         | 
| 38 | 
            +
                        raise ImportError("Accelerate should be installed in order to use tools.")
         | 
| 39 | 
            +
                    if not is_diffusers_available():
         | 
| 40 | 
            +
                        raise ImportError("Diffusers should be installed in order to use the StableDiffusionTool.")
         | 
| 41 | 
            +
                    if not is_vision_available():
         | 
| 42 | 
            +
                        raise ImportError("Pillow should be installed in order to use the StableDiffusionTool.")
         | 
| 43 | 
            +
                    if not is_opencv_available():
         | 
| 44 | 
            +
                        raise ImportError("opencv should be installed in order to use the StableDiffusionTool.")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    if controlnet is None:
         | 
| 49 | 
            +
                        controlnet = self.default_controlnet_checkpoint
         | 
| 50 | 
            +
                    self.controlnet_checkpoint = controlnet
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    if stable_diffusion is None:
         | 
| 53 | 
            +
                        stable_diffusion = self.default_stable_diffusion_checkpoint
         | 
| 54 | 
            +
                    self.stable_diffusion_checkpoint = stable_diffusion
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    self.device = device
         | 
| 57 | 
            +
                    self.hub_kwargs = hub_kwargs
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def setup(self):
         | 
| 60 | 
            +
                    if self.device is None:
         | 
| 61 | 
            +
                        self.device = get_default_device()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.controlnet = ControlNetModel.from_pretrained(self.controlnet_checkpoint, torch_dtype=torch.float16)
         | 
| 64 | 
            +
                    self.pipeline = StableDiffusionControlNetPipeline.from_pretrained(
         | 
| 65 | 
            +
                        self.stable_diffusion_checkpoint, controlnet=self.controlnet, torch_dtype=torch.float16
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
                    self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
         | 
| 68 | 
            +
                    self.pipeline.enable_model_cpu_offload()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    self.is_initialized = True
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def __call__(self, image, prompt):
         | 
| 73 | 
            +
                    if not self.is_initialized:
         | 
| 74 | 
            +
                        self.setup()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    initial_prompt = "super-hero character, best quality, extremely detailed"
         | 
| 77 | 
            +
                    prompt = initial_prompt + prompt
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    low_threshold = 100
         | 
| 80 | 
            +
                    high_threshold = 200
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    image = np.array(image)
         | 
| 83 | 
            +
                    image = cv2.Canny(image, low_threshold, high_threshold)
         | 
| 84 | 
            +
                    image = image[:, :, None]
         | 
| 85 | 
            +
                    image = np.concatenate([image, image, image], axis=2)
         | 
| 86 | 
            +
                    canny_image = Image.fromarray(image)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    generator = torch.Generator(device="cpu").manual_seed(2)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    return self.pipeline(
         | 
| 91 | 
            +
                        prompt,
         | 
| 92 | 
            +
                        canny_image,
         | 
| 93 | 
            +
                        negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
         | 
| 94 | 
            +
                        num_inference_steps=20,
         | 
| 95 | 
            +
                        generator=generator,
         | 
| 96 | 
            +
                    ).images[0]
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            transformers @ git+https://github.com/huggingface/transformers@test_composition
         | 
| 2 | 
            +
            diffusers
         | 
| 3 | 
            +
            accelerate
         | 
| 4 | 
            +
            opencv-python
         | 
| 5 | 
            +
            torch
         | 
    	
        tool_config.json
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "custom_tools": {"image-transformation": "image_transformation.ImageTransformationTool"}
         | 
| 3 | 
            +
            }
         |