Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Bugfix] [Numpy] Add kAddTo and kNullOp to Transpose #16979

Merged
merged 8 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,22 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
if (req[0] == kNullOp) return;
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
<< "Transpose only supports kWriteTo, kNullOp and kAddTo";
mxnet::TShape axes;
if (ndim_is_known(param.axes)) {
mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
axes = common::CanonicalizeAxes(param.axes);
} else {
mxnet::TShape axes(inputs[0].ndim(), -1);
axes = mxnet::TShape(inputs[0].ndim(), -1);
for (int i = 0; i < axes.ndim(); ++i) {
axes[i] = axes.ndim() - 1 - i;
}
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
}
if (req[0] == kAddTo) {
TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
} else {
TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
}
}

Expand Down
12 changes: 9 additions & 3 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/

#include <vector>
#include <set>
#include "./np_matrix_op-inl.h"
#include "../nn/concat-inl.h"

Expand Down Expand Up @@ -67,8 +68,13 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape ret(ndim, -1);

if (ndim_is_known(param.axes)) {
CHECK_EQ(ndim, param.axes.ndim());
CHECK_EQ(ndim, param.axes.ndim())
<< "The number of axes does not match the dimension of the tensor. axes = "
<< param.axes << ", input tensor shape = " << shp;
mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
std::set<dim_t> axes_set(axes.begin(), axes.end());
CHECK_EQ(axes_set.size(), axes.ndim()) << "Repeated axis in transpose. param.axes = "
<< param.axes;
if (ndim_is_known(shp)) {
for (int i = 0; i < ndim; ++i) {
ret[i] = shp[axes[i]];
Expand Down Expand Up @@ -117,9 +123,9 @@ NNVM_REGISTER_OP(_np_transpose)
}
std::ostringstream os;
os << axes;
return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}});
return MakeNonlossGradNode("_np_transpose", n, ograds, {}, {{"axes", os.str()}});
} else {
return MakeNonlossGradNode("transpose", n, ograds, {},
return MakeNonlossGradNode("_np_transpose", n, ograds, {},
std::unordered_map<std::string, std::string>());
}
})
Expand Down
106 changes: 71 additions & 35 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,10 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
* \param out output tensor
* \param row shape of dim 0 of input
* \param col shape of dim 1 of input
* \tparam DType Data type
* \tparam is_addto
*/
template<typename DType>
template<typename DType, bool is_addto>
MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
// ensure cache line hits and prevent cache miss for any configuration
// L1 cache size to be utilized = 32kb = 2^15
Expand All @@ -282,7 +284,7 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
// blocksize * blocksize * num_threads = cache_size / dtype_size
// Instead of explicit unroll, let compiler figure out optimal unroll factor
index_t blocksize = 32;
const index_t blocksize = 32;

// collapse 2 parallelizes 2 for loops
// inner 2 for loops aren't parallelized to prevent cache miss
Expand All @@ -299,14 +301,25 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
// transpose the block
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
out[a * row + b] = in[b * col + a];
if (!is_addto) {
out[a * row + b] = in[b * col + a];
} else {
out[a * row + b] += in[b * col + a];
}
}
}
}
}
}

template<typename xpu>
inline bool IsIdentityTranspose(const TShape& axes) {
for (dim_t i = 0; i < axes.ndim(); i++) {
if (axes[i] != i) return false;
}
return true;
}

template<typename xpu, bool is_addto = false>
void TransposeImpl(RunContext ctx,
const TBlob& src,
const TBlob& ret,
Expand All @@ -323,62 +336,79 @@ void TransposeImpl(RunContext ctx,
// Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3).
if (isPseudo2DTranspose(axes)) {
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
transpose_pseudo2D<DType>(ret, src, axes, s);
transpose_pseudo2D<DType, is_addto>(ret, src, axes, s);
});
return;
}
#endif
// Special handle the identity case
if (IsIdentityTranspose(axes)) {
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(src.Size()), s);
Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(ret.Size()), s);
if (!is_addto) {
// Use memcpy to accelerate the speed
Copy(out, in, s);
} else {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kAddTo>, xpu>::Launch(
s, ret.Size(), out.dptr_, in.dptr_);
}
});
return;
}
// Handle the general transpose case
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
switch (axes.ndim()) {
case 0: {
Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
Copy(out, in, s);
break;
}
case 1: {
Tensor<xpu, 1, DType> in = src.get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> out = ret.get<xpu, 1, DType>(s);
Copy(out, in, s);
break;
}
case 2: {
mshadow::Tensor<xpu, 2, DType> in = src.FlatTo2D<xpu, DType>(s);
mshadow::Tensor<xpu, 2, DType> out = ret.FlatTo2D<xpu, DType>(s);

if (axes[0] == 1 && axes[1] == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this if case removed?
mx.nd.transpose(a,axes=(0,1)) is basically No Transpose (and it is handled in this scenario)

Realized you moved the Copy function above and called it IdentityTranspose
Makes sense now

if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
Transpose2D<DType>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
} else {
out = in.T();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR for 2D transpose with GPU is still WIP #16706
Till then we default to the mshadow expression template based implementation of Transpose

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChaiBapchya We have already had transpose_pseudo2D, which covers the case of 2D Transpose in GPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right..

}
Tensor<xpu, 2, DType> in = src.get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> out = ret.get<xpu, 2, DType>(s);
if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
Transpose2D<DType, is_addto>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
} else {
Copy(out, in, s);
LOG(FATAL) << "Not Implemented. We should never reach here because the 2D case "
"in GPU has been covered by transpose_pseudo2D."
" Report an issue in Github.";
}
break;
}
case 3: {
Tensor<xpu, 3, DType> in = src.get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> out = ret.get<xpu, 3, DType>(s);
out = transpose(in, axes.get<3>());
if (!is_addto) {
out = transpose(in, axes.get<3>());
} else {
out += transpose(in, axes.get<3>());
}
break;
}
case 4: {
Tensor<xpu, 4, DType> in = src.get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out = ret.get<xpu, 4, DType>(s);
out = transpose(in, axes.get<4>());
if (!is_addto) {
out = transpose(in, axes.get<4>());
} else {
out += transpose(in, axes.get<4>());
}
break;
}
case 5: {
Tensor<xpu, 5, DType> in = src.get<xpu, 5, DType>(s);
Tensor<xpu, 5, DType> out = ret.get<xpu, 5, DType>(s);
out = transpose(in, axes.get<5>());
if (!is_addto) {
out = transpose(in, axes.get<5>());
} else {
out += transpose(in, axes.get<5>());
}
break;
}
case 6: {
Tensor<xpu, 6, DType> in = src.get<xpu, 6, DType>(s);
Tensor<xpu, 6, DType> out = ret.get<xpu, 6, DType>(s);
out = transpose(in, axes.get<6>());
if (!is_addto) {
out = transpose(in, axes.get<6>());
} else {
out += transpose(in, axes.get<6>());
}
break;
}
default:
Expand All @@ -399,15 +429,21 @@ void Transpose(const nnvm::NodeAttrs& attrs,
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
<< "Transpose only supports kNullOp, kWriteTo and kAddTo";
mxnet::TShape axes;
if (param.axes.ndim() == 0) {
mxnet::TShape axes(inputs[0].ndim(), -1);
axes = mxnet::TShape(inputs[0].ndim(), -1);
for (int i = 0; i < axes.ndim(); ++i) {
axes[i] = axes.ndim() - 1 - i;
}
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
} else {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], param.axes);
axes = common::CanonicalizeAxes(param.axes);
}
if (req[0] == kAddTo) {
TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
} else {
TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,12 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
CHECK(req[0] == kWriteTo || req[0] == kAddTo) <<
"Transpose only supports kNullOp, kWriteTo and kAddTo";
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);

if (SupportMKLDNNTranspose(param, inputs[0])) {
if (SupportMKLDNNTranspose(param, inputs[0]) && req[0] == kWriteTo) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
return;
}
Expand Down
Loading