Spaces:
Paused
Paused
| import os, io, tempfile | |
| from typing import Optional | |
| from PIL import Image | |
| import torch | |
| import gradio as gr | |
| from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter | |
| from diffusers.utils import export_to_gif | |
| # Modelos padrão (ajuste se desejar) | |
| MODEL_ID = "SG161222/Realistic_Vision_V5.1_noVAE" # SD1.5 finetunado [attached_file:1] | |
| ADAPTER_ID = "guoyww/animatediff-motion-adapter-v1-5-2" # MotionAdapter p/ SD1.4/1.5 [attached_file:1] | |
| pipe = None | |
| def load_pipe(model_id: str, adapter_id: str, cpu_offload: bool): | |
| global pipe | |
| if pipe is not None: | |
| return pipe | |
| # dtype preferível: float16 em CUDA, senão float32 | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # MotionAdapter não aceita dtype em from_pretrained nas versões atuais | |
| adapter = MotionAdapter.from_pretrained(adapter_id) # [attached_file:1] | |
| # Carregar pipeline com dtype | |
| try: | |
| p = AnimateDiffPipeline.from_pretrained( | |
| model_id, | |
| motion_adapter=adapter, | |
| dtype=dtype # novas versões aceitam 'dtype' [attached_file:1] | |
| ) | |
| except TypeError: | |
| p = AnimateDiffPipeline.from_pretrained( | |
| model_id, | |
| motion_adapter=adapter, | |
| torch_dtype=dtype # fallback para versões que ainda usam torch_dtype [attached_file:1] | |
| ) | |
| # Scheduler recomendado para estabilidade temporal | |
| p.scheduler = DDIMScheduler.from_pretrained( | |
| model_id, | |
| subfolder="scheduler", | |
| clip_sample=False, | |
| timestep_spacing="linspace", | |
| beta_schedule="linear", | |
| steps_offset=1 | |
| ) # [attached_file:1] | |
| # Otimizações de VRAM (APIs novas via VAE) | |
| p.vae.enable_slicing() # [attached_file:1] | |
| try: | |
| p.vae.enable_tiling() # útil em resoluções mais altas [attached_file:1] | |
| except Exception: | |
| pass | |
| # Alocação de device / offload | |
| if cpu_offload and torch.cuda.is_available(): | |
| p.enable_model_cpu_offload() # reduz pico de VRAM [attached_file:1] | |
| else: | |
| p.to("cuda" if torch.cuda.is_available() else "cpu") | |
| pipe = p | |
| return pipe | |
| def generate( | |
| image: Image.Image, | |
| prompt: str, | |
| negative_prompt: str, | |
| num_frames: int, | |
| steps: int, | |
| guidance: float, | |
| seed: int, | |
| width: Optional[int], | |
| height: Optional[int], | |
| fps: int, | |
| save_mp4: bool, | |
| model_id_ui: str, | |
| adapter_id_ui: str, | |
| cpu_offload: bool | |
| ): | |
| if image is None or not prompt or not prompt.strip(): | |
| return None, None, "Envie uma imagem e um prompt válidos." # [attached_file:1] | |
| p = load_pipe(model_id_ui or MODEL_ID, adapter_id_ui or ADAPTER_ID, cpu_offload) | |
| gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed)) | |
| # img2vid sem IP-Adapter: NÃO passar ip_adapter_image | |
| out = p( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt or "", | |
| num_frames=int(num_frames), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance), | |
| generator=gen, | |
| width=int(width) if width else None, | |
| height=int(height) if height else None | |
| ) # [attached_file:1] | |
| frames = out.frames[0] # lista de PILs [attached_file:1] | |
| # Salvar GIF em caminho temporário com extensão .gif (evita erro do PIL) | |
| temp_gif = os.path.join(tempfile.gettempdir(), "animation.gif") | |
| export_to_gif(frames, temp_gif, fps=int(fps)) # [attached_file:1] | |
| # Opcional: gravar MP4 com imageio-ffmpeg | |
| mp4_path = None | |
| if save_mp4: | |
| try: | |
| import imageio | |
| mp4_path = os.path.join(tempfile.gettempdir(), "animation.mp4") | |
| # Converter cada frame PIL para ndarray esperado pelo writer | |
| with imageio.get_writer(mp4_path, fps=int(fps), codec="libx264", quality=8) as writer: | |
| for fr in frames: | |
| writer.append_data(imageio.v3.imread(io.BytesIO(fr.convert("RGB").tobytes()))) | |
| except Exception: | |
| mp4_path = None # se falhar, apenas não retorna MP4 | |
| return temp_gif, mp4_path, f"Gerado {len(frames)} frames @ {fps} fps." # [attached_file:1] | |
| def ui(): | |
| with gr.Blocks(title="AnimateDiff img2vid") as demo: | |
| gr.Markdown("## AnimateDiff img2vid") # [attached_file:1] | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image = gr.Image(type="pil", label="Imagem inicial") # [attached_file:1] | |
| prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Descreva estilo/movimento...") # [attached_file:1] | |
| negative = gr.Textbox(label="Negative prompt", lines=2, value="low quality, worst quality") # [attached_file:1] | |
| with gr.Row(): | |
| frames = gr.Slider(8, 64, value=16, step=1, label="Frames") # [attached_file:1] | |
| steps = gr.Slider(4, 60, value=25, step=1, label="Steps") # [attached_file:1] | |
| with gr.Row(): | |
| guidance = gr.Slider(0.5, 15.0, value=7.5, step=0.5, label="Guidance") # [attached_file:1] | |
| fps = gr.Slider(4, 30, value=8, step=1, label="FPS") # [attached_file:1] | |
| with gr.Row(): | |
| seed = gr.Number(value=42, precision=0, label="Seed") # [attached_file:1] | |
| width = gr.Number(value=None, precision=0, label="Largura (opcional)") # [attached_file:1] | |
| height = gr.Number(value=None, precision=0, label="Altura (opcional)") # [attached_file:1] | |
| with gr.Row(): | |
| model_id_ui = gr.Textbox(value=MODEL_ID, label="Model ID (SD1.5 finetune)") # [attached_file:1] | |
| adapter_id_ui = gr.Textbox(value=ADAPTER_ID, label="MotionAdapter ID") # [attached_file:1] | |
| with gr.Row(): | |
| cpu_offload = gr.Checkbox(value=False, label="CPU offload") # [attached_file:1] | |
| save_mp4 = gr.Checkbox(value=False, label="Salvar MP4") # [attached_file:1] | |
| run_btn = gr.Button("Gerar animação") # [attached_file:1] | |
| with gr.Column(scale=1): | |
| video_out = gr.Video(label="Preview (GIF)") # [attached_file:1] | |
| file_mp4 = gr.File(label="MP4 (download)", interactive=False) # [attached_file:1] | |
| status = gr.Textbox(label="Status", interactive=False) # [attached_file:1] | |
| def _run(*args): | |
| temp_gif, mp4_path, msg = generate(*args) | |
| return temp_gif, mp4_path, msg # [attached_file:1] | |
| run_btn.click( | |
| _run, | |
| inputs=[image, prompt, negative, frames, steps, guidance, seed, width, height, fps, save_mp4, model_id_ui, adapter_id_ui, cpu_offload], | |
| outputs=[video_out, file_mp4, status] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = ui() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=True) # [attached_file:1] | |