shunk031 commited on
Commit
9a0312e
·
1 Parent(s): 94fe6fb

deploy: fb8481effdf5a0b23ff86fad414906046d7620bd

Browse files
Files changed (1) hide show
  1. layout-unreadability.py +75 -15
layout-unreadability.py CHANGED
@@ -15,7 +15,32 @@ Computes the non-flatness of regions that text elements are solely put on, refer
15
  """
16
 
17
  _KWARGS_DESCRIPTION = """\
18
- FIXME
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
 
21
  _CITATION = """\
@@ -35,8 +60,8 @@ ReqType = Literal["pil2cv", "cv2pil"]
35
  class LayoutUnreadability(evaluate.Metric):
36
  def __init__(
37
  self,
38
- canvas_width: int,
39
- canvas_height: int,
40
  text_label_index: int = 1,
41
  decoration_label_index: int = 3,
42
  **kwargs,
@@ -98,6 +123,8 @@ class LayoutUnreadability(evaluate.Metric):
98
  def load_image_canvas(
99
  self,
100
  filepath: Union[os.PathLike, List[os.PathLike]],
 
 
101
  ) -> npt.NDArray[np.float64]:
102
  if isinstance(filepath, list):
103
  assert len(filepath) == 1, filepath
@@ -105,8 +132,8 @@ class LayoutUnreadability(evaluate.Metric):
105
 
106
  canvas_pil = Image.open(filepath) # type: ignore
107
  canvas_pil = canvas_pil.convert("RGB") # type: ignore
108
- if canvas_pil.size != (self.canvas_width, self.canvas_height):
109
- canvas_pil = canvas_pil.resize((self.canvas_width, self.canvas_height)) # type: ignore
110
 
111
  canvas_pil = self.img_to_g_xy(canvas_pil)
112
  assert isinstance(canvas_pil, PilImage)
@@ -115,20 +142,24 @@ class LayoutUnreadability(evaluate.Metric):
115
  return canvas_arr
116
 
117
  def get_rid_of_invalid(
118
- self, predictions: npt.NDArray[np.float64], gold_labels: npt.NDArray[np.int64]
 
 
 
 
119
  ) -> npt.NDArray[np.int64]:
120
  assert len(predictions) == len(gold_labels)
121
 
122
- w = self.canvas_width / 100
123
- h = self.canvas_height / 100
124
 
125
  for i, prediction in enumerate(predictions):
126
  for j, b in enumerate(prediction):
127
  xl, yl, xr, yr = b
128
  xl = max(0, xl)
129
  yl = max(0, yl)
130
- xr = min(self.canvas_width, xr)
131
- yr = min(self.canvas_height, yr)
132
  if abs((xr - xl) * (yr - yl)) < w * h * 10:
133
  if gold_labels[i, j]:
134
  gold_labels[i, j] = 0
@@ -140,15 +171,42 @@ class LayoutUnreadability(evaluate.Metric):
140
  predictions: Union[npt.NDArray[np.float64], List[List[float]]],
141
  gold_labels: Union[npt.NDArray[np.int64], List[int]],
142
  image_canvases: List[os.PathLike],
 
 
 
 
143
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  predictions = np.array(predictions)
145
  gold_labels = np.array(gold_labels)
146
 
147
- predictions[:, :, ::2] *= self.canvas_width
148
- predictions[:, :, 1::2] *= self.canvas_height
149
 
150
  gold_labels = self.get_rid_of_invalid(
151
- predictions=predictions, gold_labels=gold_labels
 
 
 
152
  )
153
  score = 0.0
154
 
@@ -159,16 +217,18 @@ class LayoutUnreadability(evaluate.Metric):
159
  for prediction, gold_label, image_canvas in it:
160
  canvas_arr = self.load_image_canvas(
161
  image_canvas,
 
 
162
  )
163
  cal_mask = np.zeros_like(canvas_arr)
164
 
165
  prediction = np.array(prediction, dtype=int)
166
  gold_label = np.array(gold_label, dtype=int)
167
 
168
- is_text = (gold_label == self.text_label_index).reshape(-1)
169
  prediction_text = prediction[is_text]
170
 
171
- is_decoration = (gold_label == self.decoration_label_index).reshape(-1)
172
  prediction_deco = prediction[is_decoration]
173
 
174
  for mp in prediction_text:
 
15
  """
16
 
17
  _KWARGS_DESCRIPTION = """\
18
+ Args:
19
+ predictions (`list` of `list` of `float`): A list of lists of floats representing normalized `ltrb`-format bounding boxes.
20
+ gold_labels (`list` of `list` of `int`): A list of lists of integers representing class labels.
21
+ image_canvases (`list` of `str`): A list of file paths to canvas images (background images).
22
+ canvas_width (`int`, *optional*): Width of the canvas in pixels. Can be provided at initialization or during computation.
23
+ canvas_height (`int`, *optional*): Height of the canvas in pixels. Can be provided at initialization or during computation.
24
+ text_label_index (`int`, *optional*, defaults to 1): The label index for text elements.
25
+ decoration_label_index (`int`, *optional*, defaults to 3): The label index for decoration (underlay) elements.
26
+
27
+ Returns:
28
+ float: The unreadability score measuring the non-flatness of regions where text elements are placed. Computed using gradient analysis (Sobel operator) on the canvas image. Lower values indicate better readability (text on flatter/cleaner backgrounds).
29
+
30
+ Examples:
31
+ >>> import evaluate
32
+ >>> metric = evaluate.load("creative-graphic-design/layout-unreadability")
33
+ >>> predictions = [[[0.1, 0.1, 0.5, 0.3], [0.6, 0.6, 0.9, 0.8]]]
34
+ >>> gold_labels = [[1, 2]] # 1 is text, 2 is other element
35
+ >>> image_canvases = ["/path/to/canvas.png"]
36
+ >>> result = metric.compute(
37
+ ... predictions=predictions,
38
+ ... gold_labels=gold_labels,
39
+ ... image_canvases=image_canvases,
40
+ ... canvas_width=512,
41
+ ... canvas_height=512
42
+ ... )
43
+ >>> print(f"Unreadability score: {result:.4f}")
44
  """
45
 
46
  _CITATION = """\
 
60
  class LayoutUnreadability(evaluate.Metric):
61
  def __init__(
62
  self,
63
+ canvas_width: int | None = None,
64
+ canvas_height: int | None = None,
65
  text_label_index: int = 1,
66
  decoration_label_index: int = 3,
67
  **kwargs,
 
123
  def load_image_canvas(
124
  self,
125
  filepath: Union[os.PathLike, List[os.PathLike]],
126
+ canvas_width: int,
127
+ canvas_height: int,
128
  ) -> npt.NDArray[np.float64]:
129
  if isinstance(filepath, list):
130
  assert len(filepath) == 1, filepath
 
132
 
133
  canvas_pil = Image.open(filepath) # type: ignore
134
  canvas_pil = canvas_pil.convert("RGB") # type: ignore
135
+ if canvas_pil.size != (canvas_width, canvas_height):
136
+ canvas_pil = canvas_pil.resize((canvas_width, canvas_height)) # type: ignore
137
 
138
  canvas_pil = self.img_to_g_xy(canvas_pil)
139
  assert isinstance(canvas_pil, PilImage)
 
142
  return canvas_arr
143
 
144
  def get_rid_of_invalid(
145
+ self,
146
+ predictions: npt.NDArray[np.float64],
147
+ gold_labels: npt.NDArray[np.int64],
148
+ canvas_width: int,
149
+ canvas_height: int,
150
  ) -> npt.NDArray[np.int64]:
151
  assert len(predictions) == len(gold_labels)
152
 
153
+ w = canvas_width / 100
154
+ h = canvas_height / 100
155
 
156
  for i, prediction in enumerate(predictions):
157
  for j, b in enumerate(prediction):
158
  xl, yl, xr, yr = b
159
  xl = max(0, xl)
160
  yl = max(0, yl)
161
+ xr = min(canvas_width, xr)
162
+ yr = min(canvas_height, yr)
163
  if abs((xr - xl) * (yr - yl)) < w * h * 10:
164
  if gold_labels[i, j]:
165
  gold_labels[i, j] = 0
 
171
  predictions: Union[npt.NDArray[np.float64], List[List[float]]],
172
  gold_labels: Union[npt.NDArray[np.int64], List[int]],
173
  image_canvases: List[os.PathLike],
174
+ canvas_width: int | None = None,
175
+ canvas_height: int | None = None,
176
+ text_label_index: int | None = None,
177
+ decoration_label_index: int | None = None,
178
  ):
179
+ # パラメータの優先順位処理
180
+ canvas_width = canvas_width if canvas_width is not None else self.canvas_width
181
+ canvas_height = (
182
+ canvas_height if canvas_height is not None else self.canvas_height
183
+ )
184
+ text_label_index = (
185
+ text_label_index if text_label_index is not None else self.text_label_index
186
+ )
187
+ decoration_label_index = (
188
+ decoration_label_index
189
+ if decoration_label_index is not None
190
+ else self.decoration_label_index
191
+ )
192
+
193
+ if canvas_width is None or canvas_height is None:
194
+ raise ValueError(
195
+ "canvas_width and canvas_height must be provided either "
196
+ "at initialization or during computation"
197
+ )
198
+
199
  predictions = np.array(predictions)
200
  gold_labels = np.array(gold_labels)
201
 
202
+ predictions[:, :, ::2] *= canvas_width
203
+ predictions[:, :, 1::2] *= canvas_height
204
 
205
  gold_labels = self.get_rid_of_invalid(
206
+ predictions=predictions,
207
+ gold_labels=gold_labels,
208
+ canvas_width=canvas_width,
209
+ canvas_height=canvas_height,
210
  )
211
  score = 0.0
212
 
 
217
  for prediction, gold_label, image_canvas in it:
218
  canvas_arr = self.load_image_canvas(
219
  image_canvas,
220
+ canvas_width,
221
+ canvas_height,
222
  )
223
  cal_mask = np.zeros_like(canvas_arr)
224
 
225
  prediction = np.array(prediction, dtype=int)
226
  gold_label = np.array(gold_label, dtype=int)
227
 
228
+ is_text = (gold_label == text_label_index).reshape(-1)
229
  prediction_text = prediction[is_text]
230
 
231
+ is_decoration = (gold_label == decoration_label_index).reshape(-1)
232
  prediction_deco = prediction[is_decoration]
233
 
234
  for mp in prediction_text: