forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Infra to use tvm write op kernels (apache#15550)
* intra to use tvm write op kernels * add cmake support for tvm op * fix header lint * cleanup USE_TVM_OP logic in Makefile * add doc, cmake def, etc. * fix doc * test rand shape * add with_seed to test * improve err msg. add #if
- Loading branch information
Showing
26 changed files
with
874 additions
and
9 deletions.
There are no files selected for viewing
Submodule dlpack
updated
6 files
+63 −3 | CMakeLists.txt | |
+2 −0 | apps/from_numpy/Makefile | |
+107 −0 | apps/from_numpy/main.py | |
+52 −0 | apps/from_numpy/numpy_dlpack.c | |
+4 −0 | cmake/template/Config.cmake.in | |
+39 −7 | include/dlpack/dlpack.h |
Submodule dmlc-core
updated
6 files
+1 −1 | CMakeLists.txt | |
+5 −0 | include/dmlc/common.h | |
+2 −2 | include/dmlc/logging.h | |
+8 −3 | include/dmlc/threadediter.h | |
+75 −6 | test/unittest/unittest_threaditer_exc_handling.cc | |
+1 −1 | tracker/dmlc_tracker/kubernetes.py |
Submodule tvm
updated
from 21935d to afd4b3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
message(STATUS "Prepare external packages for TVM...") | ||
execute_process(COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/prepare_tvm.sh") | ||
|
||
# Whether enable ROCM runtime | ||
# | ||
# Possible values: | ||
# - ON: enable ROCM with cmake's auto search | ||
# - OFF: disable ROCM | ||
# - /path/to/rocm: use specific path to rocm | ||
set(USE_ROCM OFF) | ||
|
||
# Whether enable SDAccel runtime | ||
set(USE_SDACCEL OFF) | ||
|
||
# Whether enable Intel FPGA SDK for OpenCL (AOCL) runtime | ||
set(USE_AOCL OFF) | ||
|
||
# Whether enable OpenCL runtime | ||
set(USE_OPENCL OFF) | ||
|
||
# Whether enable Metal runtime | ||
set(USE_METAL OFF) | ||
|
||
# Whether enable Vulkan runtime | ||
# | ||
# Possible values: | ||
# - ON: enable Vulkan with cmake's auto search | ||
# - OFF: disable vulkan | ||
# - /path/to/vulkan-sdk: use specific path to vulkan-sdk | ||
set(USE_VULKAN OFF) | ||
|
||
# Whether enable OpenGL runtime | ||
set(USE_OPENGL OFF) | ||
|
||
# Whether to enable SGX runtime | ||
# | ||
# Possible values for USE_SGX: | ||
# - /path/to/sgxsdk: path to Intel SGX SDK | ||
# - OFF: disable SGX | ||
# | ||
# SGX_MODE := HW|SIM | ||
set(USE_SGX OFF) | ||
set(SGX_MODE "SIM") | ||
set(RUST_SGX_SDK "/path/to/rust-sgx-sdk") | ||
|
||
# Whether enable RPC runtime | ||
set(USE_RPC ON) | ||
|
||
# Whether embed stackvm into the runtime | ||
set(USE_STACKVM_RUNTIME OFF) | ||
|
||
# Whether enable tiny embedded graph runtime. | ||
set(USE_GRAPH_RUNTIME ON) | ||
|
||
# Whether enable additional graph debug functions | ||
set(USE_GRAPH_RUNTIME_DEBUG OFF) | ||
|
||
# Whether build with LLVM support | ||
# Requires LLVM version >= 4.0 | ||
# | ||
# Possible values: | ||
# - ON: enable llvm with cmake's find search | ||
# - OFF: disable llvm | ||
# - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is available. | ||
set(USE_LLVM "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/build/llvm/bin/llvm-config") | ||
|
||
#--------------------------------------------- | ||
# Contrib libraries | ||
#--------------------------------------------- | ||
# Whether use BLAS, choices: openblas, mkl, atlas, apple | ||
set(USE_BLAS none) | ||
|
||
# /path/to/mkl: mkl root path when use mkl blas library | ||
# set(USE_MKL_PATH /opt/intel/mkl) for UNIX | ||
# set(USE_MKL_PATH ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32 | ||
set(USE_MKL_PATH none) | ||
|
||
# Whether use contrib.random in runtime | ||
set(USE_RANDOM OFF) | ||
|
||
# Whether use NNPack | ||
set(USE_NNPACK OFF) | ||
|
||
# Whether use CuDNN | ||
if(USE_CUDNN AND USE_CUDA) | ||
detect_cuDNN() | ||
if(HAVE_CUDNN) | ||
set(USE_CUDNN ON) | ||
else() | ||
set(USE_CUDNN OFF) | ||
endif() | ||
else() | ||
set(USE_CUDNN OFF) | ||
endif() | ||
|
||
# Whether use cuBLAS | ||
set(USE_CUBLAS OFF) | ||
|
||
# Whether use MIOpen | ||
set(USE_MIOPEN OFF) | ||
|
||
# Whether use MPS | ||
set(USE_MPS OFF) | ||
|
||
# Whether use rocBlas | ||
set(USE_ROCBLAS OFF) | ||
|
||
# Whether use contrib sort | ||
set(USE_SORT OFF) | ||
|
||
# Build ANTLR parser for Relay text format | ||
set(USE_ANTLR OFF) | ||
|
||
# Build TSIM for VTA | ||
set(USE_VTA_TSIM OFF) | ||
|
||
# Whether use Relay debug mode | ||
set(USE_RELAY_DEBUG OFF) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
from .opdef import defop | ||
from .utils import AllTypes, RealTypes | ||
|
||
from . import basic |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
from . import ufunc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
import tvm | ||
from .. import defop, AllTypes | ||
|
||
def compute_add(dtype, ndim): | ||
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype) | ||
B = tvm.placeholder([tvm.var() for _ in range(ndim)], name='B', dtype=dtype) | ||
C = tvm.compute([tvm.var() for _ in range(ndim)], | ||
lambda *index: A[index] + B[index], name='C') | ||
s = tvm.create_schedule(C.op) | ||
return s, A, B, C | ||
|
||
@defop(name="vadd", target="cpu", auto_broadcast=True, | ||
dtype=AllTypes, ndim=list(range(1, 6))) | ||
def vadd(dtype, ndim): | ||
s, A, B, C = compute_add(dtype, ndim) | ||
axes = [axis for axis in C.op.axis] | ||
fused = s[C].fuse(*axes) | ||
s[C].parallel(fused) | ||
|
||
return s, [A, B, C] | ||
|
||
@defop(name="cuda_vadd", target="cuda", auto_broadcast=True, | ||
dtype=["float32", "float64"], ndim=list(range(1, 6))) | ||
def vadd_gpu(dtype, ndim): | ||
s, A, B, C = compute_add(dtype, ndim) | ||
s = tvm.create_schedule(C.op) | ||
axes = [axis for axis in C.op.axis] | ||
fused = s[C].fuse(*axes) | ||
bx, tx = s[C].split(fused, factor=64) | ||
s[C].bind(bx, tvm.thread_axis("blockIdx.x")) | ||
s[C].bind(tx, tvm.thread_axis("threadIdx.x")) | ||
return s, [A, B, C] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
"""TVM Operator compile entry point""" | ||
import tvm | ||
|
||
import os | ||
import argparse | ||
from tvmop.opdef import __OP_DEF__ | ||
|
||
def get_target(device): | ||
if device == "cpu": | ||
return "llvm" | ||
elif device == "cuda" or device == "gpu": | ||
return "cuda" | ||
assert False, "Unknown device " + device | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
sys.path.append(os.path.dirname(sys.path[0])) | ||
parser = argparse.ArgumentParser(description="Generate tvm operators") | ||
parser.add_argument("-o", action="store", required=True, dest="target_path", | ||
help="Target path which stores compiled library") | ||
arguments = parser.parse_args() | ||
|
||
func_list_llvm = [] | ||
func_list_cuda = [] | ||
|
||
# TODO: attach instruction features to the library, e.g., avx-512, etc. | ||
for operator_def in __OP_DEF__: | ||
for sch, args in operator_def.invoke_all(): | ||
if tvm.module.enabled(get_target(operator_def.target)): | ||
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda | ||
func_lower = tvm.lower(sch, args, | ||
name=operator_def.get_op_name(args), | ||
binds=operator_def.get_binds(args)) | ||
func_list.append(func_lower) | ||
|
||
lowered_funcs = {get_target("cpu") : func_list_llvm} | ||
if len(func_list_cuda) > 0: | ||
lowered_funcs[get_target("cuda")] = func_list_cuda | ||
func_binary = tvm.build(lowered_funcs, name="tvmop") | ||
func_binary.export_library(arguments.target_path) |
Oops, something went wrong.