-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Profiler] Allow user to flush L2 cache in time_evalutor function for profiling CUDA kernels
#13726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
177efd9
flush_l2
yzh119 43580b5
not necessarily sm_86
yzh119 dd32136
fix lint and test
yzh119 f1237c4
fix
yzh119 530fabf
revert profiling
yzh119 d650695
number=1
yzh119 d56c8b8
use parametrize
yzh119 486288c
use (void**)
yzh119 5fd7717
use reinterpret_cast for lint
yzh119 fcb0f3f
refactor and add license
yzh119 284b125
empty line for lint
yzh119 71eb46d
header order
yzh119 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| // Acknowledgement: l2flush struct in nvbench project. | ||
| // Reference: | ||
| // https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh | ||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
| #include <dmlc/thread_local.h> | ||
| #include <tvm/runtime/device_api.h> | ||
| #include <tvm/runtime/registry.h> | ||
|
|
||
| #include "cuda_common.h" | ||
|
|
||
| namespace tvm { | ||
|
|
||
| namespace runtime { | ||
|
|
||
| class L2Flush { | ||
| public: | ||
| L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {} | ||
|
|
||
| ~L2Flush() { | ||
| if (l2_size_ > 0) { | ||
echuraev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| CUDA_CALL(cudaFree(l2_buffer_)); | ||
| } | ||
| } | ||
|
|
||
| void Flush() { | ||
| if (!initialized_) { | ||
echuraev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // initialize l2_buffer_ and l2_size_ | ||
| initialized_ = true; | ||
| int device_id; | ||
| CUDA_CALL(cudaGetDevice(&device_id)); | ||
| CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id)); | ||
| if (l2_size_ > 0) { | ||
| CUDA_CALL(cudaMalloc(reinterpret_cast<void**>(&l2_buffer_), l2_size_)); | ||
| } | ||
| } | ||
| cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; | ||
| if (l2_size_ > 0) { | ||
| CUDA_CALL(cudaMemsetAsync(l2_buffer_, 0, l2_size_, stream)); | ||
| } | ||
| } | ||
|
|
||
| static L2Flush* ThreadLocal(); | ||
|
|
||
| private: | ||
| bool initialized_ = false; | ||
| int l2_size_; | ||
| int* l2_buffer_; | ||
| }; | ||
|
|
||
| typedef dmlc::ThreadLocalStore<L2Flush> L2FlushStore; | ||
|
|
||
| L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } | ||
|
|
||
| TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { | ||
| ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; | ||
| L2Flush::ThreadLocal()->Flush(); | ||
| }); | ||
|
|
||
| } // namespace runtime | ||
| } // namespace tvm | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| import tvm | ||
| from tvm import te | ||
| from tvm.script import tir as T | ||
| import tvm.testing | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
|
|
||
| @T.prim_func | ||
| def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, [128, 128]) | ||
| B = T.match_buffer(b, [128, 128]) | ||
| C = T.match_buffer(c, [128, 128]) | ||
| for i, j, k in T.grid(128, 128, 128): | ||
| with T.block("matmul"): | ||
| vi, vj, vk = T.axis.remap("SSR", [i, j, k]) | ||
| with T.init(): | ||
| C[vi, vj] = 0.0 | ||
| C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] | ||
|
|
||
|
|
||
| @tvm.testing.requires_cuda | ||
| @pytest.mark.parametrize("f_preproc", ["", "l2_cache_flush_cuda"]) | ||
| def test_time_evalutor_with_preproc(f_preproc: str): | ||
| mod = tvm.IRModule.from_expr(matmul) | ||
| sch = tvm.tir.Schedule(mod) | ||
| blk = sch.get_block("matmul") | ||
| i, j, k = sch.get_loops(blk) | ||
| sch.bind(i, "blockIdx.x") | ||
| sch.bind(j, "threadIdx.x") | ||
| f = tvm.build(sch.mod["main"], target="cuda") | ||
| dev = tvm.cuda(0) | ||
| evaluator = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1, f_preproc=f_preproc) | ||
|
|
||
| a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) | ||
| b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) | ||
| c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev) | ||
| args = [a, b, c] | ||
| print("Evaluator (f_preproc={}):\t{:.5f}ms".format(f_preproc, evaluator(*args).mean * 1000)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_time_evalutor_with_preproc("l2_cache_flush_cuda") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.