Skip to content

Commit

Permalink
[BYOC][TRT] Support batch norm for all ranks <=5, and all axes (#7026)
Browse files Browse the repository at this point in the history
* [TRT] Support batch norm for all ranks <=5, and all axis

* Add return false

* Fix TRT < 6 build
  • Loading branch information
Trevor Morris authored Dec 8, 2020
1 parent 8ac40fa commit 0095b21
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
6 changes: 6 additions & 0 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
return False
if len(args[0].checked_type.shape) > 5:
logger.info("nn.batch_norm: Input rank must be 5 or less.")
return False
if int(attrs.axis) not in (1, 3):
logger.info("nn.batch_norm: axis is %d but must be 1 or 3.", int(attrs.axis))
return False
Expand Down
47 changes: 41 additions & 6 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,35 @@ class BatchNormOpConverter : public TensorRTOpConverter {
const int axis = std::stoi(params->node.GetAttr<std::vector<std::string>>("axis")[0]);
const bool scale = std::stoi(params->node.GetAttr<std::vector<std::string>>("scale")[0]);
const bool center = std::stoi(params->node.GetAttr<std::vector<std::string>>("center")[0]);
ICHECK(axis == 1 || axis == 3);
const bool need_transpose = axis == 3;
auto input_dims = TrtDimsToVector(input->getDimensions());
const size_t min_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4;
const size_t max_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 4 : 5;
ICHECK_LE(input_dims.size(), max_rank);
const bool need_reshape = input_dims.size() < min_rank;
const bool need_transpose = axis != 1;

// Reshape if needed
if (need_reshape) {
// Add dims of size 1 until rank is required_rank.
std::vector<int> new_shape(input_dims);
while (new_shape.size() < min_rank) new_shape.insert(new_shape.end(), 1);
input = Reshape(params, input, new_shape);
}

// Transpose if needed.
const int input_rank_with_batch =
input->getDimensions().nbDims + (TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0);
ICHECK(input_rank_with_batch == 4 || input_rank_with_batch == 5);
std::vector<int> transpose_order(input_rank_with_batch);
if (need_transpose) {
// Move axis dim to first dim after batch.
for (int i = 0; i < input_rank_with_batch; ++i) {
transpose_order[i] = i;
}
transpose_order[1] = axis;
transpose_order[axis] = 1;
input = Transpose(params, input, transpose_order);
}

void* weight_scale_ptr = new float[gamma.count];
nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count};
Expand All @@ -414,15 +441,23 @@ class BatchNormOpConverter : public TensorRTOpConverter {
shift_ptr[i] += beta_ptr[i];
}
}
if (need_transpose) {
input = Transpose(params, input, {0, 3, 1, 2});
}

#if TRT_VERSION_GE(6, 0, 1)
const int channel_dim = TRT_HAS_IMPLICIT_BATCH(params) ? 0 : 1;
nvinfer1::IScaleLayer* scale_layer = params->network->addScaleNd(
*input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power, channel_dim);
#else
ICHECK_EQ(input->getDimensions().nbDims(), 3);
nvinfer1::IScaleLayer* scale_layer = params->network->addScale(
*input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power);
#endif
ICHECK(scale_layer != nullptr);
auto output = scale_layer->getOutput(0);
if (need_transpose) {
output = Transpose(params, output, {0, 2, 3, 1});
output = Transpose(params, output, transpose_order);
}
if (need_reshape) {
output = Reshape(params, output, input_dims);
}
params->outputs.push_back(output);
}
Expand Down
6 changes: 6 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,12 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):

run_and_verify_func(get_graph((1, 64, 56, 56), (64,)))
run_and_verify_func(get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05))
run_and_verify_func(get_graph((1, 4, 8, 4), (8,), axis=2))
run_and_verify_func(get_graph((1, 8, 4, 4, 4), (8,), axis=1))
run_and_verify_func(get_graph((1, 4, 8, 4, 4), (8,), axis=2))
run_and_verify_func(get_graph((1, 4, 4, 4, 8), (8,), axis=4))
run_and_verify_func(get_graph((1, 8), (8,), axis=1))
run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2))


def test_unary():
Expand Down

0 comments on commit 0095b21

Please sign in to comment.