Skip to content
Merged
12 changes: 8 additions & 4 deletions onnxruntime/contrib_ops/cpu/maxpool_with_mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,14 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
const TensorShape& x_shape = X->Shape();
const TensorShape& m_shape = M->Shape();
ORT_RETURN_IF_NOT(x_shape.NumDimensions() >= 3, "Input dimension cannot be less than 3.");

// TODO: fix this checker later
// ONNXRUNTIME_RETURN_IF_NOT((x_shape[2] == m_shape[2]) && (x_shape[3] == m_shape[3]), " Input shape and mask shape
// mismatch: ", x_shape, " vs ", m_shape);
ORT_RETURN_IF_NOT(m_shape.NumDimensions() == x_shape.NumDimensions(),
"Mask and input must have the same number of dimensions. Got mask dims: ",
m_shape.NumDimensions(), " input dims: ", x_shape.NumDimensions());
Comment thread
xadupre marked this conversation as resolved.
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(m_shape[i] == x_shape[i],
"Mask and input spatial dimensions mismatch at dimension ", i,
": mask=", m_shape[i], " input=", x_shape[i]);
}

TensorShapeVector pads = pool_attrs_.pads;
TensorShapeVector kernel_shape = pool_attrs_.kernel_shape;
Expand Down
54 changes: 54 additions & 0 deletions onnxruntime/test/contrib_ops/maxpool_mask_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,59 @@ TEST(ContribOpTest, MaxPoolWithMask) {
test.Run();
}

TEST(ContribOpTest, MaxPoolWithMask_SpatialDimMismatch) {
OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
test.AddAttribute("kernel_shape", std::vector<int64_t>{8, 8});

// Input X has shape {1, 1, 8, 8}
std::vector<int64_t> x_dims = {1, 1, 8, 8};
std::vector<float> x_vals(64, 1.0f);

// Mask M has wrong spatial dimensions: {1, 1, 4, 8} instead of {1, 1, 8, 8}
std::vector<int64_t> m_dims = {1, 1, 4, 8};
std::vector<int32_t> m_vals(32, 1);

// Placeholder output shape and values (not validated since we expect failure)
std::vector<int64_t> expected_dims = {1, 1, 1, 1};
std::vector<float> expected_vals = {1.0f};

test.AddInput<float>("X", x_dims, x_vals);
test.AddInput<int32_t>("M", m_dims, m_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(BaseTester::ExpectResult::kExpectFailure,
"Mask and input spatial dimensions mismatch at dimension 2");
}

TEST(ContribOpTest, MaxPoolWithMask_DimCountMismatch) {
OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
test.AddAttribute("kernel_shape", std::vector<int64_t>{8, 8});

// Input X has shape {1, 1, 8, 8} (4D)
std::vector<int64_t> x_dims = {1, 1, 8, 8};
std::vector<float> x_vals(64, 1.0f);

// Mask M has wrong number of dimensions: {1, 1, 8} (3D) instead of 4D
std::vector<int64_t> m_dims = {1, 1, 8};
std::vector<int32_t> m_vals(8, 1);

// Placeholder output shape and values (not validated since we expect failure)
std::vector<int64_t> expected_dims = {1, 1, 1, 1};
std::vector<float> expected_vals = {1.0f};

test.AddInput<float>("X", x_dims, x_vals);
test.AddInput<int32_t>("M", m_dims, m_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(BaseTester::ExpectResult::kExpectFailure,
"Mask and input must have the same number of dimensions");
}

Comment thread
xadupre marked this conversation as resolved.
} // namespace test
} // namespace onnxruntime
Loading