Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,20 +4,20 @@ from utils import get_patch_embeddings, compute_patch_similarity, overlay_simila
|
|
| 4 |
|
| 5 |
selected_patch = {"row": 0, "col": 0}
|
| 6 |
|
| 7 |
-
def init_states(img):
|
| 8 |
if img is None:
|
| 9 |
return gr.update(value=None), None
|
| 10 |
patch_embs, patch_embs_norm, rows, cols = get_patch_embeddings(img, ps=16, device=device)
|
| 11 |
|
| 12 |
sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, 0, 0)
|
| 13 |
-
result_img = overlay_similarity(img, sim_map, alpha=
|
| 14 |
|
| 15 |
state = {
|
| 16 |
"img": img,
|
| 17 |
"patch_embs": patch_embs,
|
| 18 |
"patch_embs_norm": patch_embs_norm,
|
| 19 |
"grid_size": rows,
|
| 20 |
-
"alpha":
|
| 21 |
"overlay_img":result_img,
|
| 22 |
}
|
| 23 |
|
|
@@ -49,7 +49,7 @@ def store_patch(evt, state):
|
|
| 49 |
return state
|
| 50 |
|
| 51 |
|
| 52 |
-
def reload_overlay(evt: gr.SelectData,state):
|
| 53 |
if state is None:
|
| 54 |
return None
|
| 55 |
store_patch(evt, state)
|
|
@@ -57,7 +57,6 @@ def reload_overlay(evt: gr.SelectData,state):
|
|
| 57 |
img = state["img"]
|
| 58 |
patch_embs = state["patch_embs"]
|
| 59 |
patch_embs_norm = state["patch_embs_norm"]
|
| 60 |
-
alpha = state["alpha"]
|
| 61 |
sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, row, col)
|
| 62 |
result_img = overlay_similarity(img, sim_map, alpha=alpha, cmap="hot")
|
| 63 |
return result_img
|
|
@@ -125,14 +124,15 @@ with gr.Blocks() as demo:
|
|
| 125 |
|
| 126 |
img_input.change(
|
| 127 |
fn=init_states,
|
| 128 |
-
inputs=[img_input],
|
| 129 |
outputs=[state_store, output_img]
|
| 130 |
)
|
| 131 |
|
| 132 |
output_img.select(
|
| 133 |
fn=reload_overlay,
|
| 134 |
-
inputs=[state_store],
|
| 135 |
outputs=[output_img]
|
| 136 |
)
|
| 137 |
|
|
|
|
| 138 |
demo.launch()
|
|
|
|
| 4 |
|
| 5 |
selected_patch = {"row": 0, "col": 0}
|
| 6 |
|
| 7 |
+
def init_states(img, alpha):
|
| 8 |
if img is None:
|
| 9 |
return gr.update(value=None), None
|
| 10 |
patch_embs, patch_embs_norm, rows, cols = get_patch_embeddings(img, ps=16, device=device)
|
| 11 |
|
| 12 |
sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, 0, 0)
|
| 13 |
+
result_img = overlay_similarity(img, sim_map, alpha=alpha, cmap="hot")
|
| 14 |
|
| 15 |
state = {
|
| 16 |
"img": img,
|
| 17 |
"patch_embs": patch_embs,
|
| 18 |
"patch_embs_norm": patch_embs_norm,
|
| 19 |
"grid_size": rows,
|
| 20 |
+
"alpha": alpha,
|
| 21 |
"overlay_img":result_img,
|
| 22 |
}
|
| 23 |
|
|
|
|
| 49 |
return state
|
| 50 |
|
| 51 |
|
| 52 |
+
def reload_overlay(evt: gr.SelectData,state,alpha):
|
| 53 |
if state is None:
|
| 54 |
return None
|
| 55 |
store_patch(evt, state)
|
|
|
|
| 57 |
img = state["img"]
|
| 58 |
patch_embs = state["patch_embs"]
|
| 59 |
patch_embs_norm = state["patch_embs_norm"]
|
|
|
|
| 60 |
sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, row, col)
|
| 61 |
result_img = overlay_similarity(img, sim_map, alpha=alpha, cmap="hot")
|
| 62 |
return result_img
|
|
|
|
| 124 |
|
| 125 |
img_input.change(
|
| 126 |
fn=init_states,
|
| 127 |
+
inputs=[img_input, alpha_slider],
|
| 128 |
outputs=[state_store, output_img]
|
| 129 |
)
|
| 130 |
|
| 131 |
output_img.select(
|
| 132 |
fn=reload_overlay,
|
| 133 |
+
inputs=[state_store, alpha_slider],
|
| 134 |
outputs=[output_img]
|
| 135 |
)
|
| 136 |
|
| 137 |
+
|
| 138 |
demo.launch()
|