-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bad1d42
commit 03035fb
Showing
61 changed files
with
4,425 additions
and
0 deletions.
There are no files selected for viewing
5 changes: 5 additions & 0 deletions
5
Any2Point_clip_lang_modelnet/cfgs/dataset_configs/ModelNet40.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
NAME: ModelNet | ||
DATA_PATH: ./data/modelnet40_normal_resampled | ||
N_POINTS: 8192 | ||
NUM_CATEGORY: 40 | ||
USE_NORMALS: FALSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
optimizer : { | ||
type: AdamW, | ||
kwargs: { | ||
lr : 0.0005, | ||
weight_decay : 0.05 | ||
}} | ||
|
||
scheduler: { | ||
type: CosLR, | ||
kwargs: { | ||
epochs: 300, | ||
initial_epochs : 10 | ||
}} | ||
|
||
dataset : { | ||
train : { _base_: cfgs/dataset_configs/ModelNet40.yaml, | ||
others: {subset: 'train'}}, | ||
val : { _base_: cfgs/dataset_configs/ModelNet40.yaml, | ||
others: {subset: 'test'}}, | ||
test : { _base_: cfgs/dataset_configs/ModelNet40.yaml, | ||
others: {subset: 'test'}}} | ||
model : { | ||
NAME: PointTransformer, | ||
trans_dim: 768, | ||
depth: 12, | ||
drop_path_rate: 0.1, | ||
cls_dim: 40, | ||
num_heads: 12, | ||
group_size: 32, | ||
num_group: 64, | ||
encoder_dims: 768, | ||
adapter_dim: 16, | ||
drop_rate_adapter: 0.1, | ||
patchknn: 64, | ||
attn1d_dim: 12 | ||
} | ||
|
||
|
||
npoints: 1024 | ||
total_bs : 32 | ||
step_per_update : 1 | ||
max_epoch : 300 | ||
grad_norm_clip : 10 |
149 changes: 149 additions & 0 deletions
149
Any2Point_clip_lang_modelnet/datasets/ModelNetDataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
''' | ||
@author: Xu Yan | ||
@file: ModelNet.py | ||
@time: 2021/3/19 15:51 | ||
''' | ||
import os | ||
import numpy as np | ||
import warnings | ||
import pickle | ||
|
||
from tqdm import tqdm | ||
from torch.utils.data import Dataset | ||
from .build import DATASETS | ||
from utils.logger import * | ||
import torch | ||
|
||
warnings.filterwarnings('ignore') | ||
|
||
|
||
def pc_normalize(pc): | ||
centroid = np.mean(pc, axis=0) | ||
pc = pc - centroid | ||
m = np.max(np.sqrt(np.sum(pc**2, axis=1))) | ||
pc = pc / m | ||
return pc | ||
|
||
|
||
|
||
def farthest_point_sample(point, npoint): | ||
""" | ||
Input: | ||
xyz: pointcloud data, [N, D] | ||
npoint: number of samples | ||
Return: | ||
centroids: sampled pointcloud index, [npoint, D] | ||
""" | ||
N, D = point.shape | ||
xyz = point[:,:3] | ||
centroids = np.zeros((npoint,)) | ||
distance = np.ones((N,)) * 1e10 | ||
farthest = np.random.randint(0, N) | ||
for i in range(npoint): | ||
centroids[i] = farthest | ||
centroid = xyz[farthest, :] | ||
dist = np.sum((xyz - centroid) ** 2, -1) | ||
mask = dist < distance | ||
distance[mask] = dist[mask] | ||
farthest = np.argmax(distance, -1) | ||
point = point[centroids.astype(np.int32)] | ||
return point | ||
|
||
@DATASETS.register_module() | ||
class ModelNet(Dataset): | ||
def __init__(self, config): | ||
self.root = config.DATA_PATH | ||
self.npoints = config.N_POINTS | ||
self.use_normals = config.USE_NORMALS | ||
self.num_category = config.NUM_CATEGORY | ||
self.process_data = True | ||
self.uniform = True | ||
split = config.subset | ||
self.subset = config.subset | ||
|
||
if self.num_category == 10: | ||
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt') | ||
else: | ||
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') | ||
|
||
self.cat = [line.rstrip() for line in open(self.catfile)] | ||
self.classes = dict(zip(self.cat, range(len(self.cat)))) | ||
|
||
shape_ids = {} | ||
if self.num_category == 10: | ||
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))] | ||
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))] | ||
else: | ||
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] | ||
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] | ||
|
||
assert (split == 'train' or split == 'test') | ||
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] | ||
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i | ||
in range(len(shape_ids[split]))] | ||
print_log('The size of %s data is %d' % (split, len(self.datapath)), logger = 'ModelNet') | ||
|
||
if self.uniform: | ||
self.save_path = os.path.join(self.root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints)) | ||
else: | ||
self.save_path = os.path.join(self.root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints)) | ||
|
||
if self.process_data: | ||
if not os.path.exists(self.save_path): | ||
print_log('Processing data %s (only running in the first time)...' % self.save_path, logger = 'ModelNet') | ||
self.list_of_points = [None] * len(self.datapath) | ||
self.list_of_labels = [None] * len(self.datapath) | ||
|
||
for index in tqdm(range(len(self.datapath)), total=len(self.datapath)): | ||
fn = self.datapath[index] | ||
cls = self.classes[self.datapath[index][0]] | ||
cls = np.array([cls]).astype(np.int32) | ||
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) | ||
|
||
if self.uniform: | ||
point_set = farthest_point_sample(point_set, self.npoints) | ||
else: | ||
point_set = point_set[0:self.npoints, :] | ||
|
||
self.list_of_points[index] = point_set | ||
self.list_of_labels[index] = cls | ||
|
||
with open(self.save_path, 'wb') as f: | ||
pickle.dump([self.list_of_points, self.list_of_labels], f) | ||
else: | ||
print_log('Load processed data from %s...' % self.save_path, logger = 'ModelNet') | ||
with open(self.save_path, 'rb') as f: | ||
self.list_of_points, self.list_of_labels = pickle.load(f) | ||
|
||
def __len__(self): | ||
return len(self.datapath) | ||
|
||
def _get_item(self, index): | ||
if self.process_data: | ||
point_set, label = self.list_of_points[index], self.list_of_labels[index] | ||
else: | ||
fn = self.datapath[index] | ||
cls = self.classes[self.datapath[index][0]] | ||
label = np.array([cls]).astype(np.int32) | ||
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) | ||
|
||
if self.uniform: | ||
point_set = farthest_point_sample(point_set, self.npoints) | ||
else: | ||
point_set = point_set[0:self.npoints, :] | ||
|
||
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) | ||
if not self.use_normals: | ||
point_set = point_set[:, 0:3] | ||
|
||
return point_set, label[0] | ||
|
||
|
||
def __getitem__(self, index): | ||
points, label = self._get_item(index) | ||
pt_idxs = np.arange(0, points.shape[0]) # 2048 | ||
if self.subset == 'train': | ||
np.random.shuffle(pt_idxs) | ||
current_points = points[pt_idxs].copy() | ||
current_points = torch.from_numpy(current_points).float() | ||
return 'ModelNet', 'sample', (current_points, label) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .build import build_dataset_from_cfg | ||
import datasets.ModelNetDataset |
Binary file added
BIN
+5.15 KB
Any2Point_clip_lang_modelnet/datasets/__pycache__/ModelNetDataset.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+261 Bytes
Any2Point_clip_lang_modelnet/datasets/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+559 Bytes
Any2Point_clip_lang_modelnet/datasets/__pycache__/build.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+5.22 KB
Any2Point_clip_lang_modelnet/datasets/__pycache__/data_transforms.cpython-37.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from utils import registry | ||
|
||
|
||
DATASETS = registry.Registry('dataset') | ||
|
||
|
||
def build_dataset_from_cfg(cfg, default_args = None): | ||
""" | ||
Build a dataset, defined by `dataset_name`. | ||
Args: | ||
cfg (eDICT): | ||
Returns: | ||
Dataset: a constructed dataset specified by dataset_name. | ||
""" | ||
return DATASETS.build(cfg, default_args = default_args) | ||
|
||
|
117 changes: 117 additions & 0 deletions
117
Any2Point_clip_lang_modelnet/datasets/data_transforms.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import numpy as np | ||
import torch | ||
import random | ||
|
||
|
||
class PointcloudRotate(object): | ||
def __call__(self, pc): | ||
bsize = pc.size()[0] | ||
for i in range(bsize): | ||
rotation_angle = np.random.uniform() * 2 * np.pi | ||
cosval = np.cos(rotation_angle) | ||
sinval = np.sin(rotation_angle) | ||
rotation_matrix = np.array([[cosval, 0, sinval], | ||
[0, 1, 0], | ||
[-sinval, 0, cosval]]) | ||
R = torch.from_numpy(rotation_matrix.astype(np.float32)).to(pc.device) | ||
pc[i, :, :] = torch.matmul(pc[i], R) | ||
return pc | ||
|
||
class PointcloudScaleAndTranslate(object): | ||
def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2): | ||
self.scale_low = scale_low | ||
self.scale_high = scale_high | ||
self.translate_range = translate_range | ||
|
||
def __call__(self, pc): | ||
bsize = pc.size()[0] | ||
for i in range(bsize): | ||
xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) | ||
xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) | ||
|
||
pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda() | ||
|
||
return pc | ||
|
||
class PointcloudJitter(object): | ||
def __init__(self, std=0.01, clip=0.05): | ||
self.std, self.clip = std, clip | ||
|
||
def __call__(self, pc): | ||
bsize = pc.size()[0] | ||
for i in range(bsize): | ||
jittered_data = pc.new(pc.size(1), 3).normal_( | ||
mean=0.0, std=self.std | ||
).clamp_(-self.clip, self.clip) | ||
pc[i, :, 0:3] += jittered_data | ||
|
||
return pc | ||
|
||
class PointcloudScale(object): | ||
def __init__(self, scale_low=2. / 3., scale_high=3. / 2.): | ||
self.scale_low = scale_low | ||
self.scale_high = scale_high | ||
|
||
def __call__(self, pc): | ||
bsize = pc.size()[0] | ||
for i in range(bsize): | ||
xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) | ||
|
||
pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) | ||
|
||
return pc | ||
|
||
class PointcloudTranslate(object): | ||
def __init__(self, translate_range=0.2): | ||
self.translate_range = translate_range | ||
|
||
def __call__(self, pc): | ||
bsize = pc.size()[0] | ||
for i in range(bsize): | ||
xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) | ||
|
||
pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda() | ||
|
||
return pc | ||
|
||
|
||
class PointcloudRandomInputDropout(object): | ||
def __init__(self, max_dropout_ratio=0.5): | ||
assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 | ||
self.max_dropout_ratio = max_dropout_ratio | ||
|
||
def __call__(self, pc): | ||
bsize = pc.size()[0] | ||
for i in range(bsize): | ||
dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 | ||
drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0] | ||
if len(drop_idx) > 0: | ||
cur_pc = pc[i, :, :] | ||
cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point | ||
pc[i, :, :] = cur_pc | ||
|
||
return pc | ||
|
||
class RandomHorizontalFlip(object): | ||
|
||
|
||
def __init__(self, upright_axis = 'z', is_temporal=False): | ||
""" | ||
upright_axis: axis index among x,y,z, i.e. 2 for z | ||
""" | ||
self.is_temporal = is_temporal | ||
self.D = 4 if is_temporal else 3 | ||
self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()] | ||
# Use the rest of axes for flipping. | ||
self.horz_axes = set(range(self.D)) - set([self.upright_axis]) | ||
|
||
|
||
def __call__(self, coords): | ||
bsize = coords.size()[0] | ||
for i in range(bsize): | ||
if random.random() < 0.95: | ||
for curr_ax in self.horz_axes: | ||
if random.random() < 0.5: | ||
coord_max = torch.max(coords[i, :, curr_ax]) | ||
coords[i, :, curr_ax] = coord_max - coords[i, :, curr_ax] | ||
return coords |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import h5py | ||
import numpy as np | ||
import os | ||
|
||
class IO: | ||
@classmethod | ||
def get(cls, file_path): | ||
_, file_extension = os.path.splitext(file_path) | ||
|
||
if file_extension in ['.npy']: | ||
return cls._read_npy(file_path) | ||
# elif file_extension in ['.pcd']: | ||
# return cls._read_pcd(file_path) | ||
elif file_extension in ['.h5']: | ||
return cls._read_h5(file_path) | ||
elif file_extension in ['.txt']: | ||
return cls._read_txt(file_path) | ||
else: | ||
raise Exception('Unsupported file extension: %s' % file_extension) | ||
|
||
# References: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py | ||
@classmethod | ||
def _read_npy(cls, file_path): | ||
return np.load(file_path) | ||
|
||
# References: https://github.com/dimatura/pypcd/blob/master/pypcd/pypcd.py#L275 | ||
# Support PCD files without compression ONLY! | ||
# @classmethod | ||
# def _read_pcd(cls, file_path): | ||
# pc = open3d.io.read_point_cloud(file_path) | ||
# ptcloud = np.array(pc.points) | ||
# return ptcloud | ||
|
||
@classmethod | ||
def _read_txt(cls, file_path): | ||
return np.loadtxt(file_path) | ||
|
||
@classmethod | ||
def _read_h5(cls, file_path): | ||
f = h5py.File(file_path, 'r') | ||
return f['data'][()] |
Oops, something went wrong.