Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix InferType logic - add backward inference for some ops #16817

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,8 @@ def _infer_type_impl(self, partial, *args, **kwargs):
else:
str_keys = []
for k, v in kwargs.items():
v = _numpy.dtype(v).type
# if v is None just use that to search in _DTYPE_NP_TO_MX
v = _numpy.dtype(v).type if v else None
if v in _DTYPE_NP_TO_MX:
str_keys.append(k)
sdata.append(_DTYPE_NP_TO_MX[v])
Expand Down
17 changes: 5 additions & 12 deletions src/operator/contrib/deformable_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "../operator_common.h"
#include "../nn/im2col.h"
#include "./nn/deformable_im2col.h"
#include "../elemwise_op_common.h"
#include "../linalg.h"


Expand Down Expand Up @@ -453,18 +454,10 @@ class DeformableConvolutionProp : public OperatorProperty {
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, this would just discard the type inferred by the other operators...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, many operators behave this way. probably more than what we have captured in the issue list :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that the logic for type/shape inference in operators really needs to be looked at. I have another issue (and PR with a fix) coming today about how it is currently possible to not let the operator check their input and output attributes (and there should be also another how operators lie in their infertype/shape about whether they successfully inferred everything or not).

return true;
std::string node_name = "deformable_convolution_node";
return ElemwiseAttrHelper<int, type_is_none,
type_assign, true,
type_string>(node_name, in_type, out_type, -1);
}

OperatorProperty* Copy() const override {
Expand Down
17 changes: 5 additions & 12 deletions src/operator/contrib/fft-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <utility>
#include <iostream>
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../mshadow_op.h"

#if MXNET_USE_CUDA
Expand Down Expand Up @@ -256,18 +257,10 @@ class FFTProp : public OperatorProperty {
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), 1);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
std::string node_name = "fft_node";
return ElemwiseAttrHelper<int, type_is_none,
type_assign, true,
type_string>(node_name, in_type, out_type, -1);
}

OperatorProperty* Copy() const override {
Expand Down
17 changes: 5 additions & 12 deletions src/operator/contrib/ifft-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <utility>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"

#if MXNET_USE_CUDA
#include <cufft.h>
Expand Down Expand Up @@ -248,18 +249,10 @@ class IFFTProp : public OperatorProperty {
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), 1);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i=0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
std::string node_name = "fft_node";
return ElemwiseAttrHelper<int, type_is_none,
type_assign, true,
type_string>(node_name, in_type, out_type, -1);
}

OperatorProperty* Copy() const override {
Expand Down
16 changes: 4 additions & 12 deletions src/operator/contrib/multi_sum_sq-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <mxnet/operator.h>
#include <vector>
#include "../operator_common.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -64,18 +65,9 @@ inline bool MultiSumSqType(const NodeAttrs& attrs,
std::vector<int>* out_type) {
const auto& p = dmlc::get<MultiSumSqParam>(attrs.parsed);
CHECK_EQ(in_type->size(), p.num_arrays);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, "array_" + std::to_string(i));
}
}
out_type->clear();
out_type->push_back(mshadow::kFloat32);
return true;
return ElemwiseAttr<int, type_is_none,
type_assign, true,
type_string>(attrs, in_type, out_type, -1);
}

template<typename xpu>
Expand Down
17 changes: 5 additions & 12 deletions src/operator/convolution_v1-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <string>
#include <utility>
#include "./operator_common.h"
#include "./elemwise_op_common.h"
#include "./linalg.h"

namespace mxnet {
Expand Down Expand Up @@ -498,18 +499,10 @@ class ConvolutionV1Prop : public OperatorProperty {
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), 1);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
std::string node_name = "convolution_v1_node";
return ElemwiseAttrHelper<int, type_is_none,
type_assign, true,
type_string>(node_name, in_type, out_type, -1);
}

OperatorProperty* Copy() const override {
Expand Down
16 changes: 4 additions & 12 deletions src/operator/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "./upsampling-inl.h"
#include <nnvm/op_attr_types.h>
#include "./deconvolution-inl.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -90,18 +91,9 @@ static bool UpSamplingType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
return ElemwiseAttr<int, type_is_none,
type_assign, true,
type_string>(attrs, in_type, out_type, -1);
}

struct UpSamplingGrad {
Expand Down
15 changes: 5 additions & 10 deletions src/operator/sequence_reverse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "./mxnet_op.h"
#include "./operator_common.h"
#include "./sequence_op_common.h"
#include "./elemwise_op_common.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -243,16 +244,10 @@ class SequenceReverseProp : public OperatorProperty {
bool InferType(std::vector<int> *in_type, std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), param_.use_sequence_length ? 2U : 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
std::string node_name = "sequence_reverse_node";
return ElemwiseAttrHelper<int, type_is_none,
type_assign, true,
type_string>(node_name, in_type, out_type, -1);
}

OperatorProperty *Copy() const override {
Expand Down
89 changes: 89 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,95 @@ def test_symbol_infer_type():
assert out == [np.float32]
assert aux == []

# partial infer type with unknown dtypes
data = mx.sym.var("data")
data2 = mx.sym.var("data2")
out = mx.sym.elemwise_add(data, data2)
arg_types, out_types, aux_types = out.infer_type_partial(data=None, data2=np.float32)
assert arg_types == [np.float32, np.float32]
assert out_types == [np.float32]
assert aux_types == []

def check_infer_type_one_input(op):
data = mx.sym.var("data")
fft_res = op(data)
data3 = mx.sym.var("data3")
out = mx.sym.elemwise_add(data3, fft_res)
arg_types, out_types, aux_types = out.infer_type_partial(data=None, data3=np.float32)
# data should be inferred during backward inference
assert arg_types == [np.float32, np.float32]
assert out_types == [np.float32]
assert aux_types == []

op_list = [mx.sym.contrib.fft, mx.sym.contrib.ifft, #mx.sym.contrib.DeformableConvolution,
mx.sym.SequenceReverse] #mx.sym.multi_sum_sq, mx.sym.Convolution_v1,
#mx.sym.UpSampling]
for op in op_list:
check_infer_type_one_input(op)


def check_infer_type_convolution(op, deformable=False):
data = mx.sym.var("data")
weight = mx.sym.var("weight")
data2 = mx.sym.var("data2")
conv_res = op(data, weight, pad=(3, 3), num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7))
out = mx.sym.elemwise_add(conv_res, data2)
arg_types, out_types, aux_types = out.infer_type_partial(data=None, weight=None, data2=np.float32)
# data and weight should be inferred during backward inference
if deformable:
assert arg_types == [np.float32, np.float32, np.float32, np.float32]
else:
assert arg_types == [np.float32, np.float32, np.float32]
assert out_types == [np.float32]
assert aux_types == []

op_list = [mx.sym.Convolution_v1, mx.sym.contrib.DeformableConvolution]
for op in op_list:
check_infer_type_convolution(op, op == mx.sym.contrib.DeformableConvolution)

def check_infer_type_two_inputs(op, upsampling=False):
data = mx.sym.var("data")
weight = mx.sym.var("weight")
data2 = mx.sym.var("data2")
conv_res = op(data, weight, pad=(3, 3), num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7)) if not upsampling \
else op(data, weight, sample_type="bilinear", num_args=2, scale=10)
out = mx.sym.elemwise_add(conv_res, data2)
arg_types, out_types, aux_types = out.infer_type_partial(data=None, weight=None, data2=np.float32)
# data and weight should be inferred during backward inference
assert arg_types == [np.float32, np.float32, np.float32]
assert out_types == [np.float32]
assert aux_types == []

def check_infer_type_upsampling(op):
data = mx.sym.var("data")
weight = mx.sym.var("weight")
data2 = mx.sym.var("data2")
upsampling_res = op(data, weight, sample_type="bilinear", num_args=2, scale=10)
out = mx.sym.elemwise_add(upsampling_res, data2)
arg_types, out_types, aux_types = out.infer_type_partial(data=None, weight=None, data2=np.float32)
# data and weight should be inferred during backward inference
assert arg_types == [np.float32, np.float32, np.float32]
assert out_types == [np.float32]
assert aux_types == []


check_infer_type_upsampling(mx.sym.UpSampling)


def check_infer_type_multi_sum_sq(op):
data = mx.sym.var("data")
weight = mx.sym.var("weight")
data2 = mx.sym.var("data2")
upsampling_res = op(data, weight, num_arrays=2)
out = mx.sym.elemwise_add(upsampling_res, data2)
arg_types, out_types, aux_types = out.infer_type_partial(data=None, weight=None, data2=np.float32)
# data and weight should be inferred during backward inference
assert arg_types == [np.float32, np.float32, np.float32]
assert out_types == [np.float32]
assert aux_types == []

check_infer_type_multi_sum_sq(mx.sym.multi_sum_sq)


def test_symbol_infer_shape():
num_hidden = 128
Expand Down