HBDing commited on
Commit
9a25740
·
1 Parent(s): 2c01375
Files changed (1) hide show
  1. app.py +32 -10
app.py CHANGED
@@ -1,6 +1,17 @@
1
  import gradio as gr
2
  from gradio_image_annotation import image_annotator
 
3
  import os
 
 
 
 
 
 
 
 
 
 
4
  example_annotation = {
5
  "image": os.path.join(os.path.dirname(__file__), "background.png"),
6
  "boxes": [],
@@ -12,23 +23,34 @@ def get_boxes_json(annotations):
12
  width = image.shape[1]
13
  height = image.shape[0]
14
  boxes = annotations["boxes"]
 
15
  for box in boxes:
16
  box["xmin"] = box["xmin"] / width
17
  box["xmax"] = box["xmax"] / width
18
  box["ymin"] = box["ymin"] / height
19
  box["ymax"] = box["ymax"] / height
20
- return annotations["boxes"]
 
 
 
 
 
21
 
22
  with gr.Blocks() as demo:
23
- with gr.Tab("DreamRenderer", id="DreamRenderer"):
24
- annotator = image_annotator(example_annotation,
25
- height=512,
26
- width=512
27
- )
28
-
29
- button_get = gr.Button("Get bounding boxes")
30
- json_boxes = gr.JSON()
31
- button_get.click(get_boxes_json, annotator, json_boxes)
 
 
 
 
 
32
 
33
  if __name__ == "__main__":
34
  demo.launch()
 
1
  import gradio as gr
2
  from gradio_image_annotation import image_annotator
3
+ from diffusers import StableDiffusionPipeline
4
  import os
5
+ import torch
6
+
7
+
8
+ # Load model
9
+ pipe = StableDiffusionPipeline.from_pretrained(
10
+ "runwayml/stable-diffusion-v1-5",
11
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
12
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
13
+ pipe.safety_checker = None
14
+
15
  example_annotation = {
16
  "image": os.path.join(os.path.dirname(__file__), "background.png"),
17
  "boxes": [],
 
23
  width = image.shape[1]
24
  height = image.shape[0]
25
  boxes = annotations["boxes"]
26
+ prompt_final = [[]]
27
  for box in boxes:
28
  box["xmin"] = box["xmin"] / width
29
  box["xmax"] = box["xmax"] / width
30
  box["ymin"] = box["ymin"] / height
31
  box["ymax"] = box["ymax"] / height
32
+ prompt_final[0].append(box["label"])
33
+ # import pdb; pdb.set_trace()
34
+ prompt = ", ".join(prompt_final[0])
35
+ image = pipe(prompt).images[0]
36
+ return image
37
+ # return annotations["boxes"]
38
 
39
  with gr.Blocks() as demo:
40
+ with gr.Tab("DreamRenderer", id="DreamRenderer"):
41
+ with gr.Row():
42
+ with gr.Column(scale=1):
43
+ annotator = image_annotator(
44
+ example_annotation,
45
+ height=512,
46
+ width=512
47
+ )
48
+ with gr.Column(scale=1):
49
+ generated_image = gr.Image(label="Generated Image", height=512, width=512)
50
+
51
+ button_get = gr.Button("Generation")
52
+ button_get.click(get_boxes_json, inputs=annotator, outputs=generated_image)
53
+
54
 
55
  if __name__ == "__main__":
56
  demo.launch()