Spaces:
Runtime error
Runtime error
| from PIL import Image, ImageDraw, ImageFont | |
| from skimage.measure import label, regionprops | |
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| from tensorflow.keras.preprocessing.image import array_to_img | |
| import json | |
| import os | |
| from transformers import AutoModel | |
| from transformers import TFSegformerForSemanticSegmentation | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import matplotlib.font_manager as fm | |
| from sklearn.cluster import KMeans | |
| from skimage import color | |
| import io | |
| import pandas as pd | |
| # Set the font to support Chinese characters | |
| #font_path = 'simhei.ttf' | |
| #font_prop = fm.FontProperties(fname=font_path) | |
| #matplotlib.rcParams['font.family'] = font_prop.get_name() | |
| #matplotlib.rcParams['font.family'] = 'Droid Sans Fallback' | |
| id2color= {1: [209, 35, 69], | |
| 2: [216, 208, 246], | |
| 3: [172, 196, 170], | |
| 4: [178, 80, 80], | |
| 6: [89, 89, 89], | |
| 7: [160, 146, 229], | |
| 8: [18, 17, 20], | |
| 10: [190, 209, 189], | |
| 13: [37, 12, 156], | |
| 15: [250, 50, 83], | |
| 16: [61, 245, 61], | |
| 17: [230, 203, 104], | |
| 18: [125, 104, 227], | |
| 19: [228, 225, 249], | |
| 20: [51, 221, 255], | |
| 21: [95, 95, 95], | |
| 23: [156, 239, 255], | |
| 24: [153, 102, 51], | |
| 26: [0, 0, 226], | |
| 27: [254, 242, 208], | |
| 29: [89, 134, 179], | |
| 32: [255, 0, 204], | |
| 33: [170, 240, 209], | |
| 34: [140, 120, 240], | |
| 35: [118, 255, 166], | |
| 36: [250, 250, 55], | |
| 37: [243, 232, 208], | |
| 38: [1, 118, 141], | |
| 39: [243, 241, 255], | |
| 41: [158, 108, 4], | |
| 43: [132, 0, 0], | |
| 44: [245, 147, 49], | |
| 46: [240, 120, 240], | |
| 47: [149, 83, 203], | |
| 48: [52, 209, 183], | |
| 49: [200, 101, 0], | |
| 50: [65, 112, 192], | |
| 52: [255, 204, 51], | |
| 53: [36, 179, 83], | |
| 56: [90, 98, 89], | |
| 57: [255, 191, 0], | |
| 58: [204, 153, 51], | |
| 59: [31, 73, 125], | |
| 60: [155, 149, 205], | |
| 61: [154, 150, 169], | |
| 62: [128, 128, 128], | |
| 63: [163, 160, 172], | |
| 64: [255, 106, 77], | |
| 65: [115, 51, 128], | |
| 0: [10, 9, 10]} | |
| id2label= {1: '动物皮', | |
| 2: '骨/牙/角', | |
| 3: '砖块', | |
| 4: '纸板/纸', | |
| 6: '天花板瓦片', | |
| 7: '瓷', | |
| 8: '黑板', | |
| 10: '混凝土', | |
| 13: '织物/布/地毯', | |
| 15: '火', | |
| 16: '树叶', | |
| 17: '食物', | |
| 18: '毛皮', | |
| 19: '宝石/石英', | |
| 20: '玻璃', | |
| 21: '毛发', | |
| 23: '冰', | |
| 24: '皮革', | |
| 26: '金属', | |
| 27: '镜子', | |
| 29: '油漆/抹灰/石膏', | |
| 32: '照片/绘画/布面招牌', | |
| 33: '透明塑料', | |
| 34: '非透明塑料', | |
| 35: '橡胶/乳胶', | |
| 36: '沙', | |
| 37: '皮肤/嘴唇', | |
| 38: '天空', | |
| 39: '雪', | |
| 41: '土壤/泥土', | |
| 43: '天然石材', | |
| 44: '抛光石材', | |
| 46: '片状地砖/石地砖/瓷地砖', | |
| 47: '壁纸', | |
| 48: '水', | |
| 49: '蜡', | |
| 50: '白板', | |
| 52: '木材', | |
| 53: '树木', | |
| 56: '沥青', | |
| 57: '珐琅/琉璃', | |
| 58: '夯土', | |
| 59: '塑钢复合装饰板', | |
| 60: '水泥', | |
| 61: '陶', | |
| 62: '屋顶防水卷材', | |
| 63: '金属网窗(远景)', | |
| 64: '砖雕', | |
| 65: '纱窗', | |
| 0: '背景/未知'} | |
| id2material={1: 'Animal skin', | |
| 2: 'Bone/teeth/horn', | |
| 3: 'Brickwork', | |
| 4: 'Cardboard/Paper', | |
| 6: 'Ceiling tile', | |
| 7: 'Ceramic', | |
| 8: 'Chalkboard/blackboard', | |
| 10: 'Concrete', | |
| 13: 'Fabric/cloth', | |
| 15: 'Fire', | |
| 16: 'Foliage', | |
| 17: 'Food', | |
| 18: 'Fur', | |
| 19: 'Gemstone/quartz', | |
| 20: 'Glass', | |
| 21: 'Hair', | |
| 23: 'Ice', | |
| 24: 'Leather', | |
| 26: 'Metal', | |
| 27: 'Mirror', | |
| 29: 'Paint/plaster', | |
| 32: 'Photograph/painting', | |
| 33: 'Plastic, clear', | |
| 34: 'Plastic, non-clear', | |
| 35: 'Rubber/latex', | |
| 36: 'Sand', | |
| 37: 'Skin/lips', | |
| 38: 'Sky', | |
| 39: 'Snow', | |
| 41: 'Soil/mud', | |
| 43: 'natural stone', | |
| 44: 'polished stone & engineered stone', | |
| 46: 'Tile', | |
| 47: 'Wallpaper', | |
| 48: 'Water', | |
| 49: 'Wax', | |
| 50: 'Whiteboard', | |
| 52: 'Wood', | |
| 53: 'tree', | |
| 56: 'Asphalt', | |
| 57: 'enamel', | |
| 58: 'Rammed earth', | |
| 59: 'composite decorative board', | |
| 60: 'Cement', | |
| 61: 'Pottery', | |
| 62: 'Roofing waterproof material', | |
| 63: 'Metal mesh window (perspective)', | |
| 64: 'carved brick', | |
| 65: 'window screen', | |
| 0: 'background'} | |
| model_save_path ='jinfengxie/BFM_segformer0821' | |
| model = TFSegformerForSemanticSegmentation.from_pretrained(model_save_path) | |
| def predict_and_visualize(image): | |
| #image = Image.open(image_path) | |
| image_np = np.array(image) | |
| height,width,_=image_np.shape | |
| maxhl=max(height,width) | |
| image = tf.convert_to_tensor(image_np, dtype=tf.float32) | |
| if maxhl>1500: | |
| if maxhl==height: | |
| image=tf.image.resize(image,(1500,int(1500*width/height))) | |
| if maxhl==width: | |
| image=tf.image.resize(image,(int(1500*height/width),1500)) | |
| #image = tf.image.resize_with_pad(image, 1500, 1500) | |
| image = tf.cast(image, tf.float32) / 255.0 | |
| image = tf.transpose(image, perm=[2, 0, 1]) | |
| images= tf.expand_dims(image, axis=0) | |
| # 进行预测 | |
| preds = model.predict(images).logits | |
| pred_mask = tf.argmax(preds, axis=1) | |
| pred_mask = tf.expand_dims(pred_mask, axis=-1) | |
| pred_mask = pred_mask[0] # 取出批处理的第一个结果 | |
| pred_mask=tf.image.resize(pred_mask,(height,width),method='nearest') | |
| pred_mask=tf.squeeze(pred_mask) | |
| print(pred_mask.shape) | |
| #pred_mask = pred_mask[:,:,0] .numpy() # 取出批处理的第一个结果 | |
| #print(pred_mask.shape) | |
| unique, counts = np.unique(pred_mask, return_counts=True) | |
| counts_dict = dict(zip(unique, counts)) | |
| # 转换预测掩码为颜色图像 | |
| color_mask = np.zeros((height,width, 3)) | |
| label_positions = {} | |
| for key, value in id2color.items(): | |
| #print("mask shape",mask.shape) | |
| color_mask[pred_mask == key] = np.array(value) # 颜色值需要被标准化到[0,1] | |
| indices = np.where(pred_mask == key) | |
| if indices[0].size > 0: | |
| # 计算标签的位置为当前类别像素的中心点 | |
| label_positions[key] = (np.mean(indices[1]), np.mean(indices[0])) | |
| color_mask = color_mask.astype(np.uint8) | |
| result_image = Image.fromarray(color_mask) | |
| draw = ImageDraw.Draw(result_image) | |
| font = ImageFont.truetype("arial.ttf", int(height/30)) # 尝试加载Arial字体,大小为12 | |
| for key, position in label_positions.items(): | |
| if key in id2label: | |
| # 绘制文本,您可能需要调整文本位置和字体大小 | |
| material=id2material[key] | |
| draw.text((position[0], position[1]), str(material), font=font, fill='white') | |
| return pred_mask,result_image,counts_dict | |
| def ext_colors(image_path,mask,n_clusters=4): | |
| #image = Image.open(image_path) | |
| # 将图像和掩码转换为numpy数组 | |
| image_np = np.array(image_path) | |
| mask_np = np.array(mask) | |
| # 获取掩码中的唯一类别 | |
| unique_classes = np.unique(mask_np) | |
| # 为每个类别提取颜色 | |
| colors_per_class = {} | |
| for cls in unique_classes: | |
| # 提取当前类别的像素点 | |
| indices = np.where(mask_np == cls) | |
| #print(indices) | |
| pixels = image_np[indices] | |
| # 使用K-means聚类来找到主要颜色 | |
| kmeans = KMeans(n_clusters=n_clusters,n_init=10) | |
| kmeans.fit(pixels) | |
| dominant_colors = kmeans.cluster_centers_ | |
| # 将颜色存储为整数值 | |
| dominant_colors = dominant_colors.astype(int) | |
| # 保存颜色 | |
| colors_per_class[cls] = dominant_colors | |
| return colors_per_class | |
| def plot_material_color_palette_grid(material_dict, materials_per_row=4): | |
| # Calculate total number of color rows and header rows needed | |
| total_rows = sum((len(colors) + 1) for colors in material_dict.values()) # +1 for the header row per material | |
| num_materials = len(material_dict) | |
| grid_rows = (num_materials + materials_per_row - 1) // materials_per_row | |
| total_grid_rows = 0 | |
| for i in range(grid_rows): | |
| row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row] | |
| row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1 | |
| total_grid_rows += row_height | |
| # Set dimensions and spacing | |
| block_width = 1 | |
| block_height = 0.5 | |
| text_gap = 0.2 | |
| row_gap = 0.2 | |
| column_gap = 1.5 # Gap between material columns within the same row | |
| # Calculate figure width and height dynamically | |
| fig_width = materials_per_row * (block_width + text_gap + column_gap) | |
| fig_height = total_grid_rows * (block_height + row_gap) | |
| # Create a figure and a set of subplots | |
| fig, ax = plt.subplots(figsize=(fig_width, fig_height)) | |
| # Set the title of the figure | |
| #ax.set_title('Material Color Palette Grid') | |
| # Remove axes | |
| ax.axis('off') | |
| # Reverse the Y-axis to top-align the origin | |
| ax.invert_yaxis() | |
| current_row = 0 # Tracker for the current row position in the grid | |
| for i in range(grid_rows): | |
| row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row] | |
| max_row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1 | |
| for j, material in enumerate(row_materials): | |
| if material not in material_dict: | |
| continue | |
| colors = material_dict[material] | |
| # Add a header for each material class | |
| ax.text(j * (block_width + text_gap + column_gap), current_row * (block_height + row_gap)+0.5, | |
| material, va='center', fontsize=12, fontweight='bold', ha='left') | |
| material_row_start = current_row | |
| for k, color in enumerate(colors): | |
| # Normalize the RGB values to [0, 1] for Matplotlib | |
| normalized_color = np.array(color) / 255.0 | |
| y_pos = (material_row_start + 1 + k) * (block_height + row_gap) | |
| # Draw a rectangle for each color | |
| ax.add_patch(plt.Rectangle((j * (block_width + text_gap + column_gap), y_pos), | |
| block_width, block_height, color=normalized_color)) | |
| # Annotate the RGB values to the right of each color block | |
| ax.text(j * (block_width + text_gap + column_gap) + block_width + text_gap, y_pos + block_height / 2, | |
| str(color), va='center', fontsize=10) | |
| current_row += max_row_height | |
| # Adjust plot limits | |
| ax.set_xlim(0, fig_width) | |
| ax.set_ylim(current_row * (block_height + row_gap), 0) | |
| # 保存到内存,而不是显示图像 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close() | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| # 将matplotlib图转换为图像 | |
| def plt_to_image(): | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png',dpi=300) | |
| plt.close() | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| def calculate_slice_statistics(one_mask, slice_size=256): | |
| """计算每个切片的材质占比""" | |
| num_rows, num_cols = one_mask.shape[0] // slice_size, one_mask.shape[1] // slice_size | |
| slice_stats = {} | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| slice_mask = one_mask[i*slice_size:(i+1)*slice_size, j*slice_size:(j+1)*slice_size] | |
| unique, counts = np.unique(slice_mask, return_counts=True) | |
| total_pixels = counts.sum() | |
| slice_stats[(i, j)] = {k: v / total_pixels for k, v in zip(unique, counts)} | |
| return slice_stats | |
| def find_top_slices(slice_stats, exclusion_list, min_percent=0.7, min_slices=1, top_k=3): | |
| """找出每个类材质占比最高的前三个切片,加入新的筛选条件""" | |
| from collections import defaultdict | |
| import heapq | |
| top_slices = defaultdict(list) | |
| for slice_pos, stats in slice_stats.items(): | |
| for material_id, percent in stats.items(): | |
| # 第一个判断:材质是否在排除列表中 | |
| if material_id in exclusion_list: | |
| continue | |
| # 第二个判断:材质占比是否至少为70% | |
| if percent < min_percent: | |
| continue | |
| # 将符合条件的切片添加到堆中 | |
| if len(top_slices[material_id]) < top_k: | |
| heapq.heappush(top_slices[material_id], (percent, slice_pos)) | |
| else: | |
| heapq.heappushpop(top_slices[material_id], (percent, slice_pos)) | |
| # 过滤出符合第三个条件的材质 | |
| valid_top_slices = {} | |
| for material_id, slices in top_slices.items(): | |
| if len(slices) > min_slices: # 至少有超过一个切片 | |
| valid_top_slices[material_id] = sorted(slices, reverse=True) | |
| return valid_top_slices | |
| def extract_and_visualize_top_slices(image, top_slices, slice_size=256): | |
| fig, axs = plt.subplots(nrows=len(top_slices), ncols=3, figsize=(15, 5 * len(top_slices))) | |
| image=Image.fromarray(image) | |
| if len(top_slices) == 1: | |
| axs = [axs] | |
| for idx, (material_id, slices) in enumerate(top_slices.items()): | |
| for col, (_, pos) in enumerate(slices): | |
| i, j = pos | |
| img_slice = image.crop((j * slice_size, i * slice_size, (j + 1) * slice_size, (i + 1) * slice_size)) | |
| axs[idx][col].imshow(img_slice) | |
| axs[idx][col].set_title(f'Material {id2material[material_id]} - Slice {pos}') | |
| axs[idx][col].axis('off') | |
| plt.tight_layout() | |
| # 保存到内存,而不是显示图像 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close() | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| # main program | |
| def process_image(image_path): | |
| #image = Image.open(image_path) | |
| one_mask,color_mask, counts_dict = predict_and_visualize(image_path) | |
| colors_per_class=ext_colors(image_path,one_mask,n_clusters=4) | |
| colors_per_label = {id2material[key]: value for key, value in colors_per_class.items()} | |
| # 定义一个列表,包含需要从字典中删除的键 | |
| labels_to_remove = ['Sky', 'background','Glass','tree','water','Plastic, clear'] | |
| # 使用字典推导式删除列表中的键 | |
| colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove} | |
| palette_image = plot_material_color_palette_grid(colors_per_label) | |
| # 将结果转化为图片展示 | |
| plt.figure(figsize=(5, 5)) | |
| plt.imshow(color_mask) | |
| plt.tight_layout() | |
| plt.axis('off') | |
| color_mask_img = plt_to_image() | |
| counts_dict2={id2label[key]: value for key, value in counts_dict.items()} | |
| counts_df = pd.DataFrame(list(counts_dict2.items()), columns=['类别', '计数']) | |
| # 计算总计数 | |
| total_count = counts_df['计数'].sum() | |
| # 计算每个类别的百分比 | |
| counts_df['百分比'] = (counts_df['计数'] / total_count * 100).round(2) | |
| # 重新命名 DataFrame 为 percentage_df 以清楚表达其内容 | |
| percentage_df = counts_df.rename(columns={'计数': 'pixels', '百分比': 'percentage (%)'}) | |
| slice_size = 128 | |
| exclusion_list = [38] | |
| slice_stats = calculate_slice_statistics(one_mask, slice_size=slice_size) | |
| top_slices = find_top_slices(slice_stats, exclusion_list=exclusion_list, min_percent=0.5, min_slices=1) | |
| slice_image=extract_and_visualize_top_slices(image_path, top_slices, slice_size=slice_size) | |
| return color_mask_img, palette_image, slice_image, percentage_df | |
| iface = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(), | |
| outputs=[ | |
| gr.Image(type="pil", label="Color Mask"), | |
| gr.Image(type="pil", label="Color Palette"), | |
| gr.Image(type='pil', label='Texture Slices'), | |
| gr.DataFrame() | |
| ], | |
| title="Building Facade Material Segmentation", | |
| description="Upload an image to segment material masks, and get color palettes." | |
| ) | |
| iface.launch() |