Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix undeterminism of dot(csr.T, dns) = dns with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jul 19, 2018
1 parent d86f954 commit f63c106
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 200 deletions.
314 changes: 118 additions & 196 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
namespace mxnet {
namespace op {

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
int k = 1;
while (a >>= 1) k++;
return k;
}

/*!
* \brief GPU scalar kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 thread/element
Expand Down Expand Up @@ -176,119 +183,6 @@ struct DotCsrTransDnsDnsScalarKernel {
}
};

/*!
* \brief GPU warp kernel of dot(csr.T, dns1) = dns2
* Parallelization by columns: 1 warp computes one lhs column for one rhs column
*/
struct DotCsrTransDnsDnsWarpKernel {
/*!
* \brief see DotCsrTransDnsDnsScalarKernel Map for documentation.
*/
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t warp_id = tid / 32; // global warp id
const dim_t lane = tid & (32-1); // local thread id within warp
const dim_t icol = warp_id / num_cols_r; // lhs column that this warp computes
const dim_t kcol = warp_id % num_cols_r; // rhs column that this warp computes

// Compute range of nnz elements in this column
const dim_t low = static_cast<dim_t>(indptr_l[icol]);
const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);

// Iterate through the nnz elements in this column
for (dim_t j = low+lane; j < high; j+=32) {
const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
const DType val = data_l[j]*data_r[icol*num_cols_r+kcol];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+kcol])), val);
}
}
};

/*!
* \brief GPU thread block kernel of dot(csr.T, dns1) = dns2
* Parallelization by columns: 1 thread block computes one lhs column for all rhs columns
*/
struct DotCsrTransDnsDnsThreadBlockKernel {
/*!
* \brief see DotCsrTransDnsDnsScalarKernel Map for documentation.
*/
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t warps_per_block = blockDim.x / 32; // number of warps in this thread block
const dim_t warp_id = tid / 32; // global warp id
const dim_t lane = tid & (32-1); // local thread id within warp
const dim_t icol = blockIdx.x; // lhs column that this thread block computes
const dim_t kcol = warp_id % warps_per_block; // rhs column where warp starts computing (offset)

// Compute range of nnz elements in this lhs column
const dim_t low = static_cast<dim_t>(indptr_l[icol]);
const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);

// Iterate through the nnz elements in this lhs column
for (dim_t j = low+lane; j < high; j+=32) {
const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
const DType datum_l = data_l[j];
// Iterate over rhs columns that this warp computes
for (dim_t k = kcol; k < num_cols_r; k+=warps_per_block) {
const DType val = datum_l*data_r[icol*num_cols_r+k];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
}
}
}
};

/*!
* \brief GPU warp block kernel of dot(csr.T, dns1) = dns2
* Parallelization by columns: 1 warp computes one lhs column for all rhs columns
*/
struct DotCsrTransDnsDnsWarpBlockKernel {
/*!
* \brief see DotCsrTransDnsDnsScalarKernel Map for documentation.
*/
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t warp_id = tid / 32; // global warp id
const dim_t lane = tid & (32-1); // local thread id within warp
const dim_t icol = warp_id; // lhs column that this warp computes

// Compute range of nnz elements in this column
const dim_t low = static_cast<dim_t>(indptr_l[icol]);
const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);

// Iterate through the nnz elements in lhs column
for (dim_t j = low+lane; j < high; j+=32) {
const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
const DType datum_l = data_l[j];
// Iterate over all rhs columns
for (dim_t k = 0; k < num_cols_r; k++) {
const DType val = datum_l*data_r[icol*num_cols_r+k];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
}
}
}
};

/*!
* \brief GPU Kernel of dot(csr.T, rsp1) = rsp2
* Parallelization by rows: 1 thread/row
Expand Down Expand Up @@ -510,6 +404,7 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
return;
}

using namespace mshadow;
using mshadow::cuda::kBaseThreadNum;
using mxnet_op::Kernel;
using mxnet_op::set_zero;
Expand Down Expand Up @@ -539,86 +434,120 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
Kernel<set_zero, gpu>::Launch(s, num_threads, data_out.dptr<DType>());
}
if (trans_lhs) {
// Different kernel versions are optimized for different matrix instances
// TODO: switch between kernel versions depending on input
// (1) 'Scalar kernel' (one thread computing one output element )
// (2) 'Warp kernel' (one warp computing one lhs column for one rhs column )
// (3) 'Thread block kernel' (one thread block computing one lhs column for all rhs columns)
// (4) 'Warp block kernel' (one warp computing one lhs column for all rhs columns)
const int kernel_version = 0;
switch (kernel_version) {
case 1:
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrTransDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_rows_l, num_cols_r);
});
break;
case 2:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
Kernel<DotCsrTransDnsDnsWarpKernel, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
break;
case 3:
num_threads = threads_per_block * num_rows_l;
Kernel<DotCsrTransDnsDnsThreadBlockKernel, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
break;
case 4:
num_threads = threads_per_warp * num_rows_l;
Kernel<DotCsrTransDnsDnsWarpBlockKernel, gpu>::Launch(s, num_threads,
// TODO(haojin2): Switching to deterministic algorithm for now.
// Further optimizations to come later.
const nnvm::dim_t num_csr_rows = lhs.shape()[0];
const nnvm::dim_t num_csr_cols = lhs.shape()[1];
const nnvm::dim_t num_dns_rows = rhs.shape_[0];
const nnvm::dim_t nnz = lhs.storage_shape().Size();

IType* original_idx_ptr = nullptr;
IType* csc_indices_ptr = nullptr;
IType* csc_cols_ptr = nullptr;
CType* csr_rows_ptr = nullptr;
CType* csc_indptr_ptr = nullptr;
DType* csc_data_ptr = nullptr;
char* temp_storage_ptr = nullptr;
size_t original_idx_bytes = nnz*sizeof(IType);
size_t csc_indices_bytes = nnz*sizeof(IType);
size_t csc_cols_bytes = nnz*sizeof(IType);
size_t csr_rows_bytes = nnz*sizeof(CType);
size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType);
size_t csc_data_bytes = nnz*sizeof(DType);
size_t scan_temp_storage_bytes = 0;
size_t temp_storage_bytes = SortByKeyWorkspaceSize<IType, IType, gpu>(nnz);
IType* csr_indices_ptr = col_idx_l.dptr<IType>();
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
scan_temp_storage_bytes,
csc_indptr_ptr,
csc_indptr_ptr,
num_csr_cols+1,
mshadow::Stream<gpu>::GetStream(s));
temp_storage_bytes = std::max(temp_storage_bytes, scan_temp_storage_bytes);
temp_storage_bytes += (sizeof(dim_t) - temp_storage_bytes % sizeof(dim_t));
size_t total_workspace_bytes =
original_idx_bytes + csc_indices_bytes + csc_cols_bytes + csr_rows_bytes +
csc_indptr_bytes + csc_data_bytes + temp_storage_bytes;
total_workspace_bytes += (sizeof(IType) - total_workspace_bytes % sizeof(IType));
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(total_workspace_bytes), s);
original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_);
csc_indices_ptr = reinterpret_cast<IType*>(workspace.dptr_ + original_idx_bytes);
csc_cols_ptr = reinterpret_cast<IType*>(workspace.dptr_ + original_idx_bytes +
csc_indices_bytes);
csr_rows_ptr = reinterpret_cast<CType*>(workspace.dptr_ + original_idx_bytes +
csc_indices_bytes + csc_cols_bytes);
csc_indptr_ptr = reinterpret_cast<CType*>(workspace.dptr_ + original_idx_bytes +
csc_indices_bytes + csc_cols_bytes +
csr_rows_bytes);
temp_storage_ptr = workspace.dptr_ + original_idx_bytes + csc_indices_bytes +
csc_cols_bytes + csr_rows_bytes + csc_indptr_bytes;
csc_data_ptr = reinterpret_cast<DType*>(
workspace.dptr_ + total_workspace_bytes - csc_data_bytes);

// Fill original_idx
mxnet_op::Kernel<range_fwd, gpu>::Launch(
s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
// Fill csc_cols with copy of csr_indices
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
s, nnz, csc_cols_ptr, csr_indices_ptr);
// Allocate the tensors needed for SortByKey
Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(nnz), s);
Tensor<gpu, 1, IType> csc_cols(csc_cols_ptr, Shape1(nnz), s);
Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);

int num_bits = log2i(num_csr_cols - 1);
SortByKey(csc_cols, original_idx, true, &temp_storage, 0, num_bits);

// Scatter csr indptr to row id
mxnet_op::Kernel<CsrRowScatterKernel, gpu>::Launch(
s, num_csr_rows, indptr_l.dptr<CType>(), csr_rows_ptr, num_csr_rows);
// Reset indptr to zero
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, num_csr_cols+1, csc_indptr_ptr);
// Histogram on the sorted cols
mxnet_op::Kernel<HistogramKernel, gpu>::Launch(
s, nnz, csc_indptr_ptr, csc_cols_ptr, nnz);
// Scan the bin counts for every column to get csc_indptr
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
temp_storage_bytes,
csc_indptr_ptr,
csc_indptr_ptr,
num_csr_cols+1,
mshadow::Stream<gpu>::GetStream(s));
// Assign data to csc matrix arrays
mxnet_op::Kernel<CscDataIndicesKernel, gpu>::Launch(
s, nnz, original_idx_ptr, data_l.dptr<DType>(), csr_rows_ptr, csc_data_ptr,
csc_indices_ptr, nnz);
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), csc_data_ptr, csc_indptr_ptr,
csc_indices_ptr, data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), csc_data_ptr, csc_indptr_ptr,
csc_indices_ptr, data_r.dptr<DType>(), num_cols_r);
});
}
} else {
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
break;
default:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
Kernel<DotCsrTransDnsDnsWarpKernel, gpu>::Launch(s, num_threads,
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
break;
}
} else {
// Different kernel versions are optimized for different matrix instances
// (1) 'Scalar kernel' (one thread computing one output element)
// (2) 'Vector kernel' (one warp computing one output element)
const int kernel_version = 0;
switch (kernel_version) {
case 1:
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
case 2:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
default:
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
}
break;
});
}
}
});
Expand Down Expand Up @@ -671,13 +600,6 @@ struct DotCsrTransDnsRspKernel {
}
};

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
int k = 1;
while (a >>= 1) k++;
return k;
}

/*!
* \brief GPU Impl of dot(csr.T, dns) = rsp
*/
Expand Down
9 changes: 5 additions & 4 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,7 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):

@with_seed()
def test_sparse_dot_determinism():
def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b, forward_stype):
def check_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b, forward_stype):
lhs_row = rnd.randint(50, 100)
lhs_col = rnd.randint(50, 100)
if transpose_a:
Expand All @@ -1444,10 +1444,11 @@ def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpo
res2 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype)
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.0, atol=0.0)

test_dot_determinism('csr', 'default', 0.1, 1.0, True, False, 'row_sparse')
check_dot_determinism('csr', 'default', 0.1, 1.0, True, False, 'row_sparse')
forward_stype = 'csr' if default_context() == mx.cpu() else 'default'
test_dot_determinism('default', 'csr', 1.0, 0.1, False, False, forward_stype)
test_dot_determinism('default', 'csr', 1.0, 0.1, False, True, forward_stype)
check_dot_determinism('default', 'csr', 1.0, 0.1, False, False, forward_stype)
check_dot_determinism('default', 'csr', 1.0, 0.1, False, True, forward_stype)
check_dot_determinism('csr', 'default', 1.0, 1.0, True, False, 'default')


@with_seed()
Expand Down

0 comments on commit f63c106

Please sign in to comment.