Skip to content

Commit

Permalink
Replace rays.py with full raygen functionality
Browse files Browse the repository at this point in the history
Fix issues with custom resolution pixel grid

Fix camera raygen recipe

MR fixes round 1

Fix broken test

Fix broken test 2

Add raygen tests

Fix docs

Signed-off-by: operel <[email protected]>
  • Loading branch information
operel committed Jul 22, 2024
1 parent 0abf080 commit 0b0eeaf
Show file tree
Hide file tree
Showing 15 changed files with 791 additions and 161 deletions.
Binary file added docs/img/camera_raygen_grid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/kaolin_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def run_apidoc(_):
"kaolin/render/camera/intrinsics_pinhole.py",
"kaolin/render/camera/intrinsics.py",
"kaolin/render/camera/legacy.py",
"kaolin/render/camera/raygen.py",
"kaolin/non_commercial/flexicubes/",
"kaolin/non_commercial/flexicubes/flexicubes.py",
"kaolin/non_commercial/flexicubes/tables.py"
Expand Down
1 change: 0 additions & 1 deletion docs/modules/kaolin.render.camera.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,3 @@ Functions
:undoc-members:
:show-inheritance:


4 changes: 3 additions & 1 deletion docs/modules/kaolin.render.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
.. _kaolin.render:

This submodule contains all the features related to rendering, including differentiable rasterizer, raytracer, lighting, materials and our Camera API.


kaolin.render
=============

Expand All @@ -13,4 +16,3 @@ kaolin.render
kaolin.render.materials
kaolin.render.mesh
kaolin.render.spc
kaolin.render.rays
14 changes: 14 additions & 0 deletions docs/notes/differentiable_camera.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,20 @@ CameraIntrinsics class
currently there are two subclasses of intrinsics: :class:`kaolin.render.camera.OrthographicIntrinsics` and
:class:`kaolin.render.camera.PinholeIntrinsics`.

Ray Generation
======================
Starting with kaolin 0.16.0, :class:`kaolin.render.camera.Camera` supports ray generation of pinhole and
orthographic cameras via
:func:`generate_rays() <Camera.generate_rays()>`.

The full functional api is included in the :ref:`raygen.py module<kaolin.render.camera.raygen>`,
and allows for lower level operations such as :func:`generate_centered_custom_resolution_pixel_coords()`
for creating a custom pixel-grid to guide ray-generation.
This is useful for supporting ray-tracing with lower resolution image planes,
or implementing more advanced effects like ray-jittering.

.. image:: ../img/camera_raygen_grid.png

API Documentation:
------------------

Expand Down
82 changes: 36 additions & 46 deletions examples/recipes/camera/camera_ray_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,9 @@

import torch
import numpy as np
from typing import Tuple
from kaolin.render.camera import Camera, CameraFOV

def generate_pixel_grid(res_x=None, res_y=None, device='cuda'):
h_coords = torch.arange(res_x, device=device)
w_coords = torch.arange(res_y, device=device)
pixel_y, pixel_x = torch.meshgrid(h_coords, w_coords)
pixel_x = pixel_x + 0.5
pixel_y = pixel_y + 0.5
return pixel_y, pixel_x


def generate_perspective_rays(camera: Camera, pixel_grid: Tuple[torch.Tensor, torch.Tensor]):
# coords_grid should remain immutable (a new tensor is implicitly created here)
pixel_y, pixel_x = pixel_grid
pixel_x = pixel_x.to(camera.device, camera.dtype)
pixel_y = pixel_y.to(camera.device, camera.dtype)

# Account for principal point offset from canvas center
pixel_x = pixel_x - camera.x0
pixel_y = pixel_y + camera.y0

# pixel values are now in range [-1, 1], both tensors are of shape res_y x res_x
pixel_x = 2 * (pixel_x / camera.width) - 1.0
pixel_y = 2 * (pixel_y / camera.height) - 1.0

ray_dir = torch.stack((pixel_x * camera.tan_half_fov(CameraFOV.HORIZONTAL),
-pixel_y * camera.tan_half_fov(CameraFOV.VERTICAL),
-torch.ones_like(pixel_x)), dim=-1)

ray_dir = ray_dir.reshape(-1, 3) # Flatten grid rays to 1D array
ray_orig = torch.zeros_like(ray_dir)

# Transform from camera to world coordinates
ray_orig, ray_dir = camera.extrinsics.inv_transform_rays(ray_orig, ray_dir)
ray_dir /= torch.linalg.norm(ray_dir, dim=-1, keepdim=True)
ray_orig, ray_dir = ray_orig[0], ray_dir[0] # Assume a single camera

return ray_orig, ray_dir, camera.near, camera.far

from kaolin.render.camera import Camera, \
generate_rays, generate_pinhole_rays, \
generate_centered_pixel_coords, generate_centered_custom_resolution_pixel_coords

camera = Camera.from_args(
eye=torch.tensor([4.0, 4.0, 4.0]),
Expand All @@ -58,14 +21,41 @@ def generate_perspective_rays(camera: Camera, pixel_grid: Tuple[torch.Tensor, to
device='cuda'
)

pixel_grid = generate_pixel_grid(200, 200)
ray_orig, ray_dir, near, far = generate_perspective_rays(camera, pixel_grid)
# General raygen functiontional version -- will invoke raygen according to the camera lens type
ray_orig, ray_dir = generate_rays(camera)
print(f'Created a ray grid of dimensions: {ray_orig.shape}')
print('Ray origins:')
print(ray_orig)
print('Ray directions:')
print(ray_dir)
print('\n')

# General raygen function OOP version -- can also be invoked directly on the camera object
ray_orig, ray_dir = camera.generate_rays()
print(f'Created a ray grid of dimensions: {ray_orig.shape}')
print('Ray origins:')
print(ray_orig)
print('Ray directions:')
print(ray_dir)
print('\n')

# A specific raygen function can also be invoked directly. You may also add your own custom raygen functions that way
ray_orig, ray_dir = generate_pinhole_rays(camera)
print(f'Created a ray grid of dimensions: {ray_orig.shape}')
print('Ray origins:')
print(ray_orig)
print('Ray directions:')
print(ray_dir)
print('\n')

# By using a custom grid input, other effects like lower resolution images can be supported
height = 200
width = 400
pixel_grid = generate_centered_custom_resolution_pixel_coords(camera.width, camera.height, width, height, camera.device)
ray_orig, ray_dir = generate_pinhole_rays(camera, pixel_grid)
print(f'Created a ray grid of different dimensions from camera image plane resolution: {ray_orig.shape}')
print('Ray origins:')
print(ray_orig)
print('Ray directions:')
print(ray_dir)
print('Near clipping plane:')
print(near)
print('Far clipping plane:')
print(far)
print('\n')
8 changes: 5 additions & 3 deletions examples/tutorial/sg_specular_lighting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"import math\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from kaolin.render.rays import generate_pinhole_rays_dir \n",
"from kaolin.render.camera import generate_pinhole_rays \n",
"\n",
"def disp_imgs(imgs, title=None):\n",
" \"\"\"scatter images plotting\"\"\"\n",
Expand Down Expand Up @@ -170,7 +170,9 @@
"# Compute the rays\n",
"rays_d = []\n",
"for cam in cams:\n",
" rays_d.append(generate_pinhole_rays_dir(cam).squeeze(0))\n",
" _, per_cam_to_ray_d = generate_pinhole_rays(cam)\n",
" per_cam_to_ray_d = per_cam_to_ray_d.reshape(cam.height, cam.width, 3)\n",
" rays_d.append(per_cam_to_ray_d)\n",
"# Rays must be toward the camera\n",
"rays_d = -torch.stack(rays_d, dim=0)"
]
Expand Down Expand Up @@ -426,4 +428,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
1 change: 0 additions & 1 deletion kaolin/render/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@
from . import lighting
from . import materials
from . import mesh
from . import rays
from . import spc
1 change: 1 addition & 0 deletions kaolin/render/camera/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .intrinsics_ortho import *
from .coordinates import *
from .legacy import *
from .raygen import *

__all__ = [k for k in locals().keys() if not k.startswith('__')]
23 changes: 22 additions & 1 deletion kaolin/render/camera/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from copy import deepcopy
import torch
import inspect
from typing import Sequence, List, Dict, Union, Tuple, Type, FrozenSet, Callable
from typing import Sequence, List, Dict, Union, Tuple, Type, FrozenSet, Callable, Optional
from torch.types import _float, _bool
from .extrinsics import CameraExtrinsics, ExtrinsicsParamsDefEnum
from .intrinsics import CameraIntrinsics, IntrinsicsParamsDefEnum
from .intrinsics_ortho import OrthographicIntrinsics
from .intrinsics_pinhole import PinholeIntrinsics


__all__ = [
'Camera',
'allclose'
Expand Down Expand Up @@ -543,6 +544,26 @@ def view_projection_matrix(self):
projection = self.intrinsics.projection_matrix()
return torch.bmm(projection, view)

def generate_rays(self, coords_grid: Optional[torch.Tensor] = None):
r"""Default ray generation function for kaolin cameras.
The camera lens type will determine the exact raygen logic that runs (i.e. pinhole, ortho..)
Args:
camera (kaolin.render.camera): The camera class.
coords_grid (optional, torch.FloatTensor):
Pixel grid of ray-intersecting coordinates of shape :math:`(\text{H, W, 2})`.
Coordinates integer parts represent the pixel (i,j) coords, and the fraction part of [0,1]
represents the location within the pixel itself.
For example, a coordinate of (0.5, 0.5) represents the center of the top-left pixel.
Returns:
(torch.FloatTensor, torch.FloatTensor):
The generated camera rays according to the camera lens type, as ray origins and ray direction tensors of
:math:`(\text{HxW, 3})`.
"""
from .raygen import generate_rays as raygen
return raygen(self, coords_grid)

@classmethod
def cat(cls, cameras: Sequence[Camera]):
"""Concatenate multiple Camera's.
Expand Down
Loading

0 comments on commit 0b0eeaf

Please sign in to comment.