|
20 | 20 | from typing import Literal, Tuple
|
21 | 21 |
|
22 | 22 | import torch
|
23 |
| -from jaxtyping import Bool, Float |
| 23 | +from jaxtyping import Bool, Float, Int |
24 | 24 | from torch import Tensor
|
25 | 25 |
|
26 | 26 | from nerfstudio.data.scene_box import OrientedBox
|
27 | 27 |
|
28 | 28 |
|
29 |
| -def components_from_spherical_harmonics( |
30 |
| - levels: int, directions: Float[Tensor, "*batch 3"] |
31 |
| -) -> Float[Tensor, "*batch components"]: |
32 |
| - """ |
33 |
| - Returns value for each component of spherical harmonics. |
34 |
| -
|
35 |
| - Args: |
36 |
| - levels: Number of spherical harmonic levels to compute. |
37 |
| - directions: Spherical harmonic coefficients |
38 |
| - """ |
39 |
| - num_components = levels**2 |
40 |
| - components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device) |
41 |
| - |
42 |
| - assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}" |
43 |
| - assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}" |
44 |
| - |
45 |
| - x = directions[..., 0] |
46 |
| - y = directions[..., 1] |
47 |
| - z = directions[..., 2] |
48 |
| - |
49 |
| - xx = x**2 |
50 |
| - yy = y**2 |
51 |
| - zz = z**2 |
52 |
| - |
53 |
| - # l0 |
54 |
| - components[..., 0] = 0.28209479177387814 |
55 |
| - |
56 |
| - # l1 |
57 |
| - if levels > 1: |
58 |
| - components[..., 1] = 0.4886025119029199 * y |
59 |
| - components[..., 2] = 0.4886025119029199 * z |
60 |
| - components[..., 3] = 0.4886025119029199 * x |
61 |
| - |
62 |
| - # l2 |
63 |
| - if levels > 2: |
64 |
| - components[..., 4] = 1.0925484305920792 * x * y |
65 |
| - components[..., 5] = 1.0925484305920792 * y * z |
66 |
| - components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999 |
67 |
| - components[..., 7] = 1.0925484305920792 * x * z |
68 |
| - components[..., 8] = 0.5462742152960396 * (xx - yy) |
69 |
| - |
70 |
| - # l3 |
71 |
| - if levels > 3: |
72 |
| - components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy) |
73 |
| - components[..., 10] = 2.890611442640554 * x * y * z |
74 |
| - components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1) |
75 |
| - components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3) |
76 |
| - components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1) |
77 |
| - components[..., 14] = 1.445305721320277 * z * (xx - yy) |
78 |
| - components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy) |
79 |
| - |
80 |
| - # l4 |
81 |
| - if levels > 4: |
82 |
| - components[..., 16] = 2.5033429417967046 * x * y * (xx - yy) |
83 |
| - components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy) |
84 |
| - components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1) |
85 |
| - components[..., 19] = 0.6690465435572892 * y * z * (7 * zz - 3) |
86 |
| - components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3) |
87 |
| - components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3) |
88 |
| - components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1) |
89 |
| - components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy) |
90 |
| - components[..., 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) |
91 |
| - |
92 |
| - return components |
93 |
| - |
94 |
| - |
95 | 29 | @dataclass
|
96 | 30 | class Gaussians:
|
97 | 31 | """Stores Gaussians
|
@@ -323,7 +257,9 @@ def masked_reduction(
|
323 | 257 |
|
324 | 258 |
|
325 | 259 | def normalized_depth_scale_and_shift(
|
326 |
| - prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"] |
| 260 | + prediction: Float[Tensor, "1 32 mult"], |
| 261 | + target: Float[Tensor, "1 32 mult"], |
| 262 | + mask: Bool[Tensor, "1 32 mult"], |
327 | 263 | ):
|
328 | 264 | """
|
329 | 265 | More info here: https://arxiv.org/pdf/2206.00665.pdf supplementary section A2 Depth Consistency Loss
|
@@ -405,7 +341,10 @@ def _compute_tesselation_weights(v: int) -> Tensor:
|
405 | 341 |
|
406 | 342 |
|
407 | 343 | def _tesselate_geodesic(
|
408 |
| - vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4 |
| 344 | + vertices: Float[Tensor, "N 3"], |
| 345 | + faces: Float[Tensor, "M 3"], |
| 346 | + v: int, |
| 347 | + eps: float = 1e-4, |
409 | 348 | ) -> Tensor:
|
410 | 349 | """Tesselate the vertices of a geodesic polyhedron.
|
411 | 350 |
|
@@ -518,3 +457,58 @@ def generate_polyhedron_basis(
|
518 | 457 |
|
519 | 458 | basis = verts.flip(-1)
|
520 | 459 | return basis
|
| 460 | + |
| 461 | + |
| 462 | +def random_quat_tensor(N: int) -> Float[Tensor, "*batch 4"]: |
| 463 | + """ |
| 464 | + Defines a random quaternion tensor. |
| 465 | +
|
| 466 | + Args: |
| 467 | + N: Number of quaternions to generate |
| 468 | +
|
| 469 | + Returns: |
| 470 | + a random quaternion tensor of shape (N, 4) |
| 471 | +
|
| 472 | + """ |
| 473 | + u = torch.rand(N) |
| 474 | + v = torch.rand(N) |
| 475 | + w = torch.rand(N) |
| 476 | + return torch.stack( |
| 477 | + [ |
| 478 | + torch.sqrt(1 - u) * torch.sin(2 * math.pi * v), |
| 479 | + torch.sqrt(1 - u) * torch.cos(2 * math.pi * v), |
| 480 | + torch.sqrt(u) * torch.sin(2 * math.pi * w), |
| 481 | + torch.sqrt(u) * torch.cos(2 * math.pi * w), |
| 482 | + ], |
| 483 | + dim=-1, |
| 484 | + ) |
| 485 | + |
| 486 | + |
| 487 | +def k_nearest_sklearn( |
| 488 | + x: torch.Tensor, k: int, metric: str = "euclidean" |
| 489 | +) -> Tuple[Float[Tensor, "*batch k"], Int[Tensor, "*batch k"]]: |
| 490 | + """ |
| 491 | + Find k-nearest neighbors using sklearn's NearestNeighbors. |
| 492 | +
|
| 493 | + Args: |
| 494 | + x: input tensor |
| 495 | + k: number of neighbors to find |
| 496 | + metric: metric to use for distance computation |
| 497 | +
|
| 498 | + Returns: |
| 499 | + distances: distances to the k-nearest neighbors |
| 500 | + indices: indices of the k-nearest neighbors |
| 501 | + """ |
| 502 | + # Convert tensor to numpy array |
| 503 | + x_np = x.cpu().numpy() |
| 504 | + |
| 505 | + # Build the nearest neighbors model |
| 506 | + from sklearn.neighbors import NearestNeighbors |
| 507 | + |
| 508 | + nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric=metric).fit(x_np) |
| 509 | + |
| 510 | + # Find the k-nearest neighbors |
| 511 | + distances, indices = nn_model.kneighbors(x_np) |
| 512 | + |
| 513 | + # Exclude the point itself from the result and return |
| 514 | + return torch.tensor(distances[:, 1:], dtype=torch.float32), torch.tensor(indices[:, 1:], dtype=torch.int64) |
0 commit comments