Skip to content

Commit 7ddd5cc

Browse files
committed
optimize by reviewing comments
1 parent 1fe0a97 commit 7ddd5cc

File tree

4 files changed

+5
-12
lines changed

4 files changed

+5
-12
lines changed

cpp/include/raft/core/bitmap.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
125125
* The bitmap is interpreted as a row-major matrix, with rows and columns defined by
126126
* the dimensions of the bitmap.
127127
*
128-
* @tparam bitmap_t The data type of the elements in the bitmap matrix.
129-
* @tparam index_t The data type used for indexing the elements in the matrices.
130128
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to raft::device_csr_matrix.
131129
*
132130
* @param[in] res RAFT resources for managing CUDA streams and execution policies.

cpp/include/raft/core/bitset.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,6 @@ struct bitset_view {
231231
* // 1, 1, 1, 1];
232232
* @endcode
233233
*
234-
* @tparam bitset_t The data type of the elements in the bitset matrix.
235-
* @tparam index_t The data type used for indexing the elements in the matrices.
236234
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to raft::device_csr_matrix.
237235
*
238236
* @param[in] res RAFT resources for managing CUDA streams and execution policies.

cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ void bitmap_to_csr(raft::resources const& handle,
330330
thrust_policy, sub_nnz.data(), sub_nnz.data() + sub_nnz_size + 1, sub_nnz.data());
331331

332332
if constexpr (is_device_csr_sparsity_owning_v<csr_matrix_t>) {
333-
index_t nnz = 0;
333+
nnz_t nnz = 0;
334334
RAFT_CUDA_TRY(cudaMemcpyAsync(
335-
&nnz, sub_nnz.data() + sub_nnz_size, sizeof(index_t), cudaMemcpyDeviceToHost, stream));
335+
&nnz, sub_nnz.data() + sub_nnz_size, sizeof(nnz_t), cudaMemcpyDeviceToHost, stream));
336336
resource::sync_stream(handle);
337337
csr.initialize_sparsity(nnz);
338338
}

cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh

+3-6
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,8 @@ RAFT_KERNEL repeat_csr_kernel(const index_t* indptr,
5858

5959
__syncthreads();
6060

61-
int block_offset = blockIdx.x * blockDim.x;
62-
6361
index_t item;
64-
int idx = block_offset + threadIdx.x;
65-
item = (idx < nnz) ? indices[idx] : -1;
62+
item = (global_id < nnz) ? indices[global_id] : -1;
6663

6764
__syncthreads();
6865

@@ -144,10 +141,10 @@ void bitset_to_csr(raft::resources const& handle,
144141
thrust::exclusive_scan(
145142
thrust_policy, sub_nnz.data(), sub_nnz.data() + sub_nnz_size + 1, sub_nnz.data());
146143

147-
index_t bitset_nnz = 0;
144+
nnz_t bitset_nnz = 0;
148145
if constexpr (is_device_csr_sparsity_owning_v<csr_matrix_t>) {
149146
RAFT_CUDA_TRY(cudaMemcpyAsync(
150-
&bitset_nnz, sub_nnz.data() + sub_nnz_size, sizeof(index_t), cudaMemcpyDeviceToHost, stream));
147+
&bitset_nnz, sub_nnz.data() + sub_nnz_size, sizeof(nnz_t), cudaMemcpyDeviceToHost, stream));
151148
resource::sync_stream(handle);
152149
csr.initialize_sparsity(bitset_nnz * csr_view.get_n_rows());
153150
} else {

0 commit comments

Comments
 (0)