Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoutianzi666 authored and Aurelius84 committed Jul 29, 2022
1 parent d2f3904 commit 363458b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 18 deletions.
100 changes: 84 additions & 16 deletions paddle/fluid/inference/tensorrt/convert/fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,23 +333,91 @@ class FcOpConverter : public OpConverter {
if (!engine_->with_dynamic_shape()) {
x_num_col_dims--;
}
PADDLE_ENFORCE_GT(
x_dim.nbDims,
x_num_col_dims,
platform::errors::InvalidArgument(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d.",
x_dim.nbDims,
x_num_col_dims));
// need reshape input before and after fc
auto* reshape_before_fc_layer =
reshape_before_fc(X, x_dim, x_num_col_dims, output_name);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
if (enable_int8 || support_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
if (x_dim.nbDims == 4 && x_num_col_dims == 1) {
if (enable_int8 || support_int8) {
// add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_,
Convolution,
*X,
n_output,
nv_ksize,
weight.get(),
bias.get());
if (activation_type == "relu") {
fc_layer_int8->setName(
("ernie_fc_op_int8: Convolution (Output: " + output_name + ")")
.c_str());
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"),
true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
float out_scale = 0;
if (enable_int8) {
out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold"));
} else {
out_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("Out"));
}
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_,
Activation,
*(fc_layer_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_int8,
"relu_after_ernie_fc_int8",
{output_name},
test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_int8,
"ernie_fc_op_int8: Convolution",
{output_name},
test_mode);
}
} else {
// add fc layer
auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *X, n_output, weight.get(), bias.get());
if (activation_type == "relu") {
fc_layer_float->setName(
("ernie_fc_op_float: (Output: " + output_name + ")").c_str());
nvinfer1::IActivationLayer* relu_layer_float =
TRT_ENGINE_ADD_LAYER(engine_,
Activation,
*(fc_layer_float->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_float,
"relu_after_ernie_fc_float",
{output_name},
test_mode);
} else {
RreplenishLayerAndOutput(
fc_layer_float, "ernie_fc_op_float", {output_name}, test_mode);
}
}
} else { // need reshape input before and after fc
PADDLE_ENFORCE_GT(
x_dim.nbDims,
x_num_col_dims,
platform::errors::InvalidArgument(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d.",
x_dim.nbDims,
x_num_col_dims));
auto* reshape_before_fc_layer =
reshape_before_fc(X, x_dim, x_num_col_dims, output_name);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
if (enable_int8 || support_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
regist_fc(reshape_itensor, n_output, weight, bias);
}
regist_fc(reshape_itensor, n_output, weight, bias);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def test(self):

# this is the special case when x_dim.nbDims == 4 && x_num_col_dims == 1
class TrtConvertFcTest3(TrtLayerAutoScanTest):

# this case will invoke a bug in fc_op.cc, so return False
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
return False

def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824
Expand Down

0 comments on commit 363458b

Please sign in to comment.