Spaces:
Paused
Paused
| import gradio as gr | |
| import requests | |
| import io | |
| import random | |
| import os | |
| from PIL import Image | |
| from deep_translator import GoogleTranslator | |
| from langdetect import detect | |
| import cv2 | |
| import torch | |
| from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
| from gfpgan.utils import GFPGANer | |
| from realesrgan.utils import RealESRGANer | |
| os.system("pip freeze") | |
| # download weights | |
| if not os.path.exists('realesr-general-x4v3.pth'): | |
| os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .") | |
| if not os.path.exists('GFPGANv1.2.pth'): | |
| os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .") | |
| if not os.path.exists('GFPGANv1.3.pth'): | |
| os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .") | |
| if not os.path.exists('GFPGANv1.4.pth'): | |
| os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .") | |
| if not os.path.exists('RestoreFormer.pth'): | |
| os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .") | |
| if not os.path.exists('CodeFormer.pth'): | |
| os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .") | |
| # background enhancer with RealESRGAN | |
| model_us = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| model_us_path = 'realesr-general-x4v3.pth' | |
| half = True if torch.cuda.is_available() else False | |
| upsampler = RealESRGANer(scale=4, model_path=model_us_path, model=model_us, tile=0, tile_pad=10, pre_pad=0, half=half) | |
| os.makedirs('output', exist_ok=True) | |
| API_URL = "https://api-inference.huggingface.co/models/openskyml/dalle-3-xl" | |
| API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free | |
| headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
| models_list = ["AbsoluteReality 1.8.1", "DALL-E 3 XL", "Playground 2", "Openjourney 4", "Lyriel 1.6", "Animagine XL 2.0", "Counterfeit 2.5", "Realistic Vision 5.1", "Incursios 1.6", "Anime Detailer XL LoRA", "epiCRealism", "PixelArt XL", "NewReality XL"] | |
| def query(prompt, model, is_negative=False, steps=20, cfg_scale=7, seed=None): | |
| language = detect(prompt) | |
| if language == 'ru': | |
| prompt = GoogleTranslator(source='ru', target='en').translate(prompt) | |
| print(f'\033[1mГенерация:\033[0m {prompt}') | |
| if model == 'DALL-E 3 XL': | |
| API_URL = "https://api-inference.huggingface.co/models/openskyml/dalle-3-xl" | |
| if model == 'Playground 2': | |
| API_URL = "https://api-inference.huggingface.co/models/playgroundai/playground-v2-1024px-aesthetic" | |
| if model == 'Openjourney 4': | |
| API_URL = "https://api-inference.huggingface.co/models/prompthero/openjourney-v4" | |
| if model == 'AbsoluteReality 1.8.1': | |
| API_URL = "https://api-inference.huggingface.co/models/digiplay/AbsoluteReality_v1.8.1" | |
| if model == 'Lyriel 1.6': | |
| API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/lyrielv16" | |
| if model == 'Animagine XL 2.0': | |
| API_URL = "https://api-inference.huggingface.co/models/Linaqruf/animagine-xl-2.0" | |
| if model == 'Counterfeit 2.5': | |
| API_URL = "https://api-inference.huggingface.co/models/gsdf/Counterfeit-V2.5" | |
| if model == 'Realistic Vision 5.1': | |
| API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/realistic-vision-v51" | |
| if model == 'Incursios 1.6': | |
| API_URL = "https://api-inference.huggingface.co/models/digiplay/incursiosMemeDiffusion_v1.6" | |
| if model == 'Anime Detailer XL LoRA': | |
| API_URL = "https://api-inference.huggingface.co/models/Linaqruf/anime-detailer-xl-lora" | |
| if model == 'epiCRealism': | |
| API_URL = "https://api-inference.huggingface.co/models/emilianJR/epiCRealism" | |
| if model == 'PixelArt XL': | |
| API_URL = "https://api-inference.huggingface.co/models/nerijs/pixel-art-xl" | |
| if model == 'NewReality XL': | |
| API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/newrealityxl-global-nsfw" | |
| payload = { | |
| "inputs": prompt, | |
| "is_negative": is_negative, | |
| "steps": steps, | |
| "cfg_scale": cfg_scale, | |
| "seed": seed if seed is not None else random.randint(-1, 2147483647) | |
| } | |
| image_bytes = requests.post(API_URL, headers=headers, json=payload).content | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| return image | |
| def up(img, version, scale, weight): | |
| weight /= 100 | |
| print(img, version, scale, weight) | |
| try: | |
| extension = os.path.splitext(os.path.basename(str(img)))[1] | |
| img = cv2.imread(img, cv2.IMREAD_UNCHANGED) | |
| if len(img.shape) == 3 and img.shape[2] == 4: | |
| img_mode = 'RGBA' | |
| elif len(img.shape) == 2: # for gray inputs | |
| img_mode = None | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| else: | |
| img_mode = None | |
| if version == 'v1.2': | |
| face_enhancer = GFPGANer( | |
| model_us_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'v1.3': | |
| face_enhancer = GFPGANer( | |
| model_us_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'v1.4': | |
| face_enhancer = GFPGANer( | |
| model_us_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'RestoreFormer': | |
| face_enhancer = GFPGANer( | |
| model_us_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'CodeFormer': | |
| face_enhancer = GFPGANer( | |
| model_us_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
| try: | |
| _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight) | |
| except RuntimeError as error: | |
| print('Error', error) | |
| try: | |
| interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
| h, w = img.shape[0:2] | |
| output = cv2.resize(output, (int(w * scale), int(h * scale)), interpolation=interpolation) | |
| except Exception as error: | |
| print('wrong scale input.', error) | |
| if img_mode == 'RGBA': # RGBA images should be saved in png format | |
| extension = 'png' | |
| else: | |
| extension = 'jpg' | |
| save_path = f'output/out.{extension}' | |
| cv2.imwrite(save_path, output) | |
| output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | |
| return output | |
| except Exception as error: | |
| print('global exception', error) | |
| return None | |
| css = """ | |
| footer {visibility: hidden !important;} | |
| """ | |
| with gr.Blocks(css=css) as dalle: | |
| with gr.Tab("Базовые настройки"): | |
| with gr.Row(): | |
| with gr.Column(elem_id="prompt-container"): | |
| text_prompt = gr.Textbox(label="Prompt", placeholder="Описание изображения", lines=3, elem_id="prompt-text-input") | |
| model = gr.Radio(label="Модель", value="DALL-E 3 XL", choices=models_list) | |
| with gr.Tab("Расширенные настройки"): | |
| negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Чего не должно быть на изображении", value="[deformed | disfigured], poorly drawn, [bad : wrong] anatomy, [extra | missing | floating | disconnected] limb, (mutated hands and fingers), blurry, text, fuzziness", lines=3, elem_id="negative-prompt-text-input") | |
| with gr.Tab("Настройки апскейлинга"): | |
| up_1 = gr.Radio(choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'], value='v1.4', label='Версия'), | |
| up_2 = gr.Slider(label="Коэффициент масштабирования", value=2, minimum=2, maximum=6), | |
| up_3 = gr.Slider(0, 100, label='Weight, только для CodeFormer. 0 для лучшего качества, 100 для лучшей идентичности', value=50) | |
| with gr.Row(): | |
| text_button = gr.Button("Генерация", variant='primary', elem_id="gen-button") | |
| with gr.Row(): | |
| image_output = gr.Image(type="pil", label="Изображение", elem_id="gallery") | |
| with gr.Row(): | |
| up_button = gr.Button("Улучшить изображение", variant='primary', elem_id="gen-button") | |
| with gr.Row(): | |
| up_output = gr.Image(type="pil", label="Улучшенное изображение", elem_id="gallery"), | |
| text_button.click(query, inputs=[text_prompt, model, negative_prompt], outputs=image_output) | |
| up_button.click(up, inputs=[image_output, up_1, up_2, up_3], outputs=up_output) | |
| dalle.launch(show_api=False) |