Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI]Add yaml and unittest for bmm op #44625

Merged
merged 2 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,16 @@
kernel :
func : bitwise_xor

# bmm
- api : bmm
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : BmmInferMeta
kernel :
func : bmm
backward : bmm_grad

# box_coder
- api : box_coder
args : (Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type, bool box_normalized, int axis, float[] variance)
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@
kernel :
func : bilinear_tensor_product_grad

- backward_api : bmm_grad
forward : bmm (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : BmmGradInferMeta
kernel :
func : bmm_grad

- backward_api : brelu_grad
forward : brelu (Tensor x, float t_min, float t_max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float t_min, float t_max)
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/fluid/tests/unittests/test_bmm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ class TestBmmOp(OpTest):

def setUp(self):
self.op_type = "bmm"
self.python_api = paddle.tensor.bmm
X = np.random.random((10, 3, 4)).astype("float64")
Y = np.random.random((10, 4, 5)).astype("float64")
self.inputs = {'X': X, 'Y': Y}
Out = np.matmul(X, Y)
self.outputs = {'Out': Out}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_checkout_grad(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)


class API_TestBmm(unittest.TestCase):
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,9 @@ def bmm(x, y, name=None):
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}"
.format(x_shape, y_shape))

if in_dygraph_mode():
return _C_ops.final_state_bmm(x, y)

if paddle.in_dynamic_mode():
return _C_ops.bmm(x, y)

Expand Down