Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify SPC query API #442

Merged
merged 3 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions kaolin/csrc/ops/spc/spc_query.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ __global__ void spc_query_kernel(

at::Tensor spc_query(
at::Tensor octree,
at::Tensor points,
at::Tensor pyramid,
at::Tensor prefixsum,
at::Tensor query_points,
uint targetLevel) {
Expand Down
2 changes: 0 additions & 2 deletions kaolin/csrc/ops/spc/spc_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

at::Tensor spc_query(
at::Tensor octree,
at::Tensor points,
at::Tensor pyramid,
at::Tensor prefixsum,
at::Tensor query_points,
uint targetLevel);
Expand Down
14 changes: 4 additions & 10 deletions kaolin/ops/spc/spc.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,27 +239,21 @@ def feature_grids_to_spc(feature_grids, masks=None):
coalescent_features = torch.cat(coalescent_features, dim=0)
return octrees, lengths, coalescent_features

def unbatched_query(octree, point_hierarchy, pyramid, exsum, query_points, level):
def unbatched_query(octree, exsum, query_points, level):
r"""Query point indices from the octree.

Given a point hierarchy, this function will efficiently find the corresponding indices of the
points in the points tensor. For each input in query_points, returns a index to the points tensor.
Returns -1 if the point does not exist.

Args:
octree (torch.ByteTensor): The octree, of shape :math:`(\text{num_bytes})`.
point_hierarchy (torch.ShortTensor):
The points hierarchy, of shape :math:`(\text{num_points}, 3)`.
See :ref:`spc_points` for more details.
pyramid (torch.IntTensor): The pyramid info of the point hierarchy,
of shape :math:`(2, \text{max_level} + 2)`.
See :ref:`spc_pyramids` for more details.
exsum (torch.IntTensor): The exclusive sum of the octree bytes,
of shape :math:`(\text{num_bytes} + 1)`.
See :ref:`spc_pyramids` for more details.
query_points (torch.ShortTensor): A collection of query indices,
of shape :math:`(\text{num_query}, 3)`.
level (int): The level of the octree to query from.
"""
return _C.ops.spc.spc_query(octree.contiguous(), point_hierarchy.contiguous(),
pyramid.contiguous(), exsum.contiguous(),
query_points.contiguous(), level)
return _C.ops.spc.spc_query(octree.contiguous(), exsum.contiguous(),
query_points.contiguous(), level).long()
30 changes: 30 additions & 0 deletions tests/python/kaolin/ops/spc/test_spc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from kaolin.rep import Spc

from kaolin.ops.spc import scan_octrees, generate_points, to_dense, feature_grids_to_spc
from kaolin.ops.spc import unbatched_query, unbatched_points_to_octree

from kaolin.utils.testing import FLOAT_TYPES, with_seed, check_tensor

Expand Down Expand Up @@ -158,6 +159,35 @@ def test_generate_points(self, octrees, lengths, max_level):
dim=0).cuda().short()
assert torch.equal(point_hierarchies, expected_point_hierarchies)

class TestQuery:
def test_query(self):
points = torch.tensor(
[[3,2,0],
[3,1,1],
[0,0,0],
[3,3,3]], device='cuda', dtype=torch.short)
octree = unbatched_points_to_octree(points, 2)
length = torch.tensor([len(octree)], dtype=torch.int32)
_, pyramid, prefix = scan_octrees(octree, length)

query_points = torch.tensor(
[[3,2,0],
[3,1,1],
[0,0,0],
[3,3,3],
[2,2,2],
[1,1,1]], device='cuda', dtype=torch.short)

point_hierarchy = generate_points(octree, pyramid, prefix)

results = unbatched_query(octree, prefix, query_points, 2)

expected_results = torch.tensor(
[7,6,5,8,-1,-1], dtype=torch.long, device='cuda')

assert torch.equal(point_hierarchy[results[:-2]], query_points[:-2])
assert torch.equal(expected_results, results)

class TestToDense:
@pytest.mark.parametrize('with_spc_to_dict', [False, True])
def test_simple(self, with_spc_to_dict):
Expand Down