diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index acd4f4740b2d..bda71468d9e2 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -341,6 +341,12 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False + if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1): + logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.") + return False + if len(args[0].checked_type.shape) > 5: + logger.info("nn.batch_norm: Input rank must be 5 or less.") + return False if int(attrs.axis) not in (1, 3): logger.info("nn.batch_norm: axis is %d but must be 1 or 3.", int(attrs.axis)) return False diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index c3ff1c45f50e..1e6867b83cff 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -386,8 +386,35 @@ class BatchNormOpConverter : public TensorRTOpConverter { const int axis = std::stoi(params->node.GetAttr>("axis")[0]); const bool scale = std::stoi(params->node.GetAttr>("scale")[0]); const bool center = std::stoi(params->node.GetAttr>("center")[0]); - ICHECK(axis == 1 || axis == 3); - const bool need_transpose = axis == 3; + auto input_dims = TrtDimsToVector(input->getDimensions()); + const size_t min_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; + const size_t max_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 4 : 5; + ICHECK_LE(input_dims.size(), max_rank); + const bool need_reshape = input_dims.size() < min_rank; + const bool need_transpose = axis != 1; + + // Reshape if needed + if (need_reshape) { + // Add dims of size 1 until rank is required_rank. + std::vector new_shape(input_dims); + while (new_shape.size() < min_rank) new_shape.insert(new_shape.end(), 1); + input = Reshape(params, input, new_shape); + } + + // Transpose if needed. + const int input_rank_with_batch = + input->getDimensions().nbDims + (TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0); + ICHECK(input_rank_with_batch == 4 || input_rank_with_batch == 5); + std::vector transpose_order(input_rank_with_batch); + if (need_transpose) { + // Move axis dim to first dim after batch. + for (int i = 0; i < input_rank_with_batch; ++i) { + transpose_order[i] = i; + } + transpose_order[1] = axis; + transpose_order[axis] = 1; + input = Transpose(params, input, transpose_order); + } void* weight_scale_ptr = new float[gamma.count]; nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count}; @@ -414,15 +441,23 @@ class BatchNormOpConverter : public TensorRTOpConverter { shift_ptr[i] += beta_ptr[i]; } } - if (need_transpose) { - input = Transpose(params, input, {0, 3, 1, 2}); - } + +#if TRT_VERSION_GE(6, 0, 1) + const int channel_dim = TRT_HAS_IMPLICIT_BATCH(params) ? 0 : 1; + nvinfer1::IScaleLayer* scale_layer = params->network->addScaleNd( + *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power, channel_dim); +#else + ICHECK_EQ(input->getDimensions().nbDims(), 3); nvinfer1::IScaleLayer* scale_layer = params->network->addScale( *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power); +#endif ICHECK(scale_layer != nullptr); auto output = scale_layer->getOutput(0); if (need_transpose) { - output = Transpose(params, output, {0, 2, 3, 1}); + output = Transpose(params, output, transpose_order); + } + if (need_reshape) { + output = Reshape(params, output, input_dims); } params->outputs.push_back(output); } diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index aadfa1303655..9b62ee2c4087 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -710,6 +710,12 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): run_and_verify_func(get_graph((1, 64, 56, 56), (64,))) run_and_verify_func(get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05)) + run_and_verify_func(get_graph((1, 4, 8, 4), (8,), axis=2)) + run_and_verify_func(get_graph((1, 8, 4, 4, 4), (8,), axis=1)) + run_and_verify_func(get_graph((1, 4, 8, 4, 4), (8,), axis=2)) + run_and_verify_func(get_graph((1, 4, 4, 4, 8), (8,), axis=4)) + run_and_verify_func(get_graph((1, 8), (8,), axis=1)) + run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2)) def test_unary():