NeerjaK commited on
Commit
6800efc
·
verified ·
1 Parent(s): 1ed363b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -4,20 +4,20 @@ from utils import get_patch_embeddings, compute_patch_similarity, overlay_simila
4
 
5
  selected_patch = {"row": 0, "col": 0}
6
 
7
- def init_states(img):
8
  if img is None:
9
  return gr.update(value=None), None
10
  patch_embs, patch_embs_norm, rows, cols = get_patch_embeddings(img, ps=16, device=device)
11
 
12
  sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, 0, 0)
13
- result_img = overlay_similarity(img, sim_map, alpha=0.6, cmap="hot")
14
 
15
  state = {
16
  "img": img,
17
  "patch_embs": patch_embs,
18
  "patch_embs_norm": patch_embs_norm,
19
  "grid_size": rows,
20
- "alpha": 0.6,
21
  "overlay_img":result_img,
22
  }
23
 
@@ -49,7 +49,7 @@ def store_patch(evt, state):
49
  return state
50
 
51
 
52
- def reload_overlay(evt: gr.SelectData,state):
53
  if state is None:
54
  return None
55
  store_patch(evt, state)
@@ -57,7 +57,6 @@ def reload_overlay(evt: gr.SelectData,state):
57
  img = state["img"]
58
  patch_embs = state["patch_embs"]
59
  patch_embs_norm = state["patch_embs_norm"]
60
- alpha = state["alpha"]
61
  sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, row, col)
62
  result_img = overlay_similarity(img, sim_map, alpha=alpha, cmap="hot")
63
  return result_img
@@ -125,14 +124,15 @@ with gr.Blocks() as demo:
125
 
126
  img_input.change(
127
  fn=init_states,
128
- inputs=[img_input],
129
  outputs=[state_store, output_img]
130
  )
131
 
132
  output_img.select(
133
  fn=reload_overlay,
134
- inputs=[state_store],
135
  outputs=[output_img]
136
  )
137
 
 
138
  demo.launch()
 
4
 
5
  selected_patch = {"row": 0, "col": 0}
6
 
7
+ def init_states(img, alpha):
8
  if img is None:
9
  return gr.update(value=None), None
10
  patch_embs, patch_embs_norm, rows, cols = get_patch_embeddings(img, ps=16, device=device)
11
 
12
  sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, 0, 0)
13
+ result_img = overlay_similarity(img, sim_map, alpha=alpha, cmap="hot")
14
 
15
  state = {
16
  "img": img,
17
  "patch_embs": patch_embs,
18
  "patch_embs_norm": patch_embs_norm,
19
  "grid_size": rows,
20
+ "alpha": alpha,
21
  "overlay_img":result_img,
22
  }
23
 
 
49
  return state
50
 
51
 
52
+ def reload_overlay(evt: gr.SelectData,state,alpha):
53
  if state is None:
54
  return None
55
  store_patch(evt, state)
 
57
  img = state["img"]
58
  patch_embs = state["patch_embs"]
59
  patch_embs_norm = state["patch_embs_norm"]
 
60
  sim_map = compute_patch_similarity(patch_embs, patch_embs_norm, row, col)
61
  result_img = overlay_similarity(img, sim_map, alpha=alpha, cmap="hot")
62
  return result_img
 
124
 
125
  img_input.change(
126
  fn=init_states,
127
+ inputs=[img_input, alpha_slider],
128
  outputs=[state_store, output_img]
129
  )
130
 
131
  output_img.select(
132
  fn=reload_overlay,
133
+ inputs=[state_store, alpha_slider],
134
  outputs=[output_img]
135
  )
136
 
137
+
138
  demo.launch()