Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-882] Support for N-d arrays added to diag op. #12430

Merged
merged 15 commits into from
Sep 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,5 @@ List of Contributors
* [Aaron Markham](https://github.com/aaronmarkham)
* [Sam Skalicky](https://github.com/samskalicky)
* [Per Goncalves da Silva](https://github.com/perdasilva)
* [Zhijingcheng Yu](https://github.com/jasonyu1996)
* [Cheng-Che Lee](https://github.com/stu1130)
232 changes: 173 additions & 59 deletions src/operator/tensor/diag_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file diag_op-inl.h
* \brief CPU Implementation of the diag op
* \author Istvan Fehervari
* \author Istvan Fehervari, Zhijingcheng Yu
*/

#ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
Expand All @@ -30,33 +30,51 @@
#include <dmlc/parameter.h>
#include <vector>
#include <algorithm>
#include <utility>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "./broadcast_reduce_op.h"

namespace mxnet {
namespace op {

struct DiagParam : public dmlc::Parameter<DiagParam> {
dmlc::optional<int> k;
int k;
int32_t axis1;
int32_t axis2;
DMLC_DECLARE_PARAMETER(DiagParam) {
DMLC_DECLARE_FIELD(k)
.set_default(dmlc::optional<int>(0))
.describe("Diagonal in question. The default is 0. "
"Use k>0 for diagonals above the main diagonal, "
"and k<0 for diagonals below the main diagonal. "
"If input has shape (S0 S1) k must be between -S0 and S1");
.set_default(0)
.describe("Diagonal in question. The default is 0. "
"Use k>0 for diagonals above the main diagonal, "
"and k<0 for diagonals below the main diagonal. "
"If input has shape (S0 S1) k must be between -S0 and S1");
DMLC_DECLARE_FIELD(axis1)
.set_default(0)
.describe("The first axis of the sub-arrays of interest. "
"Ignored when the input is a 1-D array.");
DMLC_DECLARE_FIELD(axis2)
.set_default(1)
.describe("The second axis of the sub-arrays of interest. "
"Ignored when the input is a 1-D array.");
}
};

inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
inline TShape DiagShapeImpl(const TShape& ishape, const int k,
const int32_t axis1, const int32_t axis2) {
if (ishape.ndim() == 1) {
auto s = ishape[0] + std::abs(k);
return TShape({s, s});
}

auto h = ishape[0];
auto w = ishape[1];
int32_t x1 = CheckAxis(axis1, ishape.ndim());
int32_t x2 = CheckAxis(axis2, ishape.ndim());

CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the the same axis " << x1;

auto h = ishape[x1];
auto w = ishape[x2];

if (k > 0) {
w -= k;
Expand All @@ -69,7 +87,24 @@ inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
s = 0;
}

return TShape({s});
if (x1 > x2) {
std::swap(x1, x2);
}

int32_t n_dim = static_cast<int32_t>(ishape.ndim()) - 1;
TShape oshape(n_dim);

// remove axis1 and axis2 and append the new axis to the end
uint32_t idx = 0;
for (int32_t i = 0; i <= n_dim; ++i) {
if (i != x1 && i != x2) {
oshape[idx++] = ishape[i];
}
}

oshape[n_dim - 1] = s;

return oshape;
}

inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
Expand All @@ -79,12 +114,16 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);

const TShape& ishape = (*in_attrs)[0];
if (ishape.ndim() == 0) return false;
if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d.";
if (ishape.ndim() == 0) {
return false;
}

const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);

TShape oshape = DiagShapeImpl(ishape, param.k.value());
TShape oshape = DiagShapeImpl(ishape,
param.k,
param.axis1,
param.axis2);
if (shape_is_none(oshape)) {
LOG(FATAL) << "Diagonal does not exist.";
}
Expand All @@ -104,42 +143,144 @@ inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
return (*out_attrs)[0] != -1;
}

template<int req>
template<int ndim, int req, bool back>
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
struct diag {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
mshadow::Shape<2> ishape, int k) {
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
mshadow::Shape<ndim> oshape,
mshadow::Shape<ndim> ishape,
index_t stride, index_t offset,
index_t base) {
using namespace mxnet_op;
int j = 0;
if (k > 0) {
j = ravel(mshadow::Shape2(i, i + k), ishape);
} else if (k < 0) {
j = ravel(mshadow::Shape2(i - k, i), ishape);
index_t idx = i / base;
index_t j = ravel(unravel(idx, oshape), ishape) + offset + stride * (i - idx * base);
if (back) {
KERNEL_ASSIGN(out[j], req, a[i]);
} else {
j = ravel(mshadow::Shape2(i, i), ishape);
KERNEL_ASSIGN(out[i], req, a[j]);
}

KERNEL_ASSIGN(out[i], req, a[j]);
}
};

template<int req>
template<int req, bool back>
struct diag_gen {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
mshadow::Shape<2> oshape, int k) {
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
using namespace mxnet_op;

auto j = unravel(i, oshape);
if (j[1] == (j[0] + k)) {
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
auto l = j[0] < j[1] ? j[0] : j[1];
KERNEL_ASSIGN(out[i], req, a[l]);
} else {
if (back) {
KERNEL_ASSIGN(out[l], req, a[i]);
} else {
KERNEL_ASSIGN(out[i], req, a[l]);
}
} else if (!back) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
}
}
};

template<typename xpu, bool back>
void DiagOpProcess(const TBlob& in_data,
const TBlob& out_data,
const TShape& ishape,
const TShape& oshape,
index_t dsize,
const DiagParam& param,
mxnet_op::Stream<xpu> *s,
const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow;
if (ishape.ndim() > 1) {
// input : (leading + i, body + i, trailing)
uint32_t x1 = CheckAxis(param.axis1, ishape.ndim());
uint32_t x2 = CheckAxis(param.axis2, ishape.ndim());

uint32_t idim = ishape.ndim(), odim = oshape.ndim();

uint32_t minx = x1, maxx = x2;
if (minx > maxx) {
std::swap(minx, maxx);
}

// merges contiguous axes that are not separated
// by axis1 or axis2 since they can be directly
// mapped to the output and there is no need
// to distinguish them
// (After this the input will have no more than
// three axes, hence improving the rave and
// unravel efficiency)

index_t oleading = 1,
obody = 1,
otrailing = 1;

for (uint32_t i = 0; i < minx; ++i) {
oleading *= ishape[i];
}
for (uint32_t i = minx + 1; i < maxx; ++i) {
obody *= ishape[i];
}
for (uint32_t i = maxx + 1; i < idim; ++i) {
otrailing *= ishape[i];
}

index_t ileading = oleading,
ibody = obody * ishape[minx],
itrailing = otrailing * ishape[maxx];

index_t stride1 = itrailing * obody,
stride2 = otrailing;
// stride1 + stride2 is the stride for
// iterating over the diagonal in question

if (x1 == maxx) {
std::swap(stride1, stride2);
}

// the extra index offset introduced by k
index_t offset;
int k = param.k;
if (k > 0) {
offset = stride2 * k;
} else if (k < 0) {
offset = stride1 * -k;
} else {
offset = 0;
}

MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (back && req[0] != kAddTo && req[0] != kNullOp) {
out_data.FlatTo1D<xpu, DType>(s) = 0;
}
if (ileading == 1) {
Kernel<diag<2, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(obody, otrailing),
Shape2(ibody, itrailing),
stride1 + stride2, offset, oshape[odim - 1]);
} else {
Kernel<diag<3, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
in_data.dptr<DType>(), Shape3(oleading, obody, otrailing),
Shape3(ileading, ibody, itrailing),
stride1 + stride2, offset, oshape[odim - 1]);
}
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag_gen<req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(oshape[0], oshape[1]),
param.k);
});
});
}
}

template<typename xpu>
void DiagOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -159,21 +300,7 @@ void DiagOpForward(const nnvm::NodeAttrs& attrs,
const TShape& oshape = outputs[0].shape_;
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);

if (ishape.ndim() == 2) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(ishape[0], ishape[1]), param.k.value());
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(oshape[0], oshape[1]), param.k.value());
});
});
}
DiagOpProcess<xpu, false>(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req);
}

template<typename xpu>
Expand All @@ -194,23 +321,10 @@ void DiagOpBackward(const nnvm::NodeAttrs& attrs,
const TShape& oshape = outputs[0].shape_;
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);

if (oshape.ndim() == 2) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(oshape[0], oshape[1]), param.k.value());
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(ishape[0], ishape[1]), param.k.value());
});
});
}
DiagOpProcess<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req);
}


} // namespace op
} // namespace mxnet

Expand Down
27 changes: 23 additions & 4 deletions src/operator/tensor/diag_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file diag_op.cc
* \brief
* \author Istvan Fehervari
* \author Istvan Fehervari, Zhijingcheng Yu
*/

#include "./diag_op-inl.h"
Expand All @@ -36,9 +36,13 @@ NNVM_REGISTER_OP(diag)

``diag``'s behavior depends on the input array dimensions:

- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero
- 2-D arrays: returns elements in the diagonal as a new 1-D array
- N-D arrays: not supported yet
- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero.
- N-D arrays: extracts the diagonals of the sub-arrays with axes specified by ``axis1`` and ``axis2``.
The output shape would be decided by removing the axes numbered ``axis1`` and ``axis2`` from the
input shape and appending to the result a new axis with the size of the diagonals in question.

For example, when the input shape is `(2, 3, 4, 5)`, ``axis1`` and ``axis2`` are 0 and 2
respectively and ``k`` is 0, the resulting shape would be `(3, 5, 2)`.

Examples::

Expand All @@ -65,6 +69,21 @@ Examples::
[1, 0, 0],
[0, 2, 0]]

x = [[[1, 2],
[3, 4]],

[[5, 6],
[7, 8]]]

diag(x) = [[1, 7],
[2, 8]]

diag(x, k=1) = [[3],
[4]]

diag(x, axis1=-2, axis2=-1) = [[1, 4],
[5, 8]]

)code" ADD_FILELINE)
.set_attr_parser(ParamParser<DiagParam>)
.set_num_inputs(1)
Expand Down
Loading