Skip to content

Commit 9613385

Browse files
authored
[Codegen][ROCm] Mismatched Dtype of Workgroup/Workitem (#15777)
This PR fixes a ROCm codegen error that the dtype of `@llvm.amdgcn.workgroup.id*` and `@llvm.amdgcn.workitem.id.*` are always i32 when generating LLVM IR, even if it's marked as T.int64 in TIR. An example that triggers this issue: ```python @T.prim_func def encode_kernel(A: T.handle("float16", "global"), max_abs_value: T.handle("float16", "global"), v: T.int64): T.func_attr({"calling_conv": 2, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["rocm", "gpu"], "kind": "rocm", "max_num_threads": 256, "max_shared_memory_per_block": 65536, "max_threads_per_block": 1024, "mcpu": "gfx1100", "mtriple": "amdgcn-amd-amdhsa-hcc", "tag": "", "thread_warp_size": 32}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x"], "tir.noalias": T.bool(True)}) A_1 = T.decl_buffer((v * T.int64(8192),), "float16", data=A) max_abs_value_1 = T.decl_buffer((T.min(v, (v * T.int64(256) + T.int64(65535)) // T.int64(65536) * T.int64(256)) * T.int64(256),), "float16", data=max_abs_value) blockIdx_x = T.launch_thread("blockIdx.x", T.int64(256)) threadIdx_x = T.launch_thread("threadIdx.x", T.int64(256)) for i_j_fused_0, k in T.grid(T.shift_right(v + T.int64(255), T.int64(8)), T.int64(32)): if i_j_fused_0 * T.int64(256) + blockIdx_x - v < T.int64(0): if k == T.int64(0): max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x] = T.float16(-65504) max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x] = T.max(max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x], T.call_pure_extern("float16", "__ocml_fabs_f16", A_1[i_j_fused_0 * T.int64(2097152) + blockIdx_x * T.int64(8192) + threadIdx_x * T.int64(32) + k])) ```
1 parent d60fb72 commit 9613385

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/target/llvm/codegen_amdgpu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
187187
}
188188
}
189189
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
190-
return builder_->CreateCall(f, {});
190+
llvm::Value* result = builder_->CreateCall(f, {});
191+
return this->CreateCast(DataType::Int(32), iv->var->dtype, result);
191192
}
192193

193194
llvm::Value* CreateStorageSync(const CallNode* op) final {

src/target/llvm/intrin_rule_rocm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) {
8989
index = self + delta;
9090
index = Select((self & (width - 1)) + delta >= width, self, index);
9191
}
92-
PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(),
92+
PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(),
9393
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
9494
return res;
9595
}

0 commit comments

Comments
 (0)