Commit 9613385
authored
[Codegen][ROCm] Mismatched Dtype of Workgroup/Workitem (#15777)
This PR fixes a ROCm codegen error that the dtype of
`@llvm.amdgcn.workgroup.id*` and `@llvm.amdgcn.workitem.id.*` are
always i32 when generating LLVM IR, even if it's marked as T.int64 in
TIR.
An example that triggers this issue:
```python
@T.prim_func
def encode_kernel(A: T.handle("float16", "global"), max_abs_value: T.handle("float16", "global"), v: T.int64):
T.func_attr({"calling_conv": 2, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["rocm", "gpu"], "kind": "rocm", "max_num_threads": 256, "max_shared_memory_per_block": 65536, "max_threads_per_block": 1024, "mcpu": "gfx1100", "mtriple": "amdgcn-amd-amdhsa-hcc", "tag": "", "thread_warp_size": 32}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x"], "tir.noalias": T.bool(True)})
A_1 = T.decl_buffer((v * T.int64(8192),), "float16", data=A)
max_abs_value_1 = T.decl_buffer((T.min(v, (v * T.int64(256) + T.int64(65535)) // T.int64(65536) * T.int64(256)) * T.int64(256),), "float16", data=max_abs_value)
blockIdx_x = T.launch_thread("blockIdx.x", T.int64(256))
threadIdx_x = T.launch_thread("threadIdx.x", T.int64(256))
for i_j_fused_0, k in T.grid(T.shift_right(v + T.int64(255), T.int64(8)), T.int64(32)):
if i_j_fused_0 * T.int64(256) + blockIdx_x - v < T.int64(0):
if k == T.int64(0):
max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x] = T.float16(-65504)
max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x] = T.max(max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x], T.call_pure_extern("float16", "__ocml_fabs_f16", A_1[i_j_fused_0 * T.int64(2097152) + blockIdx_x * T.int64(8192) + threadIdx_x * T.int64(32) + k]))
```1 parent d60fb72 commit 9613385
2 files changed
+3
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
187 | 187 | | |
188 | 188 | | |
189 | 189 | | |
190 | | - | |
| 190 | + | |
| 191 | + | |
191 | 192 | | |
192 | 193 | | |
193 | 194 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
89 | 89 | | |
90 | 90 | | |
91 | 91 | | |
92 | | - | |
| 92 | + | |
93 | 93 | | |
94 | 94 | | |
95 | 95 | | |
| |||
0 commit comments