Skip to content

Commit a61e237

Browse files
committed
[MetaSchedule][Runtime] Enhance Runner RandomFill
1 parent 1b8f3b5 commit a61e237

File tree

13 files changed

+357
-64
lines changed

13 files changed

+357
-64
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ tvm_option(USE_CUDNN "Build with cuDNN" OFF)
8989
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
9090
tvm_option(USE_CUTLASS "Build with CUTLASS" OFF)
9191
tvm_option(USE_THRUST "Build with Thrust" OFF)
92+
tvm_option(USE_CURAND "Build with cuRAND" OFF)
9293
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
9394
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
9495
tvm_option(USE_SORT "Build with sort support" ON)

cmake/config.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ set(USE_VTA_FPGA OFF)
296296
# Whether use Thrust
297297
set(USE_THRUST OFF)
298298

299+
# Whether use cuRAND
300+
set(USE_CURAND OFF)
301+
299302
# Whether to build the TensorFlow TVMDSOOp module
300303
set(USE_TF_TVMDSOOP OFF)
301304

cmake/modules/CUDA.cmake

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ if(USE_CUDA)
6969
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
7070
endif(USE_THRUST)
7171

72+
if(USE_CURAND)
73+
message(STATUS "Build with cuRAND support")
74+
message(STATUS "${CUDA_CURAND_LIBRARY}")
75+
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
76+
enable_language(CUDA)
77+
tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CC src/runtime/contrib/curand/*.cc)
78+
tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CU src/runtime/contrib/curand/*.cu)
79+
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CURAND_LIBRARY})
80+
list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CC})
81+
list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CU})
82+
endif(USE_CURAND)
83+
7284
if(USE_GRAPH_EXECUTOR_CUDA_GRAPH)
7385
if(NOT USE_GRAPH_EXECUTOR)
7486
message(FATAL_ERROR "CUDA Graph is only supported by graph executor, please set USE_GRAPH_EXECUTOR=ON")

cmake/modules/LibInfo.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ function(add_lib_info src_file)
111111
TVM_INFO_USE_TFLITE="${USE_TFLITE}"
112112
TVM_INFO_USE_THREADS="${USE_THREADS}"
113113
TVM_INFO_USE_THRUST="${USE_THRUST}"
114+
TVM_INFO_USE_CURAND="${USE_CURAND}"
114115
TVM_INFO_USE_VITIS_AI="${USE_VITIS_AI}"
115116
TVM_INFO_USE_VULKAN="${USE_VULKAN}"
116117
TVM_INFO_USE_CLML="${USE_CLML}"

cmake/utils/FindCUDA.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ macro(find_cuda use_cuda use_cudnn)
8585
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
8686
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
8787
NO_DEFAULT_PATH)
88+
find_library(CUDA_CURAND_LIBRARY curand
89+
${CUDA_TOOLKIT_ROOT_DIR}/lib64
90+
${CUDA_TOOLKIT_ROOT_DIR}/lib
91+
NO_DEFAULT_PATH)
8892
find_library(CUDA_CUBLAS_LIBRARY cublas
8993
${CUDA_TOOLKIT_ROOT_DIR}/lib64
9094
${CUDA_TOOLKIT_ROOT_DIR}/lib
@@ -134,6 +138,7 @@ macro(find_cuda use_cuda use_cudnn)
134138
message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS})
135139
message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY})
136140
message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY})
141+
message(STATUS "Found CUDA_CURAND_LIBRARY=" ${CUDA_CURAND_LIBRARY})
137142
message(STATUS "Found CUDA_CUBLASLT_LIBRARY=" ${CUDA_CUBLASLT_LIBRARY})
138143
endif(CUDA_FOUND)
139144
endmacro(find_cuda)

python/tvm/meta_schedule/runner/local_runner.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@
2323

2424
from ...contrib.popen_pool import PopenPoolExecutor
2525
from ...runtime import Device, Module
26+
from ..profiler import Profiler
2627
from ..utils import derived_object, get_global_func_with_default_on_worker
2728
from .config import EvaluatorConfig
28-
from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture
29+
from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult
2930
from .utils import (
30-
T_ARGUMENT_LIST,
3131
T_ARG_INFO_JSON_OBJ_LIST,
32+
T_ARGUMENT_LIST,
3233
alloc_argument_common,
3334
run_evaluator_common,
3435
)
3536

36-
3737
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
3838

3939

@@ -137,26 +137,29 @@ def resource_handler():
137137
yield
138138
finally:
139139
# Final step. Always clean up
140-
f_cleanup()
140+
with Profiler.timeit("LocalRunner/cleanup"):
141+
f_cleanup()
141142

142143
with resource_handler():
143144
# Step 1: create the local runtime module
144-
rt_mod = tvm.runtime.load_module(artifact_path)
145-
# Step 2: create the local device
146-
device = tvm.runtime.device(dev_type=device_type, dev_id=0)
147-
# Step 3: Allocate input arguments
148-
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
149-
device,
150-
args_info,
151-
alloc_repeat,
152-
)
153-
# Step 4: Run time_evaluator
154-
costs: List[float] = f_run_evaluator(
155-
rt_mod,
156-
device,
157-
evaluator_config,
158-
repeated_args,
159-
)
145+
with Profiler.timeit("LocalRunner/load_module"):
146+
rt_mod = tvm.runtime.load_module(artifact_path)
147+
# Step 2: Allocate input arguments
148+
with Profiler.timeit("LocalRunner/alloc_argument"):
149+
device = tvm.runtime.device(dev_type=device_type, dev_id=0)
150+
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
151+
device,
152+
args_info,
153+
alloc_repeat,
154+
)
155+
# Step 3: Run time_evaluator
156+
with Profiler.timeit("LocalRunner/run_evaluator"):
157+
costs: List[float] = f_run_evaluator(
158+
rt_mod,
159+
device,
160+
evaluator_config,
161+
repeated_args,
162+
)
160163
return costs
161164

162165

@@ -313,9 +316,6 @@ def _check(
313316
get_global_func_with_default_on_worker(name=f_alloc_argument, default=None)
314317
get_global_func_with_default_on_worker(name=f_run_evaluator, default=None)
315318
get_global_func_with_default_on_worker(name=f_cleanup, default=None)
316-
get_global_func_with_default_on_worker(
317-
name="tvm.contrib.random.random_fill", default=None
318-
)
319319

320320
value = self.pool.submit(
321321
_check,
@@ -348,7 +348,7 @@ def default_alloc_argument(
348348
The allocation args
349349
"""
350350
f_random_fill = get_global_func_with_default_on_worker(
351-
name="tvm.contrib.random.random_fill", default=None
351+
name="tvm.contrib.random.random_fill_for_measure", default=None
352352
)
353353
return alloc_argument_common(f_random_fill, device, args_info, alloc_repeat)
354354

python/tvm/meta_schedule/runner/rpc_runner.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tvm.rpc import RPCSession
2626
from tvm.runtime import Device, Module
2727

28+
from ..profiler import Profiler
2829
from ..utils import (
2930
cpu_count,
3031
derived_object,
@@ -378,31 +379,36 @@ def resource_handler():
378379
yield
379380
finally:
380381
# Final step. Always clean up
381-
f_cleanup(session, remote_path)
382+
with Profiler.timeit("RPCRunner/cleanup"):
383+
f_cleanup(session, remote_path)
382384

383385
with resource_handler():
384386
# Step 1. Create session
385-
session = f_create_session(rpc_config)
386-
device = session.device(dev_type=device_type, dev_id=0)
387+
with Profiler.timeit("RPCRunner/create_session"):
388+
session = f_create_session(rpc_config)
389+
device = session.device(dev_type=device_type, dev_id=0)
387390
# Step 2. Upload the module
388-
_, remote_path = osp.split(artifact_path)
389-
local_path: str = artifact_path
390-
rt_mod: Module = f_upload_module(session, local_path, remote_path)
391+
with Profiler.timeit("RPCRunner/upload_module"):
392+
_, remote_path = osp.split(artifact_path)
393+
local_path: str = artifact_path
394+
rt_mod: Module = f_upload_module(session, local_path, remote_path)
391395
# Step 3: Allocate input arguments
392-
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
393-
session,
394-
device,
395-
args_info,
396-
alloc_repeat,
397-
)
396+
with Profiler.timeit("RPCRunner/alloc_argument"):
397+
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
398+
session,
399+
device,
400+
args_info,
401+
alloc_repeat,
402+
)
398403
# Step 4: Run time_evaluator
399-
costs: List[float] = f_run_evaluator(
400-
session,
401-
rt_mod,
402-
device,
403-
evaluator_config,
404-
repeated_args,
405-
)
404+
with Profiler.timeit("LocalRunner/run_evaluator"):
405+
costs: List[float] = f_run_evaluator(
406+
session,
407+
rt_mod,
408+
device,
409+
evaluator_config,
410+
repeated_args,
411+
)
406412
return costs
407413

408414

@@ -474,7 +480,7 @@ def default_alloc_argument(
474480
"""
475481
f_random_fill = get_global_func_on_rpc_session(
476482
session,
477-
"tvm.contrib.random.random_fill",
483+
"tvm.contrib.random.random_fill_for_measure",
478484
"Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.",
479485
)
480486

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include <curand.h>
20+
#include <tvm/runtime/c_runtime_api.h>
21+
#include <tvm/runtime/registry.h>
22+
23+
#include "../../cuda/cuda_common.h"
24+
#include "./helper_cuda_kernels.h"
25+
26+
namespace tvm {
27+
namespace runtime {
28+
namespace curand {
29+
30+
#define TVM_CURAND_CALL(func) \
31+
{ \
32+
curandStatus_t e = (func); \
33+
ICHECK(e == CURAND_STATUS_SUCCESS) << "cuRAND error: " << e; \
34+
}
35+
36+
class CURandGenerator {
37+
public:
38+
CURandGenerator() { TVM_CURAND_CALL(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); }
39+
~CURandGenerator() { TVM_CURAND_CALL(curandDestroyGenerator(gen)); }
40+
41+
void Generate32bit(void* ptr, int64_t n) {
42+
TVM_CURAND_CALL(curandGenerateNormal(gen, static_cast<float*>(ptr), n, 0.0f, 5.0f));
43+
cudaDeviceSynchronize();
44+
}
45+
46+
void Generate64bit(void* ptr, int64_t n) {
47+
TVM_CURAND_CALL(curandGenerateNormalDouble(gen, static_cast<double*>(ptr), n, 0.0f, 5.0f));
48+
}
49+
50+
curandGenerator_t gen;
51+
};
52+
53+
DeviceAPI* GetCUDADeviceAPI() {
54+
const PackedFunc* get_cuda_api = runtime::Registry::Get("device_api.cuda");
55+
ICHECK(get_cuda_api) << "ValueError: TVM is not built with USE_CUDA=ON";
56+
void* ret = (*get_cuda_api)();
57+
runtime::DeviceAPI* cuda_api = static_cast<runtime::DeviceAPI*>(ret);
58+
return cuda_api;
59+
}
60+
61+
int64_t GetTensorSize(DLTensor* tensor) {
62+
int64_t tensor_size = 1;
63+
for (int i = 0; i < tensor->ndim; ++i) {
64+
tensor_size *= tensor->shape[i];
65+
}
66+
return tensor_size;
67+
}
68+
69+
struct DeferredFunc {
70+
public:
71+
DeferredFunc(std::function<void()> func) : func_(func) {}
72+
~DeferredFunc() { func_(); }
73+
74+
private:
75+
std::function<void()> func_;
76+
};
77+
78+
void RandomFill(DLTensor* tensor) {
79+
static DeviceAPI* cuda_api = GetCUDADeviceAPI();
80+
CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA)
81+
<< "ValueError: cuRAND only works on CUDA devices";
82+
if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 16) {
83+
int64_t tensor_size = GetTensorSize(tensor);
84+
void* data = cuda_api->AllocWorkspace(tensor->device, tensor_size * sizeof(float));
85+
{
86+
DeferredFunc defer([data, tensor]() { cuda_api->FreeWorkspace(tensor->device, data); });
87+
CURandGenerator().Generate32bit(data, GetTensorSize(tensor));
88+
ConvertFp32toFp16(/*src=*/data, /*dst=*/tensor->data, /*num=*/tensor_size);
89+
}
90+
} else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 32) {
91+
CURandGenerator().Generate32bit(tensor->data, GetTensorSize(tensor));
92+
} else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 64) {
93+
CURandGenerator().Generate64bit(tensor->data, GetTensorSize(tensor));
94+
} else {
95+
LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype;
96+
}
97+
TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr);
98+
}
99+
100+
TVM_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill);
101+
102+
} // namespace curand
103+
} // namespace runtime
104+
} // namespace tvm
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include <cuda_fp16.h>
20+
21+
#include "./helper_cuda_kernels.h"
22+
23+
namespace tvm {
24+
namespace runtime {
25+
namespace curand {
26+
27+
__global__ void KernelFp32ToFp16(const float* src, half* dst, int num) {
28+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
29+
if (idx < num) {
30+
dst[idx] = src[idx];
31+
}
32+
}
33+
34+
void ConvertFp32toFp16(const void* _src, void* _dst, int64_t num) {
35+
const float* src = static_cast<const float*>(_src);
36+
half* dst = static_cast<half*>(_dst);
37+
KernelFp32ToFp16<<<(num + 255) / 256, 256>>>(src, dst, num);
38+
}
39+
40+
} // namespace curand
41+
} // namespace runtime
42+
} // namespace tvm

0 commit comments

Comments
 (0)