Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyan66 committed Dec 18, 2019
1 parent 6a67aee commit 6090a41
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
22 changes: 9 additions & 13 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_

#include <mshadow/tensor.h>
#include <algorithm>
#include <vector>
#include <string>
Expand Down Expand Up @@ -902,12 +901,11 @@ struct median_forward {

template<typename xpu>
void NumpyMedianForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp)
return;
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) return;

using namespace mxnet;
using namespace mxnet_op;
Expand Down Expand Up @@ -1156,10 +1154,10 @@ void NumpyMedianForward(const nnvm::NodeAttrs& attrs,
element_num)), 0, k), Shape3(0, 2, 1)));
ASSIGN_DISPATCH(ret_indices, req_TopK[1], tcast<index_t>(F<mshadow_op::mod>(transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k), Shape3(0, 2, 1)), element_num)));
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k), Shape3(0, 2, 1)), element_num)));
} else {
Tensor<xpu, 2, DType> ret_value =
ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
Expand All @@ -1172,8 +1170,6 @@ void NumpyMedianForward(const nnvm::NodeAttrs& attrs,
}
}



MXNET_NDIM_SWITCH(small.ndim()+1, NDim, {
Kernel<median_forward<NDim>, xpu>::Launch(
s, r_shape.Size(), r.dptr<DType>(), a_sort.dptr<DType>(),
Expand Down
1 change: 0 additions & 1 deletion src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ NNVM_REGISTER_OP(_npi_median)
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
// .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
Expand Down

0 comments on commit 6090a41

Please sign in to comment.