Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import repeat | |
| def sample_farthest_points(pts, k, return_index=False): | |
| b, c, n = pts.shape | |
| farthest_pts = torch.zeros((b, 3, k), device=pts.device, dtype=pts.dtype) | |
| indexes = torch.zeros((b, k), device=pts.device, dtype=torch.int64) | |
| index = torch.randint(n, [b], device=pts.device) | |
| gather_index = repeat(index, 'b -> b c 1', c=c) | |
| farthest_pts[:, :, 0] = torch.gather(pts, 2, gather_index)[:, :, 0] | |
| indexes[:, 0] = index | |
| distances = torch.norm(farthest_pts[:, :, 0][:, :, None] - pts, dim=1) | |
| for i in range(1, k): | |
| _, index = torch.max(distances, dim=1) | |
| gather_index = repeat(index, 'b -> b c 1', c=c) | |
| farthest_pts[:, :, i] = torch.gather(pts, 2, gather_index)[:, :, 0] | |
| indexes[:, i] = index | |
| distances = torch.min(distances, torch.norm(farthest_pts[:, :, i][:, :, None] - pts, dim=1)) | |
| if return_index: | |
| return farthest_pts, indexes | |
| else: | |
| return farthest_pts | |
| def line_segment_distance(a, b, points, sqrt=True): | |
| """ | |
| compute the distance between a point and a line segment defined by a and b | |
| a, b: ... x D | |
| points: ... x D | |
| """ | |
| def sumprod(x, y, keepdim=True): | |
| return torch.sum(x * y, dim=-1, keepdim=keepdim) | |
| a, b = a[..., None, :], b[..., None, :] | |
| t_min = sumprod(points - a, b - a) / torch.max(sumprod(b - a, b - a), torch.tensor(1e-6, device=a.device)) | |
| t_line = torch.clamp(t_min, 0.0, 1.0) | |
| # closest points on the line to every point | |
| s = a + t_line * (b - a) | |
| distance = sumprod(s - points, s - points, keepdim=False) | |
| if sqrt: | |
| distance = torch.sqrt(distance + 1e-6) | |
| return distance | |