File size: 6,965 Bytes
88c3d2e
6f11d0a
30a7037
003a51b
88c3d2e
 
30a7037
 
6f11d0a
 
 
30a7037
88c3d2e
 
 
 
 
 
6f11d0a
cae213f
cb89de3
6f11d0a
 
cb89de3
6f11d0a
 
 
 
 
 
 
 
 
 
 
 
 
cb89de3
6f11d0a
cae213f
88c3d2e
 
 
 
 
 
cae213f
cb89de3
6f11d0a
 
003a51b
6f11d0a
003a51b
30a7037
cb89de3
6f11d0a
88c3d2e
6f11d0a
30a7037
cae213f
cb89de3
88c3d2e
30a7037
 
88c3d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cae213f
6f11d0a
 
88c3d2e
6f11d0a
88c3d2e
6f11d0a
cae213f
88c3d2e
30a7037
88c3d2e
 
 
 
 
 
cae213f
 
30a7037
6f11d0a
 
 
 
 
30a7037
6f11d0a
88c3d2e
 
 
 
 
6f11d0a
cae213f
 
6f11d0a
88c3d2e
6f11d0a
30a7037
6f11d0a
30a7037
88c3d2e
 
6f11d0a
88c3d2e
 
6f11d0a
 
 
88c3d2e
6f11d0a
 
88c3d2e
6f11d0a
 
88c3d2e
6f11d0a
 
 
88c3d2e
6f11d0a
 
88c3d2e
6f11d0a
 
 
88c3d2e
6f11d0a
 
 
cae213f
88c3d2e
6f11d0a
 
cae213f
88c3d2e
 
 
cae213f
88c3d2e
 
003a51b
 
88c3d2e
6f11d0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os, io, tempfile
from typing import Optional
from PIL import Image
import torch
import gradio as gr
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif

# Modelos padrão (ajuste se desejar)
MODEL_ID = "SG161222/Realistic_Vision_V5.1_noVAE"         # SD1.5 finetunado [attached_file:1]
ADAPTER_ID = "guoyww/animatediff-motion-adapter-v1-5-2"   # MotionAdapter p/ SD1.4/1.5 [attached_file:1]

pipe = None

def load_pipe(model_id: str, adapter_id: str, cpu_offload: bool):
    global pipe
    if pipe is not None:
        return pipe
    # dtype preferível: float16 em CUDA, senão float32
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    # MotionAdapter não aceita dtype em from_pretrained nas versões atuais
    adapter = MotionAdapter.from_pretrained(adapter_id)  # [attached_file:1]

    # Carregar pipeline com dtype
    try:
        p = AnimateDiffPipeline.from_pretrained(
            model_id,
            motion_adapter=adapter,
            dtype=dtype  # novas versões aceitam 'dtype' [attached_file:1]
        )
    except TypeError:
        p = AnimateDiffPipeline.from_pretrained(
            model_id,
            motion_adapter=adapter,
            torch_dtype=dtype  # fallback para versões que ainda usam torch_dtype [attached_file:1]
        )

    # Scheduler recomendado para estabilidade temporal
    p.scheduler = DDIMScheduler.from_pretrained(
        model_id,
        subfolder="scheduler",
        clip_sample=False,
        timestep_spacing="linspace",
        beta_schedule="linear",
        steps_offset=1
    )  # [attached_file:1]

    # Otimizações de VRAM (APIs novas via VAE)
    p.vae.enable_slicing()  # [attached_file:1]
    try:
        p.vae.enable_tiling()  # útil em resoluções mais altas [attached_file:1]
    except Exception:
        pass

    # Alocação de device / offload
    if cpu_offload and torch.cuda.is_available():
        p.enable_model_cpu_offload()  # reduz pico de VRAM [attached_file:1]
    else:
        p.to("cuda" if torch.cuda.is_available() else "cpu")

    pipe = p
    return pipe

def generate(
    image: Image.Image,
    prompt: str,
    negative_prompt: str,
    num_frames: int,
    steps: int,
    guidance: float,
    seed: int,
    width: Optional[int],
    height: Optional[int],
    fps: int,
    save_mp4: bool,
    model_id_ui: str,
    adapter_id_ui: str,
    cpu_offload: bool
):
    if image is None or not prompt or not prompt.strip():
        return None, None, "Envie uma imagem e um prompt válidos."  # [attached_file:1]

    p = load_pipe(model_id_ui or MODEL_ID, adapter_id_ui or ADAPTER_ID, cpu_offload)

    gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed))

    # img2vid sem IP-Adapter: NÃO passar ip_adapter_image
    out = p(
        prompt=prompt,
        negative_prompt=negative_prompt or "",
        num_frames=int(num_frames),
        num_inference_steps=int(steps),
        guidance_scale=float(guidance),
        generator=gen,
        width=int(width) if width else None,
        height=int(height) if height else None
    )  # [attached_file:1]

    frames = out.frames[0]  # lista de PILs [attached_file:1]

    # Salvar GIF em caminho temporário com extensão .gif (evita erro do PIL)
    temp_gif = os.path.join(tempfile.gettempdir(), "animation.gif")
    export_to_gif(frames, temp_gif, fps=int(fps))  # [attached_file:1]

    # Opcional: gravar MP4 com imageio-ffmpeg
    mp4_path = None
    if save_mp4:
        try:
            import imageio
            mp4_path = os.path.join(tempfile.gettempdir(), "animation.mp4")
            # Converter cada frame PIL para ndarray esperado pelo writer
            with imageio.get_writer(mp4_path, fps=int(fps), codec="libx264", quality=8) as writer:
                for fr in frames:
                    writer.append_data(imageio.v3.imread(io.BytesIO(fr.convert("RGB").tobytes())))
        except Exception:
            mp4_path = None  # se falhar, apenas não retorna MP4

    return temp_gif, mp4_path, f"Gerado {len(frames)} frames @ {fps} fps."  # [attached_file:1]

def ui():
    with gr.Blocks(title="AnimateDiff img2vid") as demo:
        gr.Markdown("## AnimateDiff img2vid")  # [attached_file:1]
        with gr.Row():
            with gr.Column(scale=1):
                image = gr.Image(type="pil", label="Imagem inicial")  # [attached_file:1]
                prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Descreva estilo/movimento...")  # [attached_file:1]
                negative = gr.Textbox(label="Negative prompt", lines=2, value="low quality, worst quality")  # [attached_file:1]
                with gr.Row():
                    frames = gr.Slider(8, 64, value=16, step=1, label="Frames")  # [attached_file:1]
                    steps = gr.Slider(4, 60, value=25, step=1, label="Steps")  # [attached_file:1]
                with gr.Row():
                    guidance = gr.Slider(0.5, 15.0, value=7.5, step=0.5, label="Guidance")  # [attached_file:1]
                    fps = gr.Slider(4, 30, value=8, step=1, label="FPS")  # [attached_file:1]
                with gr.Row():
                    seed = gr.Number(value=42, precision=0, label="Seed")  # [attached_file:1]
                    width = gr.Number(value=None, precision=0, label="Largura (opcional)")  # [attached_file:1]
                    height = gr.Number(value=None, precision=0, label="Altura (opcional)")  # [attached_file:1]
                with gr.Row():
                    model_id_ui = gr.Textbox(value=MODEL_ID, label="Model ID (SD1.5 finetune)")  # [attached_file:1]
                    adapter_id_ui = gr.Textbox(value=ADAPTER_ID, label="MotionAdapter ID")  # [attached_file:1]
                with gr.Row():
                    cpu_offload = gr.Checkbox(value=False, label="CPU offload")  # [attached_file:1]
                    save_mp4 = gr.Checkbox(value=False, label="Salvar MP4")  # [attached_file:1]
                run_btn = gr.Button("Gerar animação")  # [attached_file:1]
            with gr.Column(scale=1):
                video_out = gr.Video(label="Preview (GIF)")  # [attached_file:1]
                file_mp4 = gr.File(label="MP4 (download)", interactive=False)  # [attached_file:1]
                status = gr.Textbox(label="Status", interactive=False)  # [attached_file:1]

        def _run(*args):
            temp_gif, mp4_path, msg = generate(*args)
            return temp_gif, mp4_path, msg  # [attached_file:1]

        run_btn.click(
            _run,
            inputs=[image, prompt, negative, frames, steps, guidance, seed, width, height, fps, save_mp4, model_id_ui, adapter_id_ui, cpu_offload],
            outputs=[video_out, file_mp4, status]
        )
    return demo

if __name__ == "__main__":
    demo = ui()
    demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=True)  # [attached_file:1]