Skip to content

Commit b3b0aa9

Browse files
committed
[vulkan] Add integer dot product (4xint8, 4xuint8) tensorization for the vulkan SPIR-V target. Currently only autotvm path is supported.
Prerequisites for compilation: (1) Use VulkanSDK 1.2.198 release with SPIR-V integer dot product suppport (2) set(USE_SPIRV_KHR_INTEGER_DOT_PRODUCT ON) in config.cmake and build (3) Use a driver that supports VK_KHR_shader_integer_dot_product extension. The compiled binary can only be run on a hardware that supports relevant ISA. This work is tested on AMD RDNA2 famillies (e.g., Rembrandt and RX6800). To compile on a device that supports this extension, use target: vulkan -from_device=0 To compile on a device that supports int8 but does not support this extension, add "-supports_integer_dot_product=1" or "-mattr=+dotprod" to the target string. To support pre-released vulkan and SPIR-V extensions, we need SPIR-V tool and header file from Khronos github, use the option: USE_KHRONOS_SPIRV in config.cmake.
1 parent 16e9491 commit b3b0aa9

File tree

24 files changed

+403
-160
lines changed

24 files changed

+403
-160
lines changed

CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ endif()
2626
tvm_option(USE_CUDA "Build with CUDA" OFF)
2727
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
2828
tvm_option(USE_VULKAN "Build with Vulkan" OFF)
29+
30+
# Whether to use spirv-tools.and SPIRV-Headers from Khronos github or gitlab.
31+
#
32+
# Possible values:
33+
# - OFF: not to use
34+
# - /path/to/install: path to your khronis spirv-tools and SPIRV-Headers installation directory
35+
#
36+
tvm_option(USE_KHRONOS_SPIRV "Whether to use spirv-tools.and SPIRV-Headers from Khronos github or gitlab" OFF)
37+
tvm_option(USE_SPIRV_KHR_INTEGER_DOT_PRODUCT "whether enable SPIRV_KHR_DOT_PRODUCT" OFF)
2938
tvm_option(USE_METAL "Build with Metal" OFF)
3039
tvm_option(USE_ROCM "Build with ROCM" OFF)
3140
tvm_option(ROCM_PATH "The path to rocm" /opt/rocm)

cmake/config.cmake

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ set(USE_METAL OFF)
8181
# - /path/to/vulkan-sdk: use specific path to vulkan-sdk
8282
set(USE_VULKAN OFF)
8383

84+
85+
# Whether to use spirv-tools.and SPIRV-Headers from Khronos github or gitlab.
86+
#
87+
# Possible values:
88+
# - OFF: not to use
89+
# - /path/to/install: path to your khronis spirv-tools and SPIRV-Headers installation directory
90+
#
91+
set(USE_KHRONOS_SPIRV OFF)
92+
93+
# whether enable SPIRV_KHR_DOT_PRODUCT
94+
set(USE_SPIRV_KHR_INTEGER_DOT_PRODUCT OFF)
95+
8496
# Whether enable OpenGL runtime
8597
set(USE_OPENGL OFF)
8698

cmake/modules/Vulkan.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616
# under the License.
1717

1818
# Be compatible with older version of CMake
19-
find_vulkan(${USE_VULKAN})
19+
find_vulkan(${USE_VULKAN} ${USE_KHRONOS_SPIRV})
2020

2121
if(USE_VULKAN)
2222
if(NOT Vulkan_FOUND)
2323
message(FATAL_ERROR "Cannot find Vulkan, USE_VULKAN=" ${USE_VULKAN})
2424
endif()
25+
if (USE_SPIRV_KHR_INTEGER_DOT_PRODUCT)
26+
add_definitions(-DTVM_SPIRV_KHR_INTEGER_DOT_PRODUCT=1)
27+
message(STATUS "Enable SPIRV_KHR_INTEGER_DOT_PRODUCT")
28+
endif()
2529
include_directories(SYSTEM ${Vulkan_INCLUDE_DIRS})
2630
message(STATUS "Build with Vulkan support")
2731
tvm_file_glob(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)

cmake/utils/FindVulkan.cmake

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# - Vulkan_SPIRV_TOOLS_LIBRARY
3333
#
3434

35-
macro(find_vulkan use_vulkan)
35+
macro(find_vulkan use_vulkan use_khronos_spirv)
3636
set(__use_vulkan ${use_vulkan})
3737
if(IS_DIRECTORY ${__use_vulkan})
3838
set(__vulkan_sdk ${__use_vulkan})
@@ -43,6 +43,15 @@ macro(find_vulkan use_vulkan)
4343
set(__vulkan_sdk "")
4444
endif()
4545

46+
47+
if(IS_DIRECTORY ${use_khronos_spirv})
48+
set(__use_khronos_spirv ${use_khronos_spirv})
49+
message(STATUS "Custom khronos spirv PATH=" ${__use_khronos_spirv})
50+
else()
51+
set(__use_khronos_spirv "")
52+
endif()
53+
54+
4655
if(__vulkan_sdk)
4756
set(Vulkan_INCLUDE_DIRS ${__vulkan_sdk}/include)
4857
find_library(Vulkan_LIBRARY NAMES vulkan vulkan-1 PATHS ${__vulkan_sdk}/lib)
@@ -61,11 +70,18 @@ macro(find_vulkan use_vulkan)
6170

6271
if(Vulkan_FOUND)
6372
get_filename_component(VULKAN_LIBRARY_PATH ${Vulkan_LIBRARY} DIRECTORY)
64-
find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools
65-
HINTS ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib)
73+
if (WIN32)
74+
find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools
75+
HINTS ${__use_khronos_spirv}/spirv-tools/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib)
76+
find_path(_libspirv libspirv.h HINTS ${__use_khronos_spirv}/spirv-tools/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools)
77+
find_path(_spirv spirv.hpp HINTS ${__use_khronos_spirv}/SPIRV-Headers/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
78+
else()
79+
find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools
80+
HINTS ${__use_khronos_spirv}/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib)
81+
find_path(_libspirv libspirv.h HINTS ${__use_khronos_spirv}/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools)
82+
find_path(_spirv spirv.hpp HINTS ${__use_khronos_spirv}/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
83+
endif()
6684

67-
find_path(_libspirv libspirv.h HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools)
68-
find_path(_spirv spirv.hpp HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
6985
find_path(_glsl_std GLSL.std.450.h HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
7086
list(APPEND Vulkan_INCLUDE_DIRS ${_libspirv} ${_spirv} ${_glsl_std})
7187
message(STATUS "Vulkan_INCLUDE_DIRS=" ${Vulkan_INCLUDE_DIRS})

python/tvm/relay/op/strategy/cuda.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
145145
if layout == "NCHW":
146146
assert kernel_layout == "OIHW"
147147
if (
148-
target.kind.name == "cuda"
148+
(target.kind.name in ["cuda", "vulkan"])
149149
and data.dtype in ("int8", "uint8")
150150
and kernel.dtype in ("int8", "uint8")
151151
):
@@ -296,7 +296,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
296296
"Unsupported shape for conv2d HWNC.\
297297
Need to satisfy tensor core schedule."
298298
)
299-
elif target.kind.name == "cuda" and layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
299+
elif (
300+
(target.kind.name in ["cuda", "vulkan"])
301+
and layout == "NCHW4c"
302+
and data.dtype in ["int8", "uint8"]
303+
):
300304
assert kernel_layout == "OIHW4o4i"
301305
strategy.add_implementation(
302306
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
@@ -372,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
372376
ic_chunk = in_channels // 4
373377

374378
if (
375-
target.kind.name == "cuda"
379+
(target.kind.name in ["cuda", "vulkan"])
376380
and data.dtype in ["int8", "uint8"]
377381
and kernel.dtype in ["int8", "uint8"]
378382
and channels % groups == 0

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def is_aarch64_arm():
387387
return "aarch64" in target.attrs.get("mtriple", "")
388388

389389

390+
def is_vulkan():
391+
"""Checks whether we are compiling for a vulkan/spirv target."""
392+
target = tvm.target.Target.current(allow_none=False)
393+
return "vulkan" in target.keys
394+
395+
390396
########################
391397
# ARM CPU legalizations.
392398
########################
@@ -438,17 +444,23 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
438444

439445

440446
#####################
441-
# CUDA legalizations.
447+
# CUDA and vulkan legalizations.
442448
#####################
443449

444450

445-
@qnn_conv2d_legalize.register("cuda")
451+
@qnn_conv2d_legalize.register(["cuda", "gpu"])
446452
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
453+
if is_vulkan():
454+
# prefers the dtypes to be same. Mixed type is not yet supported.
455+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
447456
# CUDA prefers both datatypes to be int8.
448457
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d)
449458

450459

451-
@qnn_dense_legalize.register("cuda")
460+
@qnn_dense_legalize.register(["cuda", "gpu"])
452461
def _qnn_dense_legalize_cuda(attrs, inputs, types):
462+
if is_vulkan():
463+
# prefers the dtypes to be same. Mixed type is not yet supported.
464+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
453465
# CUDA prefers both datatypes to be the int8.
454466
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense)

python/tvm/target/target.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def mattr(self):
188188
"""Returns the mattr from the target if it exists."""
189189
return list(self.attrs.get("mattr", []))
190190

191+
@property
192+
def supports_integer_dot_product(self):
193+
if self.attrs.get("supports_integer_dot_product", []):
194+
return bool(self.attrs["supports_integer_dot_product"])
195+
return False
196+
191197
@property
192198
def libs(self):
193199
return list(self.attrs.get("libs", []))

python/tvm/topi/cuda/conv2d_alter_op.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,18 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
8383
cfg = dispatch_ctx.query(target, workload)
8484
if cfg.is_fallback: # if is fallback, clear query cache and return None
8585
autotvm.task.clear_fallback_cache(target, workload)
86-
return None
86+
do_new_layout = False
87+
if "vulkan" in target.keys:
88+
do_new_layout = "+dotprod" in target.mattr or target.supports_integer_dot_product
89+
if not do_new_layout:
90+
return None
8791

8892
topi_tmpl = workload[0]
8993
if topi_tmpl == "conv2d_NCHWc_int8.cuda":
9094
assert data_layout == "NCHW" and kernel_layout == "OIHW"
9195
N, CI, H, W = get_const_tuple(data.shape)
9296
CO, _, KH, KW = get_const_tuple(kernel.shape)
93-
97+
assert CO % 4 == 0, "Number of output channels should be multiple of 4"
9498
new_layout = "NCHW4c"
9599
new_attrs["channels"] = CO
96100
new_attrs["data_layout"] = new_layout
@@ -324,7 +328,7 @@ def _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_te
324328
return out
325329

326330

327-
@conv2d_legalize.register("cuda")
331+
@conv2d_legalize.register(["cuda", "gpu"])
328332
def _conv2d_legalize(attrs, inputs, arg_types):
329333
"""Legalizes Conv2D op.
330334

python/tvm/topi/cuda/conv2d_int8.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,15 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
153153
kh = te.reduce_axis((0, kernel_h), name="kh")
154154
kw = te.reduce_axis((0, kernel_w), name="kw")
155155

156+
packed_kernel_dtype = packed_kernel.dtype
157+
packed_dtype = "int32" if packed_kernel_dtype == "int8" else "uint32"
156158
conv = te.compute(
157159
oshape,
158160
lambda n, oc_chunk, oh, ow, oc_block: te.sum(
159161
pad_data[
160162
n, icc, oh * stride_h + kh * dilation_h, ow * stride_w + kw * dilation_w, icb
161-
].astype("int32")
162-
* packed_kernel[oc_chunk, icc, kh, kw, oc_block, icb].astype("int32"),
163+
].astype(packed_dtype)
164+
* packed_kernel[oc_chunk, icc, kh, kw, oc_block, icb].astype(packed_dtype),
163165
axis=[icc, kh, kw, icb],
164166
),
165167
)
@@ -188,9 +190,6 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
188190
return output
189191

190192

191-
_dp4a = dp4a("shared", "shared", "local")
192-
193-
194193
@autotvm.register_topi_schedule("conv2d_NCHWc_int8.cuda")
195194
def schedule_conv2d_NCHWc_int8(cfg, outs):
196195
"""Schedule conv2d int8 NCHWc template"""
@@ -311,7 +310,14 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
311310
cfg["reorder_inner"].apply(s, conv, [rci, ryi, rxi])
312311

313312
_, rc_block = s[conv].split(rc_block, factor=4)
314-
s[conv].tensorize(rc_block, _dp4a)
313+
target = tvm.target.Target.current(allow_none=False)
314+
do_tensorize = True
315+
if "vulkan" in target.keys:
316+
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
317+
318+
if do_tensorize:
319+
dtypes = (pad_data.dtype, packed_kernel.dtype)
320+
s[conv].tensorize(rc_block, dp4a("shared", "shared", "local", dtypes))
315321

316322
cache_loc = [rco, ryo, rxo][cfg["reorder_inner"].perm[-1]]
317323
s[AA].compute_at(s[conv], cache_loc)

python/tvm/topi/cuda/dense.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=invalid-name, unused-argument
1818
"""Schedule for dense operator"""
1919
import logging
20+
import tvm
2021
from tvm import te
2122
import tvm.autotvm as autotvm
2223
from tvm.contrib import cublas
@@ -133,9 +134,6 @@ def _callback(op):
133134
return s
134135

135136

136-
_dp4a = dp4a("shared", "shared", "local")
137-
138-
139137
def _schedule_dense_int8(cfg, s, output):
140138
data, weight = s[output].op.input_tensors
141139
if len(weight.op.input_tensors) == 1 and weight.op.input_tensors[0] == data:
@@ -173,7 +171,14 @@ def _schedule_dense_int8(cfg, s, output):
173171
ko = CC.op.reduce_axis[0]
174172
ko, ki = s[CC].split(ko, factor=4)
175173
ko, kt = cfg["tile_k"].apply(s, CC, ko)
176-
s[CC].tensorize(ki, _dp4a)
174+
target = tvm.target.Target.current(allow_none=False)
175+
if (
176+
"vulkan" not in target.keys
177+
or "+dotprod" in target.mattr
178+
or target.supports_integer_dot_product
179+
):
180+
dtypes = (data.dtype, weight.dtype)
181+
s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes))
177182
by, vy, ty, yi = cfg["tile_y"].apply(s, output, n)
178183
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
179184

0 commit comments

Comments
 (0)