MikkoLipsanen commited on
Commit
c99fa8d
·
verified ·
1 Parent(s): f097c5b

Update segmentation to use rfdetr model

Browse files
Files changed (1) hide show
  1. segment_image.py +456 -316
segment_image.py CHANGED
@@ -1,344 +1,484 @@
1
- from huggingface_hub import hf_hub_download
2
  from shapely.validation import make_valid
3
  from shapely.geometry import Polygon
4
- from ultralytics import YOLO
5
- from PIL import Image
6
  import numpy as np
 
7
  import os
8
 
9
- from reading_order import OrderPolygons
 
 
 
 
 
 
 
 
10
 
11
  class SegmentImage:
12
- """Class for segmenting document image regions and text lines."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def __init__(self,
14
- line_model_path,
15
- device,
16
- line_iou=0.5,
17
- region_iou=0.5,
18
- line_overlap=0.5,
19
- line_nms_iou=0.7,
20
- region_nms_iou=0.3,
21
- line_conf_threshold=0.25,
22
- region_conf_threshold=0.25,
23
- region_model_path=None,
24
- order_regions=True,
25
- region_half_precision=False,
26
- line_half_precision=False):
27
-
28
- # Path to text line detection model
29
- self.line_model_path = line_model_path
30
- # Path to text region detection model
31
- self.region_model_path = region_model_path
32
- # Defines the IoU threshold used in the non-maximum suppression (NMS) process to
33
- # determine which prediction boxes should be suppressed or discarded based on their overlap with other boxes
34
- self.line_nms_iou = line_nms_iou
35
- self.region_nms_iou = region_nms_iou
36
- # Defines the IoU threshold for text lines
37
  self.line_iou = line_iou
38
- # Defines the IoU threshold for text regions
39
  self.region_iou = region_iou
40
- # Defines the extent of line polygon overlap used for merging the polygons
41
- self.line_overlap = line_overlap
42
- # Defines confidence threshold for line detection
43
- self.line_conf_threshold = line_conf_threshold
44
- # Defines confidence threshold for region detection
45
- self.region_conf_threshold = region_conf_threshold
46
- # Defines the device to be used ('cpu', gpu '0', gpu '1' etc.)
47
- self.device = device
48
- # Defines whether a reading order is also estimated for the region detections
49
- self.order_regions = order_regions
50
- # Defines whether half precision (FP16) is used by the region and line prediction models
51
- self.region_half_precision = region_half_precision
52
- self.line_half_precision = line_half_precision
53
- self.order_poly = OrderPolygons()
54
- # Initialize segmentation model(s)
55
- self.line_model = self.init_line_model()
56
- if self.region_model_path:
57
- self.region_model = self.init_region_model()
58
-
59
- def init_line_model(self):
60
- """Function for initializing the line detection model."""
61
- try:
62
- # Load the trained line detection model
63
- cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt")
64
- line_model = YOLO(cached_model_path)
65
- return line_model
66
- except Exception as e:
67
- print('Failed to load the line detection model: %s' % e)
68
 
69
- def init_region_model(self):
70
- """Function for initializing the region detection model."""
 
71
  try:
72
- # Load the trained line detection model
73
- cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt")
74
- region_model = YOLO(cached_model_path)
75
- return region_model
76
  except Exception as e:
77
- print('Failed to load the region detection model: %s' % e)
78
 
79
- def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape):
80
- """Function for creating unique id for each detected region."""
81
- n = min(len(classes), len(coords))
82
- res = []
83
- for i in range(n):
84
- # Creates a simple index-based id for each region
85
- region_id = str(i)
86
- # Extracts region name corresponding to the index
87
- region_type = names[classes[i]]
88
- poly_dict = {'coords': coords[i],
89
- 'max_min': max_min[i],
90
- 'class': str(classes[i]),
91
- 'name': region_type,
92
- 'conf': box_confs[i],
93
- 'id': region_id,
94
- 'img_shape': img_shape}
95
- res.append(poly_dict)
96
- return res
97
-
98
- def get_max_min(self, polygons):
99
- """Creates an array with the minimum and maximum
100
- x and y values of the input polygons."""
101
- n_rows = len(polygons)
102
- xy_array = np.zeros([n_rows, 4])
103
- for i, poly in enumerate(polygons):
104
- x = [point[0] for point in poly]
105
- y = [point[1] for point in poly]
106
- if x:
107
- xy_array[i,0] = max(x)
108
- xy_array[i,1] = min(x)
109
- if y:
110
- xy_array[i,2] = max(y)
111
- xy_array[i,3] = min(y)
112
- return xy_array
113
-
114
- def validate_polygon(self, polygon):
115
- """"Function for testing and correcting the validity of polygons."""
116
  if len(polygon) > 2:
117
- polygon = Polygon(polygon)
118
- if not polygon.is_valid:
119
- polygon = make_valid(polygon)
120
- return polygon
 
 
 
 
121
  else:
122
  return None
123
 
124
- def get_iou(self, poly1, poly2):
125
- """Function for calculating Intersection over Union (IoU) values."""
126
- # If the polygons don't intersect, IoU is 0
127
- iou = 0
128
- poly1 = self.validate_polygon(poly1)
129
- poly2 = self.validate_polygon(poly2)
130
-
131
- if poly1 and poly2:
132
- if poly1.intersects(poly2):
133
- # Calculates intersection of the 2 polygons
134
- intersect = poly1.intersection(poly2).area
135
- # Calculates union of the 2 polygons
136
- uni = poly1.union(poly2)
137
- # Calculates intersection over union
138
- iou = intersect / uni.area
139
- return iou
140
-
141
- def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None):
142
- """Merges polygons that have an IoU value
143
- above the given threshold."""
144
- new_polygons = []
145
- dropped = set()
146
- # Loops over all input polygons and merges them if the
147
- # IoU value is over the given threshold
148
- for i in range(0, len(polygons)):
149
- poly1 = self.validate_polygon(polygons[i])
150
- merged = None
151
- for j in range(i+1, len(polygons)):
152
- poly2 = self.validate_polygon(polygons[j])
153
- if poly1 and poly2:
154
- if poly1.intersects(poly2):
155
- overlap = False
156
- intersect = poly1.intersection(poly2)
157
- uni = poly1.union(poly2)
158
- # Calculates intersection over union
159
- iou = intersect.area / uni.area
160
- if overlap_threshold:
161
- overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area))
162
- if (iou > iou_threshold) or overlap:
163
- if merged:
164
- # If there are multiple overlapping polygons
165
- # with IoU over the threshold, they are all merged together
166
- merged = uni.union(merged)
167
- dropped.add(j)
168
- else:
169
- merged = uni
170
- # Polygons that are merged together are dropped from
171
- # the list
172
- dropped.add(i)
173
- dropped.add(j)
174
- if merged:
175
- if merged.geom_type in ['GeometryCollection','MultiPolygon']:
176
- for geom in merged.geoms:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  if geom.geom_type == 'Polygon':
178
- new_polygons.append(list(geom.exterior.coords))
179
- elif merged.geom_type == 'Polygon':
180
- new_polygons.append(list(merged.exterior.coords))
181
- res = [i for j, i in enumerate(polygons) if j not in dropped]
182
- res += new_polygons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- return res
185
-
186
- def get_region_preds(self, img):
187
- """Function for predicting text region coordinates."""
188
- results = self.region_model.predict(source=img,
189
- device=self.device,
190
- conf=self.region_conf_threshold,
191
- half=bool(self.region_half_precision),
192
- iou=self.region_nms_iou)
193
- results = results[0].cpu()
194
- if results.masks:
195
- # Extracts detected region polygons
196
- coords = results.masks.xy
197
- # Merge overlapping polygons
198
- coords = self.merge_polygons(coords, self.region_iou)
199
- # Maximum and minimum x and y axis values for detected polygons used for ordering the polygons
200
- max_min = self.get_max_min(coords).tolist()
201
- # Gets a list of the predicted class labels for detected regions
202
- classes = results.boxes.cls.tolist()
203
- # A dictionary with class ids as keys and class names as values
204
- names = results.names
205
- # Confidence values for detections
206
- box_confs = results.boxes.conf.tolist()
207
- # A tuple containing the shape of the original image
208
- img_shape = results.orig_shape
209
- res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape)
210
- return res
211
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  return None
213
 
 
 
 
214
 
215
- def get_line_preds(self, img):
216
- """Function for predicting text line coordinates."""
217
- results = self.line_model.predict(source=img,
218
- device=self.device,
219
- conf=self.line_conf_threshold,
220
- half=bool(self.line_half_precision),
221
- iou=self.line_nms_iou)
222
- results = results[0].cpu()
223
- if results.masks:
224
- # Detected text line polygons
225
- coords = results.masks.xy
226
- # Merge overlapping polygons
227
- coords = self.merge_polygons(coords, self.line_iou, self.line_overlap)
228
- # Maximum and minimum x and y axis values for detected polygons
229
- max_min = self.get_max_min(coords).tolist()
230
- # Confidence values for detections
231
- box_confs = results.boxes.conf.tolist()
232
- res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs}
233
- return res_dict
234
- else:
 
235
  return None
236
 
237
- def get_dist(self, line_polygon, regions):
238
- """Function for finding the closest region to the text line."""
239
- dist, reg_id = 1000000, None
240
- line_polygon = self.validate_polygon(line_polygon)
241
-
242
- if line_polygon:
243
- for region in regions:
244
- # Calculates dictance between line and regions polygons
245
- region_polygon = self.validate_polygon(region['coords'])
246
- if region_polygon:
247
- line_reg_dist = line_polygon.distance(region_polygon)
248
- if line_reg_dist < dist:
249
- dist = line_reg_dist
250
- reg_id = region['id']
251
- return reg_id
252
-
253
- def get_line_regions(self, lines, regions):
254
- """Function for connecting each text line to one region."""
255
- lines_list = []
256
- for i in range(len(lines['coords'])):
257
- iou, reg_id, conf = 0, '', 0.0
258
- max_min = [0.0, 0.0, 0.0, 0.0]
259
- polygon = lines['coords'][i]
260
- for region in regions:
261
- line_reg_iou = self.get_iou(polygon, region['coords'])
262
- if line_reg_iou > iou:
263
- iou = line_reg_iou
264
- reg_id = region['id']
265
- # If line polygon does not intersect with any region, a distance metric is used for defining
266
- # the region that the line belongs to
267
- if iou == 0:
268
- reg_id = self.get_dist(polygon, regions)
269
-
270
- if (len(lines['max_min']) - 1) >= i:
271
- max_min = lines['max_min'][i]
272
-
273
- if (len(lines['confs']) - 1) >= i:
274
- conf = lines['confs'][i]
275
-
276
- new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf}
277
- lines_list.append(new_line)
278
- return lines_list
279
-
280
- def order_regions_lines(self, lines, regions):
281
- """Function for ordering line predictions inside each region."""
282
- regions_with_rows = []
283
- region_max_mins = []
284
- for i, region in enumerate(regions):
285
- line_max_mins = []
286
- line_confs = []
287
- line_polygons = []
288
- for line in lines:
289
- if line['reg_id'] == region['id']:
290
- line_max_mins.append(line['max_min'])
291
- line_confs.append(line['conf'])
292
- line_polygons.append(line['polygon'])
293
- if line_polygons:
294
- # If one or more lines are connected to a region, line order inside the region is defined
295
- # and the predicted text lines are joined in the same python dict
296
- line_order = self.order_poly.order(line_max_mins)
297
- line_polygons = [line_polygons[i] for i in line_order]
298
- line_confs = [line_confs[i] for i in line_order]
299
- new_region = {'region_coords': region['coords'],
300
- 'region_name': region['name'],
301
- 'lines': line_polygons,
302
- 'line_confs': line_confs,
303
- 'region_conf': region['conf'],
304
- 'img_shape': region['img_shape']}
305
- region_max_mins.append(region['max_min'])
306
- regions_with_rows.append(new_region)
307
- else:
308
- continue
309
- # Creates an ordering of the detected regions based on their polygon coordinates
310
- if self.order_regions:
311
- region_order = self.order_poly.order(region_max_mins)
312
- regions_with_rows = [regions_with_rows[i] for i in region_order]
313
-
314
- return regions_with_rows
315
-
316
- def get_default_region(self, image):
317
- """Function for creating a default region if no regions are detected."""
318
- w, h = image.size
319
- region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]],
320
- 'max_min': [w, 0.0, h, 0.0],
321
- 'class': '0',
322
- 'name': "paragraph",
323
- 'conf': 0.0,
324
- 'id': '0',
325
- 'img_shape': (h, w)}
326
- return [region]
327
-
328
- def get_segmentation(self, image):
329
- """Segment input image into ordered text lines or ordered text regions and text lines."""
330
- line_preds = self.get_line_preds(image)
331
- if line_preds:
332
- # If region detection model is defined, text regions and text lines are detected
333
- region_preds = self.get_region_preds(image)
334
- if not region_preds:
335
- region_preds = self.get_default_region(image)
336
- print(f'No regions detected from image {image}')
337
- lines_with_regions = self.get_line_regions(line_preds, region_preds)
338
- ordered_regions = self.order_regions_lines(lines_with_regions, region_preds)
339
- return ordered_regions
340
  else:
341
- print(f'No text lines detected from image {image}')
342
- return None
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
-
 
1
+ from typing import List, Tuple, Optional, Dict, Any
2
  from shapely.validation import make_valid
3
  from shapely.geometry import Polygon
4
+ from rfdetr import RFDETRSegPreview
5
+ from collections import defaultdict
6
  import numpy as np
7
+ import cv2
8
  import os
9
 
10
+ from image_processing import (
11
+ load_with_torchvision,
12
+ preprocess_resize_torch_transform,
13
+ upscale_bbox,
14
+ upscale_mask_opencv,
15
+ crop_line
16
+ )
17
+
18
+ from utils import get_default_region, get_line_regions, order_regions_lines
19
 
20
  class SegmentImage:
21
+ """
22
+ Document image segmentation for detecting text regions and lines.
23
+
24
+ Uses an RFDETR segmentation model to detect and extract text regions and lines
25
+ from document images. Includes polygon merging, validation, and ordering.
26
+
27
+ Args:
28
+ model_path: Path to the RFDETR segmentation model weights
29
+ max_size: Maximum dimension (height or width) for image preprocessing (default: 768)
30
+ confidence_threshold: Minimum confidence score for detections (default: 0.15, range: 0-1)
31
+ line_percentage_threshold: Minimum polygon area as fraction of image area for lines
32
+ (default: 7e-05, i.e., 0.007% of image)
33
+ region_percentage_threshold: Minimum polygon area as fraction of image area for regions
34
+ (default: 7e-05, i.e., 0.007% of image)
35
+ line_iou: IoU threshold for merging overlapping line polygons (default: 0.3, range: 0-1)
36
+ region_iou: IoU threshold for merging overlapping region polygons (default: 0.3, range: 0-1)
37
+ line_overlap_threshold: Area overlap ratio threshold for merging lines (default: 0.5, range: 0-1)
38
+ region_overlap_threshold: Area overlap ratio threshold for merging regions (default: 0.5, range: 0-1)
39
+ class_id_region: Class ID constant for identifying regions in segmentation model output
40
+ class_id_line: Class ID constant for identifying lines in segmentation model output
41
+ min_polygon_points: Minimum number of points to form a valid polygon
42
+ """
43
  def __init__(self,
44
+ model_path: str,
45
+ max_size: int = 768,
46
+ confidence_threshold: float = 0.15,
47
+ line_percentage_threshold: float = 7e-05,
48
+ region_percentage_threshold: float = 7e-05,
49
+ line_iou: float = 0.3,
50
+ region_iou: float = 0.3,
51
+ line_overlap_threshold: float = 0.5,
52
+ region_overlap_threshold: float = 0.5,
53
+ class_id_region: int = 1,
54
+ class_id_line: int = 2,
55
+ min_polygon_points: int = 3):
56
+
57
+ self.model_path = model_path
58
+ self.max_size = max_size
59
+ self.confidence_threshold = confidence_threshold
60
+ self.line_percentage_threshold = line_percentage_threshold
61
+ self.region_percentage_threshold = region_percentage_threshold
 
 
 
 
 
62
  self.line_iou = line_iou
 
63
  self.region_iou = region_iou
64
+ self.line_overlap_threshold = line_overlap_threshold
65
+ self.region_overlap_threshold = region_overlap_threshold
66
+ self.class_id_region = class_id_region
67
+ self.class_id_line = class_id_line
68
+ self.min_polygon_points = min_polygon_points
69
+
70
+ # Validate model path
71
+ if not os.path.exists(self.model_path):
72
+ raise FileNotFoundError(f"Model path does not exist: {self.model_path}")
73
+
74
+ self.init_model()
75
+
76
+ def init_model(self) -> None:
77
+ """
78
+ Load and optimize an RFDETR segmentation model for inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ Raises:
81
+ Exception: If model initialization fails
82
+ """
83
  try:
84
+ self.model = RFDETRSegPreview(pretrain_weights=self.model_path)
85
+ self.model.optimize_for_inference()
86
+ print(f"✓ Segmentation model initialized successfully")
 
87
  except Exception as e:
88
+ raise RuntimeError(f'Failed to initialize segmentation model: {e}')
89
 
90
+ def validate_polygon(self, polygon: np.ndarray) -> Optional[Polygon]:
91
+ """
92
+ Test and correct the validity of a polygon using Shapely.
93
+
94
+ Converts numpy array to Shapely Polygon, validates it, and attempts
95
+ to fix invalid geometries using make_valid().
96
+
97
+ Args:
98
+ polygon: Array of polygon coordinates with shape (N, 2)
99
+
100
+ Returns:
101
+ Valid Shapely Polygon object, or None if polygon has fewer than 3 points
102
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if len(polygon) > 2:
104
+ try:
105
+ shapely_polygon = Polygon(polygon)
106
+ if not shapely_polygon.is_valid:
107
+ shapely_polygon = make_valid(shapely_polygon)
108
+ return shapely_polygon
109
+ except Exception as e:
110
+ print(f"Warning: Failed to validate polygon: {e}")
111
+ return None
112
  else:
113
  return None
114
 
115
+ def merge_polygons(self,
116
+ polygons: List[np.ndarray],
117
+ polygon_iou: float,
118
+ overlap_threshold: float) -> Tuple[List[np.ndarray], List[int]]:
119
+ """
120
+ Merge overlapping polygons using connected components (union-find algorithm).
121
+
122
+ Uses IoU (Intersection over Union) and area overlap ratio to determine which
123
+ polygons should be merged. Implements union-find to group connected components
124
+ of overlapping polygons, then merges each component into a single polygon.
125
+
126
+ Args:
127
+ polygons: List of polygon coordinate arrays, each with shape (N, 2)
128
+ polygon_iou: IoU threshold for merging (0-1)
129
+ overlap_threshold: Minimum area overlap ratio for merging (0-1)
130
+
131
+ Returns:
132
+ Tuple of:
133
+ - merged_polygons: List of merged polygon coordinate arrays
134
+ - polygon_mapping: List mapping each input polygon index to its output
135
+ polygon index (-1 if invalid/skipped)
136
+ """
137
+ n = len(polygons)
138
+ if n == 0:
139
+ return [], []
140
+
141
+ # Validate all polygons
142
+ validated = [self.validate_polygon(p) for p in polygons]
143
+
144
+ # Build adjacency graph of overlapping polygons
145
+ parent = list(range(n))
146
+
147
+ def find(x: int) -> int:
148
+ """Find root of element x with path compression."""
149
+ if parent[x] != x:
150
+ parent[x] = find(parent[x])
151
+ return parent[x]
152
+
153
+ def union(x: int, y: int) -> None:
154
+ """Union two sets containing x and y."""
155
+ px, py = find(x), find(y)
156
+ if px != py:
157
+ parent[px] = py
158
+
159
+ # Build adjacency graph by checking all pairs for overlap
160
+ for i in range(n):
161
+ poly1 = validated[i]
162
+ if not poly1:
163
+ continue
164
+
165
+ for j in range(i + 1, n):
166
+ poly2 = validated[j]
167
+ if not poly2 or not poly1.intersects(poly2):
168
+ continue
169
+
170
+ # Calculate intersection and union for IoU
171
+ intersection = poly1.intersection(poly2)
172
+ union_geom = poly1.union(poly2)
173
+ iou = intersection.area / union_geom.area if union_geom.area > 0 else 0
174
+
175
+ # Check merge criteria
176
+ should_merge = iou > polygon_iou
177
+
178
+ # If IoU threshold not met, check area overlap ratio
179
+ if not should_merge and overlap_threshold > 0:
180
+ smaller_area = min(poly1.area, poly2.area)
181
+ overlap_ratio = intersection.area / smaller_area if smaller_area > 0 else 0
182
+ should_merge = overlap_ratio > overlap_threshold
183
+
184
+ # Merge polygons by updating union-find structure
185
+ if should_merge:
186
+ union(i, j)
187
+
188
+ # Group polygons by their connected component
189
+ components = defaultdict(list)
190
+ for i in range(n):
191
+ if validated[i]:
192
+ root = find(i)
193
+ components[root].append(i)
194
+
195
+ # Merge each connected component
196
+ merged_polygons = []
197
+ polygon_mapping = [-1] * n # -1 indicates invalid/unmapped polygon
198
+
199
+ for root, indices in components.items():
200
+ output_idx = len(merged_polygons)
201
+
202
+ if len(indices) == 1:
203
+ # Single polygon, no merging needed
204
+ idx = indices[0]
205
+ merged_polygons.append(polygons[idx])
206
+ polygon_mapping[idx] = output_idx
207
+
208
+ else:
209
+ # Merge all polygons in this component using Shapely union
210
+ merged = validated[indices[0]]
211
+ for idx in indices[1:]:
212
+ merged = merged.union(validated[idx])
213
+
214
+ # Extract polygon coordinates from merged geometry
215
+ if merged.geom_type == 'Polygon':
216
+ # Single polygon result
217
+ merged_polygons.append(
218
+ np.array(merged.exterior.coords).astype(np.int32)
219
+ )
220
+ for idx in indices:
221
+ polygon_mapping[idx] = output_idx
222
+
223
+ elif merged.geom_type in ['MultiPolygon', 'GeometryCollection']:
224
+ # Multiple polygons resulted from merge (e.g., touching at single point)
225
+ for geom in merged.geoms:
226
  if geom.geom_type == 'Polygon':
227
+ merged_polygons.append(
228
+ np.array(geom.exterior.coords).astype(np.int32)
229
+ )
230
+ # Map all source polygons to first output polygon
231
+ for idx in indices:
232
+ polygon_mapping[idx] = output_idx
233
+
234
+ return merged_polygons, polygon_mapping
235
+
236
+ def calculate_polygon_area(self, vertices: np.ndarray) -> float:
237
+ """
238
+ Calculate polygon area using the Shoelace formula (surveyor's formula).
239
+
240
+ Computes area using coordinate cross products. Works for simple polygons
241
+ (non-self-intersecting) regardless of vertex ordering.
242
+
243
+ Args:
244
+ vertices: Array of polygon coordinates with shape (N, 2)
245
+
246
+ Returns:
247
+ Area of the polygon in square pixels
248
+ """
249
+ x = vertices[:, 0]
250
+ y = vertices[:, 1]
251
+ # Shoelace formula implementation using array operations
252
+ area = 0.5 * np.abs(np.sum(x[:-1] * y[1:]) - np.sum(y[:-1] * x[1:]) + x[-1] * y[0] - y[-1] * x[0])
253
+ return area
254
+
255
+ def mask_to_polygon_cv2(self,
256
+ mask: np.ndarray,
257
+ original_shape: Tuple[int, int]) -> Tuple[List[np.ndarray], np.ndarray]:
258
+ """
259
+ Convert binary segmentation mask to polygon coordinates using OpenCV contours.
260
 
261
+ Extracts contours from mask, converts them to polygons, and scales coordinates
262
+ back to original image dimensions. Also calculates area percentages for filtering.
263
+
264
+ Args:
265
+ mask: Binary mask as numpy array (bool or uint8, 0-255)
266
+ original_shape: Tuple of (height, width) of original image
267
+
268
+ Returns:
269
+ Tuple of:
270
+ - scaled_polygons: List of polygon coordinate arrays scaled to original size
271
+ - area_percentages: Array of polygon areas as fraction of mask size
272
+ """
273
+ # Ensure mask is uint8
274
+ if mask.dtype == bool:
275
+ mask_uint8 = mask.astype(np.uint8) * 255
 
 
 
 
 
 
 
 
 
 
 
 
276
  else:
277
+ mask_uint8 = mask.astype(np.uint8)
278
+
279
+ # Find external contours (only outer boundaries)
280
+ contours, _ = cv2.findContours(
281
+ mask_uint8,
282
+ cv2.RETR_EXTERNAL,
283
+ cv2.CHAIN_APPROX_SIMPLE
284
+ )
285
+
286
+ # Convert contours to polygons (filter out degenerate contours)
287
+ polygons = [
288
+ contour.squeeze()
289
+ for contour in contours
290
+ if len(contour) >= self.min_polygon_points
291
+ ]
292
+
293
+ # Calculate scaling factors from mask to original image
294
+ orig_height, orig_width = original_shape
295
+ mask_height, mask_width = mask.shape[:2]
296
+ scale_x = orig_width / mask_width
297
+ scale_y = orig_height / mask_height
298
+
299
+ # Scale polygons and calculate areas
300
+ scaled_polygons = []
301
+ area_percentages = []
302
+ mask_area = mask_height * mask_width
303
+
304
+ for poly in polygons:
305
+ # Calculate area on mask coordinates (before scaling)
306
+ area = self.calculate_polygon_area(
307
+ poly if len(poly.shape) > 1 else poly.reshape(1, -1)
308
+ )
309
+ area_percentage = area / mask_area if mask_area > 0 else 0
310
+ area_percentages.append(area_percentage)
311
+
312
+ # Scale polygon coordinates to original image size
313
+ if len(poly.shape) == 1: # Single point edge case
314
+ scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int)
315
+ else: # Normal case with multiple points
316
+ scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int)
317
+
318
+ scaled_polygons.append(scaled_poly)
319
+
320
+ return scaled_polygons, np.array(area_percentages)
321
+
322
+
323
+ def process_polygons(self,
324
+ poly_masks: np.ndarray,
325
+ image_shape: Tuple[int, int],
326
+ percentage_threshold: float,
327
+ overlap_threshold: float,
328
+ iou_threshold: float) -> Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]:
329
+ """
330
+ Extract polygons from segmentation masks, filter by area, and merge overlapping ones.
331
+
332
+ Converts masks to polygons, filters out small detections based on area percentage,
333
+ and merges overlapping polygons based on IoU and overlap criteria.
334
+
335
+ Args:
336
+ poly_masks: Array of binary segmentation masks from model
337
+ image_shape: Tuple of (height, width) of original image
338
+ percentage_threshold: Minimum polygon area as fraction of image
339
+ overlap_threshold: Minimum overlap ratio for merging polygons
340
+ iou_threshold: Minimum IoU for merging polygons
341
+
342
+ Returns:
343
+ Tuple of:
344
+ - merged_polygons: List of polygon coordinate arrays
345
+ - merged_max_mins: List of bounding boxes as (xmin, ymin, xmax, ymax) tuples
346
+ """
347
+ all_polygons = []
348
+ all_area_percentages = []
349
+
350
+ # Extract polygons from all masks
351
+ for mask in poly_masks:
352
+ polygons, area_percentages = self.mask_to_polygon_cv2(
353
+ mask=mask,
354
+ original_shape=image_shape
355
+ )
356
+ all_polygons.extend(polygons)
357
+ all_area_percentages.extend(area_percentages)
358
+
359
+ all_area_percentages = np.array(all_area_percentages)
360
+
361
+ # Filter polygons by minimum area threshold
362
+ if len(all_area_percentages) == 0:
363
+ return [], []
364
+
365
+ valid_indices = np.where(all_area_percentages > percentage_threshold)[0]
366
+ filtered_polygons = [all_polygons[idx] for idx in valid_indices]
367
+
368
+ if not filtered_polygons:
369
+ return [], []
370
+
371
+ # Merge overlapping polygons
372
+ merged_polygons, _ = self.merge_polygons(
373
+ filtered_polygons,
374
+ iou_threshold,
375
+ overlap_threshold
376
+ )
377
+
378
+ # Calculate bounding boxes for merged polygons
379
+ merged_max_mins = []
380
+ for poly in merged_polygons:
381
+ if len(poly) > 0:
382
+ xmax, ymax = np.max(poly, axis=0)
383
+ xmin, ymin = np.min(poly, axis=0)
384
+ merged_max_mins.append((xmin, ymin, xmax, ymax))
385
+
386
+ return merged_polygons, merged_max_mins
387
+
388
+ def get_segmentation(self, image) -> Optional[List[Dict[str, Any]]]:
389
+ """
390
+ Detect and extract ordered text lines and regions from a document image.
391
+
392
+ Runs the segmentation model on the image, extracts line and region polygons,
393
+ merges overlapping detections, associates lines with regions, and orders them
394
+ for reading sequence.
395
+
396
+ Args:
397
+ image: PIL Image object in any mode (will be converted to RGB)
398
+
399
+ Returns:
400
+ List of ordered line dictionaries with region associations, or None if
401
+ no lines were detected. Each line dict contains coordinates, region ID,
402
+ and other metadata.
403
+ """
404
+ image_shape = (image.shape[0], image.shape[1])
405
+
406
+ # Preprocess image (resize for model input)
407
+ preprocessed_image = preprocess_resize_torch_transform(
408
+ image,
409
+ max_size=self.max_size
410
+ )
411
+
412
+ # Run segmentation model
413
+ try:
414
+ detections = self.model.predict(
415
+ preprocessed_image,
416
+ threshold=self.confidence_threshold
417
+ )
418
+ except Exception as e:
419
+ print(f"Error during segmentation prediction: {e}")
420
  return None
421
 
422
+ # Separate line and region masks by class ID
423
+ line_mask = detections.mask[detections.class_id == self.class_id_line]
424
+ region_mask = detections.mask[detections.class_id == self.class_id_region]
425
 
426
+ # Process line polygons
427
+ merged_line_polygons, merged_line_max_mins = self.process_polygons(
428
+ line_mask,
429
+ image_shape,
430
+ self.line_percentage_threshold,
431
+ self.line_overlap_threshold,
432
+ self.line_iou
433
+ )
434
+
435
+ # Process region polygons
436
+ merged_region_polygons, merged_region_max_mins = self.process_polygons(
437
+ region_mask,
438
+ image_shape,
439
+ self.region_percentage_threshold,
440
+ self.region_overlap_threshold,
441
+ self.region_iou
442
+ )
443
+
444
+ # If no lines detected, return None
445
+ if not merged_line_polygons:
446
+ print('No text lines detected from image.')
447
  return None
448
 
449
+ # Prepare line predictions dictionary
450
+ line_preds = {
451
+ 'coords': merged_line_polygons,
452
+ 'max_min': merged_line_max_mins
453
+ }
454
+
455
+ # Prepare region predictions (or use default if none detected)
456
+ if merged_region_polygons:
457
+ region_preds = []
458
+ for num, (region_polygon, region_max_min) in enumerate(
459
+ zip(merged_region_polygons, merged_region_max_mins)
460
+ ):
461
+ region_preds.append({
462
+ 'coords': region_polygon,
463
+ 'id': str(num),
464
+ 'max_min': region_max_min,
465
+ 'name': 'paragraph',
466
+ 'img_shape': image_shape
467
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  else:
469
+ # No regions detected, create default region covering entire image
470
+ region_preds = get_default_region(image_shape=image_shape)
471
+
472
+ # Associate lines with their containing regions
473
+ lines_connected_to_regions = get_line_regions(
474
+ lines=line_preds,
475
+ regions=region_preds
476
+ )
477
+
478
+ # Order lines within regions for proper reading sequence
479
+ ordered_lines = order_regions_lines(
480
+ lines=lines_connected_to_regions,
481
+ regions=region_preds
482
+ )
483
 
484
+ return ordered_lines