Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/tvm/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ def dnnl_conv2d(
else:
dilation_h, dilation_w = dilation

pre_cast = False
post_cast = False
if src.dtype == "float32":
pre_cast = True
elif src.dtype == "bfloat16":
pre_cast = False
if out_dtype == "float32":
post_cast = True
elif out_dtype == "bfloat16":
post_cast = False

if channel_last:
batch, in_height, in_width, _ = src.shape
kernel_h, kernel_w, _, num_filter = weights.shape
Expand Down Expand Up @@ -150,6 +161,8 @@ def dnnl_conv2d(
stride[1],
groups,
channel_last,
pre_cast,
post_cast,
),
name="C",
dtype=out_dtype,
Expand Down
51 changes: 32 additions & 19 deletions src/runtime/contrib/dnnl/dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {
void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_,
int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr,
bool channel_last) {
bool channel_last, bool pre_cast, bool post_cast) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
Expand All @@ -98,20 +98,31 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in
memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_};

auto user_src_memory =
memory({{conv2d_src_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, data);
auto user_weights_memory =
memory({{conv2d_weights_tz}, dt::f32, channel_last ? tag::hwio : tag::oihw}, eng, weights);
memory({{conv2d_src_tz}, pre_cast ? dt::f32 : dt::bf16, channel_last ? tag::nhwc : tag::nchw},
eng, data);
auto user_weights_memory = memory({{conv2d_weights_tz},
(pre_cast && post_cast) ? dt::f32 : dt::bf16,
channel_last ? tag::hwio : tag::oihw},
eng, weights);
if (p_G_ > 1)
user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, channel_last ? tag::ghwio : tag::goihw}, eng, weights);
auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
auto user_dst_memory =
memory({{conv2d_dst_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, out);

auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any);
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::any);
user_weights_memory = memory({{conv2d_weights_tz},
(pre_cast && post_cast) ? dt::f32 : dt::bf16,
channel_last ? tag::ghwio : tag::goihw},
eng, weights);
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::x}, eng, bias);
auto user_dst_memory = memory(
{{conv2d_dst_tz}, post_cast ? dt::f32 : dt::bf16, channel_last ? tag::nhwc : tag::nchw}, eng,
out);

auto conv2d_src_md =
memory::desc({conv2d_src_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);
auto conv2d_bias_md =
memory::desc({conv2d_bias_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);
auto conv2d_weights_md =
memory::desc({conv2d_weights_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);
auto conv2d_dst_md =
memory::desc({conv2d_dst_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);

auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
Expand Down Expand Up @@ -161,8 +172,8 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, i
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr,
false);
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, false,
true, true);
}

primitive_attr create_attr_with_relu_post_op() {
Expand All @@ -182,7 +193,7 @@ extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out,
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op(), false);
create_attr_with_relu_post_op(), false, true, true);
}

extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
Expand All @@ -192,7 +203,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float*
int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_,
p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op(), false);
create_attr_with_relu_post_op(), false, true, true);
}

extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) {
Expand Down Expand Up @@ -345,6 +356,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body([](TVMArgs args, TVMRetV
int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6], p_Sh_ = args[7],
p_Sw_ = args[8], p_G_ = args[9];
bool channel_last = args[10];
bool pre_cast = args[11];
bool post_cast = args[12];

int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2],
p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2],
Expand All @@ -365,7 +378,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body([](TVMArgs args, TVMRetV
return dnnl_conv2d_common(static_cast<float*>(input->data), static_cast<float*>(weights->data),
bias.data(), static_cast<float*>(output->data), p_N_, p_C_, p_H_, p_W_,
p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
attr, channel_last);
attr, channel_last, pre_cast, post_cast);
});

} // namespace contrib
Expand Down
140 changes: 88 additions & 52 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,22 @@ def test_conv2d_rocm_sdot4():
np.testing.assert_equal(out, ref)


def np_float2tvm_bf16(arr):
"""Convert a numpy array of float to a TVM array
of bf16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
nparr = np.right_shift(orig + bias, 16).astype("uint16")
return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr)


def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")


@tvm.testing.requires_x86
def test_conv2d_nchw_dnnl():
if not tvm.get_global_func("tvm.contrib.dnnl.conv2d", allow_missing=True):
Expand All @@ -2016,39 +2032,49 @@ def test_conv2d_nchw_dnnl():
padding = (1, 1)
strides = (1, 1)

data = relay.var("data", shape=d_shape, dtype="float32")
weight = relay.var("weight", shape=w_shape, dtype="float32")
out_channel = w_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype="float32",
)
def get_subgraph(dtype):
data = relay.var("data", shape=d_shape, dtype=dtype)
weight = relay.var("weight", shape=w_shape, dtype=dtype)
out_channel = w_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=dtype,
)
return conv2d

mod = tvm.IRModule.from_expr(conv2d)
for t in ["float32", "bfloat16"]:
mod = tvm.IRModule.from_expr(get_subgraph(t))

data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding)

target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})
if t == "bfloat16":
data_np = np_float2tvm_bf16(data_np)
weight_np = np_float2tvm_bf16(weight_np)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})

runtime.set_input("data", data_np)
runtime.run()
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

out = runtime.get_output(0).numpy()
runtime.set_input("data", data_np)
runtime.run()

ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding)
out = runtime.get_output(0).numpy()

np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
if t == "bfloat16":
out = np_bf162np_float(out)
np.testing.assert_allclose(out, ref, rtol=1e-2)
else:
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


@tvm.testing.requires_x86
Expand All @@ -2064,41 +2090,51 @@ def test_conv2d_nhwc_dnnl():
padding = (1, 1)
strides = (1, 1)

data = relay.var("data", shape=d_shape, dtype="float32")
weight = relay.var("weight", shape=w_shape, dtype="float32")
out_channel = w_shape[3]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[:2],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype="float32",
data_layout="NHWC",
kernel_layout="HWIO",
)
def get_subgraph(dtype):
data = relay.var("data", shape=d_shape, dtype=dtype)
weight = relay.var("weight", shape=w_shape, dtype=dtype)
out_channel = w_shape[3]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[:2],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=dtype,
data_layout="NHWC",
kernel_layout="HWIO",
)
return conv2d

mod = tvm.IRModule.from_expr(conv2d)
for t in ["float32", "bfloat16"]:
mod = tvm.IRModule.from_expr(get_subgraph(t))

data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding)

target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})
if t == "bfloat16":
data_np = np_float2tvm_bf16(data_np)
weight_np = np_float2tvm_bf16(weight_np)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})

runtime.set_input("data", data_np)
runtime.run()
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

out = runtime.get_output(0).numpy()
runtime.set_input("data", data_np)
runtime.run()

ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding)
out = runtime.get_output(0).numpy()

np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
if t == "bfloat16":
out = np_bf162np_float(out)
np.testing.assert_allclose(out, ref, rtol=1e-2)
else:
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
Expand Down