diff --git a/paddle/fluid/operators/layer_norm_op_npu.cc b/paddle/fluid/operators/layer_norm_op_npu.cc index 447eda1a8a4c4..95549319cd209 100644 --- a/paddle/fluid/operators/layer_norm_op_npu.cc +++ b/paddle/fluid/operators/layer_norm_op_npu.cc @@ -21,10 +21,36 @@ namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; +using DataLayout = framework::DataLayout; + +template +class NormDataType; + +template <> +class NormDataType { + public: + // The scaling param type is float for HALF and FLOAT tensors + using ScalingParamType = const float; + using BatchNormParamType = float; +}; + +template <> +class NormDataType { + public: + using ScalingParamType = const float; + using BatchNormParamType = float; +}; + +template +using NormDataType = NormDataType; +template +using LayerNormParamType = typename NormDataType::BatchNormParamType; + template class LayerNormNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + using U = LayerNormParamType; const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto epsilon = ctx.Attr("epsilon"); const auto* x = ctx.Input("X"); @@ -43,6 +69,7 @@ class LayerNormNPUKernel : public framework::OpKernel { for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { axes.push_back(x_dims[i]); } + auto place = ctx.GetPlace(); auto stream = ctx.template device_context() @@ -77,16 +104,93 @@ class LayerNormNPUKernel : public framework::OpKernel { } else { const_cast(bias)->Resize(framework::make_ddim(axes)); } + + // cast scale from LayerNormParamType to T if needed + Tensor cast_scale(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + scale->type() == framework::proto::VarType::FP32) { + cast_scale.Resize(scale->dims()); + cast_scale.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(x->type()); + auto runner_cast_scale = + NpuOpRunner("Cast", {*scale}, {cast_scale}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_scale.Run(stream); + } else { + cast_scale.ShareDataWith(*scale); + } + + // cast bias from LayerNormParamType to T if needed + Tensor cast_bias(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + bias->type() == framework::proto::VarType::FP32) { + cast_bias.Resize(bias->dims()); + cast_bias.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(x->type()); + auto runner_cast_bias = + NpuOpRunner("Cast", {*bias}, {cast_bias}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_bias.Run(stream); + } else { + cast_bias.ShareDataWith(*bias); + } + y->mutable_data(ctx.GetPlace()); - mean->mutable_data(ctx.GetPlace()); - variance->mutable_data(ctx.GetPlace()); - - auto runner = - NpuOpRunner("LayerNorm", {*x, *scale, *bias}, {*y, *mean, *variance}, - {{"begin_norm_axis", begin_norm_axis}, - {"begin_params_axis", begin_norm_axis}, - {"epsilon", epsilon}}); + + // mean should be of U type + Tensor* tmp_mean = mean; + Tensor cast_mean(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + (scale->type() == framework::proto::VarType::FP32 || + bias->type() == framework::proto::VarType::FP32)) { + cast_mean.Resize(mean->dims()); + cast_mean.mutable_data(ctx.GetPlace()); + tmp_mean = &cast_mean; + mean->mutable_data(ctx.GetPlace()); + } else { + mean->mutable_data(ctx.GetPlace()); + } + + // same for variance + Tensor* tmp_variance = variance; + Tensor cast_variance(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + (scale->type() == framework::proto::VarType::FP32 || + bias->type() == framework::proto::VarType::FP32)) { + cast_variance.Resize(variance->dims()); + cast_variance.mutable_data(ctx.GetPlace()); + tmp_variance = &cast_variance; + variance->mutable_data(ctx.GetPlace()); + } else { + variance->mutable_data(ctx.GetPlace()); + } + + auto runner = NpuOpRunner("LayerNorm", {*x, cast_scale, cast_bias}, + {*y, *tmp_mean, *tmp_variance}, + {{"begin_norm_axis", begin_norm_axis}, + {"begin_params_axis", begin_norm_axis}, + {"epsilon", epsilon}}); runner.Run(stream); + + // cast back from FP16 to FP32 + if (x->type() == framework::proto::VarType::FP16 && + mean->type() == framework::proto::VarType::FP32) { + auto dst_dtype = ConvertToNpuDtype(mean->type()); + auto runner_cast_mean = + NpuOpRunner("Cast", {*tmp_mean}, {*mean}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_mean.Run(stream); + } + // same for variance + if (x->type() == framework::proto::VarType::FP16 && + variance->type() == framework::proto::VarType::FP32) { + auto dst_dtype = ConvertToNpuDtype(variance->type()); + auto runner_cast_variance = + NpuOpRunner("Cast", {*tmp_variance}, {*variance}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_variance.Run(stream); + } + // revert shape of scale and bias // TODO(zhiqiu): better implementation, use tmp tensor to avoid write input // tensor. @@ -99,6 +203,7 @@ template class LayerNormGradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + using U = LayerNormParamType; const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto* x = ctx.Input("X"); const auto& x_dims = x->dims(); @@ -156,25 +261,115 @@ class LayerNormGradNPUKernel : public framework::OpKernel { const_cast(scale)->Resize(framework::make_ddim(axes)); } + // cast scale from LayerNormParamType to T if needed + Tensor cast_scale(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + scale->type() == framework::proto::VarType::FP32) { + cast_scale.Resize(scale->dims()); + cast_scale.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(x->type()); + auto runner_cast_scale = + NpuOpRunner("Cast", {*scale}, {cast_scale}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_scale.Run(stream); + } else { + cast_scale.ShareDataWith(*scale); + } + + // cast mean from LayerNormParamType to T if needed + Tensor cast_mean(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + mean->type() == framework::proto::VarType::FP32) { + cast_mean.Resize(mean->dims()); + cast_mean.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(x->type()); + auto runner_cast_mean = + NpuOpRunner("Cast", {*mean}, {cast_mean}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_mean.Run(stream); + } else { + cast_mean.ShareDataWith(*mean); + } + + // cast variance from LayerNormParamType to T if needed + Tensor cast_variance(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + variance->type() == framework::proto::VarType::FP32) { + cast_variance.Resize(variance->dims()); + cast_variance.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(x->type()); + auto runner_cast_variance = + NpuOpRunner("Cast", {*variance}, {cast_variance}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_variance.Run(stream); + } else { + cast_variance.ShareDataWith(*variance); + } + Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type()); dx = (dx == nullptr) ? &dx_ : dx; dscale = (dscale == nullptr) ? &dscale_ : dscale; dbias = (dbias == nullptr) ? &dbias_ : dbias; + dx->Resize(x->dims()); + dx->mutable_data(ctx.GetPlace()); + dscale->Resize(framework::make_ddim(axes)); - dscale->mutable_data(ctx.GetPlace()); dbias->Resize(framework::make_ddim(axes)); - dbias->mutable_data(ctx.GetPlace()); - dx->Resize(x->dims()); - dx->mutable_data(ctx.GetPlace()); + // dscale should be of U type + Tensor* tmp_dscale = dscale; + Tensor cast_dscale(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + (mean->type() == framework::proto::VarType::FP32 || + variance->type() == framework::proto::VarType::FP32)) { + cast_dscale.Resize(dscale->dims()); + cast_dscale.mutable_data(ctx.GetPlace()); + tmp_dscale = &cast_dscale; + dscale->mutable_data(ctx.GetPlace()); + } else { + dscale->mutable_data(ctx.GetPlace()); + } - auto runner = - NpuOpRunner("LayerNormGrad", {*dy, *x, *variance, *mean, *scale}, - {*dx, *dscale, *dbias}, {}); + // same for dbias + Tensor* tmp_dbias = dbias; + Tensor cast_dbias(x->type()); + if (x->type() == framework::proto::VarType::FP16 && + (mean->type() == framework::proto::VarType::FP32 || + variance->type() == framework::proto::VarType::FP32)) { + cast_dbias.Resize(dbias->dims()); + cast_dbias.mutable_data(ctx.GetPlace()); + tmp_dbias = &cast_dbias; + dbias->mutable_data(ctx.GetPlace()); + } else { + dbias->mutable_data(ctx.GetPlace()); + } + + auto runner = NpuOpRunner("LayerNormGrad", + {*dy, *x, cast_variance, cast_mean, cast_scale}, + {*dx, *tmp_dscale, *tmp_dbias}, {}); runner.Run(stream); + // cast back from FP16 to FP32 + if (x->type() == framework::proto::VarType::FP16 && + dscale->type() == framework::proto::VarType::FP32) { + auto dst_dtype = ConvertToNpuDtype(dscale->type()); + auto runner_cast_dscale = + NpuOpRunner("Cast", {*tmp_dscale}, {*dscale}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_dscale.Run(stream); + } + // same for dbias + if (x->type() == framework::proto::VarType::FP16 && + dbias->type() == framework::proto::VarType::FP32) { + auto dst_dtype = ConvertToNpuDtype(dbias->type()); + auto runner_cast_dbias = + NpuOpRunner("Cast", {*tmp_dbias}, {*dbias}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_dbias.Run(stream); + } + const_cast(mean)->Resize(mean_dims); const_cast(variance)->Resize(mean_dims); const_cast(scale)->Resize(framework::make_ddim({right})); diff --git a/python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py index 243f1e25e7877..d447dfb8d4d03 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py @@ -50,9 +50,13 @@ def set_npu(self): def init_dtype(self): self.dtype = np.float32 + self.atol = 1e-4 def __assert_close(self, tensor, np_array, msg, atol=1e-4): - self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + self.assertTrue( + np.allclose( + np.array(tensor).astype(np_array.dtype), np_array, atol=atol), + msg) def check_forward_backward(self, shape, @@ -72,13 +76,13 @@ def test_with_place(place, scale_shape = [D] np.random.seed(123) - x = np.random.random_sample(x_shape).astype(np.float32) + x = np.random.random_sample(x_shape).astype(self.dtype) scale = np.random.random_sample(scale_shape).astype( np.float32) if has_scale else None bias = np.random.random_sample(scale_shape).astype( np.float32) if has_bias else None y_grad = (np.random.random_sample(x_shape) * - y_grad_scale).astype(np.float32) + y_grad_scale).astype(self.dtype) # reference forward & backward y, mean, variance = _reference_layer_norm_naive( @@ -101,7 +105,7 @@ def test_with_place(place, for name in ground_truth: block.create_var( name=name, - dtype='float32', + dtype=self.dtype, shape=ground_truth[name].shape) inputs = {"X": block.var('x')} fetch_list = [ @@ -152,18 +156,18 @@ def test_with_place(place, for name in ['x', 'scale', 'bias', 'y@GRAD'] }, fetch_list=fetch_list) - self.__assert_close(y, out[0], "y") + self.__assert_close(y, out[0], "y", self.atol) self.__assert_close(mean, out[1], "mean") self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(x_grad, out[3], "x_grad", 1e-2) if has_scale: self.__assert_close(scale_grad, out[fetch_list.index('scale@GRAD')], - "scale_grad", 1e-3) + "scale_grad", 1e-2) if has_bias: self.__assert_close(bias_grad, out[fetch_list.index('bias@GRAD')], - "bias_grad") + "bias_grad", self.atol) test_with_place(self.place, shape, begin_norm_axis) @@ -187,5 +191,13 @@ def test_check_forward_backward_with_scale_and_bias(self): self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLayerNormOpFP16(TestLayerNormOp): + def init_dtype(self): + self.dtype = np.float16 + self.atol = 1e-2 + + if __name__ == '__main__': unittest.main()