diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index cb1b3ded53472..a1c8266842087 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -79,8 +79,10 @@ def test_api_error(self): y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) + y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2)) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) + self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3) if __name__ == "__main__": diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 15580b6618e6d..0b7ebee884cf7 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -850,6 +850,10 @@ def bmm(x, y, name=None): raise ValueError( "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}". format(x_shape, y_shape)) + if x_shape[0] != y_shape[0]: + raise ValueError( + "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}". + format(x_shape, y_shape)) helper = LayerHelper('bmm', **locals()) if in_dygraph_mode(): return core.ops.bmm(x, y)