diff --git a/docs/README.md b/docs/README.md
index afc0c936d..7811a8ddb 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -2,6 +2,10 @@
This guide is for developers who write API documentation. To build the documentation, run
+`pip install -r tools/doc_requirements.txt` to install the dependencies for documentation.
+
+Then, run
+
```make html``` on Linux
```make.bat html``` on Windows
diff --git a/docs/kaolin_ext.py b/docs/kaolin_ext.py
index 5d52180ca..2f6de91a9 100644
--- a/docs/kaolin_ext.py
+++ b/docs/kaolin_ext.py
@@ -1,5 +1,5 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
-#
+#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@@ -36,6 +36,7 @@ def run_apidoc(_):
"kaolin/ops/conversions/tetmesh.py",
"kaolin/ops/mesh/check_sign.py",
"kaolin/ops/mesh/mesh.py",
+ "kaolin/ops/mesh/tetmesh.py",
"kaolin/ops/mesh/trianglemesh.py",
"kaolin/ops/spc/spc.py",
"kaolin/ops/spc/convolution.py",
diff --git a/docs/modules/kaolin.metrics.rst b/docs/modules/kaolin.metrics.rst
index 620e8cdaf..76c84b1c1 100644
--- a/docs/modules/kaolin.metrics.rst
+++ b/docs/modules/kaolin.metrics.rst
@@ -7,6 +7,7 @@ Metrics are differentiable operators that can be used to compute loss or accurac
We currently provide an IoU for voxelgrid, sided distance based metrics such as chamfer distance,
point_to_mesh_distance and other simple regularization such as uniform_laplacian_smoothing.
+For tetrahedral mesh, we support the equivolume and AMIPS losses.
.. toctree::
:maxdepth: 2
@@ -16,3 +17,4 @@ point_to_mesh_distance and other simple regularization such as uniform_laplacian
kaolin.metrics.render
kaolin.metrics.trianglemesh
kaolin.metrics.voxelgrid
+ kaolin.metrics.tetmesh
diff --git a/docs/modules/kaolin.metrics.tetmesh.rst b/docs/modules/kaolin.metrics.tetmesh.rst
new file mode 100644
index 000000000..81c9000cc
--- /dev/null
+++ b/docs/modules/kaolin.metrics.tetmesh.rst
@@ -0,0 +1,12 @@
+.. _kaolin.metrics.tetmesh:
+
+kaolin.metrics.tetmesh
+======================
+
+API
+---
+
+.. automodule:: kaolin.metrics.tetmesh
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/modules/kaolin.ops.mesh.rst b/docs/modules/kaolin.ops.mesh.rst
index 68656c93b..85e678e07 100644
--- a/docs/modules/kaolin.ops.mesh.rst
+++ b/docs/modules/kaolin.ops.mesh.rst
@@ -1,7 +1,33 @@
.. _kaolin.ops.mesh:
kaolin.ops.mesh
-===============
+***********************
+
+A mesh is a 3D object representation consisting of a collection of vertices and polygons.
+
+Triangular meshes
+==================
+
+Triangular meshes comprise of a set of triangles that are connected by their common edges or corners. In Kaolin, they are usually represented as a set of two tensors:
+
+* ``vertices``: A :class:`torch.Tensor`, of shape :math:`(\text{batch_size}, \text{num_vertices}, 3)`, contains the vertices coordinates.
+
+* ``faces``: A :class:`torch.LongTensor`, of shape :math:`(\text{batch_size}, \text{num_faces}, 3)`, contains the mesh topology, by listing the vertices index for each face.
+
+Both tensors can be combined using :func:`kaolin.ops.mesh.index_vertices_by_faces`, to form ``face_vertices``, of shape :math:`(\text{batch_size}, \text{num_faces}, 3, 3)`, listing the vertices coordinate for each face.
+
+
+Tetrahedral meshes
+==================
+
+A tetrahedron or triangular pyramid is a polyhedron composed of four triangular faces, six straight edges, and four vertex corners. Tetrahedral meshes inside Kaolin are composed of two tensors:
+
+* ``vertices``: A :class:`torch.Tensor`, of shape :math:`(\text{batch_size}, \text{num_vertices}, 3)`, contains the vertices coordinates.
+
+* ``tet``: A :class:`torch.LongTensor`, of shape :math:`(\text{batch_size}, \text{num_tet}, 4)`, contains the tetrahedral mesh topology, by listing the vertices index for each tetrahedron.
+
+Both tensors can be combined, to form ``tet_vertices``, of shape :math:`(\text{batch_size}, \text{num_tet}, 4, 3)`, listing the tetrahedrons vertices coordinates for each face.
+
API
---
diff --git a/kaolin/metrics/__init__.py b/kaolin/metrics/__init__.py
index 6353fcb4b..b40c88c4b 100644
--- a/kaolin/metrics/__init__.py
+++ b/kaolin/metrics/__init__.py
@@ -2,3 +2,4 @@
from . import trianglemesh
from . import pointcloud
from . import render
+from . import tetmesh
diff --git a/kaolin/metrics/tetmesh.py b/kaolin/metrics/tetmesh.py
new file mode 100644
index 000000000..f250f01b5
--- /dev/null
+++ b/kaolin/metrics/tetmesh.py
@@ -0,0 +1,192 @@
+# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from kaolin.ops.mesh.tetmesh import _validate_tet_vertices
+
+
+def tetrahedron_volume(tet_vertices):
+ r"""Compute the volume of tetrahedrons.
+
+ Args:
+ tet_vertices (torch.Tensor):
+ Batched tetrahedrons, of shape
+ :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
+ Returns:
+ (torch.Tensor):
+ volume of each tetrahedron in each mesh, of shape
+ :math:`(\\text{batch_size}, \\text{num_tetrahedrons})`.
+
+ Example:
+ >>> tet_vertices = torch.tensor([[[[0.5000, 0.5000, 0.4500],
+ ... [0.4500, 0.5000, 0.5000],
+ ... [0.4750, 0.4500, 0.4500],
+ ... [0.5000, 0.5000, 0.5000]]]])
+ >>> tetrahedron_volume(tet_vertices)
+ tensor([[2.0833e-05]])
+ """
+ _validate_tet_vertices(tet_vertices)
+
+ # split the tensor
+ A, B, C, D = [split.squeeze(2) for split in
+ torch.split(tet_vertices, split_size_or_sections=1, dim=2)]
+
+ # compute the volume of each tetrahedron directly by using V = |(a - d) * ((b - d) x (c - d))| / 6
+ volumes = torch.div(
+ ((A - D) * torch.cross(input=(B - D), other=(C - D), dim=2)).sum(dim=2), 6)
+
+ return volumes
+
+def equivolume(tet_vertices, tetrahedrons_mean=None, pow=4):
+ r"""Compute the EquiVolume loss as devised by *Gao et al.* in `Learning Deformable Tetrahedral Meshes for 3D
+ Reconstruction `_ NeurIPS 2020.
+ See `supplementary material `_ for the definition of the loss function.
+
+ Args:
+ tet_vertices (torch.Tensor):
+ Batched tetrahedrons, of shape
+ :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
+ tetrahedrons_mean (torch.Tensor):
+ Mean volume of all tetrahedrons in a grid, of shape :math:`(1, 1)`.
+ pow (int):
+ Power for the equivolume loss.
+ Increasing power puts more emphasis on the larger tetrahedron deformation.
+ Default: 4.
+
+ Returns:
+ (torch.Tensor):
+ EquiVolume loss for each mesh, of shape :math:`(\\text{batch_size})`.
+
+ Example:
+ >>> tet_vertices = torch.tensor([[[[0.5000, 0.5000, 0.7500],
+ ... [0.4500, 0.8000, 0.6000],
+ ... [0.4750, 0.4500, 0.2500],
+ ... [0.5000, 0.3000, 0.3000]],
+ ... [[0.4750, 0.4500, 0.2500],
+ ... [0.5000, 0.9000, 0.3000],
+ ... [0.4500, 0.4000, 0.9000],
+ ... [0.4500, 0.4500, 0.7000]]],
+ ... [[[0.7000, 0.3000, 0.4500],
+ ... [0.4800, 0.2000, 0.3000],
+ ... [0.9000, 0.4500, 0.4500],
+ ... [0.2000, 0.5000, 0.1000]],
+ ... [[0.3750, 0.4500, 0.2500],
+ ... [0.9000, 0.8000, 0.7000],
+ ... [0.6000, 0.9000, 0.3000],
+ ... [0.5500, 0.3500, 0.9000]]]])
+ >>> equivolume(tet_vertices, pow=4)
+ tensor([[2.2898e-15],
+ [2.9661e-10]])
+ """
+ _validate_tet_vertices(tet_vertices)
+
+ # compute the volume of each tetrahedron
+ volumes = tetrahedron_volume(tet_vertices)
+
+ if tetrahedrons_mean is None:
+ # finding the mean volume of all tetrahedrons in the tetrahedron grid
+ tetrahedrons_mean = torch.mean(volumes, dim=-1, keepdim=True)
+
+ # compute EquiVolume loss
+ equivolume_loss = torch.mean(torch.pow(
+ torch.abs(volumes - tetrahedrons_mean), exponent=pow),
+ dim=-1, keepdim=True)
+
+ return equivolume_loss
+
+
+def amips(tet_vertices, inverse_offset_matrix):
+ r"""Compute the AMIPS (Advanced MIPS) loss as devised by *Fu et al.* in
+ `Computing Locally Injective Mappings by Advanced MIPS. \
+ `_
+ ACM Transactions on Graphics (TOG) - Proceedings of ACM SIGGRAPH 2015.
+
+ The Jacobian can be derived as: :math:`J = (g(x) - g(x_0)) / (x - x_0)`
+
+ Only components where the determinant of the Jacobian is positive, are included in the calculation of AMIPS.
+ This is because the AMIPS Loss is only defined for tetrahedrons whose determinant of the Jacobian is positive.
+
+ Args:
+ tet_vertices (torch.Tensor):
+ Batched tetrahedrons, of shape
+ :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
+ inverse_offset_matrix (torch.LongTensor): The inverse of the offset matrix is of shape
+ :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 3, 3)`.
+ Refer to :func:`kaolin.ops.mesh.tetmesh.inverse_vertices_offset`.
+ Returns:
+ (torch.Tensor):
+ AMIPS loss for each mesh, of shape :math:`(\\text{batch_size})`.
+
+ Example:
+ >>> tet_vertices = torch.tensor([[[[1.7000, 2.3000, 4.4500],
+ ... [3.4800, 0.2000, 5.3000],
+ ... [4.9000, 9.4500, 6.4500],
+ ... [6.2000, 8.5000, 7.1000]],
+ ... [[-1.3750, 1.4500, 3.2500],
+ ... [4.9000, 1.8000, 2.7000],
+ ... [3.6000, 1.9000, 2.3000],
+ ... [1.5500, 1.3500, 2.9000]]],
+ ... [[[1.7000, 2.3000, 4.4500],
+ ... [3.4800, 0.2000, 5.3000],
+ ... [4.9000, 9.4500, 6.4500],
+ ... [6.2000, 8.5000, 7.1000]],
+ ... [[-1.3750, 1.4500, 3.2500],
+ ... [4.9000, 1.8000, 2.7000],
+ ... [3.6000, 1.9000, 2.3000],
+ ... [1.5500, 1.3500, 2.9000]]]])
+ >>> inverse_offset_matrix = torch.tensor([[[[ -1.1561, -1.1512, -1.9049],
+ ... [1.5138, 1.0108, 3.4302],
+ ... [1.6538, 1.0346, 4.2223]],
+ ... [[ 2.9020, -1.0995, -1.8744],
+ ... [ 1.1554, 1.1519, 1.7780],
+ ... [-0.0766, 1.6350, 1.1064]]],
+ ... [[[-0.9969, 1.4321, -0.3075],
+ ... [-1.3414, 1.5795, -1.6571],
+ ... [-0.1775, -0.4349, 1.1772]],
+ ... [[-1.1077, -1.2441, 1.8037],
+ ... [-0.5722, 0.1755, -2.4364],
+ ... [-0.5263, 1.5765, 1.5607]]]])
+ >>> amips(tet_vertices, inverse_offset_matrix)
+ tensor([[13042.3408],
+ [ 2376.2517]])
+ """
+ _validate_tet_vertices(tet_vertices)
+
+ # split the tensor
+ A, B, C, D = torch.split(tet_vertices, split_size_or_sections=1, dim=2)
+
+ # compute the offset matrix of the tetrahedrons w.r.t. vertex A.
+ offset_matrix = torch.cat([B - A, C - A, D - A], dim=2)
+
+ # compute the Jacobian for each tetrahedron - the Jacobian represents the unique 3D deformation that transforms the
+ # tetrahedron t into a regular tetrahedron.
+ jacobian = torch.matmul(offset_matrix, inverse_offset_matrix)
+
+ # compute determinant of Jacobian
+ j_det = torch.det(jacobian)
+
+ # compute the trace of J * J.T
+ jacobian_squared = torch.matmul(jacobian, torch.transpose(jacobian, -2, -1))
+ trace = torch.diagonal(jacobian_squared, dim1=-2, dim2=-1).sum(-1)
+
+ # compute the determinant of the Jacobian to the 2/3
+ EPS = 1e-10
+ denominator = torch.pow(torch.pow(j_det, 2) + EPS, 1 / 3)
+
+ # compute amips energy for positive tetrahedrons whose determinant of their Jacobian is positive
+ amips_energy = torch.mean(torch.div(trace, denominator) * (j_det >= 0).float(),
+ dim=1, keepdim=True)
+
+ return amips_energy
diff --git a/kaolin/ops/mesh/__init__.py b/kaolin/ops/mesh/__init__.py
index 906c7af10..2136f9d8a 100644
--- a/kaolin/ops/mesh/__init__.py
+++ b/kaolin/ops/mesh/__init__.py
@@ -1,5 +1,6 @@
from .mesh import *
from .trianglemesh import *
from .check_sign import check_sign
+from .tetmesh import *
__all__ = [k for k in locals().keys() if not k.startswith('__')]
diff --git a/kaolin/ops/mesh/tetmesh.py b/kaolin/ops/mesh/tetmesh.py
new file mode 100644
index 000000000..ed682c7b9
--- /dev/null
+++ b/kaolin/ops/mesh/tetmesh.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+def _validate_tet_vertices(tet_vertices):
+ """Helper method to validate the dimensions of the batched tetrahedrons tensor.
+
+ Args:
+ tet_vertices (torch.Tensor):
+ Batched tetrahedrons, of shape
+ :math:`(\text{batch_size}, \text{num_tetrahedrons}, 4, 3)`.
+ """
+ assert tet_vertices.ndim == 4, \
+ f"tetrahedrons has {tetrahedrons.ndim} but must have 4 dimensions."
+ assert tet_vertices.shape[2] == 4, \
+ f"The third dimension of the tetrahedrons must be 4 " \
+ f"but the input has {tetrahedrons.shape[2]}. Each tetrahedron has 4 vertices."
+ assert tet_vertices.shape[3] == 3, \
+ f"The fourth dimension of the tetrahedrons must be 3 " \
+ f"but the input has {tetrahedrons.shape[3]}. Each vertex must have 3 dimensions."
+
+
+def inverse_vertices_offset(tet_vertices):
+ r"""Given tetrahedrons with 4 vertices A, B, C, D. Compute the inverse of the offset matrix w.r.t. vertex A for each
+ tetrahedron. The offset matrix is obtained by the concatenation of `B - A`, `C - A` and `D - A`. The resulting shape
+ of the offset matrix is :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 3, 3)`. The inverse of the offset matrix
+ is computed by this function.
+
+ Args:
+ tet_vertices (torch.Tensor):
+ Batched tetrahedrons, of shape
+ :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
+
+ Returns:
+ (torch.Tensor):
+ Batched inverse offset matrix, of shape
+ :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 3, 3)`.
+ Each offset matrix is of shape :math:`(3, 3)`,
+ hence its inverse is also of shape :math:`(3, 3)`.
+
+ Example:
+ >>> tet_vertices = torch.tensor([[[[-0.0500, 0.0000, 0.0500],
+ ... [-0.0250, -0.0500, 0.0000],
+ ... [ 0.0000, 0.0000, 0.0500],
+ ... [0.5000, 0.5000, 0.4500]]]])
+ >>> inverse_vertices_offset(tet_vertices)
+ tensor([[[[ 0.0000, 20.0000, 0.0000],
+ [ 79.9999, -149.9999, 10.0000],
+ [ -99.9999, 159.9998, -10.0000]]]])
+ """
+ _validate_tet_vertices(tet_vertices)
+
+ # split the tensor
+ A, B, C, D = torch.split(tet_vertices, split_size_or_sections=1, dim=2)
+
+ # compute the offset matrix w.r.t. vertex A
+ offset_matrix = torch.cat([B - A, C - A, D - A], dim=2)
+
+ # compute the inverse of the offset matrix
+ inverse_offset_matrix = torch.inverse(offset_matrix)
+
+ return inverse_offset_matrix
diff --git a/kaolin/render/mesh/deftet.py b/kaolin/render/mesh/deftet.py
index f07c784fe..11326a76d 100644
--- a/kaolin/render/mesh/deftet.py
+++ b/kaolin/render/mesh/deftet.py
@@ -264,7 +264,7 @@ def forward(ctx, pixel_coords, render_ranges, face_vertices_z,
sorted_w2 = (sorted_face_idx != -1).float() - (sorted_w0 + sorted_w1)
_idx = sorted_face_idx + 1
_idx = _idx.reshape(batch_size, -1, 1, 1).expand(
- batch_size, pixel_num * knum, 3, feat_dim)
+ batch_size, pixel_num * knum, 3, feat_dim)
selected_features = torch.gather(
torch.nn.functional.pad(face_features, (0, 0, 0, 0, 1, 0), value=0.), 1, _idx).reshape(
batch_size, pixel_num, knum, 3, feat_dim)
diff --git a/tests/python/kaolin/metrics/test_tetmesh.py b/tests/python/kaolin/metrics/test_tetmesh.py
new file mode 100644
index 000000000..e320e9ec8
--- /dev/null
+++ b/tests/python/kaolin/metrics/test_tetmesh.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from kaolin.metrics import tetmesh
+
+
+class TestTetMeshMetrics:
+
+ def test_tetrahedron_volume(self):
+ tetrahedrons = torch.tensor([[[[0.5000, 0.5000, 0.4500],
+ [0.4500, 0.5000, 0.5000],
+ [0.4750, 0.4500, 0.4500],
+ [0.5000, 0.5000, 0.5000]]]])
+ assert torch.allclose(tetmesh.tetrahedron_volume(tetrahedrons), torch.tensor([[-2.0833e-05]]))
+
+ def test_amips(self):
+ tetrahedrons = torch.tensor([[[
+ [1.7000, 2.3000, 4.4500],
+ [3.4800, 0.2000, 5.3000],
+ [4.9000, 9.4500, 6.4500],
+ [6.2000, 8.5000, 7.1000]],
+ [[-1.3750, 1.4500, 3.2500],
+ [4.9000, 1.8000, 2.7000],
+ [3.6000, 1.9000, 2.3000],
+ [1.5500, 1.3500, 2.9000]]],
+ [[[1.7000, 2.3000, 4.4500],
+ [3.4800, 0.2000, 5.3000],
+ [4.9000, 9.4500, 6.4500],
+ [6.2000, 8.5000, 7.1000]],
+ [[-1.3750, 1.4500, 3.2500],
+ [4.9000, 1.8000, 2.7000],
+ [3.6000, 1.9000, 2.3000],
+ [1.5500, 1.3500, 2.9000]]]])
+ inverse_offset_matrix = torch.tensor([[[[-1.1561, -1.1512, -1.9049],
+ [1.5138, 1.0108, 3.4302],
+ [1.6538, 1.0346, 4.2223]],
+ [[2.9020, -1.0995, -1.8744],
+ [1.1554, 1.1519, 1.7780],
+ [-0.0766, 1.6350, 1.1064]]],
+ [[[-0.9969, 1.4321, -0.3075],
+ [-1.3414, 1.5795, -1.6571],
+ [-0.1775, -0.4349, 1.1772]],
+ [[-1.1077, -1.2441, 1.8037],
+ [-0.5722, 0.1755, -2.4364],
+ [-0.5263, 1.5765, 1.5607]]]])
+ torch.allclose(tetmesh.amips(tetrahedrons, inverse_offset_matrix), torch.tensor([[13042.3408], [2376.2517]]))
+
+ def test_equivolume(self):
+ tetrahedrons = torch.tensor([[[[0.5000, 0.5000, 0.7500],
+ [0.4500, 0.8000, 0.6000],
+ [0.4750, 0.4500, 0.2500],
+ [0.5000, 0.3000, 0.3000]],
+ [[0.4750, 0.4500, 0.2500],
+ [0.5000, 0.9000, 0.3000],
+ [0.4500, 0.4000, 0.9000],
+ [0.4500, 0.4500, 0.7000]]],
+ [[[0.7000, 0.3000, 0.4500],
+ [0.4800, 0.2000, 0.3000],
+ [0.9000, 0.4500, 0.4500],
+ [0.2000, 0.5000, 0.1000]],
+ [[0.3750, 0.4500, 0.2500],
+ [0.9000, 0.8000, 0.7000],
+ [0.6000, 0.9000, 0.3000],
+ [0.5500, 0.3500, 0.9000]]]])
+ assert torch.allclose(tetmesh.equivolume(tetrahedrons, pow=4), torch.tensor([[2.2898e-15], [2.9661e-10]]))
diff --git a/tests/python/kaolin/ops/test_mesh.py b/tests/python/kaolin/ops/mesh/test_mesh.py
similarity index 99%
rename from tests/python/kaolin/ops/test_mesh.py
rename to tests/python/kaolin/ops/mesh/test_mesh.py
index 970307d98..bc89c48ac 100644
--- a/tests/python/kaolin/ops/test_mesh.py
+++ b/tests/python/kaolin/ops/mesh/test_mesh.py
@@ -25,7 +25,8 @@
from kaolin.ops.mesh.trianglemesh import _unbatched_subdivide_vertices
from kaolin.io import obj
-ROOT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../../samples/')
+ROOT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+ os.pardir, os.pardir, os.pardir, os.pardir, 'samples/')
@pytest.mark.parametrize("device,dtype", FLOAT_TYPES)
class TestFaceAreas:
diff --git a/tests/python/kaolin/ops/mesh/test_tetmesh.py b/tests/python/kaolin/ops/mesh/test_tetmesh.py
new file mode 100644
index 000000000..34bf56985
--- /dev/null
+++ b/tests/python/kaolin/ops/mesh/test_tetmesh.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import torch
+
+from kaolin.ops.mesh import tetmesh
+
+
+class TestTetMeshOps:
+
+ def test_validate_tetrahedrons_wrong_ndim(self):
+ wrong_ndim_tet = torch.randn(size=(2, 2))
+ with pytest.raises(Exception):
+ tetmesh._validate_tetrahedrons(wrong_ndim_tet)
+
+ def test_validate_tetrahedrons_wrong_third_dimension(self):
+ wrong_third_dim_tet = torch.randn(size=(2, 2, 3))
+ with pytest.raises(Exception):
+ tetmesh._validate_tetrahedrons(wrong_third_dim_tet)
+
+ def test_validate_tetrahedrons_wrong_fourth_dimension(self):
+ wrong_fourth_dim_tet = torch.randn(size=(2, 2, 4, 2))
+ with pytest.raises(Exception):
+ tetmesh._validate_tetrahedrons(wrong_fourth_dim_tet)
+
+ def test_inverse_vertices_offset(self):
+ tetrahedrons = torch.tensor([[[[-0.0500, 0.0000, 0.0500],
+ [-0.0250, -0.0500, 0.0000],
+ [0.0000, 0.0000, 0.0500],
+ [0.5000, 0.5000, 0.4500]]]])
+ oracle = torch.tensor([[[[0.0000, 20.0000, 0.0000],
+ [79.9999, -149.9999, 10.0000],
+ [-99.9999, 159.9998, -10.0000]]]])
+ torch.allclose(tetmesh.inverse_vertices_offset(tetrahedrons), oracle)