From 0b8d90194ce1ad34eba443eea49fc58a9029128e Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 3 Jul 2017 21:10:34 -0700 Subject: [PATCH] Implement dot(csr, rsp)=dns and dot(csr.T, rsp)=rsp and refactor (#6902) * Initial checkin Add dot(csr.T, rsp)=rsp2 Add infer storage for dot(csr, rsp)=dns and dot(csr.T, rsp)=rsp2 * Fix comments * Replace std::lower_bound with own impl for gpu use too * Add time profiling * Revert "Add time profiling" This reverts commit 8f5bb982867731df0305148b1b150b05661f8529. * Move dot and batch_dot to a single file * Move dot gpu impl to a .cuh file * More refactor * Fix include error --- src/common/utils.cc | 2 +- src/common/utils.cu | 2 +- .../{nn => tensor}/cast_storage-inl.cuh | 6 +- .../{nn => tensor}/cast_storage-inl.h | 6 +- src/operator/{nn => tensor}/cast_storage.cc | 0 src/operator/{nn => tensor}/cast_storage.cu | 0 src/operator/tensor/dot-inl.cuh | 161 +++ src/operator/tensor/dot-inl.h | 924 ++++++++++++++++++ src/operator/tensor/dot.cc | 114 +++ src/operator/tensor/dot.cu | 27 + src/operator/tensor/indexing_op.h | 6 +- src/operator/tensor/matrix_op-inl.h | 799 --------------- src/operator/tensor/matrix_op.cc | 101 -- src/operator/tensor/matrix_op.cu | 15 - tests/python/unittest/test_sparse_operator.py | 8 +- 15 files changed, 1242 insertions(+), 929 deletions(-) rename src/operator/{nn => tensor}/cast_storage-inl.cuh (78%) rename src/operator/{nn => tensor}/cast_storage-inl.h (98%) rename src/operator/{nn => tensor}/cast_storage.cc (100%) rename src/operator/{nn => tensor}/cast_storage.cu (100%) create mode 100644 src/operator/tensor/dot-inl.cuh create mode 100644 src/operator/tensor/dot-inl.h create mode 100644 src/operator/tensor/dot.cc create mode 100644 src/operator/tensor/dot.cu diff --git a/src/common/utils.cc b/src/common/utils.cc index 5bfb959fdf34..4bcae02e990c 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -5,7 +5,7 @@ */ #include "./utils.h" -#include "../operator/nn/cast_storage-inl.h" +#include "../operator/tensor/cast_storage-inl.h" namespace mxnet { namespace common { diff --git a/src/common/utils.cu b/src/common/utils.cu index a249be5bb9f5..7221a2b6ec6c 100644 --- a/src/common/utils.cu +++ b/src/common/utils.cu @@ -5,7 +5,7 @@ */ #include "./utils.h" -#include "../operator/nn/cast_storage-inl.h" +#include "../operator/tensor/cast_storage-inl.h" namespace mxnet { namespace common { diff --git a/src/operator/nn/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh similarity index 78% rename from src/operator/nn/cast_storage-inl.cuh rename to src/operator/tensor/cast_storage-inl.cuh index b99d875eb612..0d4e601d0d2e 100644 --- a/src/operator/nn/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -3,8 +3,8 @@ * \file cast_storage-inl.cuh * \brief implementation of cast_storage op on GPU */ -#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ -#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ +#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ #include #include @@ -23,4 +23,4 @@ inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDA } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ +#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ diff --git a/src/operator/nn/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h similarity index 98% rename from src/operator/nn/cast_storage-inl.h rename to src/operator/tensor/cast_storage-inl.h index f0268c797c74..9273b996d48e 100644 --- a/src/operator/nn/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -3,8 +3,8 @@ * \file cast_storage-inl.h * \brief cast_storage implementation for dense and sparse tensors */ -#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ -#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ +#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ +#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ #include #include @@ -333,4 +333,4 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ +#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ diff --git a/src/operator/nn/cast_storage.cc b/src/operator/tensor/cast_storage.cc similarity index 100% rename from src/operator/nn/cast_storage.cc rename to src/operator/tensor/cast_storage.cc diff --git a/src/operator/nn/cast_storage.cu b/src/operator/tensor/cast_storage.cu similarity index 100% rename from src/operator/nn/cast_storage.cu rename to src/operator/tensor/cast_storage.cu diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh new file mode 100644 index 000000000000..513fde306bab --- /dev/null +++ b/src/operator/tensor/dot-inl.cuh @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot-inl.cuh + * \brief implementation of matrix dot op on GPU + */ +#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ + +#include +#include + +namespace mxnet { +namespace op { + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrDnsDns { + /*! + * \brief This function represents performing an inner product between a row of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_cols number of columns of output + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols) { + const int irow = i / num_cols; // row id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { + const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs + sum += data_l[j] * data_r[cur_col*num_cols+icol]; + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrTransDnsDns { + /*! + * \brief This function represents performing an inner product between a column of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_rows_l number of rows of lhs + * \param num_cols number of columns of outputs + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const int num_rows_l, + const int num_cols) { + const int irow = i / num_cols; // col id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (int k = 0; k < num_rows_l; ++k) { + const IType low = indptr_l[k]; + const IType high = indptr_l[k+1]; + if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; + int j = -1, l = low, r = high - 1; + while (l <= r) { + int m = l + (r - l) / 2; + if (col_idx_l[m] == irow) { + j = m; break; + } + if (col_idx_l[m] < irow) { + l = m + 1; + } else { + r = m - 1; + } + } + if (j >= 0) { + sum += data_l[j] * data_r[k*num_cols+icol]; + } + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +inline void DotCsrDnsDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + if (!lhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (trans_lhs) { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + data_out.shape_[1]); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); + }); + } + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr.T, dns) = rsp + */ +inline void DotCsrDnsRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + LOG(FATAL) << "DotCsrDnsRspImpl gpu version is not implemented."; +} + +/*! + * \brief Impl of dot(csr.T, rsp) = rsp2 + */ +inline void DotCsrRspRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + LOG(FATAL) << "DotCsrRspRspImpl gpu version is not implemented."; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h new file mode 100644 index 000000000000..33cc095c0cee --- /dev/null +++ b/src/operator/tensor/dot-inl.h @@ -0,0 +1,924 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot-inl.h + * \brief Function definition of matrix dot operator + */ + +#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_H_ +#define MXNET_OPERATOR_TENSOR_DOT_INL_H_ + +#include +#include +#include +#include +#include +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "../mxnet_op.h" +#ifdef __CUDACC__ +#include "./dot-inl.cuh" +#endif // __CUDACC__ + +namespace mxnet { +namespace op { + +struct DotParam : public dmlc::Parameter { + bool transpose_a; + bool transpose_b; + DMLC_DECLARE_PARAMETER(DotParam) { + DMLC_DECLARE_FIELD(transpose_a) + .describe("If true then transpose the first input before dot.") + .set_default(false); + DMLC_DECLARE_FIELD(transpose_b) + .describe("If true then transpose the second input before dot.") + .set_default(false); + } +}; + +template +void DotForward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const DotParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, kFloat32) + << "dot only support 32 bit float so far"; + + if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { + CHECK_NE(req[0], kAddTo) << "AddTo not yet suported"; + Tensor out = outputs[0].get(s); + VectorDot(out, + inputs[0].get(s), + inputs[1].get(s)); + } else { + int ma, na, mb, nb, m, n; + if (param.transpose_a) { + ma = inputs[0].size(0); + na = inputs[0].Size()/ma; + m = na; + } else { + na = inputs[0].size(inputs[0].ndim()-1); + ma = inputs[0].Size()/na; + m = ma; + } + if (param.transpose_b) { + nb = inputs[1].size(inputs[1].ndim()-1); + mb = inputs[1].Size()/nb; + n = mb; + } else { + mb = inputs[1].size(0); + nb = inputs[1].Size()/mb; + n = nb; + } + + Tensor input0 = + inputs[0].get_with_shape(Shape2(ma, na), s); + Tensor input1 = + inputs[1].get_with_shape(Shape2(mb, nb), s); + Tensor out = + outputs[0].get_with_shape(Shape2(m, n), s); + if (param.transpose_a && param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T())); + } else if (!param.transpose_a && param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T())); + } else if (param.transpose_a && !param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1)); + } else { + ASSIGN_DISPATCH(out, req[0], dot(input0, input1)); + } + } +} + +template +void DotBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const DotParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + CHECK_NE(req[0], kWriteInplace); + CHECK_NE(req[1], kWriteInplace); + + if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) { + Tensor mout_grad = inputs[0].get(s); + Tensor mlhs_data = inputs[1].get(s); + Tensor mrhs_data = inputs[2].get(s); + Tensor mlhs_grad = outputs[0].get(s); + Tensor mrhs_grad = outputs[1].get(s); + ASSIGN_DISPATCH(mrhs_grad, req[1], + broadcast_scalar(mout_grad, mlhs_data.shape_) * mlhs_data); + ASSIGN_DISPATCH(mlhs_grad, req[0], + broadcast_scalar(mout_grad, mlhs_data.shape_) * mrhs_data); + } else { + int ma, na, mb, nb, m, n; + if (param.transpose_a) { + ma = outputs[0].size(0); + na = outputs[0].Size()/ma; + m = na; + } else { + na = outputs[0].size(outputs[0].ndim()-1); + ma = outputs[0].Size()/na; + m = ma; + } + if (param.transpose_b) { + nb = outputs[1].size(outputs[1].ndim()-1); + mb = outputs[1].Size()/nb; + n = mb; + } else { + mb = outputs[1].size(0); + nb = outputs[1].Size()/mb; + n = nb; + } + + Tensor mout_grad = + inputs[0].get_with_shape(Shape2(m, n), s); + Tensor mlhs_data = + inputs[1].get_with_shape(Shape2(ma, na), s); + Tensor mrhs_data = + inputs[2].get_with_shape(Shape2(mb, nb), s); + Tensor mlhs_grad = + outputs[0].get_with_shape(Shape2(ma, na), s); + Tensor mrhs_grad = + outputs[1].get_with_shape(Shape2(mb, nb), s); + if (param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x.T, y.T) + // dy = dot(x, dz).T = dot(dz.T, x.T) + // dx = dot(dz, y).T = dot(y.T, dz.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data.T())); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data.T(), mout_grad.T())); + } else if (!param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x, y.T) + // dy = dot(x.T, dz).T = dot(dz.T, x) + // dx = dot(dz, y) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data)); + } else if (param.transpose_a && !param.transpose_b) { + // Gradient of z = dot(x.T, y) + // dy = dot(x, dz) + // dx = dot(dz, y.T).T = dot(y, dz.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data, mout_grad)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data, mout_grad.T())); + } else { + // Gradient of z = dot(x, y) + // dy = dot(x.T, dz) + // dx = dot(dz, y.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data.T(), mout_grad)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data.T())); + } + } +} + +inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp + if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + } + return true; +} + +inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 2U); + const DotParam& param = nnvm::get(attrs.parsed); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + if (!param.transpose_a && kCSRStorage == (*in_attrs)[1]) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); + } + return true; +} + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows, const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto val = data_l[k]; + const size_t offset_r = col_idx_l[k] * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrTransDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns) = rsp + * Parallelization by row blocks. + * This kernel fills up the row_idx array + * of the rsp with 1 for nonzero rows and 0 + * for zero rows. + * The matrix will be compacted after this kernel call. + */ +struct DotCsrTransDnsRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + row_idx[col_idx] = 1; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr, rsp) = dns + * Parallelization by row blocks + */ +struct DotCsrRspDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param nnr_r storage_shape[0] of the rsp + * \param num_rows dns.shape[0] + * \param num_cols dns.shape[1] + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const RType* row_idx_r, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + // Use binary search to find the lower_bound of val in row_idx array + const RType* first = row_idx_r; + const RType* last = row_idx_r + nnr_r; + const auto val = col_idx_l[indptr_l[j]]; + const RType* it; + int count = last - first, step; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (*it < val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + const RType* row_idx_ptr = first; + // end of binary search + if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> col_idx_l[indptr_l[j+1]-1]) continue; + for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != row_idx_r+nnr_r;) { + if (col_idx_l[k] == *row_idx_ptr) { + const size_t offset_r = (row_idx_ptr - row_idx_r) * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_l[k] * data_r[offset_r+l]; + } + ++k; + ++row_idx_ptr; + } else if (col_idx_l[k] < *row_idx_ptr) { + ++k; + } else { + ++row_idx_ptr; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), rsp) = dns with row_idx marked for non-zero rows + * Parallelization by row blocks + */ +struct DotCsrTransRspRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param num_rows_l number of rows of lhs matrix + * \param nnr_r number of non-zero rows of rhs matrix + * \param num_rows number of rows of out matrix + * \param num_cols number of cols of out matrix + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx_out, + const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const RType* row_idx_r, const size_t num_rows_l, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t rid = 0; rid < nnr_r; ++rid) { + const auto j = row_idx_r[rid]; + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = rid * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + row_idx_out[col_idx] = 1; // mark nonzero row as 1 + const size_t offset_out = col_idx * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * data_l[k]; + } + } + } + } +}; + +inline void DotCsrDnsDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + if (!lhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + } else { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + data_out.shape_[0], data_out.shape_[1]); + } + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr, rsp) + */ +inline void DotCsrDnsRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + + // pre-allocate spaces for ret using the dense dimension size + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." + " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; + } + }); + }); + }); + }); +} + +template +void DotCsrRspDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + // reuse csr dns implementation when storage_shape == shape for rhs + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsDnsImpl(s, lhs, rhs.data(), req, trans_lhs, ret); + return; + } + + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) { + if (kWriteTo == req) { + MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { // data type + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + }); + } + return; + } + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + } + int num_threads = mxnet_op::get_num_threads(ret->shape_[0]); + size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet"; + } else { + mxnet_op::Kernel::Launch(s, num_threads, + ret->dptr(), data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), rhs.storage_shape()[0], + ret->shape_[0], ret->shape_[1], seg_len); + } + }); + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr.T, rsp) = rsp2 + */ +inline void DotCsrRspRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + // reuse csr dns implementation when storage_shape == shape for rhs + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsRspImpl(s, lhs, rhs.data(), req, trans_lhs, ret); + return; + } + + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + // pre-allocate spaces for ret using the dense dimension size + if (ret->storage_type() == kRowSparseStorage) { + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + } + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), lhs.shape()[0], rhs.storage_shape()[0], + ret->shape()[0], ret->shape()[1], seg_len); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr.T, rsp) = rsp2 yet"; + } + }); + }); + }); + }); +} + +inline bool DotShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + TShape& lshape = (*in_attrs)[0]; + TShape& rshape = (*in_attrs)[1]; + if (lshape.ndim() == 1 && rshape.ndim() == 1) { + CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors"; + CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1)); + } else { + bool Ta = param.transpose_a, Tb = param.transpose_b; + TShape L[2], R[2]; + if (Ta) { + L[0] = mshadow::Shape1(lshape[0]); + L[1] = lshape.ndim() > 1 ? TShape(&lshape[1], &lshape[lshape.ndim()]) : TShape(1); + } else { + L[0] = lshape.ndim() > 1 ? TShape(&lshape[0], &lshape[lshape.ndim()-1]) : TShape(1); + L[1] = mshadow::Shape1(lshape[lshape.ndim()-1]); + } + if (Tb) { + R[0] = rshape.ndim() > 1 ? TShape(&rshape[0], &rshape[rshape.ndim()-1]) : TShape(1); + R[1] = mshadow::Shape1(rshape[rshape.ndim()-1]); + } else { + R[0] = mshadow::Shape1(rshape[0]); + R[1] = rshape.ndim() > 1 ? TShape(&rshape[1], &rshape[rshape.ndim()]) : TShape(1); + } + + if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) { + CHECK_EQ(L[!Ta].Size(), R[Tb].Size()) + << "dot shape error: " << lshape << " X " << rshape; + } + std::vector buf; + if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], &L[Ta][L[Ta].ndim()]); + if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], &R[!Tb][R[!Tb].ndim()]); + TShape oshape(buf.begin(), buf.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + } + return true; +} + +template +void DotForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; + auto lhs_stype = inputs[0].storage_type(); + auto rhs_stype = inputs[1].storage_type(); + auto out_stype = outputs[0].storage_type(); + mshadow::Stream* s = ctx.get_stream(); + if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrDnsDnsImpl(s, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrRspDnsImpl(s, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage + && out_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + DotCsrDnsRspImpl(s, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kRowSparseStorage) { + NDArray ret = outputs[0]; + DotCsrRspRspImpl(s, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); + } +} + +template +void DotBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_EQ(kNullOp, req[0]) + << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; + const auto ograd_stype = inputs[0].storage_type(); + const auto lhs_stype = inputs[1].storage_type(); + const auto rhs_stype = inputs[2].storage_type(); + const auto grad_rhs_stype = outputs[1].storage_type(); + mshadow::Stream* s = ctx.get_stream(); + if (ograd_stype == kDefaultStorage // ograd dns format + && lhs_stype == kCSRStorage // csr input lhs of the op + && grad_rhs_stype == kDefaultStorage) { // grad(rhs) dns format + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(s, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else if (ograd_stype == kDefaultStorage + && lhs_stype == kCSRStorage + && grad_rhs_stype == kRowSparseStorage) { + NDArray ret = outputs[1]; + DotCsrDnsRspImpl(s, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); + } +} + +template +void BatchDotForward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, mshadow::kFloat32) + << "dot only support 32 bit float so far"; + + mshadow::Tensor out = outputs[0].get(s); + mshadow::Tensor mlhs = inputs[0].get(s); + mshadow::Tensor mrhs = inputs[1].get(s); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); + if (kNullOp != req[0]) { + if (param.transpose_a && param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else if (!param.transpose_a && param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else if (param.transpose_a && !param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } + } +} + +template +void BatchDotBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_NE(req[1], kWriteInplace); + CHECK_NE(req[0], kWriteInplace); + + mshadow::Tensor mout_grad = inputs[0].get(s); + mshadow::Tensor mlhs_data = inputs[1].get(s); + mshadow::Tensor mrhs_data = inputs[2].get(s); + mshadow::Tensor mlhs_grad = outputs[0].get(s); + mshadow::Tensor mrhs_grad = outputs[1].get(s); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed( + mshadow::Shape2(2, 3 * mout_grad.size(0)), s); + mshadow::Tensor rhs_workspace = workspace[0]; + mshadow::Tensor lhs_workspace = workspace[1]; + if (param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x.T, y.T) + // dy = dot(x, dz).T = dot(dz.T, x.T) + // dx = dot(dz, y).T = dot(y.T, dz.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else if (!param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x, y.T) + // dy = dot(x.T, dz).T = dot(dz.T, x) + // dx = dot(dz, y) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else if (param.transpose_a && !param.transpose_b) { + // Gradient of z = dot(x.T, y) + // dy = dot(x, dz) + // dx = dot(dz, y.T).T = dot(y, dz.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else { + // Gradient of z = dot(x, y) + // dy = dot(x.T, dz) + // dx = dot(dz, y.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } +} + +inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + TShape& lshape = (*in_attrs)[0]; + TShape& rshape = (*in_attrs)[1]; + if (lshape.ndim() == 3 && rshape.ndim() == 3) { + CHECK(lshape[0] == rshape[0]) + << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape + << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; + index_t out_m = param.transpose_a ? lshape[2] : lshape[1]; + index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2]; + index_t out_n = param.transpose_b ? rshape[1] : rshape[2]; + index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1]; + CHECK(lshape_k == rshape_k) + << "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape + << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n)); + } else { + LOG(FATAL) << "batch_dot currently only support 3D*3D array" + << lshape << " v.s. " << rshape; + } + return true; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_ diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc new file mode 100644 index 000000000000..fc476a75eec8 --- /dev/null +++ b/src/operator/tensor/dot.cc @@ -0,0 +1,114 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot.cc + * \brief CPU Implementation of matrix dot + */ + +#include "./dot-inl.h" + +namespace mxnet { +namespace op { +DMLC_REGISTER_PARAMETER(DotParam); + +NNVM_REGISTER_OP(dot) +.describe(R"doc(Dot product of two arrays. + +``dot``'s behavior depends on the input array dimensions: + +- 1-D arrays: inner product of vectors +- 2-D arrays: matrix multiplication +- N-D arrays: a sum product over the last axis of the first input and the first + axis of the second input + + For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the + result array will have shape `(n,m,r,s)`. It is computed by:: + + dot(x,y)[i,j,a,b] = sum(x[i,j,:]*y[:,a,b]) + + Example:: + + x = reshape([0,1,2,3,4,5,6,7], shape=(2,2,2)) + y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2)) + dot(x,y)[0,0,1,1] = 0 + sum(x[0,0,:]*y[:,1,1]) = 0 +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) +.set_attr("FInferShape", DotShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) +.add_argument("lhs", "NDArray-or-Symbol", "The first input") +.add_argument("rhs", "NDArray-or-Symbol", "The second input") +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_dot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx) +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(batch_dot) +.describe(R"doc(Batchwise dot product. + +``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and +``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`. + +For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape +`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`, +which is computed by:: + + batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:]) + +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) +.set_attr("FInferShape", BatchDotShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BatchDotForward_) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) +.add_argument("lhs", "NDArray-or-Symbol", "The first input") +.add_argument("rhs", "NDArray-or-Symbol", "The second input") +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_batch_dot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", BatchDotBackward_); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/dot.cu b/src/operator/tensor/dot.cu new file mode 100644 index 000000000000..ae00566d5d45 --- /dev/null +++ b/src/operator/tensor/dot.cu @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot.cu + * \brief GPU Implementation of matrix dot + */ + +#include "./dot-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(dot) +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx); + +NNVM_REGISTER_OP(_backward_dot) +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx); + +NNVM_REGISTER_OP(batch_dot) +.set_attr("FCompute", BatchDotForward_); + +NNVM_REGISTER_OP(_backward_batch_dot) +.set_attr("FCompute", BatchDotBackward_); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 2a1590f9a8a5..46aa6fcd73a4 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -22,7 +22,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "./sort_op.h" -#include "./matrix_op-inl.h" +#include "./dot-inl.h" namespace mxnet { namespace op { @@ -215,7 +215,7 @@ void SparseEmbeddingForwardRspImpl(const nnvm::NodeAttrs& attrs, TBlob out_blob = out->data(); // forward to dns implementation when storage_shape equals shape bool transpose_a = false; - DotCsrRspDnsImpl(ctx, data, weight, req, transpose_a, &out_blob); + DotCsrRspDnsImpl(ctx.get_stream(), data, weight, req, transpose_a, &out_blob); } template @@ -408,7 +408,7 @@ void SparseEmbeddingBackwardEx(const nnvm::NodeAttrs& attrs, if (data_stype == kCSRStorage && grad_stype == kDefaultStorage && output_stype == kDefaultStorage) { TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], true, &ret); + DotCsrDnsDnsImpl(ctx.get_stream(), inputs[1], inputs[0].data(), req[1], true, &ret); } else { LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index c684c7ad6057..e8b3936f3627 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -320,805 +320,6 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, return true; } -struct DotParam : public dmlc::Parameter { - bool transpose_a; - bool transpose_b; - DMLC_DECLARE_PARAMETER(DotParam) { - DMLC_DECLARE_FIELD(transpose_a) - .describe("If true then transpose the first input before dot.") - .set_default(false); - DMLC_DECLARE_FIELD(transpose_b) - .describe("If true then transpose the second input before dot.") - .set_default(false); - } -}; - -template -void DotForward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const DotParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, kFloat32) - << "dot only support 32 bit float so far"; - - if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { - CHECK_NE(req[0], kAddTo) << "AddTo not yet suported"; - Tensor out = outputs[0].get(s); - VectorDot(out, - inputs[0].get(s), - inputs[1].get(s)); - } else { - int ma, na, mb, nb, m, n; - if (param.transpose_a) { - ma = inputs[0].size(0); - na = inputs[0].Size()/ma; - m = na; - } else { - na = inputs[0].size(inputs[0].ndim()-1); - ma = inputs[0].Size()/na; - m = ma; - } - if (param.transpose_b) { - nb = inputs[1].size(inputs[1].ndim()-1); - mb = inputs[1].Size()/nb; - n = mb; - } else { - mb = inputs[1].size(0); - nb = inputs[1].Size()/mb; - n = nb; - } - - Tensor input0 = - inputs[0].get_with_shape(Shape2(ma, na), s); - Tensor input1 = - inputs[1].get_with_shape(Shape2(mb, nb), s); - Tensor out = - outputs[0].get_with_shape(Shape2(m, n), s); - if (param.transpose_a && param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T())); - } else if (!param.transpose_a && param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T())); - } else if (param.transpose_a && !param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1)); - } else { - ASSIGN_DISPATCH(out, req[0], dot(input0, input1)); - } - } -} - -template -void DotBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const DotParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - CHECK_NE(req[0], kWriteInplace); - CHECK_NE(req[1], kWriteInplace); - - if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) { - Tensor mout_grad = inputs[0].get(s); - Tensor mlhs_data = inputs[1].get(s); - Tensor mrhs_data = inputs[2].get(s); - Tensor mlhs_grad = outputs[0].get(s); - Tensor mrhs_grad = outputs[1].get(s); - ASSIGN_DISPATCH(mrhs_grad, req[1], - broadcast_scalar(mout_grad, mlhs_data.shape_) * mlhs_data); - ASSIGN_DISPATCH(mlhs_grad, req[0], - broadcast_scalar(mout_grad, mlhs_data.shape_) * mrhs_data); - } else { - int ma, na, mb, nb, m, n; - if (param.transpose_a) { - ma = outputs[0].size(0); - na = outputs[0].Size()/ma; - m = na; - } else { - na = outputs[0].size(outputs[0].ndim()-1); - ma = outputs[0].Size()/na; - m = ma; - } - if (param.transpose_b) { - nb = outputs[1].size(outputs[1].ndim()-1); - mb = outputs[1].Size()/nb; - n = mb; - } else { - mb = outputs[1].size(0); - nb = outputs[1].Size()/mb; - n = nb; - } - - Tensor mout_grad = - inputs[0].get_with_shape(Shape2(m, n), s); - Tensor mlhs_data = - inputs[1].get_with_shape(Shape2(ma, na), s); - Tensor mrhs_data = - inputs[2].get_with_shape(Shape2(mb, nb), s); - Tensor mlhs_grad = - outputs[0].get_with_shape(Shape2(ma, na), s); - Tensor mrhs_grad = - outputs[1].get_with_shape(Shape2(mb, nb), s); - if (param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x.T, y.T) - // dy = dot(x, dz).T = dot(dz.T, x.T) - // dx = dot(dz, y).T = dot(y.T, dz.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data.T())); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data.T(), mout_grad.T())); - } else if (!param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x, y.T) - // dy = dot(x.T, dz).T = dot(dz.T, x) - // dx = dot(dz, y) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data)); - } else if (param.transpose_a && !param.transpose_b) { - // Gradient of z = dot(x.T, y) - // dy = dot(x, dz) - // dx = dot(dz, y.T).T = dot(y, dz.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data, mout_grad)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data, mout_grad.T())); - } else { - // Gradient of z = dot(x, y) - // dy = dot(x.T, dz) - // dx = dot(dz, y.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data.T(), mout_grad)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data.T())); - } - } -} - -inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, - const Context& ctx, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - if (param.transpose_a && kCSRStorage == (*in_attrs)[0] - && kDefaultStorage == (*in_attrs)[1]) { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); - } else { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); - } - return true; -} - -inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, - const Context& ctx, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 3U); - CHECK_EQ(out_attrs->size(), 2U); - const DotParam& param = nnvm::get(attrs.parsed); - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); - if (!param.transpose_a && kDefaultStorage == (*in_attrs)[0] - && kCSRStorage == (*in_attrs)[1] && kDefaultStorage == (*in_attrs)[2]) { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); - } else { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); - } - return true; -} - -/*! - * \brief Kernel of dot(csr, dns1) = dns2 - * Parallelization by output matrix elements - */ -template -struct DotCsrDnsDns { - /*! - * \brief This function represents performing an inner product between a row of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output matrix - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_cols number of columns of output - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, - const int num_cols) { - const int irow = i / num_cols; // row id of the lhs - const int icol = i % num_cols; // col id of the rhs - DType sum = 0; - for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { - const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs - sum += data_l[j] * data_r[cur_col*num_cols+icol]; - } - KERNEL_ASSIGN(out[i], req, sum); - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns1) = dns2 - * Parallelization by output matrix elements - */ -template -struct DotCsrTransDnsDns { - /*! - * \brief This function represents performing an inner product between a column of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output matrix - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_rows_l number of rows of lhs - * \param num_cols number of columns of outputs - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, const int num_rows_l, - const int num_cols) { - const int irow = i / num_cols; // col id of the lhs - const int icol = i % num_cols; // col id of the rhs - DType sum = 0; - for (int k = 0; k < num_rows_l; ++k) { - const IType low = indptr_l[k]; - const IType high = indptr_l[k+1]; - if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; - int j = -1, l = low, r = high - 1; - while (l <= r) { - int m = l + (r - l) / 2; - if (col_idx_l[m] == irow) { - j = m; break; - } - if (col_idx_l[m] < irow) { - l = m + 1; - } else { - r = m - 1; - } - } - if (j >= 0) { - sum += data_l[j] * data_r[k*num_cols+icol]; - } - } - KERNEL_ASSIGN(out[i], req, sum); - } -}; - -/*! - * \brief Kernel of dot(csr, dns1) = dns2 - * Parallelization by row blocks - */ -struct DotCsrDnsDnsByRowBlocks { - /*! - * \brief - * \param i the i-th thread - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, const size_t seg_len, - const size_t num_rows, const size_t num_cols) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); - for (size_t j = seg_start; j < seg_end; ++j) { - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_out = j * num_cols; - for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { - const auto val = data_l[k]; - const size_t offset_r = col_idx_l[k] * num_cols; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; - } - } - } - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns1) = dns2 - * Parallelization by row blocks - */ -struct DotCsrTransDnsDnsByRowBlocks { - /*! - * \brief - * \param i the i-th thread - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, const size_t seg_len, - const size_t num_rows_l, const size_t num_rows, - const size_t num_cols) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (i + 1) * seg_len; - for (size_t j = 0; j < num_rows_l; ++j) { - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_r = j * num_cols; - for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { - const auto col_idx = col_idx_l[k]; - if (col_idx < seg_start || col_idx >= seg_end) continue; - const size_t offset_out = col_idx * num_cols; - const auto val = data_l[k]; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; - } - } - } - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns) = rsp - * Parallelization by row blocks. - * This kernel fills up the row_idx array - * of the rsp with 1 for nonzero rows and 0 - * for zero rows. - * The matrix will be compacted after this kernel call. - */ -struct DotCsrTransDnsRspByRowBlocks { - /*! - * \brief - * \param i the i-th thread - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l, - const IType* indptr_l, const CType* col_idx_l, - const DType* data_r, const size_t seg_len, - const size_t num_rows_l, const size_t num_rows, - const size_t num_cols) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (i + 1) * seg_len; - for (size_t j = 0; j < num_rows_l; ++j) { - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_r = j * num_cols; - for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { - const auto col_idx = col_idx_l[k]; - if (col_idx < seg_start || col_idx >= seg_end) continue; - const size_t offset_out = col_idx * num_cols; - row_idx[col_idx] = 1; - const auto val = data_l[k]; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; - } - } - } - } -}; - -template -void DotCsrDnsDnsImpl(const OpContext& ctx, - const NDArray& lhs, - const TBlob& rhs, - const OpReqType req, - const bool trans_lhs, - TBlob* ret) { - if (kNullOp == req) return; - CHECK_EQ(lhs.storage_type(), kCSRStorage); - if (!lhs.storage_initialized()) return; - - mshadow::Stream *s = ctx.get_stream(); - const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob& data_r = rhs; - const TBlob data_out = *ret; - - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), seg_len, - lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); - } else { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), seg_len, - data_out.shape_[0], data_out.shape_[1]); - } - } else { // gpu parallelization by output elements - if (trans_lhs) { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], - data_out.shape_[1]); - }); - } else { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); - }); - } - } - }); - }); - }); -} - -template -void DotCsrDnsRspImpl(const OpContext& ctx, - const NDArray& lhs, - const TBlob& rhs, - const OpReqType req, - const bool trans_lhs, - NDArray* ret) { - if (kNullOp == req) return; - CHECK_EQ(lhs.storage_type(), kCSRStorage); - CHECK_EQ(ret->storage_type(), kRowSparseStorage); - if (!lhs.storage_initialized()) return; - - mshadow::Stream *s = ctx.get_stream(); - const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob& data_r = rhs; - - // pre-allocate spaces for ret using the dense dimension size - ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); - const TBlob data_out = ret->data(); - const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); - - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - RType* row_idx = row_idx_out.dptr(); - mxnet_op::Kernel::Launch( - s, row_idx_out.Size(), row_idx); - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), row_idx, data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); - index_t nnr = 0; - nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); - ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); - ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); - if (0 == nnr) return; - mshadow::Tensor rsp_data = data_out.FlatTo2D(s); - size_t idx = 0; - for (index_t i = 0; i < ret->shape()[0]; ++i) { - if (row_idx[i] > 0) { - row_idx[idx] = i; - mshadow::Copy(rsp_data[idx], rsp_data[i], s); - ++idx; - } - } - } else { - LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." - " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; - } - } else { - LOG(FATAL) << "DotCsrDnsRspImpl has not implemented GPU version yet."; - } - }); - }); - }); - }); -} - -template -void DotCsrRspDnsImpl(const OpContext& ctx, - const NDArray& lhs, - const NDArray& rhs, - const OpReqType req, - const bool trans_lhs, - TBlob* ret) { - CHECK_RSP_ALL_ROWS_NON_ZERO(rhs, "Dot", "rhs"); - // reuse csr dns implementation when storage_shape == shape for rhs - DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); -} - -template -void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const DotParam& param = nnvm::get(attrs.parsed); - TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); -} - -template -void DotBackwardCsrRspDns(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const auto& rhs = inputs[2]; - CHECK_RSP_ALL_ROWS_NON_ZERO(rhs, "Dot", "rhs"); - // reuse csr dns implementation when storage_shape == shape for rhs - const DotParam& param = nnvm::get(attrs.parsed); - TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); -} - -inline bool DotShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - TShape& lshape = (*in_attrs)[0]; - TShape& rshape = (*in_attrs)[1]; - if (lshape.ndim() == 1 && rshape.ndim() == 1) { - CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors"; - CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1)); - } else { - bool Ta = param.transpose_a, Tb = param.transpose_b; - TShape L[2], R[2]; - if (Ta) { - L[0] = mshadow::Shape1(lshape[0]); - L[1] = lshape.ndim() > 1 ? TShape(&lshape[1], &lshape[lshape.ndim()]) : TShape(1); - } else { - L[0] = lshape.ndim() > 1 ? TShape(&lshape[0], &lshape[lshape.ndim()-1]) : TShape(1); - L[1] = mshadow::Shape1(lshape[lshape.ndim()-1]); - } - if (Tb) { - R[0] = rshape.ndim() > 1 ? TShape(&rshape[0], &rshape[rshape.ndim()-1]) : TShape(1); - R[1] = mshadow::Shape1(rshape[rshape.ndim()-1]); - } else { - R[0] = mshadow::Shape1(rshape[0]); - R[1] = rshape.ndim() > 1 ? TShape(&rshape[1], &rshape[rshape.ndim()]) : TShape(1); - } - - if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) { - CHECK_EQ(L[!Ta].Size(), R[Tb].Size()) - << "dot shape error: " << lshape << " X " << rshape; - } - std::vector buf; - if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], &L[Ta][L[Ta].ndim()]); - if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], &R[!Tb][R[!Tb].ndim()]); - TShape oshape(buf.begin(), buf.end()); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - } - return true; -} - -template -void DotForwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; - auto lhs_stype = inputs[0].storage_type(); - auto rhs_stype = inputs[1].storage_type(); - auto out_stype = outputs[0].storage_type(); - if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { - TBlob ret = outputs[0].data(); - DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); - } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && - out_stype == kDefaultStorage) { - TBlob ret = outputs[0].data(); - DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage - && out_stype == kRowSparseStorage) { - NDArray out = outputs[0]; - DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); - } else { - FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); - } -} - -template -void DotBackwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 3U); - CHECK_EQ(outputs.size(), 2U); - CHECK_EQ(req.size(), 2U); - CHECK_EQ(kNullOp, req[0]) - << "sparse dot does not support computing the gradient of the csr/lhs"; - CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; - - const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; - auto ograd_stype = inputs[0].storage_type(); - auto lhs_stype = inputs[1].storage_type(); - auto rhs_stype = inputs[2].storage_type(); - if (ograd_stype == kDefaultStorage // ograd dns format - && lhs_stype == kCSRStorage // csr input lhs of the op - && rhs_stype == kDefaultStorage // dns input rhs of the op - && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format - // dns, csr, dns => *, dns - DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); - } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && - rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) { - // dns, csr, rsp => *, dns - DotBackwardCsrRspDns(attrs, ctx, inputs, req, outputs); - } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && - rhs_stype == kDefaultStorage && outputs[1].storage_type() == kRowSparseStorage) { - NDArray grad_rhs = outputs[1]; - DotCsrDnsRspImpl(ctx, inputs[1], inputs[2].data(), req[1], !param.transpose_a, &grad_rhs); - } else { - FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); - } -} - -template -void BatchDotForward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, mshadow::kFloat32) - << "dot only support 32 bit float so far"; - - mshadow::Tensor out = outputs[0].get(s); - mshadow::Tensor mlhs = inputs[0].get(s); - mshadow::Tensor mrhs = inputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); - if (kNullOp != req[0]) { - if (param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } else if (!param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } else if (param.transpose_a && !param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } else { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } - } -} - -template -void BatchDotBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_NE(req[1], kWriteInplace); - CHECK_NE(req[0], kWriteInplace); - - mshadow::Tensor mout_grad = inputs[0].get(s); - mshadow::Tensor mlhs_data = inputs[1].get(s); - mshadow::Tensor mrhs_data = inputs[2].get(s); - mshadow::Tensor mlhs_grad = outputs[0].get(s); - mshadow::Tensor mrhs_grad = outputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed( - mshadow::Shape2(2, 3 * mout_grad.size(0)), s); - mshadow::Tensor rhs_workspace = workspace[0]; - mshadow::Tensor lhs_workspace = workspace[1]; - if (param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x.T, y.T) - // dy = dot(x, dz).T = dot(dz.T, x.T) - // dx = dot(dz, y).T = dot(y.T, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } else if (!param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x, y.T) - // dy = dot(x.T, dz).T = dot(dz.T, x) - // dx = dot(dz, y) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } else if (param.transpose_a && !param.transpose_b) { - // Gradient of z = dot(x.T, y) - // dy = dot(x, dz) - // dx = dot(dz, y.T).T = dot(y, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } else { - // Gradient of z = dot(x, y) - // dy = dot(x.T, dz) - // dx = dot(dz, y.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } -} - -inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - TShape& lshape = (*in_attrs)[0]; - TShape& rshape = (*in_attrs)[1]; - if (lshape.ndim() == 3 && rshape.ndim() == 3) { - CHECK(lshape[0] == rshape[0]) - << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape - << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - index_t out_m = param.transpose_a ? lshape[2] : lshape[1]; - index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2]; - index_t out_n = param.transpose_b ? rshape[1] : rshape[2]; - index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1]; - CHECK(lshape_k == rshape_k) - << "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape - << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n)); - } else { - LOG(FATAL) << "batch_dot currently only support 3D*3D array" - << lshape << " v.s. " << rshape; - } - return true; -} - struct SliceParam : public dmlc::Parameter { nnvm::Tuple > begin, end; DMLC_DECLARE_PARAMETER(SliceParam) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 72d8aadbe90a..e6ab9798bef6 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -16,7 +16,6 @@ DMLC_REGISTER_PARAMETER(ClipParam); DMLC_REGISTER_PARAMETER(SimpleCropAssignScalarParam); DMLC_REGISTER_PARAMETER(SliceParam); DMLC_REGISTER_PARAMETER(SliceAxisParam); -DMLC_REGISTER_PARAMETER(DotParam); DMLC_REGISTER_PARAMETER(RepeatParam); DMLC_REGISTER_PARAMETER(TileParam); DMLC_REGISTER_PARAMETER(ReverseParam); @@ -344,106 +343,6 @@ NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("TIsBackward", true) .set_attr("FCompute", SliceAxisGrad_); -NNVM_REGISTER_OP(dot) -.describe(R"doc(Dot product of two arrays. - -``dot``'s behavior depends on the input array dimensions: - -- 1-D arrays: inner product of vectors -- 2-D arrays: matrix multiplication -- N-D arrays: a sum product over the last axis of the first input and the first - axis of the second input - - For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the - result array will have shape `(n,m,r,s)`. It is computed by:: - - dot(x,y)[i,j,a,b] = sum(x[i,j,:]*y[:,a,b]) - - Example:: - - x = reshape([0,1,2,3,4,5,6,7], shape=(2,2,2)) - y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2)) - dot(x,y)[0,0,1,1] = 0 - sum(x[0,0,:]*y[:,1,1]) = 0 -)doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) -.set_attr("FInferShape", DotShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FInferStorageType", DotForwardInferStorageType) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", DotForward_) -.set_attr("FComputeEx", DotForwardEx) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) -.add_argument("lhs", "NDArray-or-Symbol", "The first input") -.add_argument("rhs", "NDArray-or-Symbol", "The second input") -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_dot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FInferStorageType", DotBackwardInferStorageType) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", DotBackward_) -.set_attr("FComputeEx", DotBackwardEx) -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(batch_dot) -.describe(R"doc(Batchwise dot product. - -``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and -``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`. - -For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape -`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`, -which is computed by:: - - batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:]) - -)doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) -.set_attr("FInferShape", BatchDotShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BatchDotForward_) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) -.add_argument("lhs", "NDArray-or-Symbol", "The first input") -.add_argument("rhs", "NDArray-or-Symbol", "The second input") -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_batch_dot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BatchDotBackward_); - NNVM_REGISTER_OP(clip) .describe(R"code(Clips (limits) the values in an array. diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 2e1effb9e560..91a6757b962c 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -39,21 +39,6 @@ NNVM_REGISTER_OP(slice_axis) NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("FCompute", SliceAxisGrad_); -NNVM_REGISTER_OP(dot) -.set_attr("FCompute", DotForward_) -.set_attr("FComputeEx", DotForwardEx); - -NNVM_REGISTER_OP(_backward_dot) -.set_attr("FCompute", DotBackward_) -.set_attr("FComputeEx", DotBackwardEx); - - -NNVM_REGISTER_OP(batch_dot) -.set_attr("FCompute", BatchDotForward_); - -NNVM_REGISTER_OP(_backward_batch_dot) -.set_attr("FCompute", BatchDotBackward_); - NNVM_REGISTER_OP(clip) .set_attr("FCompute", Clip); diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 4d2debe5f9d2..1fc64a7149ea 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -102,10 +102,10 @@ def test_dns_to_csr(dns_in): def test_sparse_dot(): - def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): + def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): lhs_dns = rand_ndarray(lhs_shape, 'default') lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr') - rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1) + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density) rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) if trans_lhs: @@ -130,11 +130,13 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): grad_req={'lhs': 'null', 'rhs': 'write'}, rtol=1e-3, atol=1e-4) - lhs_shape = rand_shape_2d() + lhs_shape = rand_shape_2d(50, 200) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, 0.05) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, 0.05) def test_sparse_embedding():