Skip to content

Commit

Permalink
Added BF16 to mean op (#37104)
Browse files Browse the repository at this point in the history
* Added BF16 to mean op

* fix for CI

* fix for CI

* fix for CI
  • Loading branch information
arlesniak authored Nov 15, 2021
1 parent 83eef6d commit df7cc45
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/operators/mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>);
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
3 changes: 2 additions & 1 deletion paddle/pten/kernels/cpu/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ using complex128 = ::paddle::platform::complex<double>;
// using bfloat16 = ::paddle::platform::bfloat16;

PT_REGISTER_KERNEL("sign", CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL("mean", CPU, ANY, pten::Mean, float, double) {}
PT_REGISTER_KERNEL(
"mean", CPU, ANY, pten::Mean, float, double, paddle::platform::bfloat16) {}
PT_REGISTER_KERNEL("scale",
CPU,
ANY,
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/fluid/tests/book/test_fit_a_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@
import math
import sys
import os
import struct

paddle.enable_static()


def convert_uint16_to_float(in_list):
in_list = numpy.asarray(in_list)
out = numpy.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[numpy.float32])(in_list.flat)
return numpy.reshape(out, in_list.shape)


def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
Expand Down Expand Up @@ -84,6 +93,8 @@ def train_loop(main_program):
avg_loss_value, = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_cost])
if avg_loss_value.dtype == numpy.uint16:
avg_loss_value = convert_uint16_to_float(avg_loss_value)
if avg_loss_value[0] < 10.0:
if save_dirname is not None:
paddle.static.save_inference_model(
Expand Down Expand Up @@ -154,6 +165,8 @@ def infer(use_cuda, save_dirname=None, use_bf16=False):
results = exe.run(inference_program,
feed={feed_target_names[0]: numpy.array(test_feat)},
fetch_list=fetch_targets)
if results[0].dtype == numpy.uint16:
results[0] = convert_uint16_to_float(results[0])
print("infer shape: ", results[0].shape)
print("infer results: ", results[0])
print("ground truth: ", test_label)
Expand Down
16 changes: 15 additions & 1 deletion python/paddle/fluid/tests/unittests/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, OpTestTool
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
Expand Down Expand Up @@ -76,6 +76,20 @@ def test_checkout_grad(self):
place, ['X'], 'Out', max_relative_error=0.8)


@OpTestTool.skip_if_not_cpu_bf16()
class TestBF16MeanOp(TestMeanOp):
def init_dtype_type(self):
self.dtype = np.uint16

def test_check_output(self):
paddle.enable_static()
self.check_output_with_place(core.CPUPlace())

def test_checkout_grad(self):
place = core.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out')


def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
if isinstance(axis, list):
axis = tuple(axis)
Expand Down

0 comments on commit df7cc45

Please sign in to comment.