Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][x86] Legalize - Support int8xint8 convolution to use VNNI inst #4196

Merged
merged 1 commit into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,11 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):

n, h, w, ch, cw = 1, 64, 64, 3, 3
if data_layout == 'NCHW':
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
data_shape = (n, ic, h, w)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
elif data_layout == 'NHWC':
x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype))
data_shape = (n, h, w, ic)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
else:
raise ValueError('Not supported')

Expand All @@ -559,20 +561,22 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
else:
raise ValueError('Not supported')

w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, w,
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, weight,
kernel_size=(ch, cw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=output_dtype)
func = relay.Function([x, w], y)
func = relay.Function([x, weight], y)
wdata = np.random.rand(*kernel_shape) * 10
parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}

with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)

assembly = lib.get_source("asm")
return assembly

Expand All @@ -589,58 +593,63 @@ def _has_fast_int8_instructions(asm, target):
llvm_version = tvm.codegen.llvm_version_major()
for target in targets:
if llvm_version >= 8:
fast_int8_dtypes = ('uint8', 'int8', 'int32')
dtypes = ('uint8', 'int8', 'int32')
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW",
asm = _compile(ic=ic, oc=16, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC",
asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)


# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW",
asm = _compile(ic=8, oc=oc, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC",
asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('int8', 'int8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
# Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
for target in targets:
if llvm_version >= 8:
dtypes = (('int8', 'int8', 'int32'))
# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)
assert _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)
assert _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)

# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
Expand Down
72 changes: 60 additions & 12 deletions topi/python/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,72 @@ def _conv2d_legalize(attrs, inputs, arg_types):
The legalized expr
"""

# Dilation not supported yet. Return None if dilation is not (1, 1)
dilation = attrs.get_int_tuple("dilation")
if not (dilation[0] == 1 and dilation[1] == 1):
return None

# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype
kernel_dtype = kernel_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
data, kernel = inputs

# Get the conv attrs
new_attrs = {k: attrs[k] for k in attrs.keys()}

is_int8_inputs = False
# If both the inputs are int8, we can add 128 to make the input dtype uint8, and then adjust the
# output. This will help picking up Intel VNNI instructions.
# Original --> C = A (conv) B
# A and B are int8
# C = (A + 128 - 128) (conv) B
# C = (A' conv B) - 128 (conv) B
# where A' = A + 128
# and 128 (conv) B is basically a reduce on CRS axis for weights.
if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8':
is_int8_inputs = True
padding = attrs.get_int_tuple("padding")

if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2))
pad_width = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0))
elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1]))
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3))
adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2)
else:
return None

data = relay.cast(data, 'int32')
data = relay.add(data, relay.const(128, 'int32'))
data = relay.cast(data, 'uint8')

# Do external padding as pad value has to be 128.
if not (padding[0] == 0 and padding[1] == 0):
data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
new_attrs['padding'] = (0, 0)

# The data type is now shifted to uint8
data_dtype = 'uint8'

# Multiply 128 to adjust shift.
adjust_shift = relay.multiply(adjust_shift, relay.const(128, 'int32'))

# Legalize if the datatypes are suitable for fast Int8 instructions. Int8 instructions require
# input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
# channels, we pad both the inputs and weights input channels. For output channels, we pad the
# weight and stride_slice the output.
if _is_int8_hw_support(data_tensor.dtype, kernel_tensor.dtype):
if _is_int8_hw_support(data_dtype, kernel_dtype):
# Flags to remember if the expr is modified
ic_modified = False
oc_modified = False

# Collect the input exprs.
data, kernel = inputs

# Find the value of input and output channel.
in_channel = -1
out_channel = -1
Expand Down Expand Up @@ -250,16 +298,16 @@ def _conv2d_legalize(attrs, inputs, arg_types):
else:
return None

if not (ic_modified or oc_modified):
return None

if ic_modified and not oc_modified:
return relay.nn.conv2d(data, kernel, **attrs)

if oc_modified:
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
return relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)

if is_int8_inputs:
out = relay.subtract(out, adjust_shift)

return out
return None