diff --git a/warpconvnet/geometry/coords/grid.py b/warpconvnet/geometry/coords/grid.py index b1fe9ef..0cdb33e 100644 --- a/warpconvnet/geometry/coords/grid.py +++ b/warpconvnet/geometry/coords/grid.py @@ -272,25 +272,22 @@ def dtype(self): def half(self): if not self.is_initialized: - lazy_init = self._lazy_params.replace(dtype=torch.float16) return GridCoords.from_tensor( - self.batched_tensor.half(), self.offsets, self.grid_shape, self.bounds, lazy_init + self.batched_tensor.half(), self.offsets, self.grid_shape, self.bounds ) return super().half() def float(self): if not self.is_initialized: - lazy_init = self._lazy_params.replace(dtype=torch.float32) return GridCoords.from_tensor( - self.batched_tensor.float(), self.offsets, self.grid_shape, self.bounds, lazy_init + self.batched_tensor.float(), self.offsets, self.grid_shape, self.bounds ) return super().float() def double(self): if not self.is_initialized: - lazy_init = self._lazy_params.replace(dtype=torch.float64) return GridCoords.from_tensor( - self.batched_tensor.double(), self.offsets, self.grid_shape, self.bounds, lazy_init + self.batched_tensor.double(), self.offsets, self.grid_shape, self.bounds ) return super().double() diff --git a/warpconvnet/geometry/coords/ops/voxel.py b/warpconvnet/geometry/coords/ops/voxel.py index 6617023..d1c0188 100644 --- a/warpconvnet/geometry/coords/ops/voxel.py +++ b/warpconvnet/geometry/coords/ops/voxel.py @@ -53,7 +53,7 @@ def voxel_downsample_csr_mapping( batched_points: Float[Tensor, "N 3"], # noqa: F722,F821 offsets: Int[Tensor, "B + 1"], # noqa: F722,F821 voxel_size: float, - unique_method: Literal["torch", "ravel", "morton"] | None = None, + unique_method: Literal["torch", "ravel", "morton"] = "torch", ) -> Tuple[ Int[Tensor, "M 3"], # noqa: F821 Int[Tensor, "B+1"], # noqa: F821 @@ -221,10 +221,8 @@ def voxel_downsample_mapping( down_batched_points = down_batched_points.int() # Get the batch index - up_bcoords = batch_indexed_coordinates(up_batched_points, up_offsets, return_type="torch") - down_bcoords = batch_indexed_coordinates( - down_batched_points, down_offsets, return_type="torch" - ) + up_bcoords = batch_indexed_coordinates(up_batched_points, up_offsets) + down_bcoords = batch_indexed_coordinates(down_batched_points, down_offsets) down_table = TorchHashTable.from_keys(down_bcoords) # Get the map that maps up_batched_points[up_map] ~= down_batched_points. diff --git a/warpconvnet/geometry/coords/search/torch_discrete.py b/warpconvnet/geometry/coords/search/torch_discrete.py index 780f9a6..7f242bd 100644 --- a/warpconvnet/geometry/coords/search/torch_discrete.py +++ b/warpconvnet/geometry/coords/search/torch_discrete.py @@ -376,8 +376,7 @@ def generate_kernel_map( Generate the kernel map for the spatially sparse convolution using TorchHashTable. in_to_out_stride_ratio: the ratio of the input stride to the output stride. This will be multiplied to output coordinates to find matching input coordinates. - method: 'query' directly queries the hash table for each offset point (can be slower for large kernels but flexible). - 'offset' pre-calculates all kernel offsets and uses a custom kernel to find matches (generally faster). + method: 'offset' pre-calculates all kernel offsets and uses a custom kernel to find matches (generally faster). 'size' uses a specialized kernel for 4D coordinates if applicable, otherwise falls back to 'offset'. skip_symmetric_kernel_map: If True, skip symmetric parts of the kernel map for odd-sized kernels. """ diff --git a/warpconvnet/nn/functional/sparse_conv/detail/implicit_wmma.py b/warpconvnet/nn/functional/sparse_conv/detail/implicit_wmma.py index eb7f932..4b4ac49 100644 --- a/warpconvnet/nn/functional/sparse_conv/detail/implicit_wmma.py +++ b/warpconvnet/nn/functional/sparse_conv/detail/implicit_wmma.py @@ -89,7 +89,7 @@ def _wmma_implicit_gemm_backward_logic( min_dtype = _min_dtype(in_features.dtype, weight.dtype, grad_output.dtype) if min_dtype not in [torch.float16, torch.bfloat16]: # wmma not supported for data types other than float16 and bfloat16 - return int(_C.gemm.GemmStatus.kErrorInvalidParameters) + return int(_C.gemm.GemmStatus.kErrorInvalidParameters), -1 _grad_output_detached = grad_output.contiguous().detach().to(dtype=min_dtype) _in_features_detached = in_features.contiguous().detach().to(dtype=min_dtype) _weight_detached = weight.contiguous().detach().to(dtype=min_dtype) diff --git a/warpconvnet/nn/functional/sparse_conv_depth.py b/warpconvnet/nn/functional/sparse_conv_depth.py index 69285fa..5897b03 100644 --- a/warpconvnet/nn/functional/sparse_conv_depth.py +++ b/warpconvnet/nn/functional/sparse_conv_depth.py @@ -26,7 +26,7 @@ generic_benchmark_update_entry, mark_generic_benchmark_cache_dirty, ) -from warpconvnet.nn.functional.sparse_conv import _BENCHMARK_NUM_RUNS +from warpconvnet.nn.functional.sparse_conv.detail.unified import _BENCHMARK_NUM_RUNS from warpconvnet.utils.benchmark_cache import SpatiallySparseConvConfig from warpconvnet.utils.type_cast import _min_dtype, _max_dtype, _maybe_cast from warpconvnet.utils.timer import CUDATimer @@ -359,7 +359,7 @@ def _run_depthwise_forward_benchmarks( warmup_iters = max(warmup_iters, 1) benchmark_iters = max(benchmark_iters, 1) - logger.warn( + logger.warning( "Using benchmarked depthwise forward algo. Until the algorithm finds the best parameters, forward performance will be slow." ) all_benchmark_results: List[ @@ -681,6 +681,9 @@ def forward( ) chosen_fwd_algo, chosen_fwd_params, _ = all_fwd_benchmark_results[0] + if isinstance(chosen_fwd_algo, str): + chosen_fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE(chosen_fwd_algo.lower()) + # Step 5: Execute with optimal algorithm and parameters if chosen_fwd_algo == SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.EXPLICIT: output_feature_tensor = _explicit_depthwise_forward_logic( @@ -834,6 +837,9 @@ def backward(ctx, grad_output: Float[Tensor, "M C"]) -> Tuple[ ) chosen_bwd_algo, chosen_bwd_params, _ = all_bwd_benchmark_results[0] + if isinstance(chosen_bwd_algo, str): + chosen_bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE(chosen_bwd_algo.lower()) + # Step 5: Execute with optimal algorithm and parameters if chosen_bwd_algo == SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.EXPLICIT: grad_in_features, grad_weight = _explicit_depthwise_backward_logic( diff --git a/warpconvnet/nn/modules/factor_grid.py b/warpconvnet/nn/modules/factor_grid.py index da5622c..3adc39a 100644 --- a/warpconvnet/nn/modules/factor_grid.py +++ b/warpconvnet/nn/modules/factor_grid.py @@ -28,6 +28,7 @@ from warpconvnet.geometry.types.points import Points from warpconvnet.nn.functional.encodings import sinusoidal_encoding, get_freqs from warpconvnet.nn.modules.base_module import BaseSpatialModule +from warpconvnet.nn.modules.sequential import Sequential from warpconvnet.nn.functional.factor_grid import ( factor_grid_transform, factor_grid_cat, @@ -513,7 +514,7 @@ def __init__( for compressed_spatial_dim, compressed_memory_format in zip( compressed_spatial_dims, compressed_memory_formats ): - block = nn.Sequential( + block = Sequential( nn.Conv2d( in_channels * compressed_spatial_dim, out_channels * compressed_spatial_dim, diff --git a/warpconvnet/nn/modules/point_pool.py b/warpconvnet/nn/modules/point_pool.py index d4e1359..0e71b58 100644 --- a/warpconvnet/nn/modules/point_pool.py +++ b/warpconvnet/nn/modules/point_pool.py @@ -35,7 +35,7 @@ class PointPoolBase(BaseSpatialModule): Output geometry type. Defaults to ``"point"``. unique_method : {"torch", "ravel", "morton"}, optional Method used to find unique voxel indices. Defaults to ``"torch"``. - avereage_pooled_coordinates : bool, optional + average_pooled_coordinates : bool, optional If ``True`` average coordinates of points within each voxel. Defaults to ``False``. return_neighbor_search_result : bool, optional If ``True`` also return the neighbor search result. Defaults to ``False``. @@ -48,7 +48,7 @@ def __init__( downsample_voxel_size: Optional[float] = None, return_type: Literal["point", "sparse"] = "point", unique_method: Literal["torch", "ravel", "morton"] = "torch", - avereage_pooled_coordinates: bool = False, + average_pooled_coordinates: bool = False, return_neighbor_search_result: bool = False, ): super().__init__() @@ -60,7 +60,7 @@ def __init__( self.return_type = return_type self.return_neighbor_search_result = return_neighbor_search_result self.unique_method = unique_method - self.avereage_pooled_coordinates = avereage_pooled_coordinates + self.average_pooled_coordinates = average_pooled_coordinates def forward(self, pc: Points) -> Union[Geometry, Tuple[Geometry, RealSearchResult]]: return point_pool( @@ -71,7 +71,7 @@ def forward(self, pc: Points) -> Union[Geometry, Tuple[Geometry, RealSearchResul return_type=self.return_type, return_neighbor_search_result=self.return_neighbor_search_result, unique_method=self.unique_method, - avereage_pooled_coordinates=self.avereage_pooled_coordinates, + average_pooled_coordinates=self.average_pooled_coordinates, )