Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Implement dot(csr, rsp)=dns and dot(csr.T, rsp)=rsp and refactor (#6902)
Browse files Browse the repository at this point in the history
* Initial checkin

Add dot(csr.T, rsp)=rsp2

Add infer storage for dot(csr, rsp)=dns and dot(csr.T, rsp)=rsp2

* Fix comments

* Replace std::lower_bound with own impl for gpu use too

* Add time profiling

* Revert "Add time profiling"

This reverts commit 8f5bb98.

* Move dot and batch_dot to a single file

* Move dot gpu impl to a .cuh file

* More refactor

* Fix include error
  • Loading branch information
reminisce authored and piiswrong committed Jul 4, 2017
1 parent 8ed829f commit 0b8d901
Show file tree
Hide file tree
Showing 15 changed files with 1,242 additions and 929 deletions.
2 changes: 1 addition & 1 deletion src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

#include "./utils.h"
#include "../operator/nn/cast_storage-inl.h"
#include "../operator/tensor/cast_storage-inl.h"

namespace mxnet {
namespace common {
Expand Down
2 changes: 1 addition & 1 deletion src/common/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

#include "./utils.h"
#include "../operator/nn/cast_storage-inl.h"
#include "../operator/tensor/cast_storage-inl.h"

namespace mxnet {
namespace common {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file cast_storage-inl.cuh
* \brief implementation of cast_storage op on GPU
*/
#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_
#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_
#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_
#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_

#include <mxnet/base.h>
#include <mxnet/operator.h>
Expand All @@ -23,4 +23,4 @@ inline void CastStorageDnsCsrImpl(mshadow::Stream<gpu>* s, const TBlob& dns, NDA
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_
#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file cast_storage-inl.h
* \brief cast_storage implementation for dense and sparse tensors
*/
#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_
#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_
#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_
#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_

#include <dmlc/timer.h>
#include <mxnet/ndarray.h>
Expand Down Expand Up @@ -333,4 +333,4 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs,
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_
#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_
File renamed without changes.
File renamed without changes.
161 changes: 161 additions & 0 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*!
* Copyright (c) 2017 by Contributors
* \file dot-inl.cuh
* \brief implementation of matrix dot op on GPU
*/
#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_CUH_
#define MXNET_OPERATOR_TENSOR_DOT_INL_CUH_

#include <mxnet/base.h>
#include <mxnet/operator.h>

namespace mxnet {
namespace op {

/*!
* \brief Kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements
*/
template<int req>
struct DotCsrDnsDns {
/*!
* \brief This function represents performing an inner product between a row of lhs
* and a column of rhs and then assigning the value to out[i].
* \param i i-th element in out 1D view
* \param out output matrix
* \param data_l csr values of lhs
* \param indptr_l csr indptr of lhs
* \param col_idx_l csr col_idx of lhs
* \param data_r dense data of rhs
* \param num_cols number of columns of output
*/
template<typename DType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols) {
const int irow = i / num_cols; // row id of the lhs
const int icol = i % num_cols; // col id of the rhs
DType sum = 0;
for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) {
const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs
sum += data_l[j] * data_r[cur_col*num_cols+icol];
}
KERNEL_ASSIGN(out[i], req, sum);
}
};

/*!
* \brief Kernel of dot(csr.T(), dns1) = dns2
* Parallelization by output matrix elements
*/
template<int req>
struct DotCsrTransDnsDns {
/*!
* \brief This function represents performing an inner product between a column of lhs
* and a column of rhs and then assigning the value to out[i].
* \param i i-th element in out 1D view
* \param out output matrix
* \param data_l csr values of lhs
* \param indptr_l csr indptr of lhs
* \param col_idx_l csr col_idx of lhs
* \param data_r dense data of rhs
* \param num_rows_l number of rows of lhs
* \param num_cols number of columns of outputs
*/
template<typename DType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r, const int num_rows_l,
const int num_cols) {
const int irow = i / num_cols; // col id of the lhs
const int icol = i % num_cols; // col id of the rhs
DType sum = 0;
for (int k = 0; k < num_rows_l; ++k) {
const IType low = indptr_l[k];
const IType high = indptr_l[k+1];
if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue;
int j = -1, l = low, r = high - 1;
while (l <= r) {
int m = l + (r - l) / 2;
if (col_idx_l[m] == irow) {
j = m; break;
}
if (col_idx_l[m] < irow) {
l = m + 1;
} else {
r = m - 1;
}
}
if (j >= 0) {
sum += data_l[j] * data_r[k*num_cols+icol];
}
}
KERNEL_ASSIGN(out[i], req, sum);
}
};

inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
const NDArray& lhs,
const TBlob& rhs,
const OpReqType req,
const bool trans_lhs,
TBlob* ret) {
if (kNullOp == req) return;
CHECK_EQ(lhs.storage_type(), kCSRStorage);
if (!lhs.storage_initialized()) return;

const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob& data_r = rhs;
const TBlob data_out = *ret;

MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (trans_lhs) {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDns<ReqType>, gpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), lhs.shape()[0],
data_out.shape_[1]);
});
} else {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDns<ReqType>, gpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), rhs.shape_[1]);
});
}
});
});
});
}

/*!
* \brief Impl of dot(csr.T, dns) = rsp
*/
inline void DotCsrDnsRspImpl(mshadow::Stream<gpu>* s,
const NDArray& lhs,
const TBlob& rhs,
const OpReqType req,
const bool trans_lhs,
NDArray* ret) {
LOG(FATAL) << "DotCsrDnsRspImpl gpu version is not implemented.";
}

/*!
* \brief Impl of dot(csr.T, rsp) = rsp2
*/
inline void DotCsrRspRspImpl(mshadow::Stream<gpu>* s,
const NDArray& lhs,
const NDArray& rhs,
const OpReqType req,
const bool trans_lhs,
NDArray* ret) {
LOG(FATAL) << "DotCsrRspRspImpl gpu version is not implemented.";
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_TENSOR_DOT_INL_CUH_
Loading

0 comments on commit 0b8d901

Please sign in to comment.