@@ -455,9 +455,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
455
455
CHECK_EQ (out_attrs->size (), 1U );
456
456
mxnet::TShape& shp = (*in_attrs)[0 ];
457
457
mxnet::TShape& out_shp = (*out_attrs)[0 ];
458
- CHECK_LE (shp.ndim (), 6 ) << " Transpose support at most 6 dimensions" ;
459
- if (shp.ndim () == -1 && out_shp.ndim () == -1 )
458
+ if (!mxnet::ndim_is_known (shp) && !mxnet::ndim_is_known (out_shp))
460
459
return false ; // none of the shapes is known
460
+ CHECK_LE (shp.ndim (), 6 ) << " Transpose support at most 6 dimensions" ;
461
461
if (out_shp.ndim () >= 0 && shp.ndim () >= 0 )
462
462
CHECK_EQ (out_shp.ndim (), shp.ndim ());
463
463
mxnet::TShape get (std::max (shp.ndim (), out_shp.ndim ()), -1 );
@@ -506,12 +506,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
506
506
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed );
507
507
CHECK_EQ (in_attrs->size (), 1U );
508
508
CHECK_EQ (out_attrs->size (), 1U );
509
- if (!mxnet::ndim_is_known (in_attrs->at (0 )) && !mxnet::ndim_is_known (out_attrs->at (0 ))) {
509
+ mxnet::TShape& ishape = (*in_attrs)[0 ];
510
+ mxnet::TShape& oshape = (*out_attrs)[0 ];
511
+ if (!mxnet::ndim_is_known (ishape) && !mxnet::ndim_is_known (oshape)) {
510
512
return false ;
511
513
}
512
514
513
- mxnet::TShape& ishape = (*in_attrs)[0 ];
514
- mxnet::TShape& oshape = (*out_attrs)[0 ];
515
515
int indim = ishape.ndim ();
516
516
bool unknown_ishape = false ;
517
517
if (-1 == indim) {
@@ -1434,6 +1434,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
1434
1434
CHECK_EQ (out_attrs->size (), 1U );
1435
1435
mxnet::TShape& ishape = (*in_attrs)[0 ];
1436
1436
mxnet::TShape& from_shape = (*in_attrs)[1 ];
1437
+ if (!mxnet::ndim_is_known (ishape) || !mxnet::ndim_is_known (from_shape)) {
1438
+ return false ;
1439
+ }
1437
1440
if (param.axes .ndim () == 0 ) {
1438
1441
CHECK_EQ (ishape.ndim (), from_shape.ndim ())
1439
1442
<< " By default slice_axis performs slice on all axes, but ndim mismatch "
@@ -1727,6 +1730,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs,
1727
1730
CHECK_EQ (in_attrs->size (), 1U );
1728
1731
CHECK_EQ (out_attrs->size (), 1U );
1729
1732
const mxnet::TShape& ishape = (*in_attrs)[0 ];
1733
+ if (!mxnet::ndim_is_known (ishape)) {
1734
+ return false ;
1735
+ }
1730
1736
int repeats = 0 ;
1731
1737
dmlc::optional<int > axisOpt;
1732
1738
GetRepeatParams (param, ishape, &repeats, &axisOpt);
@@ -2395,6 +2401,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs,
2395
2401
mxnet::TShape expected_out (4 , -1 );
2396
2402
2397
2403
mxnet::TShape& in_shape = in_attrs->at (0 );
2404
+ if (!mxnet::ndim_is_known (in_shape)) {
2405
+ return false ;
2406
+ }
2398
2407
int block = param.block_size ;
2399
2408
CHECK_NE (block, 0 ) << " block_size must be a positive integer value" ;
2400
2409
CHECK_NE (in_shape[1 ], 0 ) << " Depth dimension:1 cannot be 0" ;
@@ -2559,6 +2568,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs,
2559
2568
mxnet::TShape expected_out (in_attrs->at (0 ).ndim (), -1 );
2560
2569
2561
2570
mxnet::TShape& in_shape = in_attrs->at (0 );
2571
+ if (!mxnet::ndim_is_known (in_shape)) {
2572
+ return false ;
2573
+ }
2562
2574
int block = param.block_size ;
2563
2575
CHECK_NE (block, 0 ) << " block_size must be a positive integer value" ;
2564
2576
CHECK_NE (in_shape[0 ], 0 )
0 commit comments