Bertoin commited on
Commit
066a555
·
verified ·
1 Parent(s): 3ce0f3d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +155 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import spaces
5
+ import torch
6
+ from diffusers.pipelines.prx import PRXPipeline
7
+
8
+ dtype = torch.bfloat16
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load the PRX pipeline with the distilled DC-AE model
12
+ pipe = PRXPipeline.from_pretrained(
13
+ "Photoroom/prx-512-t2i-dc-ae-sft-distilled",
14
+ torch_dtype=dtype
15
+ )
16
+ pipe = pipe.to(device)
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ MAX_IMAGE_SIZE = 1024
20
+
21
+
22
+ @spaces.GPU()
23
+ def infer(
24
+ prompt,
25
+ seed=42,
26
+ randomize_seed=False,
27
+ width=512,
28
+ height=512,
29
+ num_inference_steps=8,
30
+ guidance_scale=1.0,
31
+ progress=gr.Progress(track_tqdm=True)
32
+ ):
33
+ if randomize_seed:
34
+ seed = random.randint(0, MAX_SEED)
35
+
36
+ generator = torch.Generator(device=device).manual_seed(seed)
37
+
38
+ image = pipe(
39
+ prompt=prompt,
40
+ width=width,
41
+ height=height,
42
+ num_inference_steps=num_inference_steps,
43
+ generator=generator,
44
+ guidance_scale=guidance_scale,
45
+ ).images[0]
46
+
47
+ return image, seed
48
+
49
+
50
+ examples = [
51
+ "A front-facing portrait of a lion on the golden savanna at sunset.",
52
+ "A serene mountain landscape with a crystal clear lake reflecting snow-capped peaks.",
53
+ "A futuristic cityscape at night with neon lights and flying vehicles.",
54
+ "A whimsical illustration of a cat wearing a wizard hat and casting spells.",
55
+ "A cozy coffee shop interior with warm lighting and plants on the windowsill.",
56
+ ]
57
+
58
+ css = """
59
+ #col-container {
60
+ margin: 0 auto;
61
+ max-width: 640px;
62
+ }
63
+ """
64
+
65
+ with gr.Blocks(css=css) as demo:
66
+ with gr.Column(elem_id="col-container"):
67
+ gr.Markdown("# PRX Image Generator")
68
+ gr.Markdown(
69
+ "Generate high-quality images using the PRX distilled model with DC-AE compression. "
70
+ "This model uses 8-step distillation for fast inference with cfg=1.0. "
71
+ "Works best with less detailed prompts in natural language."
72
+ )
73
+
74
+ with gr.Row():
75
+ prompt = gr.Text(
76
+ label="Prompt",
77
+ show_label=False,
78
+ max_lines=1,
79
+ placeholder="Enter your prompt",
80
+ container=False,
81
+ )
82
+
83
+ run_button = gr.Button("Run", scale=0)
84
+
85
+ result = gr.Image(label="Result", show_label=False)
86
+
87
+ with gr.Accordion("Advanced Settings", open=False):
88
+ seed = gr.Slider(
89
+ label="Seed",
90
+ minimum=0,
91
+ maximum=MAX_SEED,
92
+ step=1,
93
+ value=0,
94
+ )
95
+
96
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
97
+
98
+ with gr.Row():
99
+ width = gr.Slider(
100
+ label="Width",
101
+ minimum=256,
102
+ maximum=MAX_IMAGE_SIZE,
103
+ step=32,
104
+ value=512,
105
+ )
106
+
107
+ height = gr.Slider(
108
+ label="Height",
109
+ minimum=256,
110
+ maximum=MAX_IMAGE_SIZE,
111
+ step=32,
112
+ value=512,
113
+ )
114
+
115
+ with gr.Row():
116
+ num_inference_steps = gr.Slider(
117
+ label="Number of inference steps",
118
+ minimum=1,
119
+ maximum=28,
120
+ step=1,
121
+ value=8,
122
+ )
123
+
124
+ guidance_scale = gr.Slider(
125
+ label="Guidance scale",
126
+ minimum=0.0,
127
+ maximum=5.0,
128
+ step=0.1,
129
+ value=1.0,
130
+ )
131
+
132
+ gr.Examples(
133
+ examples=examples,
134
+ fn=infer,
135
+ inputs=[prompt],
136
+ outputs=[result, seed],
137
+ cache_examples="lazy"
138
+ )
139
+
140
+ gr.on(
141
+ triggers=[run_button.click, prompt.submit],
142
+ fn=infer,
143
+ inputs=[
144
+ prompt,
145
+ seed,
146
+ randomize_seed,
147
+ width,
148
+ height,
149
+ num_inference_steps,
150
+ guidance_scale,
151
+ ],
152
+ outputs=[result, seed]
153
+ )
154
+
155
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ torch
4
+ transformers
5
+ sentencepiece
6
+ spaces
7
+ ftfy