Skip to content

Commit

Permalink
Added nearest interp v2 BF16 FWD kernel (#39490)
Browse files Browse the repository at this point in the history
* added nearest interp v2 bf16

* disabled bilinear interp nhwc test

* added skipping UT for gpu

* added NHWC support

* removed unnecessary statements

* minor change

* CI fix

* added appropriate changes to interpolate_v1

* fix after review

* minor change

* minor change

* revert unwanted deletions

* CI fix
  • Loading branch information
jakpiase authored Feb 24, 2022
1 parent 1abfc8d commit 2ec943a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 30 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/interpolate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class InterpolateOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/interpolate_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
Expand Down
34 changes: 14 additions & 20 deletions paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int> ComputeOutputShape(
const framework::ExecutionContext& ctx) const {
const auto* x = ctx.Input<Tensor>("X");
auto in_dims = x->dims();
const bool is_channel_last = false; // In mkldnn kernel, always use NCHW

framework::DDim in_dhw_dims;
if (is_channel_last) { // NDHWC, NHWC, NWC
in_dhw_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { // NCDHW, NCHW, NCW
in_dhw_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
const auto& in_dims = x->dims();

const framework::DDim in_dhw_dims =
phi::slice_ddim(in_dims, 2, in_dims.size());

std::vector<int> out_dims;
out_dims.reserve(5);
if (in_dhw_dims.size() == 1) {
out_dims.push_back(ctx.Attr<int>("out_w"));
} else if (in_dhw_dims.size() == 2) {
Expand Down Expand Up @@ -125,12 +121,8 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
"out_d, out_h, out_w of Op(interpolate) "
"should be greater than 0."));

out_dims.insert(out_dims.begin(), in_dims[0]);
if (is_channel_last) {
out_dims.push_back(in_dims[in_dims.size() - 1]);
} else {
out_dims.insert(out_dims.begin() + 1, in_dims[1]);
}
const std::vector<int64_t> nc_dims = {in_dims[0], in_dims[1]};
out_dims.insert(out_dims.begin(), nc_dims.begin(), nc_dims.end());
return out_dims;
}

Expand All @@ -143,12 +135,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto* x = ctx.Input<Tensor>("X");
auto* z = ctx.Output<Tensor>("Out");

auto interp_method = ctx.Attr<std::string>("interp_method");
dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear;
const auto interp_method = ctx.Attr<std::string>("interp_method");
const dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear;

auto out_dims_vec = ComputeOutputShape(ctx);
const auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = phi::make_ddim(out_dims_vec);
z->Resize(dim_out);

Expand All @@ -162,6 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

resampling_prim->execute(astream, args);
astream.wait();

Expand All @@ -184,6 +177,7 @@ REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,

REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<paddle::platform::bfloat16>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci


Expand Down Expand Up @@ -59,6 +59,7 @@ def nearest_neighbor_interp_mkldnn_np(X,


@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
@OpTestTool.skip_if_not_cpu_bf16()
class TestNearestInterpV2MKLDNNOp(OpTest):
def init_test_case(self):
pass
Expand All @@ -84,7 +85,7 @@ def setUp(self):
self.init_test_case()
self.init_data_type()

if self.dtype == np.float32:
if self.dtype == np.float32 or self.dtype == np.uint16:
input_np = np.random.random(self.input_shape).astype(self.dtype)
else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
Expand Down Expand Up @@ -126,6 +127,9 @@ def setUp(self):
if isinstance(self.scale, float):
self.scale = [self.scale]

if self.dtype == np.uint16:
input_np = convert_float_to_uint16(input_np)

self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
Expand Down Expand Up @@ -191,6 +195,10 @@ class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32

class TestBf16Case(parent):
def init_data_type(self):
self.dtype = np.uint16

class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8
Expand All @@ -199,12 +207,14 @@ class TestUint8Case(parent):
def init_data_type(self):
self.dtype = np.uint8

TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case
TestFp32Case.__name__ = "{0}_{1}".format(parent.__name__, "FP32")
TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[TestFp32Case.__name__] = TestFp32Case
globals()[TestBf16Case.__name__] = TestBf16Case
globals()[TestInt8Case.__name__] = TestInt8Case
globals()[TestUint8Case.__name__] = TestUint8Case


create_test_class(TestNearestInterpV2MKLDNNOp)
Expand Down

0 comments on commit 2ec943a

Please sign in to comment.