Skip to content

Commit

Permalink
Add sphere collider (#1310)
Browse files Browse the repository at this point in the history
Co-authored-by: David McAllister <[email protected]>
  • Loading branch information
tancik and mcallisterdavid authored Jan 30, 2023
1 parent 91d6116 commit 61dc520
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions nerfstudio/model_components/scene_colliders.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,65 @@ def set_nears_and_fars(self, ray_bundle: RayBundle) -> RayBundle:
return ray_bundle


@torch.jit.script
def _intersect_with_sphere(
rays_o: torch.Tensor, rays_d: torch.Tensor, center: torch.Tensor, radius: float = 1.0, near_plane: float = 0.0
):
a = (rays_d * rays_d).sum(dim=-1, keepdim=True)
b = 2 * (rays_o - center) * rays_d
b = b.sum(dim=-1, keepdim=True)
c = (rays_o - center) * (rays_o - center)
c = c.sum(dim=-1, keepdim=True) - radius**2

# clamp to near plane
nears = (-b - torch.sqrt(torch.square(b) - 4 * a * c)) / (2 * a)
fars = (-b + torch.sqrt(torch.square(b) - 4 * a * c)) / (2 * a)

nears = torch.clamp(nears, min=near_plane)
fars = torch.maximum(fars, nears + 1e-6)

nears = torch.nan_to_num(nears, nan=0.0)
fars = torch.nan_to_num(fars, nan=0.0)

return nears, fars


class SphereCollider(SceneCollider):
"""Module for colliding rays with the scene box to compute near and far values.
Args:
center: center of sphere to intersect [3]
redius: radius of sphere to intersect
near_plane: near plane to clamp to
"""

def __init__(self, center: torch.Tensor, radius: float, near_plane: float = 0.0, **kwargs) -> None:
super().__init__(**kwargs)
self.center = center
self.radius = radius
self.near_plane = near_plane

def set_nears_and_fars(self, ray_bundle: RayBundle) -> RayBundle:
"""Intersects the rays with the scene box and updates the near and far values.
Populates nears and fars fields and returns the ray_bundle.
Args:
ray_bundle: specified ray bundle to operate on
"""
self.center = self.center.to(ray_bundle.origins.device)
near_plane = self.near_plane if self.training else 0
nears, fars = _intersect_with_sphere(
rays_o=ray_bundle.origins,
rays_d=ray_bundle.directions,
center=self.center,
radius=self.radius,
near_plane=near_plane,
)
ray_bundle.nears = nears
ray_bundle.fars = fars
return ray_bundle


class NearFarCollider(SceneCollider):
"""Sets the nears and fars with fixed values.
Expand Down

0 comments on commit 61dc520

Please sign in to comment.