Skip to content
Merged
20 changes: 16 additions & 4 deletions onnxruntime/contrib_ops/cpu/maxpool_with_mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,22 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
const TensorShape& x_shape = X->Shape();
const TensorShape& m_shape = M->Shape();
ORT_RETURN_IF_NOT(x_shape.NumDimensions() >= 3, "Input dimension cannot be less than 3.");

// TODO: fix this checker later
// ONNXRUNTIME_RETURN_IF_NOT((x_shape[2] == m_shape[2]) && (x_shape[3] == m_shape[3]), " Input shape and mask shape
// mismatch: ", x_shape, " vs ", m_shape);
ORT_RETURN_IF_NOT(m_shape.NumDimensions() == x_shape.NumDimensions(),
"Mask and input must have the same number of dimensions. Got mask dims: ",
m_shape.NumDimensions(), " input dims: ", x_shape.NumDimensions());
Comment thread
xadupre marked this conversation as resolved.
const bool input_has_nonzero_channels = x_shape[0] > 0 && x_shape[1] > 0;
// Mask N and C dimensions may differ from input (broadcasting via modulo).
// Only require them to be nonzero to prevent division-by-zero in total_mask_channels.
ORT_RETURN_IF_NOT(!input_has_nonzero_channels || (m_shape[0] > 0 && m_shape[1] > 0),
"Mask N and C dimensions must be greater than 0 when input N and C are greater than 0. "
"Got mask N=",
Comment thread
xadupre marked this conversation as resolved.
m_shape[0], " C=", m_shape[1],
" input N=", x_shape[0], " C=", x_shape[1]);
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(m_shape[i] == x_shape[i],
"Mask and input spatial dimensions mismatch at dimension ", i,
": mask=", m_shape[i], " input=", x_shape[i]);
}

TensorShapeVector pads = pool_attrs_.pads;
TensorShapeVector kernel_shape = pool_attrs_.kernel_shape;
Expand Down
81 changes: 81 additions & 0 deletions onnxruntime/test/contrib_ops/maxpool_mask_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,86 @@ TEST(ContribOpTest, MaxPoolWithMask) {
test.Run();
}

TEST(ContribOpTest, MaxPoolWithMask_SpatialDimMismatch) {
OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
test.AddAttribute("kernel_shape", std::vector<int64_t>{8, 8});

// Input X has shape {1, 1, 8, 8}
std::vector<int64_t> x_dims = {1, 1, 8, 8};
std::vector<float> x_vals(64, 1.0f);

// Mask M has wrong spatial dimensions: {1, 1, 4, 8} instead of {1, 1, 8, 8}
std::vector<int64_t> m_dims = {1, 1, 4, 8};
std::vector<int32_t> m_vals(32, 1);

// Placeholder output shape and values (not validated since we expect failure)
std::vector<int64_t> expected_dims = {1, 1, 1, 1};
std::vector<float> expected_vals = {1.0f};

test.AddInput<float>("X", x_dims, x_vals);
test.AddInput<int32_t>("M", m_dims, m_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(BaseTester::ExpectResult::kExpectFailure,
"Mask and input spatial dimensions mismatch at dimension 2");
}

TEST(ContribOpTest, MaxPoolWithMask_DimCountMismatch) {
OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
test.AddAttribute("kernel_shape", std::vector<int64_t>{8, 8});

// Input X has shape {1, 1, 8, 8} (4D)
std::vector<int64_t> x_dims = {1, 1, 8, 8};
std::vector<float> x_vals(64, 1.0f);

// Mask M has wrong number of dimensions: {1, 1, 8} (3D) instead of 4D
std::vector<int64_t> m_dims = {1, 1, 8};
std::vector<int32_t> m_vals(8, 1);

// Placeholder output shape and values (not validated since we expect failure)
std::vector<int64_t> expected_dims = {1, 1, 1, 1};
std::vector<float> expected_vals = {1.0f};

test.AddInput<float>("X", x_dims, x_vals);
test.AddInput<int32_t>("M", m_dims, m_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(BaseTester::ExpectResult::kExpectFailure,
"Mask and input must have the same number of dimensions");
}

Comment thread
xadupre marked this conversation as resolved.
TEST(ContribOpTest, MaxPoolWithMask_MaskEmptyBatchDim) {
OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
test.AddAttribute("kernel_shape", std::vector<int64_t>{8, 8});

// Input X has shape {1, 1, 8, 8} (non-empty)
std::vector<int64_t> x_dims = {1, 1, 8, 8};
std::vector<float> x_vals(64, 1.0f);

// Mask M has N=0: should trigger the nonzero N/C guard
std::vector<int64_t> m_dims = {0, 1, 8, 8};
std::vector<int32_t> m_vals; // 0 elements

// Placeholder output shape and values (not validated since we expect failure)
std::vector<int64_t> expected_dims = {1, 1, 1, 1};
std::vector<float> expected_vals = {1.0f};

test.AddInput<float>("X", x_dims, x_vals);
test.AddInput<int32_t>("M", m_dims, m_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(BaseTester::ExpectResult::kExpectFailure,
"Mask N and C dimensions must be greater than 0");
}

} // namespace test
} // namespace onnxruntime
Loading