Skip to content

Commit

Permalink
[TOPI][x86] Legalize - Support int8xint8 convolution to use VNNI inst…
Browse files Browse the repository at this point in the history
…ructions.
  • Loading branch information
anijain2305 committed Oct 24, 2019
1 parent fdb01cb commit b84b26e
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 40 deletions.
65 changes: 37 additions & 28 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,11 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):

n, h, w, ch, cw = 1, 64, 64, 3, 3
if data_layout == 'NCHW':
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
data_shape = (n, ic, h, w)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
elif data_layout == 'NHWC':
x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype))
data_shape = (n, h, w, ic)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
else:
raise ValueError('Not supported')

Expand All @@ -559,20 +561,22 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
else:
raise ValueError('Not supported')

w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, w,
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, weight,
kernel_size=(ch, cw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=output_dtype)
func = relay.Function([x, w], y)
func = relay.Function([x, weight], y)
wdata = np.random.rand(*kernel_shape) * 10
parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}

with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)

assembly = lib.get_source("asm")
return assembly

Expand All @@ -589,58 +593,63 @@ def _has_fast_int8_instructions(asm, target):
llvm_version = tvm.codegen.llvm_version_major()
for target in targets:
if llvm_version >= 8:
fast_int8_dtypes = ('uint8', 'int8', 'int32')
dtypes = ('uint8', 'int8', 'int32')
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW",
asm = _compile(ic=ic, oc=16, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC",
asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)


# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW",
asm = _compile(ic=8, oc=oc, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC",
asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('int8', 'int8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
# Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
for target in targets:
if llvm_version >= 8:
dtypes = (('int8', 'int8', 'int32'))
# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)
assert _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)
assert _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)

# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
Expand Down
72 changes: 60 additions & 12 deletions topi/python/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,72 @@ def _conv2d_legalize(attrs, inputs, arg_types):
The legalized expr
"""

# Dilation not supported yet. Return None if dilation is not (1, 1)
dilation = attrs.get_int_tuple("dilation")
if not (dilation[0] == 1 and dilation[1] == 1):
return None

# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype
kernel_dtype = kernel_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
data, kernel = inputs

# Get the conv attrs
new_attrs = {k: attrs[k] for k in attrs.keys()}

is_int8_inputs = False
# If both the inputs are int8, we can add 128 to make the input dtype uint8, and then adjust the
# output. This will help picking up Intel VNNI instructions.
# Original --> C = A (conv) B
# A and B are int8
# C = (A + 128 - 128) (conv) B
# C = (A' conv B) - 128 (conv) B
# where A' = A + 128
# and 128 (conv) B is basically a reduce on CRS axis for weights.
if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8':
is_int8_inputs = True
padding = attrs.get_int_tuple("padding")

if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2))
pad_width = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0))
elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1]))
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3))
adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2)
else:
return None

data = relay.cast(data, 'int32')
data = relay.add(data, relay.const(128, 'int32'))
data = relay.cast(data, 'uint8')

# Do external padding as pad value has to be 128.
if not (padding[0] == 0 and padding[1] == 0):
data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
new_attrs['padding'] = (0, 0)

# The data type is now shifted to uint8
data_dtype = 'uint8'

# Multiply 128 to adjust shift.
adjust_shift = relay.multiply(adjust_shift, relay.const(128, 'int32'))

# Legalize if the datatypes are suitable for fast Int8 instructions. Int8 instructions require
# input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
# channels, we pad both the inputs and weights input channels. For output channels, we pad the
# weight and stride_slice the output.
if _is_int8_hw_support(data_tensor.dtype, kernel_tensor.dtype):
if _is_int8_hw_support(data_dtype, kernel_dtype):
# Flags to remember if the expr is modified
ic_modified = False
oc_modified = False

# Collect the input exprs.
data, kernel = inputs

# Find the value of input and output channel.
in_channel = -1
out_channel = -1
Expand Down Expand Up @@ -250,16 +298,16 @@ def _conv2d_legalize(attrs, inputs, arg_types):
else:
return None

if not (ic_modified or oc_modified):
return None

if ic_modified and not oc_modified:
return relay.nn.conv2d(data, kernel, **attrs)

if oc_modified:
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
return relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)

if is_int8_inputs:
out = relay.subtract(out, adjust_shift)

return out
return None

0 comments on commit b84b26e

Please sign in to comment.