Skip to content

Commit

Permalink
add numpy compatible trace (apache#16008)
Browse files Browse the repository at this point in the history
add doc for trace
  • Loading branch information
hzfan authored and larroy committed Sep 28, 2019
1 parent aa3be50 commit 72230ef
Show file tree
Hide file tree
Showing 5 changed files with 527 additions and 10 deletions.
66 changes: 56 additions & 10 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,22 +530,22 @@ def _np_roll(a, shift, axis=None):
Parameters
----------
a : ndarray
Input array.
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.
Returns
-------
res : ndarray
Output array, with the same shape as `a`.
Output array, with the same shape as `a`.
Notes
-----
Expand Down Expand Up @@ -581,5 +581,51 @@ def _np_roll(a, shift, axis=None):
>>> np.roll(x2, -1, axis=1)
array([[1., 2., 3., 4., 0.],
[6., 7., 8., 9., 5.]])
"""


def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
"""trace(a, offset=0, axis1=0, axis2=1, out=None)
Return the sum along diagonals of the array.
If `a` is 2-D, the sum along its diagonal with the given offset
is returned, i.e., the sum of elements ``a[i,i+offset]`` for all i.
If `a` has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-arrays whose traces are returned.
The shape of the resulting array is the same as that of `a` with `axis1`
and `axis2` removed.
Parameters
----------
a : ndarray
Input array, from which the diagonals are taken.
offset : int, optional
Offset of the diagonal from the main diagonal. Can be both positive
and negative. Defaults to 0.
axis1, axis2 : int, optional
Axes to be used as the first and second axis of the 2-D sub-arrays
from which the diagonals should be taken. Defaults are the first two
axes of `a`.
out : ndarray, optional
Array into which the output is placed. It must be of the right shape
and right type to hold the output.
Returns
-------
sum_along_diagonals : ndarray
If `a` is 2-D, the sum along the diagonal is returned. If `a` has
larger dimensions, then an array of sums along diagonals is returned.
Examples
--------
>>> a = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
>>> np.trace(a)
array(3.)
>>> a = np.arange(8).reshape((2, 2, 2))
>>> np.trace(a)
array([6., 8.])
>>> a = np.arange(24).reshape((2, 2, 2, 3))
>>> np.trace(a).shape
(2, 3)
"""
pass
255 changes: 255 additions & 0 deletions src/operator/numpy/np_trace_op-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_trace_op-inl.h
* \brief Function definition of matrix numpy-compatible trace operator
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_TRACE_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_TRACE_OP_INL_H_

#include <dmlc/parameter.h>
#include <mxnet/operator_util.h>
#include <vector>
#include <utility>
#include <algorithm>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../tensor/broadcast_reduce_op.h"

namespace mxnet {
namespace op {

struct NumpyTraceParam: public dmlc::Parameter<NumpyTraceParam> {
int offset, axis1, axis2;
DMLC_DECLARE_PARAMETER(NumpyTraceParam) {
DMLC_DECLARE_FIELD(offset)
.set_default(0)
.describe("Offset of the diagonal from the main diagonal. "
"Can be both positive and negative. Defaults to 0.");
DMLC_DECLARE_FIELD(axis1)
.set_default(0)
.describe("Axes to be used as the first axis of the 2-D sub-arrays "
"from which the diagonals should be taken. Defaults to 0.");
DMLC_DECLARE_FIELD(axis2)
.set_default(1)
.describe("Axes to be used as the second axis of the 2-D sub-arrays "
"from which the diagonals should be taken. Defaults to 1.");
}
};

template<int ndim, int req, bool back>
struct numpy_trace {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
mshadow::Shape<ndim> oshape,
mshadow::Shape<ndim> ishape,
index_t stride, index_t offset, int dlength) {
using namespace mxnet_op;
using namespace mshadow;
index_t j = ravel(unravel(i, oshape), ishape) + offset;
if (back) {
for (index_t k = 0; k < dlength; ++k) {
KERNEL_ASSIGN(out[j], req, a[i]);
j += stride;
}
} else {
if (req == kWriteTo) {
out[i] = 0;
for (index_t k = 0; k < dlength; ++k) {
out[i] += a[j];
j += stride;
}
} else if (req == kAddTo) {
for (index_t k = 0; k < dlength; ++k) {
out[i] += a[j];
j += stride;
}
}
}
}
};

template<typename xpu, bool back>
void NumpyTraceOpProcess(const TBlob& in_data,
const TBlob& out_data,
const mxnet::TShape& ishape,
const mxnet::TShape& oshape,
index_t dsize,
const NumpyTraceParam& param,
mxnet_op::Stream<xpu> *s,
const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow;
if (dsize == 0) {
if (back) {
if (out_data.Size() != 0) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (req_type == kWriteTo) {
out_data.FlatTo1D<xpu, DType>(s) = 0;
}
});
});
}
}
return;
} else if (ishape.Size() == 0) {
if (!back) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (req_type == kWriteTo) {
out_data.FlatTo1D<xpu, DType>(s) = 0;
}
});
});
}
return;
}
uint32_t x1 = CheckAxis(param.axis1, ishape.ndim());
uint32_t x2 = CheckAxis(param.axis2, ishape.ndim());

uint32_t idim = ishape.ndim();

uint32_t minx = x1, maxx = x2;
if (minx > maxx) {
std::swap(minx, maxx);
}

// merges contiguous axes that are not separated
// by axis1 or axis2 since they can be directly
// mapped to the output and there is no need
// to distinguish them
// (After this the input will have no more than
// three axes, hence improving the rave and
// unravel efficiency)

index_t oleading = 1,
obody = 1,
otrailing = 1;

for (uint32_t i = 0; i < minx; ++i) {
oleading *= ishape[i];
}
for (uint32_t i = minx + 1; i < maxx; ++i) {
obody *= ishape[i];
}
for (uint32_t i = maxx + 1; i < idim; ++i) {
otrailing *= ishape[i];
}

index_t ileading = oleading,
ibody = obody * ishape[minx],
itrailing = otrailing * ishape[maxx];

index_t stride1 = itrailing * obody,
stride2 = otrailing;
// stride1 + stride2 is the stride for
// iterating over the diagonal in question

if (x1 == maxx) {
std::swap(stride1, stride2);
}

// the extra index offset introduced by offset
index_t offset;
if (param.offset > 0) {
offset = stride2 * param.offset;
} else if (param.offset < 0) {
offset = stride1 * -param.offset;
} else {
offset = 0;
}

// number of elements in the offset diagonal
// may be negative
int dlength;
if (param.offset > 0) {
dlength = std::min(ishape[x1], ishape[x2] - param.offset);
} else if (param.offset < 0) {
dlength = std::min(ishape[x1] - (-param.offset), ishape[x2]);
} else {
dlength = std::min(ishape[x1], ishape[x2]);
}

MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (back) {
out_data.FlatTo1D<xpu, DType>(s) = 0;
}
Kernel<numpy_trace<3, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(),
Shape3(oleading, obody, otrailing),
Shape3(ileading, ibody, itrailing),
stride1 + stride2, offset, dlength);
});
});
}

template<typename xpu>
void NumpyTraceOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const mxnet::TShape& ishape = inputs[0].shape_;
const mxnet::TShape& oshape = outputs[0].shape_;
const NumpyTraceParam& param = nnvm::get<NumpyTraceParam>(attrs.parsed);

NumpyTraceOpProcess<xpu, false>(in_data, out_data, ishape, oshape,
out_data.Size(), param, s, req);
}

template<typename xpu>
void NumpyTraceOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();

const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const mxnet::TShape& ishape = inputs[0].shape_;
const mxnet::TShape& oshape = outputs[0].shape_;
const NumpyTraceParam& param = nnvm::get<NumpyTraceParam>(attrs.parsed);

NumpyTraceOpProcess<xpu, true>(in_data, out_data, oshape, ishape,
in_data.Size(), param, s, req);
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_TRACE_OP_INL_H_
Loading

0 comments on commit 72230ef

Please sign in to comment.