diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 1dfe8bf19c296..ab26e8f7c787b 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -67,24 +67,27 @@ SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x, // x[0:begin_norm_axis], only the first axis of x can // be sharded std::string x_axes(x_ndim, '1'); - x_axes[0] = alphabet[0]; + std::string mean_axes(begin_norm_axis, '1'); + std::string variance_axes(begin_norm_axis, '1'); + // allow axis before begin_norm_axis be sharded + for (int i = 0; i < begin_norm_axis; ++i) { + x_axes[i] = alphabet[i]; + mean_axes[i] = alphabet[i]; + variance_axes[i] = alphabet[i]; + } + // x_axes[0] = alphabet[0]; std::string scale_axes(1, x_axes[x_ndim - 1]); std::string bias_axes(1, x_axes[x_ndim - 1]); // get output notation std::string out_axes = x_axes; - std::string mean_axes(1, '1'), variance_axes(1, '1'); - if (begin_norm_axis > 0) { - mean_axes[0] = out_axes[0]; - variance_axes[0] = out_axes[0]; - } // Step2: Sharding Propogation // Step2.1: merge input sharding // As the mean and variance in outputs are `flattened` from // x[0:begin_norm_axis], only the first axis can be sharded, // the axes 1 to begin_norm_axis-1 are set to be replicated. - std::fill(x_dims_mapping.begin() + 1, x_dims_mapping.end(), -1); + std::fill(x_dims_mapping.begin() + begin_norm_axis, x_dims_mapping.end(), -1); std::unordered_map axis_to_dim_map = ShardingMergeForTensors({{x_axes, x_dims_mapping}}); @@ -199,16 +202,18 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, // the axes after norm_axis should be replicated, // so set their notation to '1'. std::string x_axes(x_ndim, '1'); - x_axes[0] = alphabet[0]; + std::string mean_axes(begin_norm_axis, '1'); + std::string variance_axes(begin_norm_axis, '1'); + // allow axis before begin_norm_axis be sharded + for (int i = 0; i < begin_norm_axis; ++i) { + x_axes[i] = alphabet[i]; + mean_axes[i] = alphabet[i]; + variance_axes[i] = alphabet[i]; + } + std::string scale_axes(1, x_axes[x_ndim - 1]); std::string bias_axes(1, x_axes[x_ndim - 1]); - std::string out_axes = x_axes; - std::string mean_axes(1, '1'), variance_axes(1, '1'); - if (begin_norm_axis > 0) { - mean_axes[0] = out_axes[0]; - variance_axes[0] = out_axes[0]; - } // Step2: Sharding Propogation // For the axes after norm_axis in both input and output tensors, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index d86b25b7ba224..a38e9ca6f9a14 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -579,7 +579,10 @@ void LayerNormInferMeta(const MetaTensor& x, x_dim.size())); auto matrix_dim = phi::flatten_to_2d(x_dim, begin_norm_axis); - int left = static_cast(matrix_dim[0]); + + // keep the axis size before normalization for shape of variance and mean + auto before_norm_dims = slice_ddim(x_dim, 0, begin_norm_axis); + // int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); if (scale) { PADDLE_ENFORCE_EQ(scale.dims().size(), @@ -644,11 +647,11 @@ void LayerNormInferMeta(const MetaTensor& x, ? phi::DataType::FLOAT32 : x_dtype; if (mean) { - mean->set_dims({left}); + mean->set_dims({before_norm_dims}); mean->set_dtype(param_type); } if (variance) { - variance->set_dims({left}); + variance->set_dims({before_norm_dims}); variance->set_dtype(param_type); } } diff --git a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc index 3707d36293048..ddc6359875671 100644 --- a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc @@ -52,8 +52,15 @@ void LayerNormGradKernel(const Context& dev_ctx, int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); DDim matrix_shape({left, right}); + DDim var_shape({left}); d_y.Resize(matrix_shape); + // resize mean and var to match the shape of resized d_y for broadcast (Resize + // will not modify the underline data) + auto mean_tmp = mean; + mean_tmp.Resize(var_shape); + auto variance_tmp = variance; + variance_tmp.Resize(var_shape); funcs::ColwiseSum2D colwise_sum(left, right, dev_ctx); DenseTensor x_tmp = x; @@ -69,11 +76,11 @@ void LayerNormGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(&temp_norm); // get x_norm phi::funcs::ElementwiseCompute, T>( - dev_ctx, x_tmp, mean, funcs::SubtractFunctor(), &temp_norm, 0); + dev_ctx, x_tmp, mean_tmp, funcs::SubtractFunctor(), &temp_norm, 0); phi::funcs::ElementwiseCompute, T>( dev_ctx, temp_norm, - variance, + variance_tmp, funcs::DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm, 0); @@ -137,7 +144,7 @@ void LayerNormGradKernel(const Context& dev_ctx, phi::funcs::ElementwiseCompute, T>( dev_ctx, *d_x, - variance, + variance_tmp, funcs::DivAndSqrtFunctor(static_cast(epsilon)), d_x, 0); diff --git a/paddle/phi/kernels/cpu/layer_norm_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_kernel.cc index 2a93d03b4abc1..b15b1554a51c4 100644 --- a/paddle/phi/kernels/cpu/layer_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_kernel.cc @@ -50,12 +50,20 @@ void LayerNormKernel(const Context& dev_ctx, int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); DDim matrix_shape({left, right}); + DDim normalized_shape({left}); auto x_tmp = x; x_tmp.Resize(matrix_shape); DenseTensor out; out.ShareDataWith(*y); out.Resize(matrix_shape); + // resize mean and var to match the shape of resized x_tmp for broadcast + DenseTensor mean_tmp; + mean_tmp.ShareDataWith(*mean); + mean_tmp.Resize(normalized_shape); + DenseTensor var_tmp; + var_tmp.ShareDataWith(*var); + var_tmp.Resize(normalized_shape); #if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \ defined(__OSX__) @@ -63,23 +71,23 @@ void LayerNormKernel(const Context& dev_ctx, funcs::RowwiseMean2D row_mean(left, right, dev_ctx); // get mean - row_mean(dev_ctx, x_tmp, mean); + row_mean(dev_ctx, x_tmp, &mean_tmp); // get variance phi::funcs::ElementwiseCompute, T>( - dev_ctx, x_tmp, *mean, funcs::SubAndSquareFunctor(), &out, 0); + dev_ctx, x_tmp, mean_tmp, funcs::SubAndSquareFunctor(), &out, 0); - row_mean(dev_ctx, out, var); + row_mean(dev_ctx, out, &var_tmp); // get x_norm phi::funcs::ElementwiseCompute, T>( - dev_ctx, x_tmp, *mean, funcs::SubtractFunctor(), &out, 0); + dev_ctx, x_tmp, mean_tmp, funcs::SubtractFunctor(), &out, 0); phi::funcs::ElementwiseCompute, T>( dev_ctx, out, - *var, + var_tmp, funcs::DivAndSqrtFunctor(static_cast(epsilon)), &out, 0); @@ -93,17 +101,17 @@ void LayerNormKernel(const Context& dev_ctx, dev_ctx, out, *bias, funcs::AddFunctor(), &out, 1); } #else - PADDLE_ENFORCE_EQ(mean->numel(), + PADDLE_ENFORCE_EQ(mean_tmp.numel(), left, phi::errors::InvalidArgument( "mean's length (%d) is not equal with expected (%d).", - mean->numel(), + mean_tmp.numel(), left)); - PADDLE_ENFORCE_EQ(var->numel(), + PADDLE_ENFORCE_EQ(var_tmp.numel(), left, phi::errors::InvalidArgument( "var's length (%d) is not equal with expected (%d).", - var->numel(), + var_tmp.numel(), left)); if (scale) { PADDLE_ENFORCE_EQ( @@ -128,8 +136,8 @@ void LayerNormKernel(const Context& dev_ctx, .At(right); ker(x_tmp.data(), out.data(), - mean->data(), - var->data(), + mean_tmp.data(), + var_tmp.data(), scale ? scale->data() : nullptr, bias ? bias->data() : nullptr, static_cast(left), diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 9123b98ac2054..76b142d0c9b03 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -179,8 +179,9 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): bias = reshape(bias, x.shape[begin_norm_axis:]) out = out + bias - mean_ = reshape(mean_, [-1]) - variance = reshape(variance, [-1]) + # keep the mean and variance shape as input x before begin_norm_axis + mean_ = reshape(mean_, x.shape[:begin_norm_axis]) + variance = reshape(variance, x.shape[:begin_norm_axis]) if is_amp: out = cast(out, dtype) return out, mean_, variance diff --git a/test/auto_parallel/spmd_rules/test_layer_norm_rule.py b/test/auto_parallel/spmd_rules/test_layer_norm_rule.py index 9af336fd8d214..6bbf5084e2212 100644 --- a/test/auto_parallel/spmd_rules/test_layer_norm_rule.py +++ b/test/auto_parallel/spmd_rules/test_layer_norm_rule.py @@ -15,6 +15,8 @@ import unittest from collections import OrderedDict +import numpy as np + from paddle.distributed.auto_parallel.static.dist_attribute import ( DistTensorSpec, TensorDistAttr, @@ -57,7 +59,7 @@ def setUp(self): def test_infer_forward(self): # ijk[1, -1, -1], k[-1], k[-1] --> # ijk[1, -1, -1], k[-1], k[-1], (inputs) - # ijk[1, -1, -1], z[1], z[1], z=ij (outputs) + # ijk[1, -1, -1], ij[1, -1], ij[1, -1],(outputs) # begin_norm_axis=2 self.x_spec.set_dims_mapping([1, -1, -1]) self.bias_spec.set_dims_mapping([-1]) @@ -81,12 +83,12 @@ def test_infer_forward(self): self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1, -1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [1, -1]) # ijk[1, 0, -1],k[0],k[0] --> - # [1, -1, -1], [-1], [-1] (inputs) - # [1, -1, -1], [1], [1] (outputs) + # [1, 0, -1], [-1], [-1] (inputs) + # [1, 0, -1], [1, 0], [1, 0] (outputs) # begin_norm_axis=2 self.x_spec.set_dims_mapping([1, 0, -1]) self.scale_spec.set_dims_mapping([0]) @@ -106,16 +108,16 @@ def test_infer_forward(self): self.assertEqual(len(infered_input_dist_attrs), 3) self.assertEqual(len(infered_output_dist_attrs), 3) - self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1]) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) - self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [1, 0]) # ijk[0, -1, -1],y[-1],y[1] --> # ijk[0, -1, -1],y[-1],y[-1], (inputs) - # ijk[0, -1, -1], i[0], i[0], y=jk (outputs) + # ijk[0, -1, -1], ij[0], ij[0], y=jk (outputs) # begin_norm_axis=1 self.attrs['begin_norm_axis'] = 1 self.x_spec.set_dims_mapping([0, -1, -1]) @@ -124,6 +126,8 @@ def test_infer_forward(self): self.bias_spec.shape = [x_shape[1] * x_shape[2]] self.scale_spec.set_dims_mapping([-1]) self.bias_spec.set_dims_mapping([1]) + self.mean_spec.shape = [x_shape[1]] + self.var_spec.shape = [x_shape[1]] result_dist_attrs = self.rule.infer_forward( self.x_spec, @@ -147,24 +151,21 @@ def test_infer_forward(self): self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0]) def test_infer_backward(self): - import math - - # [1, -1, -1], [1], [1] (outputs) --> + # [1, -1, -1], [1, -1], [1, -1] (outputs) --> # [1, -1, -1], [-1], [-1], (inputs) - # [1, -1, -1], [1], [1] (outputs) + # [1, -1, -1], [1, -1], [1, -1] (outputs) # begin_norm_axis=2 self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [1024] self.bias_spec.shape = [1024] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) - ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + self.out_spec.set_dims_mapping([1, -1, -1]) - self.mean_spec.set_dims_mapping([1]) - self.var_spec.set_dims_mapping([1]) + self.mean_spec.set_dims_mapping([1, -1]) + self.var_spec.set_dims_mapping([1, -1]) result_dist_attrs = self.rule.infer_backward( self.x_spec, @@ -187,29 +188,28 @@ def test_infer_backward(self): self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1, -1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [1, -1]) - # [0, -1, -1], [0], [0] (outputs) --> + # [0, -1, -1], [0, -1], [0, -1] (outputs) --> # [0, -1, -1], [-1], [-1], (inputs) - # [0, -1, -1], [0], [0] (outputs) + # [0, -1, -1], [0, -1], [0, -1] (outputs) # begin_norm_axis=2 self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] self.bias_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) - ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + self.out_spec.set_dims_mapping([0, -1, -1]) - self.mean_spec.set_dims_mapping([0]) - self.var_spec.set_dims_mapping([0]) + self.mean_spec.set_dims_mapping([0, -1]) + self.var_spec.set_dims_mapping([0, -1]) result_dist_attrs = self.rule.infer_backward( self.x_spec, @@ -232,29 +232,28 @@ def test_infer_backward(self): self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, -1]) - # [-1, -1, -1], [0], [-1] (outputs) --> - # [0, -1, -1], [-1], [-1], (inputs) - # [0, -1, -1], [0], [0] (outputs) + # [-1, -1, -1], [0, -1], [-1, 1] (outputs) --> + # [0, 1, -1], [-1], [-1], (inputs) + # [0, 1, -1], [0, 1], [0, 1] (outputs) # begin_norm_axis=2 self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] self.bias_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) - ] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + self.out_spec.set_dims_mapping([-1, -1, -1]) - self.mean_spec.set_dims_mapping([0]) - self.var_spec.set_dims_mapping([-1]) + self.mean_spec.set_dims_mapping([0, -1]) + self.var_spec.set_dims_mapping([-1, 1]) result_dist_attrs = self.rule.infer_backward( self.x_spec, @@ -273,33 +272,32 @@ def test_infer_backward(self): self.assertEqual(len(infered_input_dist_attrs), 3) self.assertEqual(len(infered_output_dist_attrs), 3) - self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) - self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, 1]) - # [-1, 1, -1], [-1], [-1] (outputs) --> - # [-1, -1, -1], [-1], [-1], (inputs) - # [-1, -1, -1], [-1], [-1] (outputs) + # [-1, 1, -1], [-1, -1], [-1, -1] (outputs) --> + # [-1, 1, -1], [-1], [-1], (inputs) + # [-1, 1, -1], [-1, 1], [-1, 1] (outputs) # begin_norm_axis=2 self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] self.bias_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) - ] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + self.out_spec.set_dims_mapping([-1, 1, -1]) - self.mean_spec.set_dims_mapping([-1]) - self.var_spec.set_dims_mapping([-1]) + self.mean_spec.set_dims_mapping([-1, -1]) + self.var_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_backward( self.x_spec, @@ -318,33 +316,30 @@ def test_infer_backward(self): self.assertEqual(len(infered_input_dist_attrs), 3) self.assertEqual(len(infered_output_dist_attrs), 3) - self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1]) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) - self.assertEqual( - infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1] - ) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [-1]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [-1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [-1, 1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [-1, 1]) - # [1, -1, -1], [0], [-1] (outputs) --> error + # [1, -1, -1], [0, -1], [-1, -1] (outputs) --> error # begin_norm_axis=2 self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] self.bias_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) - ] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + self.out_spec.set_dims_mapping([1, -1, -1]) - self.mean_spec.set_dims_mapping([0]) - self.var_spec.set_dims_mapping([-1]) + self.mean_spec.set_dims_mapping([0, -1]) + self.var_spec.set_dims_mapping([-1, -1]) with self.assertRaises(NotImplementedError): result_dist_attrs = self.rule.infer_backward( @@ -358,26 +353,25 @@ def test_infer_backward(self): self.attrs['begin_norm_axis'], ) - # [-1, 1, -1], [0], [-1] (outputs) --> - # [0, -1, -1], [-1], [-1] (inputs) - # [0, -1, -1], [0], [0] (outputs) + # [-1, 1, -1], [0, -1], [-1, -1] (outputs) --> + # [0, 1, -1], [-1], [-1] (inputs) + # [0, 1, -1], [0, 1], [0, 1] (outputs) # begin_norm_axis=2 self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] self.bias_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) - ] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + self.out_spec.set_dims_mapping([-1, 1, -1]) - self.mean_spec.set_dims_mapping([0]) - self.var_spec.set_dims_mapping([-1]) + self.mean_spec.set_dims_mapping([0, -1]) + self.var_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_backward( self.x_spec, @@ -396,33 +390,32 @@ def test_infer_backward(self): self.assertEqual(len(infered_input_dist_attrs), 3) self.assertEqual(len(infered_output_dist_attrs), 3) - self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) - self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, 1]) - # [0, 1, -1], [-1], [-1] (outputs) --> - # [0, -1, -1], [-1], [-1] (inputs) - # [0, -1, -1], [0], [0] (outputs) + # [0, 1, -1], [-1, -1], [-1, -1] (outputs) --> + # [0, 1, -1], [-1], [-1] (inputs) + # [0, 1, -1], [0, 1], [0, 1] (outputs) # begin_norm_axis=2 self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] self.bias_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) - ] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + self.out_spec.set_dims_mapping([0, 1, -1]) - self.mean_spec.set_dims_mapping([-1]) - self.var_spec.set_dims_mapping([-1]) + self.mean_spec.set_dims_mapping([-1, -1]) + self.var_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_backward( self.x_spec, @@ -441,33 +434,32 @@ def test_infer_backward(self): self.assertEqual(len(infered_input_dist_attrs), 3) self.assertEqual(len(infered_output_dist_attrs), 3) - self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) - self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, 1]) - # [0, 1, -1], [-1], [-1] (outputs) --> - # [0, -1, -1], [-1], [-1], (inputs) - # [0, -1, -1], [0], [0] (outputs) + # [0, -1, -1], [-1, 1], [-1, -1] (outputs) --> + # [0, 1, -1], [-1], [-1], (inputs) + # [0, 1, -1], [0, 1], [0, 1] (outputs) # begin_norm_axis=1 - self.attrs['begin_norm_axis'] = 1 + self.attrs['begin_norm_axis'] = 2 self.scale_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] self.bias_spec.shape = [ - math.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) + np.prod(self.x_spec.shape[self.attrs['begin_norm_axis'] :]) ] - self.mean_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) + self.mean_spec.shape = self.x_spec.shape[ + : self.attrs['begin_norm_axis'] ] - self.var_spec.shape = [ - math.prod(self.x_spec.shape[: self.attrs['begin_norm_axis']]) - ] - self.out_spec.set_dims_mapping([0, 1, -1]) - self.mean_spec.set_dims_mapping([-1]) - self.var_spec.set_dims_mapping([-1]) + self.var_spec.shape = self.x_spec.shape[: self.attrs['begin_norm_axis']] + + self.out_spec.set_dims_mapping([0, -1, -1]) + self.mean_spec.set_dims_mapping([-1, 1]) + self.var_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_backward( self.x_spec, @@ -486,12 +478,12 @@ def test_infer_backward(self): self.assertEqual(len(infered_input_dist_attrs), 3) self.assertEqual(len(infered_output_dist_attrs), 3) - self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) self.assertEqual(infered_input_dist_attrs[2].dims_mapping, [-1]) - self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) - self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0]) - self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, 1]) if __name__ == "__main__": diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index eb6d08542b04a..716bb976a267d 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -362,11 +362,11 @@ TEST(LayerNormSPMDRule, Ctor) { check_dim_mapping(infered_dist_attrs.first[1], {-1}); check_dim_mapping(infered_dist_attrs.first[2], {-1}); check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); - check_dim_mapping(infered_dist_attrs.second[1], {1}); - check_dim_mapping(infered_dist_attrs.second[2], {1}); + check_dim_mapping(infered_dist_attrs.second[1], {1, -1}); + check_dim_mapping(infered_dist_attrs.second[2], {1, -1}); VLOG(4) << "test1 done."; - // ijk[1, 0, -1],k[0],k[0] --> ijk[1, -1, -1],z[1],z[1], + // ijk[1, 0, -1],k[0],k[0] --> ijk[1, -1, -1],z[1, 0],z[1, 0], // begin_norm_axis=2 begin_norm_axis = 2; x_dist_attr.set_dims_mapping({1, 0, -1}); @@ -381,15 +381,15 @@ TEST(LayerNormSPMDRule, Ctor) { {epsilon, begin_norm_axis}); infered_dist_attrs = layer_norm_rule.InferForward(ctx); - check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[0], {1, 0, -1}); check_dim_mapping(infered_dist_attrs.first[1], {-1}); check_dim_mapping(infered_dist_attrs.first[2], {-1}); - check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); - check_dim_mapping(infered_dist_attrs.second[1], {1}); - check_dim_mapping(infered_dist_attrs.second[2], {1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, 0, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {1, 0}); + check_dim_mapping(infered_dist_attrs.second[2], {1, 0}); VLOG(4) << "test2 done."; - // ijk[0, -1, -1],y[-1],y[1] --> ijk[0, 1, -1], i[0], i[0], y=jk, + // ijk[0, -1, -1],y[-1],y[1] --> ijk[0, -1, -1], i[0], i[0], y=jk, // begin_norm_axis=1 begin_norm_axis = 1; x_dist_attr.set_dims_mapping({0, -1, -1}); diff --git a/test/mkldnn/test_layer_norm_mkldnn_op.py b/test/mkldnn/test_layer_norm_mkldnn_op.py index 4d0d7cce9c1dc..c225469e71cc8 100644 --- a/test/mkldnn/test_layer_norm_mkldnn_op.py +++ b/test/mkldnn/test_layer_norm_mkldnn_op.py @@ -49,6 +49,9 @@ def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): ) x.shape, output.shape = x_shape, x_shape + mean.shape = x_shape[0:begin_norm_axis] + var.shape = x_shape[0:begin_norm_axis] + return output, mean, var diff --git a/test/xpu/test_layer_norm_op_xpu.py b/test/xpu/test_layer_norm_op_xpu.py index 1b98c4fe081b4..cd2a42f5c1cd1 100644 --- a/test/xpu/test_layer_norm_op_xpu.py +++ b/test/xpu/test_layer_norm_op_xpu.py @@ -44,6 +44,8 @@ def ref_layer_norm(x, scale, bias, epsilon, begin_norm_axis=1): if bias is not None: y = y + bias.reshape([1, right]) x.shape, y.shape = x_shape, x_shape + mean.shape = x_shape[0:begin_norm_axis] + variance.shape = x_shape[0:begin_norm_axis] return y, mean, variance