Skip to content

Commit

Permalink
[PHI]Add yaml and unittest for bmm op
Browse files Browse the repository at this point in the history
  • Loading branch information
BiynXu committed Jul 26, 2022
1 parent 9bc54c8 commit d52efaa
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,16 @@
kernel :
func : bitwise_xor

# bmm
- api : bmm
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : BmmInferMeta
kernel :
func : bmm
backward : bmm_grad

# brelu
- api : brelu
args : (Tensor x, float t_min, float t_max)
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@
kernel :
func : bilinear_tensor_product_grad

- backward_api : bmm_grad
forward : bmm (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : BmmGradInferMeta
kernel :
func : bmm_grad

- backward_api : brelu_grad
forward : brelu (Tensor x, float t_min, float t_max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float t_min, float t_max)
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/fluid/tests/unittests/test_bmm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ class TestBmmOp(OpTest):

def setUp(self):
self.op_type = "bmm"
self.python_api = paddle.tensor.bmm
X = np.random.random((10, 3, 4)).astype("float64")
Y = np.random.random((10, 4, 5)).astype("float64")
self.inputs = {'X': X, 'Y': Y}
Out = np.matmul(X, Y)
self.outputs = {'Out': Out}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_checkout_grad(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)


class API_TestBmm(unittest.TestCase):
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,9 @@ def bmm(x, y, name=None):
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}"
.format(x_shape, y_shape))

if in_dygraph_mode():
return _C_ops.final_state_bmm(x, y)

if paddle.in_dynamic_mode():
return _C_ops.bmm(x, y)

Expand Down

0 comments on commit d52efaa

Please sign in to comment.