Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][TRT] Support batch norm for all ranks <=5, and all axes #7026

Merged
merged 3 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
return False
if len(args[0].checked_type.shape) > 5:
logger.info("nn.batch_norm: Input rank must be 5 or less.")
return False
if int(attrs.axis) not in (1, 3):
logger.info("nn.batch_norm: axis is %d but must be 1 or 3.", int(attrs.axis))
return False
Expand Down
47 changes: 41 additions & 6 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,35 @@ class BatchNormOpConverter : public TensorRTOpConverter {
const int axis = std::stoi(params->node.GetAttr<std::vector<std::string>>("axis")[0]);
const bool scale = std::stoi(params->node.GetAttr<std::vector<std::string>>("scale")[0]);
const bool center = std::stoi(params->node.GetAttr<std::vector<std::string>>("center")[0]);
ICHECK(axis == 1 || axis == 3);
const bool need_transpose = axis == 3;
auto input_dims = TrtDimsToVector(input->getDimensions());
const size_t min_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4;
const size_t max_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 4 : 5;
ICHECK_LE(input_dims.size(), max_rank);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you convert these checks to use Diagnostic instead of generating an assertion, we should strive to replace most of these with end-user readable errors.

Copy link
Contributor Author

@trevor-m trevor-m Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jroesch, thanks for reviewing!

These checks are more for sanity checking, since the annotation functions in python/tvm/relay/op/contrib/tensorrt.py will filter out the unsupported ops before they ever get to this code. I don't expect users to ever see these.

Anyway, I can make a separate PR to port all of the ICHECK to Diagnostics.

const bool need_reshape = input_dims.size() < min_rank;
const bool need_transpose = axis != 1;

// Reshape if needed
if (need_reshape) {
// Add dims of size 1 until rank is required_rank.
std::vector<int> new_shape(input_dims);
while (new_shape.size() < min_rank) new_shape.insert(new_shape.end(), 1);
input = Reshape(params, input, new_shape);
}

// Transpose if needed.
const int input_rank_with_batch =
input->getDimensions().nbDims + (TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0);
ICHECK(input_rank_with_batch == 4 || input_rank_with_batch == 5);
std::vector<int> transpose_order(input_rank_with_batch);
if (need_transpose) {
// Move axis dim to first dim after batch.
for (int i = 0; i < input_rank_with_batch; ++i) {
transpose_order[i] = i;
}
transpose_order[1] = axis;
transpose_order[axis] = 1;
input = Transpose(params, input, transpose_order);
}

void* weight_scale_ptr = new float[gamma.count];
nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count};
Expand All @@ -414,15 +441,23 @@ class BatchNormOpConverter : public TensorRTOpConverter {
shift_ptr[i] += beta_ptr[i];
}
}
if (need_transpose) {
input = Transpose(params, input, {0, 3, 1, 2});
}

#if TRT_VERSION_GE(6, 0, 1)
const int channel_dim = TRT_HAS_IMPLICIT_BATCH(params) ? 0 : 1;
nvinfer1::IScaleLayer* scale_layer = params->network->addScaleNd(
*input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power, channel_dim);
#else
ICHECK_EQ(input->getDimensions().nbDims(), 3);
nvinfer1::IScaleLayer* scale_layer = params->network->addScale(
*input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power);
#endif
ICHECK(scale_layer != nullptr);
auto output = scale_layer->getOutput(0);
if (need_transpose) {
output = Transpose(params, output, {0, 2, 3, 1});
output = Transpose(params, output, transpose_order);
}
if (need_reshape) {
output = Reshape(params, output, input_dims);
}
params->outputs.push_back(output);
}
Expand Down
6 changes: 6 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,12 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):

run_and_verify_func(get_graph((1, 64, 56, 56), (64,)))
run_and_verify_func(get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05))
run_and_verify_func(get_graph((1, 4, 8, 4), (8,), axis=2))
run_and_verify_func(get_graph((1, 8, 4, 4, 4), (8,), axis=1))
run_and_verify_func(get_graph((1, 4, 8, 4, 4), (8,), axis=2))
run_and_verify_func(get_graph((1, 4, 4, 4, 8), (8,), axis=4))
run_and_verify_func(get_graph((1, 8), (8,), axis=1))
run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2))


def test_unary():
Expand Down