Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-882] Support for N-d arrays added to diag op. #12430

Merged
merged 15 commits into from
Sep 18, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,4 @@ List of Contributors
* [Aaron Markham](https://github.com/aaronmarkham)
* [Sam Skalicky](https://github.com/samskalicky)
* [Per Goncalves da Silva](https://github.com/perdasilva)
* [Zhijingcheng Yu](https://github.com/jasonyu1996)
189 changes: 139 additions & 50 deletions src/operator/tensor/diag_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file diag_op-inl.h
* \brief CPU Implementation of the diag op
* \author Istvan Fehervari
* \author Istvan Fehervari, Zhijingcheng Yu
*/

#ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
Expand All @@ -30,33 +30,51 @@
#include <dmlc/parameter.h>
#include <vector>
#include <algorithm>
#include <utility>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "./broadcast_reduce_op.h"

namespace mxnet {
namespace op {

struct DiagParam : public dmlc::Parameter<DiagParam> {
dmlc::optional<int> k;
dmlc::optional<int> axis1;
dmlc::optional<int> axis2;
DMLC_DECLARE_PARAMETER(DiagParam) {
DMLC_DECLARE_FIELD(k)
.set_default(dmlc::optional<int>(0))
.describe("Diagonal in question. The default is 0. "
"Use k>0 for diagonals above the main diagonal, "
"and k<0 for diagonals below the main diagonal. "
"If input has shape (S0 S1) k must be between -S0 and S1");
DMLC_DECLARE_FIELD(axis1)
.set_default(dmlc::optional<int>(0))
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
.describe("The first axis of the sub-arrays of interest. "
"Ignored when the input is a 1-D array.");
DMLC_DECLARE_FIELD(axis2)
.set_default(dmlc::optional<int>(1))
.describe("The second axis of the sub-arrays of interest. "
"Ignored when the input is a 1-D array.");
}
};

inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k,
const nnvm::dim_t axis1, const nnvm::dim_t axis2) {
if (ishape.ndim() == 1) {
auto s = ishape[0] + std::abs(k);
return TShape({s, s});
}
samskalicky marked this conversation as resolved.
Show resolved Hide resolved

auto h = ishape[0];
auto w = ishape[1];
int x1 = CheckAxis(axis1, ishape.ndim());
int x2 = CheckAxis(axis2, ishape.ndim());

CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the the same axis " << x1;

auto h = ishape[x1];
auto w = ishape[x2];

if (k > 0) {
w -= k;
Expand All @@ -69,7 +87,21 @@ inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
s = 0;
}

return TShape({s});
if (x1 > x2) {
std::swap(x1, x2);
}

int n_dim = static_cast<int>(ishape.ndim()) - 1;
apeforest marked this conversation as resolved.
Show resolved Hide resolved
TShape oshape(n_dim);

// remove axis1 and axis2 and append the new axis to the end
int idx = 0;
for (int i = 0; i <= n_dim; i ++)
if (i != x1 && i != x2)
oshape[idx ++] = ishape[i];
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
oshape[n_dim - 1] = s;
samskalicky marked this conversation as resolved.
Show resolved Hide resolved

return oshape;
}

inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
Expand All @@ -80,11 +112,13 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,

const TShape& ishape = (*in_attrs)[0];
if (ishape.ndim() == 0) return false;
if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d.";
samskalicky marked this conversation as resolved.
Show resolved Hide resolved

const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);

TShape oshape = DiagShapeImpl(ishape, param.k.value());
TShape oshape = DiagShapeImpl(ishape,
param.k.value(),
param.axis1.value(),
param.axis2.value());
if (shape_is_none(oshape)) {
LOG(FATAL) << "Diagonal does not exist.";
}
Expand All @@ -104,26 +138,26 @@ inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
return (*out_attrs)[0] != -1;
}

template<int req>
template<int ndim, int req, bool back>
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
struct diag {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
mshadow::Shape<2> ishape, int k) {
mshadow::Shape<ndim> oshape,
mshadow::Shape<ndim> ishape,
int stride, int offset,
int base) {
using namespace mxnet_op;
int j = 0;
if (k > 0) {
j = ravel(mshadow::Shape2(i, i + k), ishape);
} else if (k < 0) {
j = ravel(mshadow::Shape2(i - k, i), ishape);
int idx = i / base;
int j = ravel(unravel(idx, oshape), ishape) + offset + stride * (i - idx * base);
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
if (back) {
KERNEL_ASSIGN(out[j], req, a[i]);
} else {
j = ravel(mshadow::Shape2(i, i), ishape);
KERNEL_ASSIGN(out[i], req, a[j]);
}

KERNEL_ASSIGN(out[i], req, a[j]);
}
};

template<int req>
template<int req, bool back>
struct diag_gen {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
Expand All @@ -133,13 +167,94 @@ struct diag_gen {
auto j = unravel(i, oshape);
if (j[1] == (j[0] + k)) {
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
auto l = j[0] < j[1] ? j[0] : j[1];
KERNEL_ASSIGN(out[i], req, a[l]);
} else {
if (back) {
KERNEL_ASSIGN(out[l], req, a[i]);
} else {
KERNEL_ASSIGN(out[i], req, a[l]);
}
} else if (!back) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
}
}
};

template<typename xpu, bool back>
void DiagOpProcess(const TBlob& in_data,
const TBlob& out_data,
const TShape& ishape,
const TShape& oshape,
int dsize,
const DiagParam& param,
mxnet_op::Stream<xpu> *s,
const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow;
if (ishape.ndim() > 1) {
// input : (leading + i, body + i, trailing)
int x1 = CheckAxis(param.axis1.value(), ishape.ndim());
int x2 = CheckAxis(param.axis2.value(), ishape.ndim());
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved

int idim = ishape.ndim(), odim = oshape.ndim();
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved

int minx = x1, maxx = x2;
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
if (minx > maxx)
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
std::swap(minx, maxx);

int oleading = 1, obody = 1, otrailing = 1;
for (int i = 0; i < minx; i ++)
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
oleading *= ishape[i];
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
for (int i = minx + 1; i < maxx; i ++)
obody *= ishape[i];
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
for (int i = maxx + 1; i < idim; i ++)
otrailing *= ishape[i];
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved


jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
int ileading = oleading,
ibody = obody * ishape[minx],
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
itrailing = otrailing * ishape[maxx];

int stride1 = itrailing * obody,
stride2 = otrailing;

if (x1 == maxx) {
std::swap(stride1, stride2);
}
int offset, k = param.k.value();
if (k > 0) {
offset = stride2 * k;
} else if (k < 0) {
offset = stride1 * -k;
} else {
offset = 0;
}

MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (back && req[0] != kAddTo && req[0] != kNullOp) out_data.FlatTo1D<xpu, DType>(s) = 0;
if (ileading == 1) {
Kernel<diag<2, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(obody, otrailing),
Shape2(ibody, itrailing),
stride1 + stride2, offset, oshape[odim - 1]);
} else {
Kernel<diag<3, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
jasonyu1996 marked this conversation as resolved.
Show resolved Hide resolved
in_data.dptr<DType>(), Shape3(oleading, obody, otrailing),
Shape3(ileading, ibody, itrailing),
stride1 + stride2, offset, oshape[odim - 1]);
}
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag_gen<req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(oshape[0], oshape[1]),
param.k.value());
});
});
}
}

template<typename xpu>
void DiagOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -159,21 +274,7 @@ void DiagOpForward(const nnvm::NodeAttrs& attrs,
const TShape& oshape = outputs[0].shape_;
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);

if (ishape.ndim() == 2) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(ishape[0], ishape[1]), param.k.value());
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(oshape[0], oshape[1]), param.k.value());
});
});
}
DiagOpProcess<xpu, false>(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req);
}

template<typename xpu>
Expand All @@ -194,23 +295,11 @@ void DiagOpBackward(const nnvm::NodeAttrs& attrs,
const TShape& oshape = outputs[0].shape_;
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);

if (oshape.ndim() == 2) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(oshape[0], oshape[1]), param.k.value());
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(ishape[0], ishape[1]), param.k.value());
});
});
}

DiagOpProcess<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req);
}


} // namespace op
} // namespace mxnet

Expand Down
12 changes: 8 additions & 4 deletions src/operator/tensor/diag_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file diag_op.cc
* \brief
* \author Istvan Fehervari
* \author Istvan Fehervari, Zhijingcheng Yu
*/

#include "./diag_op-inl.h"
Expand All @@ -36,9 +36,13 @@ NNVM_REGISTER_OP(diag)

``diag``'s behavior depends on the input array dimensions:

- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero
- 2-D arrays: returns elements in the diagonal as a new 1-D array
- N-D arrays: not supported yet
- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero.
- N-D arrays: extracts the diagonals of the sub-arrays with axes specified by ``axis1`` and ``axis2``.
The output shape would be decided by removing the axes numbered ``axis1`` and ``axis2`` from the
input shape and appending to the result a new axis with the size of the diagonals in question.

For example, when the input shape is `(2, 3, 4, 5)`, ``axis1`` and ``axis2`` are 0 and 2
respectively and ``k`` is 0, the resulting shape would be `(3, 5, 2)`.

Examples::

Expand Down
48 changes: 46 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4493,7 +4493,7 @@ def test_invalid_shape():
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
where_sym = mx.sym.where(condition, x, y)

assert_exception(lambda: where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]),
y=mx.nd.array([[8,9],[10,11],[12,13]]),
condition=mx.nd.array([1,0])), MXNetError)
Expand Down Expand Up @@ -5077,7 +5077,7 @@ def _validate_sample_location(input_rois, input_offset, spatial_scale, pooled_w,
trans_x = input_offset[roi_idx, class_id * 2, part_h, part_w] * trans_std
trans_y = input_offset[roi_idx, class_id * 2 + 1, part_h, part_w] * trans_std
bin_h_start, bin_w_start = ph * bin_size_h + roi_start_h, pw * bin_size_w + roi_start_w

need_check = True
while need_check:
pass_check = True
Expand Down Expand Up @@ -6901,6 +6901,50 @@ def test_diag():
diag_sym = mx.sym.diag(data=data, k=-1)
check_numeric_gradient(diag_sym, [a_np])

# Test 4d input
x1 = np.random.randint(3,9)
x2 = np.random.randint(3,9)
x3 = np.random.randint(3,9)
x4 = np.random.randint(3,9)
a_np = np.random.random((x1, x2, x3, x4)).astype(np.float32)
a = mx.nd.array(a_np).astype('float32')

# k = 0, axis1=0, axis2=1
r = mx.nd.diag(data=a, k=0, axis1=0, axis2=1)
assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=0, axis1=0, axis2=1))

# k = 1, axis1=1, axis2=0
r = mx.nd.diag(data=a, k=1, axis1=1, axis2=0)
assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=1, axis1=1, axis2=0))

# k = -1 axis1=1, axis3=3
r = mx.nd.diag(data=a, k=-1, axis1=1, axis2=3)
assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=-1, axis1=1, axis2=3))

# k = 2, axis1=-2, axis2=0
r = mx.nd.diag(data=a, k=2, axis1=-2, axis2=0)
assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=2, axis1=-2, axis2=0))

# Test 4d backward, k=0, axis1=3, axis2=0
data = mx.sym.Variable('data')
diag_sym = mx.sym.diag(data=data, k=0, axis1=3, axis2=0)
check_numeric_gradient(diag_sym, [a_np])

# Test 4d backward, k=1, axis1=1, axis2=2
data = mx.sym.Variable('data')
diag_sym = mx.sym.diag(data=data, k=1, axis1=1, axis2=2)
check_numeric_gradient(diag_sym, [a_np])

# Test 4d backward, k=-1, axis1=2, axis2=0
data = mx.sym.Variable('data')
diag_sym = mx.sym.diag(data=data, k=-1, axis1=2, axis2=0)
check_numeric_gradient(diag_sym, [a_np])

# Test 4d backward, k=-2, axis1=1, axis2=-1
data = mx.sym.Variable('data')
diag_sym = mx.sym.diag(data=data, k=-2, axis1=1, axis2=-1)
check_numeric_gradient(diag_sym, [a_np])

@with_seed()
def test_depthtospace():
def f(x, blocksize):
Expand Down