Spaces:
Sleeping
Sleeping
| from huggingface_hub import hf_hub_download | |
| from shapely.validation import make_valid | |
| from shapely.geometry import Polygon | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| from reading_order import OrderPolygons | |
| class SegmentImage: | |
| """Class for segmenting document image regions and text lines.""" | |
| def __init__(self, | |
| line_model_path, | |
| device, | |
| line_iou=0.5, | |
| region_iou=0.5, | |
| line_overlap=0.5, | |
| line_nms_iou=0.7, | |
| region_nms_iou=0.3, | |
| line_conf_threshold=0.25, | |
| region_conf_threshold=0.25, | |
| region_model_path=None, | |
| order_regions=True, | |
| region_half_precision=False, | |
| line_half_precision=False): | |
| # Path to text line detection model | |
| self.line_model_path = line_model_path | |
| # Path to text region detection model | |
| self.region_model_path = region_model_path | |
| # Defines the IoU threshold used in the non-maximum suppression (NMS) process to | |
| # determine which prediction boxes should be suppressed or discarded based on their overlap with other boxes | |
| self.line_nms_iou = line_nms_iou | |
| self.region_nms_iou = region_nms_iou | |
| # Defines the IoU threshold for text lines | |
| self.line_iou = line_iou | |
| # Defines the IoU threshold for text regions | |
| self.region_iou = region_iou | |
| # Defines the extent of line polygon overlap used for merging the polygons | |
| self.line_overlap = line_overlap | |
| # Defines confidence threshold for line detection | |
| self.line_conf_threshold = line_conf_threshold | |
| # Defines confidence threshold for region detection | |
| self.region_conf_threshold = region_conf_threshold | |
| # Defines the device to be used ('cpu', gpu '0', gpu '1' etc.) | |
| self.device = device | |
| # Defines whether a reading order is also estimated for the region detections | |
| self.order_regions = order_regions | |
| # Defines whether half precision (FP16) is used by the region and line prediction models | |
| self.region_half_precision = region_half_precision | |
| self.line_half_precision = line_half_precision | |
| self.order_poly = OrderPolygons() | |
| # Initialize segmentation model(s) | |
| self.line_model = self.init_line_model() | |
| if self.region_model_path: | |
| self.region_model = self.init_region_model() | |
| def init_line_model(self): | |
| """Function for initializing the line detection model.""" | |
| try: | |
| # Load the trained line detection model | |
| cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt") | |
| line_model = YOLO(cached_model_path) | |
| return line_model | |
| except Exception as e: | |
| print('Failed to load the line detection model: %s' % e) | |
| def init_region_model(self): | |
| """Function for initializing the region detection model.""" | |
| try: | |
| # Load the trained line detection model | |
| cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt") | |
| region_model = YOLO(cached_model_path) | |
| return region_model | |
| except Exception as e: | |
| print('Failed to load the region detection model: %s' % e) | |
| def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape): | |
| """Function for creating unique id for each detected region.""" | |
| n = min(len(classes), len(coords)) | |
| res = [] | |
| for i in range(n): | |
| # Creates a simple index-based id for each region | |
| region_id = str(i) | |
| # Extracts region name corresponding to the index | |
| region_type = names[classes[i]] | |
| poly_dict = {'coords': coords[i], | |
| 'max_min': max_min[i], | |
| 'class': str(classes[i]), | |
| 'name': region_type, | |
| 'conf': box_confs[i], | |
| 'id': region_id, | |
| 'img_shape': img_shape} | |
| res.append(poly_dict) | |
| return res | |
| def get_max_min(self, polygons): | |
| """Creates an array with the minimum and maximum | |
| x and y values of the input polygons.""" | |
| n_rows = len(polygons) | |
| xy_array = np.zeros([n_rows, 4]) | |
| for i, poly in enumerate(polygons): | |
| x = [point[0] for point in poly] | |
| y = [point[1] for point in poly] | |
| if x: | |
| xy_array[i,0] = max(x) | |
| xy_array[i,1] = min(x) | |
| if y: | |
| xy_array[i,2] = max(y) | |
| xy_array[i,3] = min(y) | |
| return xy_array | |
| def validate_polygon(self, polygon): | |
| """"Function for testing and correcting the validity of polygons.""" | |
| if len(polygon) > 2: | |
| polygon = Polygon(polygon) | |
| if not polygon.is_valid: | |
| polygon = make_valid(polygon) | |
| return polygon | |
| else: | |
| return None | |
| def get_iou(self, poly1, poly2): | |
| """Function for calculating Intersection over Union (IoU) values.""" | |
| # If the polygons don't intersect, IoU is 0 | |
| iou = 0 | |
| poly1 = self.validate_polygon(poly1) | |
| poly2 = self.validate_polygon(poly2) | |
| if poly1 and poly2: | |
| if poly1.intersects(poly2): | |
| # Calculates intersection of the 2 polygons | |
| intersect = poly1.intersection(poly2).area | |
| # Calculates union of the 2 polygons | |
| uni = poly1.union(poly2) | |
| # Calculates intersection over union | |
| iou = intersect / uni.area | |
| return iou | |
| def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None): | |
| """Merges polygons that have an IoU value | |
| above the given threshold.""" | |
| new_polygons = [] | |
| dropped = set() | |
| # Loops over all input polygons and merges them if the | |
| # IoU value is over the given threshold | |
| for i in range(0, len(polygons)): | |
| poly1 = self.validate_polygon(polygons[i]) | |
| merged = None | |
| for j in range(i+1, len(polygons)): | |
| poly2 = self.validate_polygon(polygons[j]) | |
| if poly1 and poly2: | |
| if poly1.intersects(poly2): | |
| overlap = False | |
| intersect = poly1.intersection(poly2) | |
| uni = poly1.union(poly2) | |
| # Calculates intersection over union | |
| iou = intersect.area / uni.area | |
| if overlap_threshold: | |
| overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area)) | |
| if (iou > iou_threshold) or overlap: | |
| if merged: | |
| # If there are multiple overlapping polygons | |
| # with IoU over the threshold, they are all merged together | |
| merged = uni.union(merged) | |
| dropped.add(j) | |
| else: | |
| merged = uni | |
| # Polygons that are merged together are dropped from | |
| # the list | |
| dropped.add(i) | |
| dropped.add(j) | |
| if merged: | |
| if merged.geom_type in ['GeometryCollection','MultiPolygon']: | |
| for geom in merged.geoms: | |
| if geom.geom_type == 'Polygon': | |
| new_polygons.append(list(geom.exterior.coords)) | |
| elif merged.geom_type == 'Polygon': | |
| new_polygons.append(list(merged.exterior.coords)) | |
| res = [i for j, i in enumerate(polygons) if j not in dropped] | |
| res += new_polygons | |
| return res | |
| def get_region_preds(self, img): | |
| """Function for predicting text region coordinates.""" | |
| results = self.region_model.predict(source=img, | |
| device=self.device, | |
| conf=self.region_conf_threshold, | |
| half=bool(self.region_half_precision), | |
| iou=self.region_nms_iou) | |
| results = results[0].cpu() | |
| if results.masks: | |
| # Extracts detected region polygons | |
| coords = results.masks.xy | |
| # Merge overlapping polygons | |
| coords = self.merge_polygons(coords, self.region_iou) | |
| # Maximum and minimum x and y axis values for detected polygons used for ordering the polygons | |
| max_min = self.get_max_min(coords).tolist() | |
| # Gets a list of the predicted class labels for detected regions | |
| classes = results.boxes.cls.tolist() | |
| # A dictionary with class ids as keys and class names as values | |
| names = results.names | |
| # Confidence values for detections | |
| box_confs = results.boxes.conf.tolist() | |
| # A tuple containing the shape of the original image | |
| img_shape = results.orig_shape | |
| res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape) | |
| return res | |
| else: | |
| return None | |
| def get_line_preds(self, img): | |
| """Function for predicting text line coordinates.""" | |
| results = self.line_model.predict(source=img, | |
| device=self.device, | |
| conf=self.line_conf_threshold, | |
| half=bool(self.line_half_precision), | |
| iou=self.line_nms_iou) | |
| results = results[0].cpu() | |
| if results.masks: | |
| # Detected text line polygons | |
| coords = results.masks.xy | |
| # Merge overlapping polygons | |
| coords = self.merge_polygons(coords, self.line_iou, self.line_overlap) | |
| # Maximum and minimum x and y axis values for detected polygons | |
| max_min = self.get_max_min(coords).tolist() | |
| # Confidence values for detections | |
| box_confs = results.boxes.conf.tolist() | |
| res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs} | |
| return res_dict | |
| else: | |
| return None | |
| def get_dist(self, line_polygon, regions): | |
| """Function for finding the closest region to the text line.""" | |
| dist, reg_id = 1000000, None | |
| line_polygon = self.validate_polygon(line_polygon) | |
| if line_polygon: | |
| for region in regions: | |
| # Calculates dictance between line and regions polygons | |
| region_polygon = self.validate_polygon(region['coords']) | |
| if region_polygon: | |
| line_reg_dist = line_polygon.distance(region_polygon) | |
| if line_reg_dist < dist: | |
| dist = line_reg_dist | |
| reg_id = region['id'] | |
| return reg_id | |
| def get_line_regions(self, lines, regions): | |
| """Function for connecting each text line to one region.""" | |
| lines_list = [] | |
| for i in range(len(lines['coords'])): | |
| iou, reg_id, conf = 0, '', 0.0 | |
| max_min = [0.0, 0.0, 0.0, 0.0] | |
| polygon = lines['coords'][i] | |
| for region in regions: | |
| line_reg_iou = self.get_iou(polygon, region['coords']) | |
| if line_reg_iou > iou: | |
| iou = line_reg_iou | |
| reg_id = region['id'] | |
| # If line polygon does not intersect with any region, a distance metric is used for defining | |
| # the region that the line belongs to | |
| if iou == 0: | |
| reg_id = self.get_dist(polygon, regions) | |
| if (len(lines['max_min']) - 1) >= i: | |
| max_min = lines['max_min'][i] | |
| if (len(lines['confs']) - 1) >= i: | |
| conf = lines['confs'][i] | |
| new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf} | |
| lines_list.append(new_line) | |
| return lines_list | |
| def order_regions_lines(self, lines, regions): | |
| """Function for ordering line predictions inside each region.""" | |
| regions_with_rows = [] | |
| region_max_mins = [] | |
| for i, region in enumerate(regions): | |
| line_max_mins = [] | |
| line_confs = [] | |
| line_polygons = [] | |
| for line in lines: | |
| if line['reg_id'] == region['id']: | |
| line_max_mins.append(line['max_min']) | |
| line_confs.append(line['conf']) | |
| line_polygons.append(line['polygon']) | |
| if line_polygons: | |
| # If one or more lines are connected to a region, line order inside the region is defined | |
| # and the predicted text lines are joined in the same python dict | |
| line_order = self.order_poly.order(line_max_mins) | |
| line_polygons = [line_polygons[i] for i in line_order] | |
| line_confs = [line_confs[i] for i in line_order] | |
| new_region = {'region_coords': region['coords'], | |
| 'region_name': region['name'], | |
| 'lines': line_polygons, | |
| 'line_confs': line_confs, | |
| 'region_conf': region['conf'], | |
| 'img_shape': region['img_shape']} | |
| region_max_mins.append(region['max_min']) | |
| regions_with_rows.append(new_region) | |
| else: | |
| continue | |
| # Creates an ordering of the detected regions based on their polygon coordinates | |
| if self.order_regions: | |
| region_order = self.order_poly.order(region_max_mins) | |
| regions_with_rows = [regions_with_rows[i] for i in region_order] | |
| return regions_with_rows | |
| def get_default_region(self, image): | |
| """Function for creating a default region if no regions are detected.""" | |
| w, h = image.size | |
| region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]], | |
| 'max_min': [w, 0.0, h, 0.0], | |
| 'class': '0', | |
| 'name': "paragraph", | |
| 'conf': 0.0, | |
| 'id': '0', | |
| 'img_shape': (h, w)} | |
| return [region] | |
| def get_segmentation(self, image): | |
| """Segment input image into ordered text lines or ordered text regions and text lines.""" | |
| line_preds = self.get_line_preds(image) | |
| if line_preds: | |
| # If region detection model is defined, text regions and text lines are detected | |
| region_preds = self.get_region_preds(image) | |
| if not region_preds: | |
| region_preds = self.get_default_region(image) | |
| print(f'No regions detected from image {image}') | |
| lines_with_regions = self.get_line_regions(line_preds, region_preds) | |
| ordered_regions = self.order_regions_lines(lines_with_regions, region_preds) | |
| return ordered_regions | |
| else: | |
| print(f'No text lines detected from image {image}') | |
| return None | |