Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Jul 14, 2021
1 parent 51db981 commit 47da576
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 83 deletions.
5 changes: 2 additions & 3 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def _train_dataloader(self) -> DataLoader:
train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds
shuffle: bool = False
collate_fn = self._resolve_collate_fn(train_ds, RunningStage.TRAINING)
sampler = self.sampler
drop_last = False
pin_memory = True

Expand All @@ -292,14 +291,14 @@ def _train_dataloader(self) -> DataLoader:
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
sampler=sampler
sampler=self.sampler
)

return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=sampler,
sampler=self.sampler,
num_workers=self.num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
Expand Down
1 change: 1 addition & 0 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def per_batch_transform(self, batch: Any) -> Any:
def collate(self, samples: Sequence) -> Any:
""" Transform to convert a sequence of samples to a collated batch. """

# the model can provide a custom ``collate_fn``.
collate_fn = self.get_state(CollateFn)
if collate_fn is not None:
return collate_fn.collate_fn(samples)
Expand Down
2 changes: 2 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
# TODO: create enum values to define what are the exact states
self._data_pipeline_state: Optional[DataPipelineState] = None

# model own internal state shared with the data pipeline.
self._state: Dict[Type[ProcessState], ProcessState] = {}

# Explicitly set the serializer to call the setter
Expand Down Expand Up @@ -204,6 +205,7 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:

@staticmethod
def apply_filtering(y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function is used to filter some labels or predictions which aren't conform."""
return y, y_hat

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion flash/pointcloud/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.registry import FlashRegistry
from flash.pointcloud.segmentation.open3d_ml import register_open_3d_ml
from flash.pointcloud.segmentation.open3d_ml.backbones import register_open_3d_ml

POINTCLOUD_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

Expand Down
79 changes: 0 additions & 79 deletions flash/pointcloud/segmentation/open3d_ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,79 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable

import torch
from pytorch_lightning.utilities.cloud_io import load as pl_load

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE

ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/"


def register_open_3d_ml(register: FlashRegistry):
if _POINTCLOUD_AVAILABLE:
import open3d
import open3d.ml as _ml3d
from open3d.ml.torch.dataloaders import ConcatBatcher, DefaultBatcher
from open3d.ml.torch.models import RandLANet

CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs")

def get_collate_fn(model) -> Callable:
batcher_name = model.cfg.batcher
if batcher_name == 'DefaultBatcher':
batcher = DefaultBatcher()
elif batcher_name == 'ConcatBatcher':
batcher = ConcatBatcher(torch, model.__class__.__name__)
else:
batcher = None
return batcher.collate_fn

@register
def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_s3dis.yml"))
model = RandLANet(**cfg.model)
if use_fold_5:
weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_area5_202010091333utc.pth")
else:
weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_202010091238.pth")
model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict'])
return model, 32, get_collate_fn(model)

@register
def randlanet_toronto3d(*args, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml"))
model = RandLANet(**cfg.model)
model.load_state_dict(
pl_load(os.path.join(ROOT_URL, "randlanet_toronto3d_202010091306utc.pth"),
map_location='cpu')['model_state_dict'],
)
return model, 32, get_collate_fn(model)

@register
def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml"))
model = RandLANet(**cfg.model)
model.load_state_dict(
pl_load(os.path.join(ROOT_URL, "randlanet_semantickitti_202009090354utc.pth"),
map_location='cpu')['model_state_dict'],
)
return model, 32, get_collate_fn(model)

@register
def randlanet(*args, **kwargs) -> RandLANet:
model = RandLANet(*args, **kwargs)
return model, 32, get_collate_fn(model)
79 changes: 79 additions & 0 deletions flash/pointcloud/segmentation/open3d_ml/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable

import torch
from pytorch_lightning.utilities.cloud_io import load as pl_load

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE

ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/"


def register_open_3d_ml(register: FlashRegistry):
if _POINTCLOUD_AVAILABLE:
import open3d
import open3d.ml as _ml3d
from open3d.ml.torch.dataloaders import ConcatBatcher, DefaultBatcher
from open3d.ml.torch.models import RandLANet

CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs")

def get_collate_fn(model) -> Callable:
batcher_name = model.cfg.batcher
if batcher_name == 'DefaultBatcher':
batcher = DefaultBatcher()
elif batcher_name == 'ConcatBatcher':
batcher = ConcatBatcher(torch, model.__class__.__name__)
else:
batcher = None
return batcher.collate_fn

@register
def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_s3dis.yml"))
model = RandLANet(**cfg.model)
if use_fold_5:
weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_area5_202010091333utc.pth")
else:
weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_202010091238.pth")
model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict'])
return model, 32, get_collate_fn(model)

@register
def randlanet_toronto3d(*args, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml"))
model = RandLANet(**cfg.model)
model.load_state_dict(
pl_load(os.path.join(ROOT_URL, "randlanet_toronto3d_202010091306utc.pth"),
map_location='cpu')['model_state_dict'],
)
return model, 32, get_collate_fn(model)

@register
def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml"))
model = RandLANet(**cfg.model)
model.load_state_dict(
pl_load(os.path.join(ROOT_URL, "randlanet_semantickitti_202009090354utc.pth"),
map_location='cpu')['model_state_dict'],
)
return model, 32, get_collate_fn(model)

@register
def randlanet(*args, **kwargs) -> RandLANet:
model = RandLANet(*args, **kwargs)
return model, 32, get_collate_fn(model)

0 comments on commit 47da576

Please sign in to comment.