Skip to content

Commit 6c9ca56

Browse files
committed
Backporting backward inference from 2.x apache#18348 and apache#18378
Signed-off-by: Serge Panev <[email protected]>
1 parent 0b5b449 commit 6c9ca56

9 files changed

+125
-49
lines changed

src/operator/contrib/batch_norm_relu.cc

+28-13
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,
5555
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
5656
CHECK_EQ(out_shape->size(), 4U);
5757
const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData);
58+
if (!mxnet::ndim_is_known(dshape)) {
59+
return false;
60+
}
5861

5962
const size_t channelAxis = static_cast<size_t>(param.axis < 0
6063
? static_cast<int>(dshape.ndim()) + param.axis
@@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,
6366

6467
const int channelCount = dshape[channelAxis];
6568

66-
if (!mxnet::ndim_is_known(dshape)) {
67-
return false;
68-
}
69-
7069
in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount));
7170
in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount));
7271
in_shape->at(batchnormrelu::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
@@ -84,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
8483
std::vector<int> *in_type, std::vector<int> *out_type) {
8584
using namespace mshadow;
8685
CHECK_GE(in_type->size(), 1U);
87-
const int dtype = (*in_type)[0];
88-
CHECK_NE(dtype, -1) << "First input must have specified type";
86+
const size_t n_out = 4;
8987
// For float16 input type beta, gamma, mean, and average are stored in float32.
9088
// For other input types, these parameters have the same type as input
9189
// NOTE: This requirement is from cuDNN (v. 4 and 5)
9290
int dtype_param;
93-
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
91+
int dtype = (*in_type)[0];
92+
93+
if (type_is_none(dtype)) {
94+
// Input type is undefined, we try backward inference
95+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
96+
// Neither the input nor the output are defined,
97+
// types cannot be infered for this op
98+
return false;
99+
} else {
100+
// Input type is undefined but output type is: backward inference
101+
dtype = (*out_type)[0];
102+
(*in_type)[0] = dtype;
103+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
104+
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
105+
}
106+
} else {
107+
// Input type is defined but output type is not: forward inference
108+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
94109
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
110+
out_type->clear();
111+
out_type->push_back(dtype);
112+
for (size_t i = 1; i < n_out; ++i) {
113+
out_type->push_back(dtype_param);
114+
}
115+
}
95116
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
96117
CHECK_LE(in_type->size(), args.size());
97118
for (size_t i = 1; i < in_type->size(); ++i) {
@@ -101,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
101122
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
102123
}
103124
}
104-
const size_t n_out = 4;
105-
out_type->clear();
106-
out_type->push_back(dtype);
107-
for (size_t i = 1; i < n_out; ++i) {
108-
out_type->push_back(dtype_param);
109-
}
110125
return true;
111126
}
112127

src/operator/nn/batch_norm.cc

+27-13
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
365365
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
366366
CHECK_EQ(out_shape->size(), 3U);
367367
const mxnet::TShape &dshape = in_shape->at(batchnorm::kData);
368+
if (!mxnet::ndim_is_known(dshape)) {
369+
return false;
370+
}
368371

369372
const size_t channelAxis = static_cast<size_t>(param.axis < 0
370373
? static_cast<int>(dshape.ndim()) + param.axis
@@ -373,10 +376,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
373376

374377
const int channelCount = dshape[channelAxis];
375378

376-
if (!mxnet::ndim_is_known(dshape)) {
377-
return false;
378-
}
379-
380379
in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
381380
in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
382381
in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
@@ -394,14 +393,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
394393
std::vector<int> *in_type, std::vector<int> *out_type) {
395394
using namespace mshadow;
396395
CHECK_GE(in_type->size(), 1U);
397-
const int dtype = (*in_type)[0];
398-
CHECK_NE(dtype, -1) << "First input must have specified type";
396+
const size_t n_out = 3;
399397
// For float16 input type beta, gamma, mean, and average are stored in float32.
400398
// For other input types, these parameters have the same type as input
401399
// NOTE: This requirement is from cuDNN (v. 4 and 5)
402400
int dtype_param;
403-
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
401+
int dtype = (*in_type)[0];
402+
if (type_is_none(dtype)) {
403+
// Input type is undefined, we try backward inference
404+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
405+
// Neither the input nor the output are defined,
406+
// types cannot be infered for this op
407+
return false;
408+
} else {
409+
// Input type is undefined but output type is: backward inference
410+
dtype = (*out_type)[0];
411+
(*in_type)[0] = dtype;
412+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
413+
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
414+
}
415+
} else {
416+
// Input type is defined but output type is not: forward inference
417+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
404418
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
419+
out_type->clear();
420+
out_type->push_back(dtype);
421+
for (size_t i = 1; i < n_out; ++i) {
422+
out_type->push_back(dtype_param);
423+
}
424+
}
405425
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
406426
CHECK_LE(in_type->size(), args.size());
407427
for (size_t i = 1; i < in_type->size(); ++i) {
@@ -411,12 +431,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
411431
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
412432
}
413433
}
414-
const size_t n_out = 3;
415-
out_type->clear();
416-
out_type->push_back(dtype);
417-
for (size_t i = 1; i < n_out; ++i) {
418-
out_type->push_back(dtype_param);
419-
}
420434
return true;
421435
}
422436

src/operator/nn/convolution.cc

+10-3
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,23 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
285285
const ConvolutionParam& param_ = nnvm::get<ConvolutionParam>(attrs.parsed);
286286
CHECK_GE(in_type->size(), 1U);
287287
int dtype = (*in_type)[0];
288-
CHECK_NE(dtype, -1) << "First input must have specified type";
288+
if (type_is_none(dtype)) {
289+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
290+
return false;
291+
} else {
292+
dtype = (*out_type)[0];
293+
}
294+
} else {
295+
out_type->clear();
296+
out_type->push_back(dtype);
297+
}
289298
for (size_t i = 0; i < in_type->size(); ++i) {
290299
if ((*in_type)[i] == -1) {
291300
(*in_type)[i] = dtype;
292301
} else {
293302
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
294303
}
295304
}
296-
out_type->clear();
297-
out_type->push_back(dtype);
298305
return true;
299306
}
300307

src/operator/nn/deconvolution.cc

+15-3
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,28 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
332332
const DeconvolutionParam& param_ = nnvm::get<DeconvolutionParam>(attrs.parsed);
333333
CHECK_GE(in_type->size(), 1U);
334334
int dtype = (*in_type)[0];
335-
CHECK_NE(dtype, -1) << "First input must have specified type";
335+
if (type_is_none(dtype)) {
336+
// Input type is undefined, we try backward inference
337+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
338+
// Neither the input nor the output are defined,
339+
// types cannot be infered for this op
340+
return false;
341+
} else {
342+
// Input type is undefined but output type is: backward inference
343+
dtype = (*out_type)[0];
344+
}
345+
} else {
346+
// Input type is defined but output type is not: forward inference
347+
out_type->clear();
348+
out_type->push_back(dtype);
349+
}
336350
for (size_t i = 0; i < in_type->size(); ++i) {
337351
if ((*in_type)[i] == -1) {
338352
(*in_type)[i] = dtype;
339353
} else {
340354
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
341355
}
342356
}
343-
out_type->clear();
344-
out_type->push_back(dtype);
345357
return true;
346358
}
347359

src/operator/nn/group_norm.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
3939
using namespace mshadow;
4040
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
4141
const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
42-
CHECK_GE(dshape.ndim(), 3U);
43-
const int num_groups = param.num_groups;
44-
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";
45-
4642
if (!mxnet::ndim_is_known(dshape)) {
4743
return false;
4844
}
4945

46+
CHECK_GE(dshape.ndim(), 3U);
47+
const int num_groups = param.num_groups;
48+
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";
49+
5050
in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
5151
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
5252

src/operator/nn/layer_norm.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
4343
using namespace mshadow;
4444
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
4545
const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
46+
if (!mxnet::ndim_is_known(dshape)) {
47+
return false;
48+
}
49+
4650
int axis = GetRealAxis(param.axis, dshape.ndim());
4751
CHECK(axis >= 0 && axis < dshape.ndim())
4852
<< "Channel axis out of range: axis=" << param.axis;
4953

5054
const int channelCount = dshape[axis];
5155

52-
if (!mxnet::ndim_is_known(dshape)) {
53-
return false;
54-
}
5556
SHAPE_ASSIGN_CHECK(*in_shape,
5657
layernorm::kGamma,
5758
mxnet::TShape(Shape1(channelCount)));

src/operator/nn/pooling.cc

+5-2
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,14 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
9595
mxnet::ShapeVector *out_shape) {
9696
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
9797
CHECK_EQ(in_shape->size(), 1U);
98+
const mxnet::TShape &dshape = (*in_shape)[0];
99+
if (!mxnet::ndim_is_known(dshape)) {
100+
return false;
101+
}
98102
if (param.pool_type == pool_enum::kLpPooling) {
99103
CHECK(param.p_value.has_value());
100104
}
101-
const mxnet::TShape &dshape = (*in_shape)[0];
105+
102106
if (param.pooling_convention == pool_enum::kSame) {
103107
CHECK_EQ(dshape.ndim(), 3U)
104108
<< "Pooling: Input data should be 3D in (batch, channel, x)"
@@ -114,7 +118,6 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
114118
<< "Pooling: Input data should be 3D in (batch, channel, x)"
115119
<< " Or 4D in (batch, channel, y, x) "
116120
<< " Or 5D in (batch, channel, d, y, x)";
117-
if (!mxnet::ndim_is_known(dshape)) return false;
118121
int layout = param.GetLayout(dshape.ndim());
119122
if (param.global_pool) {
120123
mxnet::TShape oshape = dshape;

src/operator/softmax_output.cc

+15-3
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,28 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
6666
std::vector<int> *out_type) {
6767
CHECK_EQ(in_type->size(), 2U);
6868
int dtype = (*in_type)[0];
69-
CHECK_NE(dtype, -1) << "First input must have specified type";
69+
if (type_is_none(dtype)) {
70+
// Input type is undefined, we try backward inference
71+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
72+
// Neither the input nor the output are defined,
73+
// types cannot be infered for this op
74+
return false;
75+
} else {
76+
// Input type is undefined but output type is: backward inference
77+
dtype = (*out_type)[0];
78+
}
79+
} else {
80+
// Input type is defined but output type is not: forward inference
81+
out_type->clear();
82+
out_type->push_back(dtype);
83+
}
7084
for (size_t i = 0; i < in_type->size(); ++i) {
7185
if ((*in_type)[i] == -1) {
7286
(*in_type)[i] = dtype;
7387
} else {
7488
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
7589
}
7690
}
77-
out_type->clear();
78-
out_type->push_back(dtype);
7991
return true;
8092
}
8193

src/operator/tensor/matrix_op-inl.h

+17-5
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
455455
CHECK_EQ(out_attrs->size(), 1U);
456456
mxnet::TShape& shp = (*in_attrs)[0];
457457
mxnet::TShape& out_shp = (*out_attrs)[0];
458-
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
459-
if (shp.ndim() == -1 && out_shp.ndim() == -1)
458+
if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
460459
return false; // none of the shapes is known
460+
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
461461
if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
462462
CHECK_EQ(out_shp.ndim(), shp.ndim());
463463
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
@@ -506,12 +506,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
506506
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
507507
CHECK_EQ(in_attrs->size(), 1U);
508508
CHECK_EQ(out_attrs->size(), 1U);
509-
if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) {
509+
mxnet::TShape& ishape = (*in_attrs)[0];
510+
mxnet::TShape& oshape = (*out_attrs)[0];
511+
if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) {
510512
return false;
511513
}
512514

513-
mxnet::TShape& ishape = (*in_attrs)[0];
514-
mxnet::TShape& oshape = (*out_attrs)[0];
515515
int indim = ishape.ndim();
516516
bool unknown_ishape = false;
517517
if (-1 == indim) {
@@ -1434,6 +1434,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
14341434
CHECK_EQ(out_attrs->size(), 1U);
14351435
mxnet::TShape& ishape = (*in_attrs)[0];
14361436
mxnet::TShape& from_shape = (*in_attrs)[1];
1437+
if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) {
1438+
return false;
1439+
}
14371440
if (param.axes.ndim() == 0) {
14381441
CHECK_EQ(ishape.ndim(), from_shape.ndim())
14391442
<< "By default slice_axis performs slice on all axes, but ndim mismatch "
@@ -1727,6 +1730,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs,
17271730
CHECK_EQ(in_attrs->size(), 1U);
17281731
CHECK_EQ(out_attrs->size(), 1U);
17291732
const mxnet::TShape& ishape = (*in_attrs)[0];
1733+
if (!mxnet::ndim_is_known(ishape)) {
1734+
return false;
1735+
}
17301736
int repeats = 0;
17311737
dmlc::optional<int> axisOpt;
17321738
GetRepeatParams(param, ishape, &repeats, &axisOpt);
@@ -2395,6 +2401,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs,
23952401
mxnet::TShape expected_out(4, -1);
23962402

23972403
mxnet::TShape& in_shape = in_attrs->at(0);
2404+
if (!mxnet::ndim_is_known(in_shape)) {
2405+
return false;
2406+
}
23982407
int block = param.block_size;
23992408
CHECK_NE(block, 0) << "block_size must be a positive integer value";
24002409
CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
@@ -2559,6 +2568,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs,
25592568
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
25602569

25612570
mxnet::TShape& in_shape = in_attrs->at(0);
2571+
if (!mxnet::ndim_is_known(in_shape)) {
2572+
return false;
2573+
}
25622574
int block = param.block_size;
25632575
CHECK_NE(block, 0) << "block_size must be a positive integer value";
25642576
CHECK_NE(in_shape[0], 0)

0 commit comments

Comments
 (0)