Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ struct ConvTransposeAttributes : public ConvAttributes {
" group: ", group);
}

// Bias shape validation (It should be a 1D tensor with size M)
// See https://github.com/microsoft/onnxruntime/issues/26144
if (B != nullptr) {
if (B->Shape().NumDimensions() != 1 || B->Shape()[0] != num_output_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Bias shape is not compatible with number of output channels."
" It should be a 1-D tensor with size num_output_channels(M).",
" Bias: ", B->Shape(),
" num_output_channels: ", num_output_channels);
}
}

TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(ComputeKernelShape(F_Shape, kernel_shape, is_nhwc));

Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cuda/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,18 @@ Status ConvTranspose<T, Layout>::UpdateState(OpKernelContext* context, bool dyna
" group: ", conv_transpose_attrs_.group);
}

// Bias shape validation (It should be a 1D tensor with size M)
// See https://github.com/microsoft/onnxruntime/issues/26144
if (B != nullptr) {
if (B->Shape().NumDimensions() != 1 || B->Shape()[0] != num_output_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Bias shape is not compatible with number of output channels."
" It should be a 1-D tensor with size num_output_channels(M).",
" Bias: ", B->Shape(),
" num_output_channels: ", num_output_channels);
}
}

TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc));

Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,74 @@ TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) {
kDmlExecutionProvider}); // TODO: Unskip when fixed #41968513
}

TEST(ConvTransposeTest, ConvTranspose_InvalidBiasShape_1) {
ConvTransposeOpAttributes attrs = {
vector<int64_t>{1, 5}, // kernel_shape
{}, // output_padding
vector<int64_t>{2, 1, 1, 14}, // output_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
vector<int64_t>{1, 1}, // dilations
1, // group
"NOTSET" // auto_pad
};
vector<float> X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f};
vector<int64_t> X_shape = {2, 1, 1, 10};
vector<float> W = {1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
vector<int64_t> W_shape = {1, 1, 1, 5};
vector<float> B = {1.0f, 2.0f}; // invalid bias shape, should be {1}
vector<int64_t> B_shape = {2}; // invalid bias shape, should be {1}
vector<int64_t> Y_shape = {2, 1, 1, 14};
vector<float> expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f,
11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f};
TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape,
OpTester::ExpectResult::kExpectFailure,
// Just ensure that it starts with the expected string.
"Bias shape is not compatible with number of output channels. "
"It should be a 1-D tensor with size num_output_channels(M).",
// The EP exclusions are along the same lines as ConvTranspose_InvalidKernelShape which
// also tests for invalid shapes. It also includes XnnPack which seems to have its own
// way of dealing with incorrectly shaped bias.
{kTensorrtExecutionProvider, kQnnExecutionProvider,
kDmlExecutionProvider, kXnnpackExecutionProvider,
kWebGpuExecutionProvider}); // Remove when https://github.com/microsoft/onnxruntime/issues/27210 is fixed
}

TEST(ConvTransposeTest, ConvTranspose_InvalidBiasShape_2) {
ConvTransposeOpAttributes attrs = {
vector<int64_t>{1, 5}, // kernel_shape
{}, // output_padding
vector<int64_t>{2, 1, 1, 14}, // output_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
vector<int64_t>{1, 1}, // dilations
1, // group
"NOTSET" // auto_pad
};
vector<float> X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f};
vector<int64_t> X_shape = {2, 1, 1, 10};
vector<float> W = {1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
vector<int64_t> W_shape = {1, 1, 1, 5};
vector<float> B = {1.0f, 2.0f};
vector<int64_t> B_shape = {1, 2}; // invalid bias rank (it should be 1-D)
vector<int64_t> Y_shape = {2, 1, 1, 14};
vector<float> expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f,
11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f};
TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape,
OpTester::ExpectResult::kExpectFailure,
// Just ensure that it starts with the expected string.
"Bias shape is not compatible with number of output channels. "
"It should be a 1-D tensor with size num_output_channels(M).",
// The EP exclusions are along the same lines as ConvTranspose_InvalidKernelShape which
// also tests for invalid shapes. It also includes XnnPack which seems to have its own
// way of dealing with incorrectly shaped bias.
{kTensorrtExecutionProvider, kQnnExecutionProvider,
kDmlExecutionProvider, kXnnpackExecutionProvider,
kWebGpuExecutionProvider}); // Remove when https://github.com/microsoft/onnxruntime/issues/27210 is fixed
}

TEST(ConvTransposeTest, ConvTranspose_onnx) {
ConvTransposeOpAttributes attrs = {
vector<int64_t>{3, 3}, // kernel_shape
Expand Down
Loading