Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions warpconvnet/geometry/coords/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,25 +272,22 @@ def dtype(self):

def half(self):
if not self.is_initialized:
lazy_init = self._lazy_params.replace(dtype=torch.float16)
return GridCoords.from_tensor(
self.batched_tensor.half(), self.offsets, self.grid_shape, self.bounds, lazy_init
self.batched_tensor.half(), self.offsets, self.grid_shape, self.bounds
)
return super().half()

def float(self):
if not self.is_initialized:
lazy_init = self._lazy_params.replace(dtype=torch.float32)
return GridCoords.from_tensor(
self.batched_tensor.float(), self.offsets, self.grid_shape, self.bounds, lazy_init
self.batched_tensor.float(), self.offsets, self.grid_shape, self.bounds
)
return super().float()

def double(self):
if not self.is_initialized:
lazy_init = self._lazy_params.replace(dtype=torch.float64)
return GridCoords.from_tensor(
self.batched_tensor.double(), self.offsets, self.grid_shape, self.bounds, lazy_init
self.batched_tensor.double(), self.offsets, self.grid_shape, self.bounds
)
return super().double()

Expand Down
8 changes: 3 additions & 5 deletions warpconvnet/geometry/coords/ops/voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def voxel_downsample_csr_mapping(
batched_points: Float[Tensor, "N 3"], # noqa: F722,F821
offsets: Int[Tensor, "B + 1"], # noqa: F722,F821
voxel_size: float,
unique_method: Literal["torch", "ravel", "morton"] | None = None,
unique_method: Literal["torch", "ravel", "morton"] = "torch",
) -> Tuple[
Int[Tensor, "M 3"], # noqa: F821
Int[Tensor, "B+1"], # noqa: F821
Expand Down Expand Up @@ -221,10 +221,8 @@ def voxel_downsample_mapping(
down_batched_points = down_batched_points.int()

# Get the batch index
up_bcoords = batch_indexed_coordinates(up_batched_points, up_offsets, return_type="torch")
down_bcoords = batch_indexed_coordinates(
down_batched_points, down_offsets, return_type="torch"
)
up_bcoords = batch_indexed_coordinates(up_batched_points, up_offsets)
down_bcoords = batch_indexed_coordinates(down_batched_points, down_offsets)

down_table = TorchHashTable.from_keys(down_bcoords)
# Get the map that maps up_batched_points[up_map] ~= down_batched_points.
Expand Down
3 changes: 1 addition & 2 deletions warpconvnet/geometry/coords/search/torch_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,7 @@ def generate_kernel_map(
Generate the kernel map for the spatially sparse convolution using TorchHashTable.

in_to_out_stride_ratio: the ratio of the input stride to the output stride. This will be multiplied to output coordinates to find matching input coordinates.
method: 'query' directly queries the hash table for each offset point (can be slower for large kernels but flexible).
'offset' pre-calculates all kernel offsets and uses a custom kernel to find matches (generally faster).
method: 'offset' pre-calculates all kernel offsets and uses a custom kernel to find matches (generally faster).
'size' uses a specialized kernel for 4D coordinates if applicable, otherwise falls back to 'offset'.
skip_symmetric_kernel_map: If True, skip symmetric parts of the kernel map for odd-sized kernels.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _wmma_implicit_gemm_backward_logic(
min_dtype = _min_dtype(in_features.dtype, weight.dtype, grad_output.dtype)
if min_dtype not in [torch.float16, torch.bfloat16]:
# wmma not supported for data types other than float16 and bfloat16
return int(_C.gemm.GemmStatus.kErrorInvalidParameters)
return int(_C.gemm.GemmStatus.kErrorInvalidParameters), -1
_grad_output_detached = grad_output.contiguous().detach().to(dtype=min_dtype)
_in_features_detached = in_features.contiguous().detach().to(dtype=min_dtype)
_weight_detached = weight.contiguous().detach().to(dtype=min_dtype)
Expand Down
10 changes: 8 additions & 2 deletions warpconvnet/nn/functional/sparse_conv_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
generic_benchmark_update_entry,
mark_generic_benchmark_cache_dirty,
)
from warpconvnet.nn.functional.sparse_conv import _BENCHMARK_NUM_RUNS
from warpconvnet.nn.functional.sparse_conv.detail.unified import _BENCHMARK_NUM_RUNS
from warpconvnet.utils.benchmark_cache import SpatiallySparseConvConfig
from warpconvnet.utils.type_cast import _min_dtype, _max_dtype, _maybe_cast
from warpconvnet.utils.timer import CUDATimer
Expand Down Expand Up @@ -359,7 +359,7 @@ def _run_depthwise_forward_benchmarks(
warmup_iters = max(warmup_iters, 1)
benchmark_iters = max(benchmark_iters, 1)

logger.warn(
logger.warning(
"Using benchmarked depthwise forward algo. Until the algorithm finds the best parameters, forward performance will be slow."
)
all_benchmark_results: List[
Expand Down Expand Up @@ -681,6 +681,9 @@ def forward(
)
chosen_fwd_algo, chosen_fwd_params, _ = all_fwd_benchmark_results[0]

if isinstance(chosen_fwd_algo, str):
chosen_fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE(chosen_fwd_algo.lower())

# Step 5: Execute with optimal algorithm and parameters
if chosen_fwd_algo == SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.EXPLICIT:
output_feature_tensor = _explicit_depthwise_forward_logic(
Expand Down Expand Up @@ -834,6 +837,9 @@ def backward(ctx, grad_output: Float[Tensor, "M C"]) -> Tuple[
)
chosen_bwd_algo, chosen_bwd_params, _ = all_bwd_benchmark_results[0]

if isinstance(chosen_bwd_algo, str):
chosen_bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE(chosen_bwd_algo.lower())

# Step 5: Execute with optimal algorithm and parameters
if chosen_bwd_algo == SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.EXPLICIT:
grad_in_features, grad_weight = _explicit_depthwise_backward_logic(
Expand Down
3 changes: 2 additions & 1 deletion warpconvnet/nn/modules/factor_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from warpconvnet.geometry.types.points import Points
from warpconvnet.nn.functional.encodings import sinusoidal_encoding, get_freqs
from warpconvnet.nn.modules.base_module import BaseSpatialModule
from warpconvnet.nn.modules.sequential import Sequential
from warpconvnet.nn.functional.factor_grid import (
factor_grid_transform,
factor_grid_cat,
Expand Down Expand Up @@ -513,7 +514,7 @@ def __init__(
for compressed_spatial_dim, compressed_memory_format in zip(
compressed_spatial_dims, compressed_memory_formats
):
block = nn.Sequential(
block = Sequential(
nn.Conv2d(
in_channels * compressed_spatial_dim,
out_channels * compressed_spatial_dim,
Expand Down
8 changes: 4 additions & 4 deletions warpconvnet/nn/modules/point_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class PointPoolBase(BaseSpatialModule):
Output geometry type. Defaults to ``"point"``.
unique_method : {"torch", "ravel", "morton"}, optional
Method used to find unique voxel indices. Defaults to ``"torch"``.
avereage_pooled_coordinates : bool, optional
average_pooled_coordinates : bool, optional
If ``True`` average coordinates of points within each voxel. Defaults to ``False``.
return_neighbor_search_result : bool, optional
If ``True`` also return the neighbor search result. Defaults to ``False``.
Expand All @@ -48,7 +48,7 @@ def __init__(
downsample_voxel_size: Optional[float] = None,
return_type: Literal["point", "sparse"] = "point",
unique_method: Literal["torch", "ravel", "morton"] = "torch",
avereage_pooled_coordinates: bool = False,
average_pooled_coordinates: bool = False,
return_neighbor_search_result: bool = False,
):
super().__init__()
Expand All @@ -60,7 +60,7 @@ def __init__(
self.return_type = return_type
self.return_neighbor_search_result = return_neighbor_search_result
self.unique_method = unique_method
self.avereage_pooled_coordinates = avereage_pooled_coordinates
self.average_pooled_coordinates = average_pooled_coordinates

def forward(self, pc: Points) -> Union[Geometry, Tuple[Geometry, RealSearchResult]]:
return point_pool(
Expand All @@ -71,7 +71,7 @@ def forward(self, pc: Points) -> Union[Geometry, Tuple[Geometry, RealSearchResul
return_type=self.return_type,
return_neighbor_search_result=self.return_neighbor_search_result,
unique_method=self.unique_method,
avereage_pooled_coordinates=self.avereage_pooled_coordinates,
average_pooled_coordinates=self.average_pooled_coordinates,
)


Expand Down