jskvrna commited on
Commit
1904e97
·
1 Parent(s): e8f3517

Improves 3D wireframe prediction and extraction

Browse files

Refactors the wireframe prediction pipeline to improve the
accuracy and robustness of 3D wireframe extraction from images.

This involves:
- Incorporating camera intrinsics (K), rotation (R), and
translation (t) matrices for more accurate point projections.
- Implementing depth fitting and sparse depth retrieval for
improved depth estimation.
- Adding a mechanism to filter occluded ground truth vertices
for more accurate visibility determination.
- Refining point cloud segmentation and filtering to extract
relevant features.
- Improve colmap point cloud visualization by colorizing apex/eave points.

Files changed (3) hide show
  1. predict.py +621 -6
  2. train.py +2 -3
  3. visu.py +59 -0
predict.py CHANGED
@@ -1,7 +1,14 @@
1
  import numpy as np
2
  from typing import Tuple, List
3
- from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, create_3d_wireframe_single_image, merge_vertices_3d, prune_not_connected, prune_too_far
4
  from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
 
 
 
 
 
 
 
5
 
6
  def convert_entry_to_human_readable(entry):
7
  out = {}
@@ -15,11 +22,377 @@ def convert_entry_to_human_readable(entry):
15
  out['__key__'] = entry['order_id']
16
  return out
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
19
  """
20
  Predict 3D wireframe from a dataset entry.
21
  """
22
  good_entry = convert_entry_to_human_readable(entry)
 
 
 
 
 
 
 
23
  vert_edge_per_image = {}
24
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
25
  good_entry['depth'],
@@ -29,17 +402,42 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
29
  good_entry['image_ids'],
30
  good_entry['ade'] # Added ade20k segmentation
31
  )):
32
- colmap_rec = good_entry['colmap_binary']
33
  K = np.array(K)
34
  R = np.array(R)
35
  t = np.array(t)
 
36
  # Resize gestalt segmentation to match depth map size
37
  depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
38
  gest_seg = gest.resize(depth_size)
39
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
40
 
 
 
 
 
 
41
  # Get 2D vertices and edges first
42
- vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=10.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Check if we have enough to proceed
45
  if (len(vertices) < 2) or (len(connections) < 1):
@@ -49,19 +447,236 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
49
 
50
  # Call the refactored function to get 3D points
51
  vertices_3d = create_3d_wireframe_single_image(
52
- vertices, connections, depth, colmap_rec, img_id, ade_seg
53
  )
 
54
  # Store original 2D vertices, connections, and computed 3D points
55
  vert_edge_per_image[i] = vertices, connections, vertices_3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Merge vertices from all images
58
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
59
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
60
- all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 4.0)
61
 
62
-
63
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
64
  print (f'Not enough vertices or connections in the 3D vertices')
65
  return empty_solution()
66
 
67
  return all_3d_vertices_clean, connections_3d_clean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from typing import Tuple, List
3
+ from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, get_house_mask, fit_scale_robust_median, get_uv_depth, merge_vertices_3d, prune_not_connected, prune_too_far
4
  from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
5
+ from PIL import Image, ImageDraw
6
+ from visu import save_gestalt_with_proj, draw_crosses_on_image
7
+ import os
8
+ import pycolmap
9
+ from PIL import Image as PImage
10
+ import cv2
11
+ import open3d as o3d
12
 
13
  def convert_entry_to_human_readable(entry):
14
  out = {}
 
22
  out['__key__'] = entry['order_id']
23
  return out
24
 
25
+ def get_gt_vertices_and_edges(entry, i, depth, colmap_rec, k, r, t, img_id, ade_seg):
26
+ depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(depth, colmap_rec, img_id, ade_seg)
27
+
28
+ #old_k, old_r, old_t = k.copy(), r.copy(), t.copy()
29
+
30
+ #k = col_img.camera.calibration_matrix()
31
+ #world_to_cam = np.eye(4)
32
+ #world_to_cam = col_img.cam_from_world.matrix()
33
+ #r = world_to_cam[:3, :3]
34
+ #t = world_to_cam[:3, 3]
35
+
36
+ wf_vertices = np.array(entry['wf_vertices'])
37
+ wf_edges = entry['wf_edges']
38
+
39
+ # Project world frame vertices into the current image
40
+ if wf_vertices.shape[0] > 0:
41
+ # Transform vertices to camera coordinates
42
+ wf_vertices_cam = (r @ wf_vertices.T) + t.reshape(3, 1)
43
+ # Project to image plane
44
+ wf_vertices_img_homogeneous = k @ wf_vertices_cam
45
+ # Convert to 2D pixel coordinates
46
+ wf_vertices_img = wf_vertices_img_homogeneous[:2, :] / wf_vertices_img_homogeneous[2, :]
47
+ projected_gt_vertices_2d = wf_vertices_img.T
48
+
49
+ # Initialize lists to store corresponding depth values from depth maps
50
+ gt_projected_depth_fitted_values = []
51
+ gt_projected_depth_sparse_values = []
52
+
53
+ # Get dimensions of the depth maps for bounds checking
54
+ # Assuming depth_fitted and depth_sparse have the same dimensions
55
+ map_height, map_width = depth_fitted.shape
56
+
57
+ for idx in range(projected_gt_vertices_2d.shape[0]):
58
+ # Get the 2D projected coordinates (x, y)
59
+ px, py = projected_gt_vertices_2d[idx]
60
+
61
+ # Round to nearest integer to use as indices for the depth maps
62
+ ix, iy = int(round(px)), int(round(py))
63
+
64
+ # Get corresponding depth_fitted value
65
+ if 0 <= iy < map_height and 0 <= ix < map_width:
66
+ gt_projected_depth_fitted_values.append(depth_fitted[iy, ix])
67
+ else:
68
+ # Projected point is outside the depth map bounds
69
+ gt_projected_depth_fitted_values.append(np.nan)
70
+
71
+ # Get corresponding depth_sparse value
72
+ if 0 <= iy < map_height and 0 <= ix < map_width: # Assuming same dimensions for depth_sparse
73
+ gt_projected_depth_sparse_values.append(depth_sparse[iy, ix])
74
+ else:
75
+ # Projected point is outside the depth map bounds
76
+ gt_projected_depth_sparse_values.append(np.nan)
77
+
78
+ # Determine occlusion status for each ground truth vertex
79
+ occlusion_status = [] # True if occluded, False otherwise
80
+
81
+ # This block executes only if there were ground truth vertices to begin with.
82
+ # wf_vertices_cam and projected_gt_vertices_2d would have been computed.
83
+ # gt_projected_depth_fitted_values list has one entry per vertex.
84
+ if wf_vertices.shape[0] > 0:
85
+ # These are the Z-coordinates (depths) of the original 3D wf_vertices
86
+ # when transformed into the camera's coordinate system.
87
+ # This is effectively the "true" depth of each vertex from the camera.
88
+ gt_vertices_depth_in_camera_system = wf_vertices_cam[2, :]
89
+
90
+ for idx in range(projected_gt_vertices_2d.shape[0]):
91
+ true_depth_of_vertex = gt_vertices_depth_in_camera_system[idx]
92
+
93
+ # This is the depth value read from the (dense) depth_fitted map
94
+ # at the 2D projection of the current wf_vertex.
95
+ depth_from_fitted_map = gt_projected_depth_fitted_values[idx]
96
+
97
+ # A vertex is considered occluded if its true depth is greater than
98
+ # the depth of the surface recorded in the depth_fitted map.
99
+ # This means the vertex is behind the observed surface.
100
+ # We also check if depth_from_fitted_map is a valid number (not NaN).
101
+ # If depth_from_fitted_map is NaN, it means the vertex projected outside
102
+ # the depth map's bounds, so we don't consider it occluded by the map.
103
+ if np.isnan(true_depth_of_vertex) or true_depth_of_vertex > depth_from_fitted_map + 200.:
104
+ occlusion_status.append(True) # Vertex is occluded
105
+ else:
106
+ occlusion_status.append(False) # Vertex is not occluded or out of map bounds
107
+
108
+ if wf_vertices.shape[0] > 0:
109
+ # Filter vertices based on occlusion status
110
+ visible_vertices_indices = [idx for idx, occluded in enumerate(occlusion_status) if not occluded]
111
+
112
+ # Create a mapping from old vertex indices to new (filtered) vertex indices
113
+ old_to_new_indices_map = {old_idx: new_idx for new_idx, old_idx in enumerate(visible_vertices_indices)}
114
+
115
+ # Filter the projected_gt_vertices_2d and transform to the new structure
116
+ new_wf_vertices = []
117
+ if projected_gt_vertices_2d.shape[0] > 0: # Ensure projected_gt_vertices_2d is not empty
118
+ for idx in visible_vertices_indices:
119
+ xy_coords = projected_gt_vertices_2d[idx]
120
+ new_wf_vertices.append({'xy': xy_coords, 'type': 'apex'})
121
+ wf_vertices = new_wf_vertices
122
+
123
+ # Filter the edges
124
+ # An edge is kept if both its vertices are in the visible_vertices_indices list
125
+ visible_edges = []
126
+ for edge_start, edge_end in wf_edges:
127
+ if edge_start in old_to_new_indices_map and edge_end in old_to_new_indices_map:
128
+ # Remap to new indices
129
+ visible_edges.append((old_to_new_indices_map[edge_start], old_to_new_indices_map[edge_end]))
130
+ wf_edges = visible_edges
131
+ else:
132
+ # If there are no original vertices, wf_vertices should be an empty list
133
+ wf_vertices = []
134
+ wf_edges = []
135
+
136
+ wf_vertices_3d_visible = np.empty((0, 3))
137
+ original_gt_3d_vertices = np.array(entry['wf_vertices'])
138
+
139
+ # Check if there were original vertices and if occlusion_status was computed for them
140
+ if original_gt_3d_vertices.shape[0] > 0 and len(occlusion_status) == original_gt_3d_vertices.shape[0]:
141
+ # Determine indices of visible vertices based on occlusion_status
142
+ # occlusion_status is True if occluded, False otherwise. We want not occluded.
143
+ visible_indices = [idx for idx, occluded_flag in enumerate(occlusion_status) if not occluded_flag]
144
+
145
+ if visible_indices: # If the list of visible_indices is not empty
146
+ wf_vertices_3d_visible = original_gt_3d_vertices[visible_indices]
147
+ # If no original_gt_3d_vertices, or if all are occluded (visible_indices is empty),
148
+ # or if occlusion_status length doesn't match (which implies an issue earlier, but defensively handled),
149
+ # wf_vertices_3d_visible will remain the initialized np.empty((0, 3)).
150
+
151
+ return wf_vertices, wf_edges, wf_vertices_3d_visible
152
+
153
+ def project_vertices_to_3d(uv: np.ndarray, depth_vert: np.ndarray, col_img: pycolmap.Image, K, R, t) -> np.ndarray:
154
+ """
155
+ Projects 2D vertex coordinates with associated depths to 3D world coordinates.
156
+
157
+ Parameters
158
+ ----------
159
+ uv : np.ndarray
160
+ (N, 2) array of 2D vertex coordinates (u, v).
161
+ depth_vert : np.ndarray
162
+ (N,) array of depth values for each vertex.
163
+ col_img : pycolmap.Image
164
+
165
+ Returns
166
+ -------
167
+ vertices_3d : np.ndarray
168
+ (N, 3) array of vertex coordinates in 3D world space.
169
+ """
170
+ # Backproject to 3D local camera coordinates
171
+ xy_local = np.ones((len(uv), 3))
172
+ #k = col_img.camera.calibration_matrix()
173
+ k = K
174
+ xy_local[:, 0] = (uv[:, 0] - k[0, 2]) / k[0, 0]
175
+ xy_local[:, 1] = (uv[:, 1] - k[1, 2]) / k[1, 1]
176
+ # Get the 3D vertices
177
+ vertices_3d_local = xy_local * depth_vert[...,None]
178
+
179
+ # Create camera-to-world transformation matrix
180
+ world_to_cam = np.eye(4)
181
+ world_to_cam[:3, :3] = R
182
+ world_to_cam[:3, 3] = t.reshape(3)
183
+ #world_to_cam[:3] = col_img.cam_from_world.matrix()
184
+ cam_to_world = np.linalg.inv(world_to_cam)
185
+
186
+ # Transform local 3D points to world coordinates
187
+ vertices_3d_homogeneous = cv2.convertPointsToHomogeneous(vertices_3d_local)
188
+ vertices_3d = cv2.transform(vertices_3d_homogeneous, cam_to_world)
189
+ vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
190
+ return vertices_3d
191
+
192
+ def get_fitted_dense_depth(depth, colmap_rec, img_id, ade20k_seg, K, R, t):
193
+ """
194
+ Gets sparse depth from COLMAP, computes a house mask, fits dense depth to sparse
195
+ depth within the mask, and returns the fitted dense depth.
196
+
197
+ Parameters
198
+ ----------
199
+ depth : np.ndarray
200
+ Initial dense depth map (H, W).
201
+ colmap_rec : pycolmap.Reconstruction
202
+ COLMAP reconstruction data.
203
+ img_id : str
204
+ Identifier for the current image within the COLMAP reconstruction.
205
+ K : np.ndarray
206
+ Camera intrinsic matrix (3x3).
207
+ R : np.ndarray
208
+ Camera rotation matrix (3x3).
209
+ t : np.ndarray
210
+ Camera translation vector (3,).
211
+ ade20k_seg : PIL.Image
212
+ ADE20k segmentation map for the image.
213
+
214
+ Returns
215
+ -------
216
+ depth_fitted : np.ndarray
217
+ Dense depth map scaled and shifted to align with sparse depth within the house mask (H, W).
218
+ depth_sparse : np.ndarray
219
+ The sparse depth map obtained from COLMAP (H, W).
220
+ found_sparse : bool
221
+ True if sparse depth points were found for this image, False otherwise.
222
+ """
223
+ depth_np = np.array(depth) / 1000. # Convert mm to meters if needed
224
+ depth_sparse, found_sparse, col_img = get_sparse_depth_custom(colmap_rec, img_id, depth_np, K, R, t)
225
+ #print(depth_sparse.sum())
226
+ #depth_sparse, found_sparse, col_img = get_sparse_depth(colmap_rec, img_id, depth_np)
227
+
228
+ if not found_sparse:
229
+ print(f'No sparse depth found for image {img_id}')
230
+ # Return original (meter-scaled) depth if no sparse data
231
+ return depth_np, np.zeros_like(depth_np), False, None
232
+
233
+ # Get house mask to focus fitting on relevant areas
234
+ house_mask = get_house_mask(ade20k_seg)
235
+
236
+ # Fit dense depth to sparse depth (scale only), using only points within the house mask
237
+ k, depth_fitted = fit_scale_robust_median(depth_np, depth_sparse, validity_mask=house_mask)
238
+ print(f"Fitted depth scale k={k:.4f} for image {img_id}")
239
+ #depth_fitted = depth_np# * house_mask.astype(np.float32)
240
+ depth_sparse = depth_sparse# * house_mask.astype(np.float32)
241
+ return depth_fitted, depth_sparse, True, col_img
242
+
243
+ def get_sparse_depth_custom(colmap_rec, img_id_substring, depth, K, R, t):
244
+ """
245
+ Return a sparse depth map for the COLMAP image whose name contains
246
+ `img_id_substring`. The output is an array of shape `depth_shape` (H,W),
247
+ where only the projected 3D points get a depth > 0, else 0.
248
+ Uses provided K, R, t for projection instead of COLMAP's image projection.
249
+ """
250
+ H, W = depth.shape
251
+
252
+ # 1) Find the matching COLMAP image to get its associated 3D points
253
+ # This part remains to identify which 3D points are relevant for this image view
254
+ found_img = None
255
+ for img_id_c, col_img_obj in colmap_rec.images.items(): # Renamed col_img to col_img_obj to avoid conflict
256
+ if img_id_substring in col_img_obj.name:
257
+ found_img = col_img_obj
258
+ break
259
+ if found_img is None:
260
+ print(f"Image substring {img_id_substring} not found in COLMAP.")
261
+ return np.zeros((H, W), dtype=np.float32), False, None
262
+
263
+ # 2) Gather 3D points that this image sees (according to COLMAP)
264
+ points_xyz_world = []
265
+ for pid, p3D in colmap_rec.points3D.items():
266
+ if found_img.has_point3D(pid):
267
+ points_xyz_world.append(p3D.xyz) # world coords
268
+ if not points_xyz_world:
269
+ print(f"No 3D points associated with {found_img.name} in COLMAP.")
270
+ return np.zeros((H, W), dtype=np.float32), False, found_img # Return found_img for consistency
271
+
272
+ points_xyz_world = np.array(points_xyz_world) # (N, 3)
273
+
274
+ # 3) Project points_xyz_world to camera coordinates using R, t
275
+ # points_cam = R @ points_xyz_world.T + t.reshape(3,1)
276
+ # points_cam = points_cam.T (N,3)
277
+ # More robustly:
278
+ points_xyz_world_h = np.hstack((points_xyz_world, np.ones((points_xyz_world.shape[0], 1)))) # (N, 4)
279
+
280
+ # World to Camera transformation matrix
281
+ world_to_cam_mat = np.eye(4)
282
+ world_to_cam_mat[:3, :3] = R
283
+ world_to_cam_mat[:3, 3] = t.flatten()
284
+
285
+ points_cam_h = (world_to_cam_mat @ points_xyz_world_h.T).T # (N, 4)
286
+ points_cam = points_cam_h[:, :3] / points_cam_h[:, 3, np.newaxis] # (N, 3) in camera coordinates
287
+
288
+ uv = []
289
+ z_vals = []
290
+
291
+ for i in range(points_cam.shape[0]):
292
+ p_cam = points_cam[i]
293
+
294
+ # Project to image plane using K
295
+ # p_img_h = K @ p_cam
296
+ # u = p_img_h[0] / p_img_h[2]
297
+ # v = p_img_h[1] / p_img_h[2]
298
+ # z = p_cam[2]
299
+
300
+ # Ensure p_cam[2] (depth) is positive
301
+ if p_cam[2] <= 0: # Point is behind or on the camera plane
302
+ continue
303
+
304
+ # Project to image plane using K
305
+ # K is [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
306
+ u_i = (K[0, 0] * p_cam[0] / p_cam[2]) + K[0, 2]
307
+ v_i = (K[1, 1] * p_cam[1] / p_cam[2]) + K[1, 2]
308
+
309
+ u_i_int = int(round(u_i))
310
+ v_i_int = int(round(v_i))
311
+
312
+ # Check in-bounds
313
+ if 0 <= u_i_int < W and 0 <= v_i_int < H:
314
+ uv.append((u_i_int, v_i_int))
315
+ z_vals.append(p_cam[2]) # Depth is the Z coordinate in camera space
316
+
317
+ if not uv:
318
+ print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
319
+ return np.zeros((H, W), dtype=np.float32), False, found_img
320
+
321
+ uv = np.array(uv, dtype=int) # shape (M,2)
322
+ z_vals = np.array(z_vals) # shape (M,)
323
+
324
+ depth_out = np.zeros((H, W), dtype=np.float32)
325
+ # Ensure z_vals are positive before assignment, though already checked
326
+ valid_depth_mask = z_vals > 0
327
+ if np.any(valid_depth_mask):
328
+ depth_out[uv[valid_depth_mask, 1], uv[valid_depth_mask, 0]] = z_vals[valid_depth_mask]
329
+
330
+ return depth_out, True, found_img
331
+
332
+
333
+ def create_3d_wireframe_single_image(vertices: List[dict],
334
+ connections: List[Tuple[int, int]],
335
+ depth: PImage,
336
+ colmap_rec: pycolmap.Reconstruction,
337
+ img_id: str,
338
+ ade_seg: PImage,
339
+ K, R, t) -> np.ndarray:
340
+ """
341
+ Processes a single image view to generate 3D vertex coordinates from existing 2D vertices/edges.
342
+
343
+ Parameters
344
+ ----------
345
+ vertices : List[dict]
346
+ List of 2D vertex dictionaries (e.g., {"xy": (x, y), "type": ...}).
347
+ connections : List[Tuple[int, int]]
348
+ List of 2D edge connections (indices into the vertices list).
349
+ depth : PIL.Image
350
+ Initial dense depth map as a PIL Image.
351
+ colmap_rec : pycolmap.Reconstruction
352
+ COLMAP reconstruction data.
353
+ img_id : str
354
+ Identifier for the current image within the COLMAP reconstruction.
355
+ ade_seg : PIL.Image
356
+ ADE20k segmentation map for the image.
357
+
358
+ Returns
359
+ -------
360
+ vertices_3d : np.ndarray
361
+ (N, 3) array of vertex coordinates in 3D world space.
362
+ Returns an empty array if processing fails (e.g., missing sparse depth).
363
+ """
364
+ # Check if initial vertices/connections are valid
365
+ if (len(vertices) < 2) or (len(connections) < 1):
366
+ # This case should ideally be handled before calling, but good to double check.
367
+ print(f'Warning: create_3d_wireframe_single_image called with insufficient vertices/connections for image {img_id}')
368
+ return np.empty((0, 3))
369
+
370
+ # Get fitted dense depth and sparse depth
371
+ depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(
372
+ depth, colmap_rec, img_id, ade_seg, K, R, t
373
+ )
374
+
375
+ # Get UV coordinates and depth for each vertex
376
+ uv, depth_vert = get_uv_depth(vertices, depth_fitted, depth_sparse, 10)
377
+
378
+ # Backproject to 3D
379
+ vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img, K, R ,t)
380
+
381
+ return vertices_3d
382
+
383
+
384
  def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
385
  """
386
  Predict 3D wireframe from a dataset entry.
387
  """
388
  good_entry = convert_entry_to_human_readable(entry)
389
+ colmap_rec = good_entry['colmap_binary']
390
+
391
+ colmap_pcloud = []
392
+ for i, p3D in colmap_rec.points3D.items():
393
+ p3D.color = np.array([0, 0, 0])
394
+ colmap_pcloud.append(p3D)
395
+
396
  vert_edge_per_image = {}
397
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
398
  good_entry['depth'],
 
402
  good_entry['image_ids'],
403
  good_entry['ade'] # Added ade20k segmentation
404
  )):
405
+ # Visualize gestalt segmentation
406
  K = np.array(K)
407
  R = np.array(R)
408
  t = np.array(t)
409
+
410
  # Resize gestalt segmentation to match depth map size
411
  depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
412
  gest_seg = gest.resize(depth_size)
413
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
414
 
415
+ pcloud_segmented, pcloud_idxs = extract_segmented_pcloud(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t)
416
+ for idx, p3D in enumerate(colmap_rec.points3D.values()):
417
+ if idx in pcloud_idxs:
418
+ p3D.color = np.array([255, 0, 0])
419
+
420
  # Get 2D vertices and edges first
421
+ vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=20.)
422
+
423
+ gt_verts = []
424
+ #gt_verts, gt_connects, gt_verts3d = get_gt_vertices_and_edges(good_entry, i, depth, colmap_rec, K, R, t, img_id, ade_seg)
425
+ #vertices, connections = gt_verts, gt_connects
426
+
427
+ if False:
428
+ gest.save(f'gestalt/{img_id}.png')
429
+ # Save ADE20k segmentation
430
+ # ade_seg is already a PIL Image
431
+ try:
432
+ ade_seg.save(f'ade_segmentations/{img_id}_ade.png')
433
+ except Exception as e:
434
+ print(f"Could not save ADE segmentation for {img_id}: {e}")
435
+ save_gestalt_with_proj(gest_seg_np, gt_verts, img_id)
436
+ # Define a local helper function to draw crosses and save the image
437
+
438
+ # Draw crosses on the ADE segmentation image and save it
439
+ # 'vertices' here refers to gt_verts
440
+ draw_crosses_on_image(ade_seg, vertices, f'crosses_{img_id}.png', color=(0, 0, 0), size=5)
441
 
442
  # Check if we have enough to proceed
443
  if (len(vertices) < 2) or (len(connections) < 1):
 
447
 
448
  # Call the refactored function to get 3D points
449
  vertices_3d = create_3d_wireframe_single_image(
450
+ vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t
451
  )
452
+ #vertices_3d = gt_verts3d
453
  # Store original 2D vertices, connections, and computed 3D points
454
  vert_edge_per_image[i] = vertices, connections, vertices_3d
455
+
456
+ # Visualize colored COLMAP point cloud with Open3D
457
+
458
+ # Create Open3D point cloud from COLMAP reconstruction
459
+ pcd = o3d.geometry.PointCloud()
460
+
461
+ # Extract points and colors
462
+ points = []
463
+ colors = []
464
+ for p3D in colmap_rec.points3D.values():
465
+ points.append(p3D.xyz)
466
+ # Normalize color to [0,1] range for Open3D
467
+ colors.append(p3D.color / 255.0)
468
+
469
+ if points:
470
+ pcd.points = o3d.utility.Vector3dVector(np.array(points))
471
+ pcd.colors = o3d.utility.Vector3dVector(np.array(colors))
472
+
473
+ # Visualize the point cloud
474
+ o3d.visualization.draw_geometries([pcd], window_name="COLMAP Point Cloud")
475
 
476
  # Merge vertices from all images
477
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
478
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
479
+ all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
480
 
 
481
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
482
  print (f'Not enough vertices or connections in the 3D vertices')
483
  return empty_solution()
484
 
485
  return all_3d_vertices_clean, connections_3d_clean
486
+
487
+
488
+ def extract_segmented_pcloud(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None):
489
+ """
490
+ Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
491
+ Also find all COLMAP points that project into apex or eave_end masks.
492
+ """
493
+ #--------------------------------------------------------------------------------
494
+ # Step A: Collect apex and eave_end vertices
495
+ #--------------------------------------------------------------------------------
496
+ if not isinstance(gest_seg_np, np.ndarray):
497
+ gest_seg_np = np.array(gest_seg_np)
498
+
499
+ # Apex
500
+ apex_color = np.array(gestalt_color_mapping['apex'])
501
+ apex_mask = cv2.inRange(gest_seg_np, apex_color-10., apex_color+10.)
502
+
503
+ # Eave end
504
+ eave_end_color = np.array(gestalt_color_mapping['eave_end_point'])
505
+ eave_end_mask = cv2.inRange(gest_seg_np, eave_end_color-10, eave_end_color+10)
506
+
507
+ # Combined mask for apex and eave_end
508
+ combined_mask = cv2.bitwise_or(apex_mask, eave_end_mask)
509
+
510
+ H, W = gest_seg_np.shape[:2]
511
+
512
+ # 1) Find the matching COLMAP image to get its associated 3D points
513
+ # This part remains to identify which 3D points are relevant for this image view
514
+ found_img = None
515
+ for img_id_c, col_img_obj in colmap_rec.images.items(): # Renamed col_img to col_img_obj to avoid conflict
516
+ if img_id_substring in col_img_obj.name:
517
+ found_img = col_img_obj
518
+ break
519
+ if found_img is None:
520
+ print(f"Image substring {img_id_substring} not found in COLMAP.")
521
+ return np.zeros((H, W), dtype=np.float32), False, None
522
+
523
+ # 2) Gather 3D points that this image sees (according to COLMAP)
524
+ points_xyz_world = []
525
+ points_idxs = []
526
+ for pid, p3D in colmap_rec.points3D.items():
527
+ if found_img.has_point3D(pid):
528
+ points_xyz_world.append(p3D.xyz) # world coords
529
+ points_idxs.append(pid)
530
+ if not points_xyz_world:
531
+ print(f"No 3D points associated with {found_img.name} in COLMAP.")
532
+ return np.zeros((H, W), dtype=np.float32), False, found_img # Return found_img for consistency
533
+
534
+ points_xyz_world = np.array(points_xyz_world) # (N, 3)
535
+ points_idxs = np.array(points_idxs) # (N,)
536
+
537
+ # 3) Project points_xyz_world to camera coordinates using R, t
538
+ # points_cam = R @ points_xyz_world.T + t.reshape(3,1)
539
+ # points_cam = points_cam.T (N,3)
540
+ # More robustly:
541
+ points_xyz_world_h = np.hstack((points_xyz_world, np.ones((points_xyz_world.shape[0], 1)))) # (N, 4)
542
+
543
+ # World to Camera transformation matrix
544
+ world_to_cam_mat = np.eye(4)
545
+ world_to_cam_mat[:3, :3] = R
546
+ world_to_cam_mat[:3, 3] = t.flatten()
547
+
548
+ points_cam_h = (world_to_cam_mat @ points_xyz_world_h.T).T # (N, 4)
549
+ points_cam = points_cam_h[:, :3] / points_cam_h[:, 3, np.newaxis] # (N, 3) in camera coordinates
550
+
551
+ uv = []
552
+ valid_indices = [] # Track which original points are valid
553
+
554
+ for i in range(points_cam.shape[0]):
555
+ p_cam = points_cam[i]
556
+
557
+ # Ensure p_cam[2] (depth) is positive
558
+ if p_cam[2] <= 0:
559
+ continue
560
+
561
+ # Project to image plane using K
562
+ u_i = (K[0, 0] * p_cam[0] / p_cam[2]) + K[0, 2]
563
+ v_i = (K[1, 1] * p_cam[1] / p_cam[2]) + K[1, 2]
564
+
565
+ u_i_int = int(round(u_i))
566
+ v_i_int = int(round(v_i))
567
+
568
+ # Check in-bounds
569
+ if 0 <= u_i_int < W and 0 <= v_i_int < H:
570
+ uv.append((u_i_int, v_i_int))
571
+ valid_indices.append(i) # Store original index
572
+
573
+ uv_colmap = []
574
+ valid_indices_colmap = []
575
+ for i, xyz in enumerate(points_xyz_world):
576
+ proj = found_img.project_point(xyz) # returns (u, v) in image coords or None
577
+ if proj is not None:
578
+ u_i, v_i = proj
579
+ u_i = int(round(u_i))
580
+ v_i = int(round(v_i))
581
+ # Check in-bounds
582
+ if 0 <= u_i < W and 0 <= v_i < H:
583
+ uv_colmap.append((u_i, v_i))
584
+ valid_indices_colmap.append(i) # Store original index
585
+
586
+ if not uv:
587
+ print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
588
+ return np.zeros((H, W), dtype=np.float32), False, found_img
589
+
590
+ house_mask = get_house_mask(ade_seg)
591
+
592
+ uv = np.array(uv, dtype=int)
593
+ valid_indices = np.array(valid_indices)
594
+
595
+ # Filter points that fall within the apex or eave_end masks
596
+ filtered_points_xyz = []
597
+ filtered_point_idxs = []
598
+
599
+ for i, (u, v) in enumerate(uv):
600
+ # Check if this projected point falls within the combined maskvalid_indices
601
+ if combined_mask[v, u] > 0 and house_mask[v, u] > 0:
602
+ original_idx = valid_indices[i] # Get original index
603
+ filtered_points_xyz.append(points_xyz_world[original_idx])
604
+ filtered_point_idxs.append(points_idxs[original_idx])
605
+
606
+ filtered_points_xyz = np.array(filtered_points_xyz) if filtered_points_xyz else np.empty((0, 3))
607
+ filtered_point_idxs = np.array(filtered_point_idxs) if filtered_point_idxs else np.empty((0,))
608
+
609
+ '''
610
+ depth_fitted, depth_sparse, _, col_img = get_fitted_dense_depth(depth, colmap_rec, img_id_substring, ade_seg, K, R, t)
611
+
612
+ # Segment the depth_fitted to get points in apex/eave_end regions
613
+ segmented_points_3d = []
614
+
615
+ # Get coordinates where the combined mask is active
616
+ mask_coords = np.where(combined_mask > 0)
617
+ v_coords, u_coords = mask_coords
618
+
619
+ # Also apply house mask for additional filtering
620
+ house_coords = np.where(house_mask > 0)
621
+ house_v, house_u = house_coords
622
+
623
+ # Find intersection of combined_mask and house_mask
624
+ valid_mask = np.logical_and(combined_mask > 0, house_mask > 0)
625
+ valid_coords = np.where(valid_mask)
626
+ v_valid, u_valid = valid_coords
627
+
628
+ if len(v_valid) > 0:
629
+ # Get depth values at these coordinates
630
+ depth_values = depth_fitted[v_valid, u_valid]
631
+
632
+ # Filter out zero or invalid depth values
633
+ valid_depth_mask = depth_values > 0
634
+ if np.any(valid_depth_mask):
635
+ u_final = u_valid[valid_depth_mask]
636
+ v_final = v_valid[valid_depth_mask]
637
+ depth_final = depth_values[valid_depth_mask]
638
+
639
+ # Create UV coordinates for backprojection
640
+ uv_depth = np.column_stack((u_final, v_final))
641
+
642
+ # Backproject to 3D world coordinates
643
+ segmented_points_3d = project_vertices_to_3d(uv_depth, depth_final, col_img, K, R, t)
644
+ '''
645
+ segmented_points_3d = []
646
+
647
+ # Visualize with the segmented depth points in blue
648
+ pcd_all = o3d.geometry.PointCloud()
649
+ pcd_filtered = o3d.geometry.PointCloud()
650
+ pcd_depth = o3d.geometry.PointCloud()
651
+
652
+ # All points in gray
653
+ all_points = []
654
+ all_colors = []
655
+ for p3D in colmap_rec.points3D.values():
656
+ all_points.append(p3D.xyz)
657
+ all_colors.append([0.5, 0.5, 0.5]) # Gray color
658
+
659
+ if all_points:
660
+ pcd_all.points = o3d.utility.Vector3dVector(np.array(all_points))
661
+ pcd_all.colors = o3d.utility.Vector3dVector(np.array(all_colors))
662
+
663
+ # Filtered COLMAP points in red
664
+ if len(filtered_points_xyz) > 0:
665
+ pcd_filtered.points = o3d.utility.Vector3dVector(filtered_points_xyz)
666
+ pcd_filtered.colors = o3d.utility.Vector3dVector(np.full((len(filtered_points_xyz), 3), [1.0, 0.0, 0.0]))
667
+
668
+ # Segmented depth points in blue
669
+ if len(segmented_points_3d) > 0:
670
+ pcd_depth.points = o3d.utility.Vector3dVector(segmented_points_3d)
671
+ pcd_depth.colors = o3d.utility.Vector3dVector(np.full((len(segmented_points_3d), 3), [0.0, 0.0, 1.0]))
672
+
673
+ # Visualize all point clouds
674
+ geometries = [pcd_all]
675
+ if len(filtered_points_xyz) > 0:
676
+ geometries.append(pcd_filtered)
677
+ if len(segmented_points_3d) > 0:
678
+ geometries.append(pcd_depth)
679
+
680
+ o3d.visualization.draw_geometries(geometries, window_name=f"Combined Point Cloud - {img_id_substring}")
681
+
682
+ return filtered_points_xyz, filtered_point_idxs
train.py CHANGED
@@ -23,16 +23,15 @@ show_visu = False
23
 
24
  idx = 0
25
  for a in ds['train']:
26
- colmap = read_colmap_rec(a['colmap_binary'])
27
-
28
  #plot_all_modalities(a)
29
-
30
  try:
31
  pred_vertices, pred_edges = predict_wireframe(a)
32
  except:
33
  pred_vertices, pred_edges = empty_solution()
34
 
35
  if show_visu:
 
36
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
37
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
38
  wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
 
23
 
24
  idx = 0
25
  for a in ds['train']:
 
 
26
  #plot_all_modalities(a)
27
+ #pred_vertices, pred_edges = predict_wireframe(a)
28
  try:
29
  pred_vertices, pred_edges = predict_wireframe(a)
30
  except:
31
  pred_vertices, pred_edges = empty_solution()
32
 
33
  if show_visu:
34
+ colmap = read_colmap_rec(a['colmap_binary'])
35
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
36
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
37
  wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
visu.py CHANGED
@@ -5,6 +5,7 @@ import pycolmap
5
  import tempfile,zipfile
6
  import io
7
  import open3d as o3d
 
8
 
9
  def _plotly_rgb_to_normalized_o3d_color(color_val) -> list[float]:
10
  """
@@ -28,6 +29,64 @@ def _plotly_rgb_to_normalized_o3d_color(color_val) -> list[float]:
28
  return [c/255.0 for c in color_val]
29
  raise ValueError(f"Unsupported color type for Open3D conversion: {type(color_val)}. Expected string or 3-element tuple/list.")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def plot_reconstruction_local(
33
  fig: go.Figure,
 
5
  import tempfile,zipfile
6
  import io
7
  import open3d as o3d
8
+ from PIL import Image, ImageDraw
9
 
10
  def _plotly_rgb_to_normalized_o3d_color(color_val) -> list[float]:
11
  """
 
29
  return [c/255.0 for c in color_val]
30
  raise ValueError(f"Unsupported color type for Open3D conversion: {type(color_val)}. Expected string or 3-element tuple/list.")
31
 
32
+ def draw_crosses_on_image(image_pil, vertices_data, output_file_path, size=5, color=(0, 0, 0)):
33
+ """
34
+ Draws crosses on a PIL Image at specified vertex locations and saves it.
35
+ Args:
36
+ image_pil (PIL.Image.Image): The image to draw on.
37
+ vertices_data (list): List of dictionaries, each with an 'xy' key
38
+ holding [x, y] coordinates.
39
+ output_file_path (str): Path to save the modified image.
40
+ size (int): Size of the cross arms.
41
+ color (tuple): RGB color for the cross.
42
+ """
43
+ # Work on a copy to avoid modifying the original image
44
+ img_to_draw_on = image_pil.copy()
45
+ drawer = ImageDraw.Draw(img_to_draw_on)
46
+
47
+ for vert_info in vertices_data:
48
+ if 'xy' in vert_info:
49
+ x, y = vert_info['xy']
50
+ # Ensure coordinates are integers for drawing
51
+ x_int, y_int = int(round(x)), int(round(y))
52
+
53
+ # Draw horizontal line
54
+ drawer.line([(x_int - size, y_int), (x_int + size, y_int)], fill=color, width=1)
55
+ # Draw vertical line
56
+ drawer.line([(x_int, y_int - size), (x_int, y_int + size)], fill=color, width=1)
57
+
58
+ img_to_draw_on.save(output_file_path)
59
+
60
+ def save_gestalt_with_proj(gest_seg_np, gt_verts, img_id):
61
+ # Convert gest_seg_np (which is a numpy array) to a PIL Image
62
+ # Assuming gest_seg_np is a 2D grayscale or a 3-channel RGB image
63
+ if gest_seg_np.ndim == 2:
64
+ img_to_draw_on = Image.fromarray(gest_seg_np, mode='L')
65
+ elif gest_seg_np.ndim == 3 and gest_seg_np.shape[2] == 3:
66
+ img_to_draw_on = Image.fromarray(gest_seg_np, mode='RGB')
67
+ else:
68
+ # Fallback or error handling if the format is unexpected
69
+ # For simplicity, let's assume it can be converted directly or handle specific cases
70
+ img_to_draw_on = Image.fromarray(gest_seg_np.astype(np.uint8))
71
+
72
+ # Ensure the image is in a mode that allows color drawing (e.g., RGB)
73
+ if img_to_draw_on.mode == 'L':
74
+ img_to_draw_on = img_to_draw_on.convert('RGB')
75
+
76
+ draw = ImageDraw.Draw(img_to_draw_on)
77
+ cross_size = 5 # Size of the cross arms
78
+ cross_color = (0, 0, 0) # Red color for the cross
79
+
80
+ for vert_dict in gt_verts:
81
+ x, y = vert_dict['xy']
82
+ # Draw horizontal line of the cross
83
+ draw.line([(x - cross_size, y), (x + cross_size, y)], fill=cross_color, width=1)
84
+ # Draw vertical line of the cross
85
+ draw.line([(x, y - cross_size), (x, y + cross_size)], fill=cross_color, width=1)
86
+
87
+ # Save the image with drawn crosses
88
+ # You might want to use a different filename or path
89
+ img_to_draw_on.save(f'gestalt_cross/{img_id}.png')
90
 
91
  def plot_reconstruction_local(
92
  fig: go.Figure,