Skip to content

Commit

Permalink
finishing add modelnet (#382)
Browse files Browse the repository at this point in the history
* finishing add modelnet

Signed-off-by: jiehanw <[email protected]>

change test path to absolute

Signed-off-by: jiehanw <[email protected]>

* rebase master and fix docstring

Signed-off-by: jiehanw <[email protected]>

* add kaolin.io.modelnet.res to toctree

Signed-off-by: jiehanw <[email protected]>

* add kaolin.io.modelnet.res to toctree

Signed-off-by: jiehanw <[email protected]>

* add face_color support for modelnet

Signed-off-by: jiehanw <[email protected]>

* change test data path

Signed-off-by: jiehanw <[email protected]>
  • Loading branch information
JerryJiehanWang authored Apr 28, 2021
1 parent 89fb808 commit 8f65c5f
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/modules/kaolin.io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ and :ref:`materials module<kaolin.io.materials>` contains Materials definition t
kaolin.io.render
kaolin.io.shapenet
kaolin.io.usd
kaolin.io.modelnet
1 change: 1 addition & 0 deletions kaolin/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from . import render
from . import shapenet
from . import usd
from . import modelnet
80 changes: 80 additions & 0 deletions kaolin/io/modelnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

from kaolin.io.dataset import KaolinDataset
from kaolin.io.off import import_mesh

import os

class ModelNet(KaolinDataset):
r"""Dataset class for the ModelNet dataset.
The `__getitem__` method will return a `KaolinDatasetItem`, with its data field
containing a namedtuple returned by :func:`kaolin.io.off.import_mesh`.
Args:
root (str): Path to the base directory of the ModelNet dataset.
split (str): Split to load ('train' vs 'test', default: 'train').
categories (list):
List of categories to load. If None is provided,
all categories will be loaded. (default: None).
"""
def __init__(self, root, categories=None, split='train'):
assert split in ['train', 'test'], f'Split must be either train or test ,but got {split}.'

self.root = Path(root)
self.paths = []
self.labels = []

if not os.path.exists(root):
raise ValueError(f'ModelNet was not found at "{root}.')

all_categories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]

# If categories is None, load all categories
if categories is None:
categories = all_categories

for idx, category in enumerate(categories):
assert category in all_categories, f'Object class {category} not in \
list of available classes: {all_categories}'

model_paths = sorted((self.root / category / split.lower()).glob('*'))

self.paths += model_paths
self.labels += [category] * len(model_paths)


self.names = [p.name for p in self.paths]

def __len__(self):
return len(self.paths)

def get_data(self, index):
obj_location = self.paths[index]
mesh = import_mesh(str(obj_location), with_face_colors=True)
return mesh

def get_attributes(self, index):
attributes = {
'name': self.names[index],
'path': self.paths[index],
'label': self.labels[index]
}
return attributes

def get_cache_key(self, index):
return self.names[index]
72 changes: 72 additions & 0 deletions tests/python/kaolin/io/test_modelnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import os

import pytest
import torch

from kaolin.io.off import return_type
from kaolin.io.modelnet import ModelNet

MODELNET_PATH = '/data/ModelNet'
MODELNET_TEST_CATEGORY_LABELS = ['bathtub']
MODELNET_TEST_CATEGORY_LABELS_2 = ['desk']
MODELNET_TEST_CATEGORY_LABELS_MULTI = ['bathtub', 'desk']

ALL_CATEGORIES = [
MODELNET_TEST_CATEGORY_LABELS,
MODELNET_TEST_CATEGORY_LABELS_2,
MODELNET_TEST_CATEGORY_LABELS_MULTI,
]

# Skip test in a CI environment
@pytest.mark.skipif(os.getenv('CI') == 'true', reason="CI does not have dataset")
@pytest.mark.parametrize('categories', ALL_CATEGORIES)
@pytest.mark.parametrize('split', ['train', 'test'])
@pytest.mark.parametrize('index', [0, -1])
class TestModelNet(object):

@pytest.fixture(autouse=True)
def modelnet_dataset(self, categories, split):
return ModelNet(root=MODELNET_PATH,
categories=categories,
split=split)

def test_basic_getitem(self, modelnet_dataset, index):
assert len(modelnet_dataset) > 0

if index == -1:
index = len(modelnet_dataset) - 1

item = modelnet_dataset[index]
data = item.data
attributes = item.attributes
assert isinstance(data, return_type)
assert isinstance(attributes, dict)

assert isinstance(data.vertices, torch.Tensor)
assert len(data.vertices.shape) == 2
assert data.vertices.shape[1] == 3
assert isinstance(data.faces, torch.Tensor)
assert len(data.faces.shape) == 2

assert isinstance(attributes['name'], str)
assert isinstance(attributes['path'], Path)
assert isinstance(attributes['label'], str)

assert isinstance(data.face_colors, torch.Tensor)

0 comments on commit 8f65c5f

Please sign in to comment.