Skip to content

Commit

Permalink
fix layer norm ut (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
USTCKAY authored Nov 15, 2023
1 parent 8b1e23c commit e1fffca
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions backends/mlu/tests/unittests/test_layer_norm_op_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def test_with_place(place, shape, begin_norm_axis, use_mkldnn=use_mkldnn):
x_grad, scale_grad, bias_grad = _reference_layer_norm_grad(
x, y_grad, scale, bias, mean, variance, begin_norm_axis
)
mean.shape = x_shape[0:begin_norm_axis]
variance.shape = x_shape[0:begin_norm_axis]

var_dict = locals()
var_dict["y@GRAD"] = y_grad
Expand Down
2 changes: 2 additions & 0 deletions backends/npu/tests/unittests/test_layer_norm_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def test_with_place(place, shape, begin_norm_axis, use_mkldnn=use_mkldnn):
x_grad, scale_grad, bias_grad = _reference_layer_norm_grad(
x, y_grad, scale, bias, mean, variance, begin_norm_axis
)
mean.shape = x_shape[0:begin_norm_axis]
variance.shape = x_shape[0:begin_norm_axis]

var_dict = locals()
var_dict["y@GRAD"] = y_grad
Expand Down

0 comments on commit e1fffca

Please sign in to comment.