Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added nearest interp v2 BF16 FWD kernel #39490

Merged
merged 15 commits into from
Feb 24, 2022
2 changes: 1 addition & 1 deletion paddle/fluid/operators/interpolate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class InterpolateOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/interpolate_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
Expand Down
34 changes: 14 additions & 20 deletions paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int> ComputeOutputShape(
const framework::ExecutionContext& ctx) const {
const auto* x = ctx.Input<Tensor>("X");
auto in_dims = x->dims();
const bool is_channel_last = false; // In mkldnn kernel, always use NCHW

framework::DDim in_dhw_dims;
if (is_channel_last) { // NDHWC, NHWC, NWC
in_dhw_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { // NCDHW, NCHW, NCW
in_dhw_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
const auto& in_dims = x->dims();

const framework::DDim in_dhw_dims =
phi::slice_ddim(in_dims, 2, in_dims.size());

std::vector<int> out_dims;
out_dims.reserve(5);
if (in_dhw_dims.size() == 1) {
out_dims.push_back(ctx.Attr<int>("out_w"));
} else if (in_dhw_dims.size() == 2) {
Expand Down Expand Up @@ -125,12 +121,8 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
"out_d, out_h, out_w of Op(interpolate) "
"should be greater than 0."));

out_dims.insert(out_dims.begin(), in_dims[0]);
if (is_channel_last) {
out_dims.push_back(in_dims[in_dims.size() - 1]);
} else {
out_dims.insert(out_dims.begin() + 1, in_dims[1]);
}
const std::vector<int64_t> nc_dims = {in_dims[0], in_dims[1]};
out_dims.insert(out_dims.begin(), nc_dims.begin(), nc_dims.end());
return out_dims;
}

Expand All @@ -143,12 +135,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto* x = ctx.Input<Tensor>("X");
auto* z = ctx.Output<Tensor>("Out");

auto interp_method = ctx.Attr<std::string>("interp_method");
dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear;
const auto interp_method = ctx.Attr<std::string>("interp_method");
const dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear;

auto out_dims_vec = ComputeOutputShape(ctx);
const auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = phi::make_ddim(out_dims_vec);
z->Resize(dim_out);

Expand All @@ -162,6 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

resampling_prim->execute(astream, args);
astream.wait();

Expand All @@ -184,6 +177,7 @@ REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,

REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<paddle::platform::bfloat16>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci


Expand Down Expand Up @@ -59,6 +59,7 @@ def nearest_neighbor_interp_mkldnn_np(X,


@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
@OpTestTool.skip_if_not_cpu_bf16()
class TestNearestInterpV2MKLDNNOp(OpTest):
def init_test_case(self):
pass
Expand All @@ -84,7 +85,7 @@ def setUp(self):
self.init_test_case()
self.init_data_type()

if self.dtype == np.float32:
if self.dtype == np.float32 or self.dtype == np.uint16:
input_np = np.random.random(self.input_shape).astype(self.dtype)
else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
Expand Down Expand Up @@ -126,6 +127,9 @@ def setUp(self):
if isinstance(self.scale, float):
self.scale = [self.scale]

if self.dtype == np.uint16:
input_np = convert_float_to_uint16(input_np)

self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
Expand Down Expand Up @@ -191,6 +195,10 @@ class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32

class TestBf16Case(parent):
def init_data_type(self):
self.dtype = np.uint16

class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8
Expand All @@ -199,12 +207,14 @@ class TestUint8Case(parent):
def init_data_type(self):
self.dtype = np.uint8

TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case
TestFp32Case.__name__ = "{0}_{1}".format(parent.__name__, "FP32")
TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[TestFp32Case.__name__] = TestFp32Case
globals()[TestBf16Case.__name__] = TestBf16Case
globals()[TestInt8Case.__name__] = TestInt8Case
globals()[TestUint8Case.__name__] = TestUint8Case


create_test_class(TestNearestInterpV2MKLDNNOp)
Expand Down