Skip to content

Commit

Permalink
[NPU] support mixed precision input for npu layer norm (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…31847)

* support mixed precision input for npu layer norm

* fix layer_norm npu kernel

Co-authored-by: zhiqiu <[email protected]>
  • Loading branch information
2 people authored and frankwhzhang committed Apr 12, 2021
1 parent 5bcfa8a commit 25aa56b
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 22 deletions.
225 changes: 210 additions & 15 deletions paddle/fluid/operators/layer_norm_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,36 @@ namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;

using DataLayout = framework::DataLayout;

template <typename T>
class NormDataType;

template <>
class NormDataType<platform::float16> {
public:
// The scaling param type is float for HALF and FLOAT tensors
using ScalingParamType = const float;
using BatchNormParamType = float;
};

template <>
class NormDataType<float> {
public:
using ScalingParamType = const float;
using BatchNormParamType = float;
};

template <typename T>
using NormDataType = NormDataType<T>;
template <typename T>
using LayerNormParamType = typename NormDataType<T>::BatchNormParamType;

template <typename T>
class LayerNormNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using U = LayerNormParamType<T>;
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
Expand All @@ -43,6 +69,7 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
axes.push_back(x_dims[i]);
}

auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
Expand Down Expand Up @@ -77,16 +104,93 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
} else {
const_cast<Tensor*>(bias)->Resize(framework::make_ddim(axes));
}

// cast scale from LayerNormParamType to T if needed
Tensor cast_scale(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
scale->type() == framework::proto::VarType::FP32) {
cast_scale.Resize(scale->dims());
cast_scale.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_scale =
NpuOpRunner("Cast", {*scale}, {cast_scale},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_scale.Run(stream);
} else {
cast_scale.ShareDataWith(*scale);
}

// cast bias from LayerNormParamType to T if needed
Tensor cast_bias(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
bias->type() == framework::proto::VarType::FP32) {
cast_bias.Resize(bias->dims());
cast_bias.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_bias =
NpuOpRunner("Cast", {*bias}, {cast_bias},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_bias.Run(stream);
} else {
cast_bias.ShareDataWith(*bias);
}

y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
variance->mutable_data<T>(ctx.GetPlace());

auto runner =
NpuOpRunner("LayerNorm", {*x, *scale, *bias}, {*y, *mean, *variance},
{{"begin_norm_axis", begin_norm_axis},
{"begin_params_axis", begin_norm_axis},
{"epsilon", epsilon}});

// mean should be of U type
Tensor* tmp_mean = mean;
Tensor cast_mean(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
(scale->type() == framework::proto::VarType::FP32 ||
bias->type() == framework::proto::VarType::FP32)) {
cast_mean.Resize(mean->dims());
cast_mean.mutable_data<T>(ctx.GetPlace());
tmp_mean = &cast_mean;
mean->mutable_data<U>(ctx.GetPlace());
} else {
mean->mutable_data<T>(ctx.GetPlace());
}

// same for variance
Tensor* tmp_variance = variance;
Tensor cast_variance(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
(scale->type() == framework::proto::VarType::FP32 ||
bias->type() == framework::proto::VarType::FP32)) {
cast_variance.Resize(variance->dims());
cast_variance.mutable_data<T>(ctx.GetPlace());
tmp_variance = &cast_variance;
variance->mutable_data<U>(ctx.GetPlace());
} else {
variance->mutable_data<T>(ctx.GetPlace());
}

auto runner = NpuOpRunner("LayerNorm", {*x, cast_scale, cast_bias},
{*y, *tmp_mean, *tmp_variance},
{{"begin_norm_axis", begin_norm_axis},
{"begin_params_axis", begin_norm_axis},
{"epsilon", epsilon}});
runner.Run(stream);

// cast back from FP16 to FP32
if (x->type() == framework::proto::VarType::FP16 &&
mean->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(mean->type());
auto runner_cast_mean =
NpuOpRunner("Cast", {*tmp_mean}, {*mean},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_mean.Run(stream);
}
// same for variance
if (x->type() == framework::proto::VarType::FP16 &&
variance->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(variance->type());
auto runner_cast_variance =
NpuOpRunner("Cast", {*tmp_variance}, {*variance},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_variance.Run(stream);
}

// revert shape of scale and bias
// TODO(zhiqiu): better implementation, use tmp tensor to avoid write input
// tensor.
Expand All @@ -99,6 +203,7 @@ template <typename T>
class LayerNormGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using U = LayerNormParamType<T>;
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
Expand Down Expand Up @@ -156,25 +261,115 @@ class LayerNormGradNPUKernel : public framework::OpKernel<T> {
const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
}

// cast scale from LayerNormParamType to T if needed
Tensor cast_scale(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
scale->type() == framework::proto::VarType::FP32) {
cast_scale.Resize(scale->dims());
cast_scale.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_scale =
NpuOpRunner("Cast", {*scale}, {cast_scale},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_scale.Run(stream);
} else {
cast_scale.ShareDataWith(*scale);
}

// cast mean from LayerNormParamType to T if needed
Tensor cast_mean(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
mean->type() == framework::proto::VarType::FP32) {
cast_mean.Resize(mean->dims());
cast_mean.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_mean =
NpuOpRunner("Cast", {*mean}, {cast_mean},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_mean.Run(stream);
} else {
cast_mean.ShareDataWith(*mean);
}

// cast variance from LayerNormParamType to T if needed
Tensor cast_variance(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
variance->type() == framework::proto::VarType::FP32) {
cast_variance.Resize(variance->dims());
cast_variance.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_variance =
NpuOpRunner("Cast", {*variance}, {cast_variance},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_variance.Run(stream);
} else {
cast_variance.ShareDataWith(*variance);
}

Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type());
dx = (dx == nullptr) ? &dx_ : dx;
dscale = (dscale == nullptr) ? &dscale_ : dscale;
dbias = (dbias == nullptr) ? &dbias_ : dbias;

dx->Resize(x->dims());
dx->mutable_data<T>(ctx.GetPlace());

dscale->Resize(framework::make_ddim(axes));
dscale->mutable_data<T>(ctx.GetPlace());

dbias->Resize(framework::make_ddim(axes));
dbias->mutable_data<T>(ctx.GetPlace());

dx->Resize(x->dims());
dx->mutable_data<T>(ctx.GetPlace());
// dscale should be of U type
Tensor* tmp_dscale = dscale;
Tensor cast_dscale(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
(mean->type() == framework::proto::VarType::FP32 ||
variance->type() == framework::proto::VarType::FP32)) {
cast_dscale.Resize(dscale->dims());
cast_dscale.mutable_data<T>(ctx.GetPlace());
tmp_dscale = &cast_dscale;
dscale->mutable_data<U>(ctx.GetPlace());
} else {
dscale->mutable_data<T>(ctx.GetPlace());
}

auto runner =
NpuOpRunner("LayerNormGrad", {*dy, *x, *variance, *mean, *scale},
{*dx, *dscale, *dbias}, {});
// same for dbias
Tensor* tmp_dbias = dbias;
Tensor cast_dbias(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
(mean->type() == framework::proto::VarType::FP32 ||
variance->type() == framework::proto::VarType::FP32)) {
cast_dbias.Resize(dbias->dims());
cast_dbias.mutable_data<T>(ctx.GetPlace());
tmp_dbias = &cast_dbias;
dbias->mutable_data<U>(ctx.GetPlace());
} else {
dbias->mutable_data<T>(ctx.GetPlace());
}

auto runner = NpuOpRunner("LayerNormGrad",
{*dy, *x, cast_variance, cast_mean, cast_scale},
{*dx, *tmp_dscale, *tmp_dbias}, {});
runner.Run(stream);

// cast back from FP16 to FP32
if (x->type() == framework::proto::VarType::FP16 &&
dscale->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(dscale->type());
auto runner_cast_dscale =
NpuOpRunner("Cast", {*tmp_dscale}, {*dscale},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_dscale.Run(stream);
}
// same for dbias
if (x->type() == framework::proto::VarType::FP16 &&
dbias->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(dbias->type());
auto runner_cast_dbias =
NpuOpRunner("Cast", {*tmp_dbias}, {*dbias},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_dbias.Run(stream);
}

const_cast<Tensor*>(mean)->Resize(mean_dims);
const_cast<Tensor*>(variance)->Resize(mean_dims);
const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
Expand Down
26 changes: 19 additions & 7 deletions python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def set_npu(self):

def init_dtype(self):
self.dtype = np.float32
self.atol = 1e-4

def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
self.assertTrue(
np.allclose(
np.array(tensor).astype(np_array.dtype), np_array, atol=atol),
msg)

def check_forward_backward(self,
shape,
Expand All @@ -72,13 +76,13 @@ def test_with_place(place,
scale_shape = [D]

np.random.seed(123)
x = np.random.random_sample(x_shape).astype(np.float32)
x = np.random.random_sample(x_shape).astype(self.dtype)
scale = np.random.random_sample(scale_shape).astype(
np.float32) if has_scale else None
bias = np.random.random_sample(scale_shape).astype(
np.float32) if has_bias else None
y_grad = (np.random.random_sample(x_shape) *
y_grad_scale).astype(np.float32)
y_grad_scale).astype(self.dtype)

# reference forward & backward
y, mean, variance = _reference_layer_norm_naive(
Expand All @@ -101,7 +105,7 @@ def test_with_place(place,
for name in ground_truth:
block.create_var(
name=name,
dtype='float32',
dtype=self.dtype,
shape=ground_truth[name].shape)
inputs = {"X": block.var('x')}
fetch_list = [
Expand Down Expand Up @@ -152,18 +156,18 @@ def test_with_place(place,
for name in ['x', 'scale', 'bias', 'y@GRAD']
},
fetch_list=fetch_list)
self.__assert_close(y, out[0], "y")
self.__assert_close(y, out[0], "y", self.atol)
self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3)
self.__assert_close(x_grad, out[3], "x_grad", 1e-2)
if has_scale:
self.__assert_close(scale_grad,
out[fetch_list.index('scale@GRAD')],
"scale_grad", 1e-3)
"scale_grad", 1e-2)
if has_bias:
self.__assert_close(bias_grad,
out[fetch_list.index('bias@GRAD')],
"bias_grad")
"bias_grad", self.atol)

test_with_place(self.place, shape, begin_norm_axis)

Expand All @@ -187,5 +191,13 @@ def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)


@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLayerNormOpFP16(TestLayerNormOp):
def init_dtype(self):
self.dtype = np.float16
self.atol = 1e-2


if __name__ == '__main__':
unittest.main()

0 comments on commit 25aa56b

Please sign in to comment.