@@ -85,20 +85,79 @@ TEST(TensorRTEngineInstructionTest, test_tensorrt_engine_instruction) {
8585 nvinfer1::DataType::kFLOAT , raw_bias, size);
8686 auto *x = engine->DeclareInput (
8787 " x" , nvinfer1::DataType::kFLOAT , nvinfer1::Dims4{-1 , 1 , 1 , 1 });
88- auto *fc_layer = TRT_ENGINE_ADD_LAYER (
89- engine, FullyConnected, *x, size, weight.get (), bias.get ());
90- PADDLE_ENFORCE_NOT_NULL (fc_layer,
91- common::errors::InvalidArgument (
92- " TRT fully connected layer building failed." ));
88+ auto *flatten_layer = engine->network ()->addShuffle (*x);
89+ PADDLE_ENFORCE_NOT_NULL (
90+ flatten_layer,
91+ common::errors::InvalidArgument (
92+ " Unable to build the TensorRT shuffle layer for the input tensor "
93+ " 'x'. "
94+ " This usually indicates the TensorRT network failed to allocate the "
95+ " intermediate reshape layer." ));
96+ flatten_layer->setReshapeDimensions (nvinfer1::Dims2{-1 , 1 });
97+
98+ auto *weight_layer = TRT_ENGINE_ADD_LAYER (
99+ engine, Constant, nvinfer1::Dims2{1 , 1 }, weight.get ());
100+ PADDLE_ENFORCE_NOT_NULL (
101+ weight_layer,
102+ common::errors::InvalidArgument (" TensorRT failed to create the constant "
103+ " layer for parameter 'weight'. "
104+ " Please confirm the TensorRT builder "
105+ " supports constant initialisation "
106+ " for the provided weight shape." ));
107+
108+ auto *bias_layer =
109+ TRT_ENGINE_ADD_LAYER (engine, Constant, nvinfer1::Dims2{1 , 1 }, bias.get ());
110+ PADDLE_ENFORCE_NOT_NULL (
111+ bias_layer,
112+ common::errors::InvalidArgument (
113+ " TensorRT failed to create the constant layer for parameter 'bias'. "
114+ " Check whether the provided bias data matches the expected shape." ));
115+
116+ auto *matmul_layer = TRT_ENGINE_ADD_LAYER (engine,
117+ MatrixMultiply,
118+ *flatten_layer->getOutput (0 ),
119+ nvinfer1::MatrixOperation::kNONE ,
120+ *weight_layer->getOutput (0 ),
121+ nvinfer1::MatrixOperation::kNONE );
122+ PADDLE_ENFORCE_NOT_NULL (
123+ matmul_layer,
124+ common::errors::InvalidArgument (
125+ " TensorRT returned a null matrix-multiply layer while fusing the "
126+ " fully-connected op. Verify the network input ranks and TensorRT "
127+ " version." ));
128+
129+ auto *add_layer = TRT_ENGINE_ADD_LAYER (engine,
130+ ElementWise,
131+ *matmul_layer->getOutput (0 ),
132+ *bias_layer->getOutput (0 ),
133+ nvinfer1::ElementWiseOperation::kSUM );
134+ PADDLE_ENFORCE_NOT_NULL (
135+ add_layer,
136+ common::errors::InvalidArgument (
137+ " TensorRT could not construct the elementwise-add layer for bias "
138+ " fusion. Ensure the bias tensor uses broadcastable dimensions." ));
93139
94- engine->DeclareOutput (fc_layer, 0 , " y" );
140+ auto *reshape_layer = engine->network ()->addShuffle (*add_layer->getOutput (0 ));
141+ PADDLE_ENFORCE_NOT_NULL (
142+ reshape_layer,
143+ common::errors::InvalidArgument (
144+ " TensorRT could not emit the final shuffle layer to restore the "
145+ " output shape. Confirm the shape tensor and inferred dimensions are "
146+ " valid." ));
147+ reshape_layer->setReshapeDimensions (nvinfer1::Dims4{-1 , 1 , 1 , 1 });
148+
149+ engine->DeclareOutput (reshape_layer, 0 , " y" );
95150 std::vector<std::string> input_names = {" x" , " " };
96151 std::vector<std::string> output_names = {" y" };
97152 std::vector<std::vector<int64_t >> outputs_shape = {{1 }};
98153 std::vector<phi::DataType> outputs_dtype = {phi::DataType::FLOAT32};
99154 LOG (INFO) << " freeze network" ;
100155 engine->FreezeNetwork ();
156+ #if IS_TRT_VERSION_GE(8600)
157+ ASSERT_EQ (engine->engine ()->getNbIOTensors (), 2 );
158+ #else
101159 ASSERT_EQ (engine->engine ()->getNbBindings (), 2 );
160+ #endif
102161 nvinfer1::IHostMemory *serialized_engine_data = engine->Serialize ();
103162
104163 std::ofstream outFile (" engine_serialized_data.bin" , std::ios::binary);
@@ -220,7 +279,10 @@ TEST(TensorRTEngineInstructionTest, test_tensorrt_engine_instruction_dynamic) {
220279 layer->setInput (1 , *shape);
221280 PADDLE_ENFORCE_NOT_NULL (
222281 layer,
223- common::errors::InvalidArgument (" TRT shuffle layer building failed." ));
282+ common::errors::InvalidArgument (
283+ " TensorRT failed to construct the dynamic shuffle layer that "
284+ " consumes the runtime shape tensor. Please check the provided "
285+ " shape binding." ));
224286 engine->DeclareOutput (layer, 0 , " y" );
225287 engine->FreezeNetwork ();
226288
@@ -401,14 +463,19 @@ TEST(PluginTest, test_generic_plugin) {
401463 creator->createPlugin (" pir_generic_plugin" , plugin_collection.get ());
402464 PADDLE_ENFORCE_NOT_NULL (
403465 generic_plugin,
404- common::errors::InvalidArgument (" TRT create generic plugin failed." ));
466+ common::errors::InvalidArgument (
467+ " TensorRT plugin registry returned nullptr while creating "
468+ " 'pir_generic_plugin'. Verify the plugin has been registered before "
469+ " building the engine." ));
405470 std::vector<nvinfer1::ITensor *> plugin_inputs;
406471 plugin_inputs.emplace_back (x);
407472 auto plugin_layer = engine->network ()->addPluginV2 (
408473 plugin_inputs.data (), plugin_inputs.size (), *generic_plugin);
409- PADDLE_ENFORCE_NOT_NULL (plugin_layer,
410- common::errors::InvalidArgument (
411- " TRT generic plugin layer building failed." ));
474+ PADDLE_ENFORCE_NOT_NULL (
475+ plugin_layer,
476+ common::errors::InvalidArgument (
477+ " TensorRT failed to add the generic plugin layer to the network. "
478+ " Ensure the plugin inputs match the expected TensorRT types." ));
412479
413480 engine->DeclareOutput (plugin_layer, 0 , " y" );
414481 std::vector<std::string> input_names = {" x" };
@@ -417,7 +484,11 @@ TEST(PluginTest, test_generic_plugin) {
417484 std::vector<phi::DataType> outputs_dtype = {phi::DataType::FLOAT32};
418485 LOG (INFO) << " freeze network" ;
419486 engine->FreezeNetwork ();
487+ #if IS_TRT_VERSION_GE(8600)
488+ ASSERT_EQ (engine->engine ()->getNbIOTensors (), 2 );
489+ #else
420490 ASSERT_EQ (engine->engine ()->getNbBindings (), 2 );
491+ #endif
421492 nvinfer1::IHostMemory *serialized_engine_data = engine->Serialize ();
422493 std::ofstream outFile (" engine_serialized_data.bin" , std::ios::binary);
423494 outFile.write (static_cast <const char *>(serialized_engine_data->data ()),
0 commit comments