Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import random | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import numpy as np | |
| import shutil | |
| from prismer.utils import create_ade20k_label_colormap | |
| matplotlib.use('agg') | |
| obj_label_map = torch.load('prismer/dataset/detection_features.pt')['labels'] | |
| coco_label_map = torch.load('prismer/dataset/coco_features.pt')['labels'] | |
| ade_color = create_ade20k_label_colormap() | |
| def islight(rgb): | |
| r, g, b = rgb | |
| hsp = np.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b)) | |
| if hsp > 127.5: | |
| return True | |
| else: | |
| return False | |
| def depth_prettify(file_path): | |
| pretty_path = file_path.replace('.png', '_p.png') | |
| if not os.path.exists(pretty_path): | |
| depth = plt.imread(file_path) | |
| plt.imsave(pretty_path, depth, cmap='rainbow') | |
| def obj_detection_prettify(rgb_path, path_name): | |
| pretty_path = path_name.replace('.png', '_p.png') | |
| if not os.path.exists(pretty_path): | |
| rgb = plt.imread(rgb_path) | |
| obj_labels = plt.imread(path_name) | |
| obj_labels_dict = json.load(open(path_name.replace('.png', '.json'))) | |
| plt.imshow(rgb) | |
| if len(np.unique(obj_labels)) == 1: | |
| plt.axis('off') | |
| plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0) | |
| plt.close() | |
| else: | |
| num_objs = np.unique(obj_labels)[:-1].max() | |
| plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8) | |
| cmap = matplotlib.colormaps.get_cmap('terrain') | |
| for i in np.unique(obj_labels)[:-1]: | |
| obj_idx_all = np.where(obj_labels == i) | |
| x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean() | |
| obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]] | |
| obj_name = obj_name.split(',')[0] | |
| if islight([c*255 for c in cmap(i / num_objs)[:3]]): | |
| plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
| else: | |
| plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
| plt.axis('off') | |
| plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
| plt.close() | |
| def seg_prettify(rgb_path, file_name): | |
| pretty_path = file_name.replace('.png', '_p.png') | |
| if not os.path.exists(pretty_path): | |
| rgb = plt.imread(rgb_path) | |
| seg_labels = plt.imread(file_name) | |
| plt.imshow(rgb) | |
| seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16) | |
| for i in np.unique(seg_labels): | |
| seg_map[seg_labels == i] = ade_color[int(i * 255)] | |
| plt.imshow(seg_map, alpha=0.8) | |
| for i in np.unique(seg_labels): | |
| obj_idx_all = np.where(seg_labels == i) | |
| if len(obj_idx_all[0]) > 20: # only plot the label with its number of labelled pixel more than 20 | |
| obj_idx = random.randint(0, len(obj_idx_all[0]) - 1) | |
| x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx] | |
| obj_name = coco_label_map[int(i * 255)] | |
| obj_name = obj_name.split(',')[0] | |
| if islight(seg_map[int(y), int(x)]): | |
| plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
| else: | |
| plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
| plt.axis('off') | |
| plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
| plt.close() | |
| def ocr_detection_prettify(rgb_path, file_name): | |
| pretty_path = file_name.replace('.png', '_p.png') | |
| if not os.path.exists(pretty_path): | |
| if os.path.exists(file_name): | |
| rgb = plt.imread(rgb_path) | |
| ocr_labels = plt.imread(file_name) | |
| ocr_labels_dict = torch.load(file_name.replace('.png', '.pt')) | |
| plt.imshow(rgb) | |
| plt.imshow(ocr_labels, cmap='gray', alpha=0.8) | |
| for i in np.unique(ocr_labels)[:-1]: | |
| text_idx_all = np.where(ocr_labels == i) | |
| x, y = text_idx_all[1].mean(), text_idx_all[0].mean() | |
| text = ocr_labels_dict[int(i * 255)]['text'] | |
| plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
| plt.axis('off') | |
| plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
| plt.close() | |
| else: | |
| rgb = plt.imread(rgb_path) | |
| ocr_labels = np.ones_like(rgb, dtype=np.float32()) | |
| plt.imshow(rgb) | |
| plt.imshow(ocr_labels, cmap='gray', alpha=0.8) | |
| x, y = rgb.shape[1] / 2, rgb.shape[0] / 2 | |
| plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
| plt.axis('off') | |
| os.makedirs(os.path.dirname(file_name), exist_ok=True) | |
| plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
| plt.close() | |
| def label_prettify(rgb_path, expert_paths): | |
| for expert_path in expert_paths: | |
| if 'depth' in expert_path: | |
| depth_prettify(expert_path) | |
| elif 'seg' in expert_path: | |
| seg_prettify(rgb_path, expert_path) | |
| elif 'ocr' in expert_path: | |
| ocr_detection_prettify(rgb_path, expert_path) | |
| elif 'obj' in expert_path: | |
| obj_detection_prettify(rgb_path, expert_path) | |
| else: | |
| pretty_path = expert_path.replace('.png', '_p.png') | |
| if not os.path.exists(pretty_path): | |
| shutil.copyfile(expert_path, pretty_path) | |