Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/pass/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,29 @@ class GPUCodeVerifier : public IRVisitor {
// record the number of threads in a block
std::string name = var.get()->name_hint;
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
size_t length = static_cast<size_t>(extent->value);
if (!visited_threads_.count(name)) {
visited_threads_.insert(name);
size_t length = static_cast<size_t>(extent->value);
thread_per_block_ *= length;

if (name == "threadIdx.x") {
valid_ &= length <= max_thread_x_;
thread_x_extent_ = length;
} else if (name == "threadIdx.y") {
valid_ &= length <= max_thread_y_;
thread_y_extent_ = length;
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
thread_z_extent_ = length;
}
} else {
// the thread should be bound to axes with the same length
if (name == "threadIdx.x") {
valid_ &= length == thread_x_extent_;
} else if (name == "threadIdx.y") {
valid_ &= length == thread_y_extent_;
} else if (name == "threadIdx.z") {
valid_ &= length == thread_z_extent_;
}
}
}
Expand All @@ -111,6 +123,8 @@ class GPUCodeVerifier : public IRVisitor {
std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
std::unordered_set<std::string> visited_threads_;

size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;

size_t local_memory_per_block_;
size_t shared_memory_per_block_;
size_t thread_per_block_;
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_pass_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,32 @@ def test_multiple_kernels():
tvm.build(s, [A, C], target)
assert valid[0]

def test_wrong_bind():
N = 1024

A = tvm.placeholder((N, N-1), name='A')
B = tvm.compute((N, N-1), lambda i, j: A[i, j])

s = tvm.create_schedule([B.op])

# bind a thread axis to two loop axes with different lengths
s[B].bind(s[B].op.axis[0], tvm.thread_axis("threadIdx.x"))
s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))

for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue

valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, max_threads_per_block=N*N))]}):
tvm.build(s, [A, B], target)
assert not valid[0]


if __name__ == "__main__":
test_local_memory()
test_shared_memory()
test_num_thread()
test_multiple_kernels()
test_wrong_bind()