Skip to content

Commit cb37bfe

Browse files
authored
[Refactor] Refactor barrier management (#744)
* Introduce Barrier * Enhance CUDA kernel with new barrier management and post-processing support - Added a new CUDA kernel implementation in `example_mla_decode.py` for improved performance with shared memory barriers. - Refactored barrier handling in `codegen_cuda.cc` and `codegen_hip.cc` to utilize a more flexible mbarrier structure. - Updated intrinsic definitions from `ptx_stmatirx` to `ptx_stmatrix` across multiple files for consistency. - Introduced additional print statements for debugging in the lowering phase of the TileLang engine. - Enhanced the overall structure and readability of the codebase. * Remove unused barrier handling code in CUDA and HIP code generators to streamline the implementation. This change enhances code clarity and reduces complexity in the barrier management logic. * Enhance barrier management in TileLang - Introduced a new intrinsic `allocate_barrier` for dynamic barrier allocation in the TileLang framework. - Updated CUDA code generation to support the new barrier structure, allowing for improved synchronization in shared memory. - Refactored existing barrier handling logic to accommodate the new intrinsic and streamline code. - Added print statements for debugging purposes in various examples and the lowering phase of the TileLang engine. - Removed deprecated memory scope handling code to enhance clarity and maintainability. * lint fix * lint fix * Remove `allocate_barrier` intrinsic and related code from TileLang to streamline barrier management. This includes updates to CUDA code generation and the removal of associated Python wrappers, enhancing code clarity and maintainability. * Refactor logging in JITKernel to improve kernel compilation tracking - Removed unused import of `torch.backends` in the example file. - Introduced logging for kernel compilation in `JITKernel`, replacing print statements with structured logging for better traceability and debugging. - Added an assertion to ensure the presence of the `global_symbol` attribute in the kernel function. * Refactor dequantization tests and update barrier function - Removed the test for `example_dequant_gemm_bf16_fp4_hopper_serial` to streamline the testing suite. - Updated the `mbarrier_cp_async_arrive` function to support both pointer and non-pointer types, enhancing flexibility in barrier management. * Update CI configuration to increase pytest parallelism from 4 to 8 threads for improved test execution speed. * Fix typos in rasterization parameters and update import path for cached module - Corrected the spelling of `enable_rasteration` to `enable_rasterization` in the matmul function and its usage. - Updated the import statement for the `cached` module to reflect the new path in the cache submodule. - Added `StridedTensor` import in the language module for enhanced tensor functionality. * Update ci.yml
1 parent eccdfe1 commit cb37bfe

24 files changed

+421
-365
lines changed

examples/dequantize_gemm/test_example_dequantize_gemm.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import example_dequant_gemv_fp16xint4
44
import example_dequant_gemm_fp4_hopper
5-
import example_dequant_gemm_bf16_fp4_hopper_serial
65

76

87
@tilelang.testing.requires_cuda
@@ -16,11 +15,5 @@ def test_example_dequant_gemm_fp4_hopper():
1615
example_dequant_gemm_fp4_hopper.main()
1716

1817

19-
@tilelang.testing.requires_cuda
20-
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
21-
def test_example_dequant_gemm_bf16_fp4_hopper_serial():
22-
example_dequant_gemm_bf16_fp4_hopper_serial.main()
23-
24-
2518
if __name__ == "__main__":
2619
tilelang.testing.main()

examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.backends
32
from tilelang import tvm as tvm
43
import tilelang.testing
54
from tvm import DataType

examples/warp_specialize/example_warp_specialize_flashmla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
391391
num_split = 1
392392

393393
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
394+
print(kernel.get_kernel_source())
394395
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
395396
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
396397
latency = profiler.do_bench(warmup=500)

examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def main():
6666

6767
# Run the kernel through the Profiler
6868
c = jit_kernel(a, b)
69-
7069
# Reference multiplication using PyTorch
7170
ref_c = a @ b
7271

src/op/builtin.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
8383
.set_attr<TCallEffectKind>("TCallEffectKind",
8484
Integer(CallEffectKind::kOpaque));
8585

86-
TIR_DEFINE_TL_BUILTIN(ptx_stmatirx)
86+
TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
8787
.set_num_inputs(-1)
8888
.set_attr<TCallEffectKind>("TCallEffectKind",
8989
Integer(CallEffectKind::kOpaque));

src/op/builtin.h

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
6262
* swizzle, l2_promotion, oob_fill)
6363
*
6464
*/
65-
const Op &create_tma_descriptor();
65+
TVM_DLL const Op &create_tma_descriptor();
6666

6767
/*!
6868
* \brief tvm intrinsics for TMADescriptor creation for image to column load
@@ -73,23 +73,23 @@ const Op &create_tma_descriptor();
7373
* l2_promotion, oob_fill)
7474
*
7575
*/
76-
const Op &create_tma_im2col_descriptor();
76+
TVM_DLL const Op &create_tma_im2col_descriptor();
7777

7878
/*!
7979
* \brief Create a list of mbarrier with num_threads
8080
*
8181
* create_list_of_mbarrier(num_threads0, num_threads1, ...)
8282
*
8383
*/
84-
const Op &create_list_of_mbarrier();
84+
TVM_DLL const Op &create_list_of_mbarrier();
8585

8686
/*!
8787
* \brief Get the mbarrier with barrier_id
8888
*
8989
* int64_t* GetMBarrier(barrier_id)
9090
*
9191
*/
92-
const Op &get_mbarrier();
92+
TVM_DLL const Op &get_mbarrier();
9393

9494
/*!
9595
* \brief tvm intrinsics for loading data from global tensor descriptor to
@@ -98,7 +98,7 @@ const Op &get_mbarrier();
9898
* tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
9999
*
100100
*/
101-
const Op &tma_load();
101+
TVM_DLL const Op &tma_load();
102102

103103
/*!
104104
* \brief tvm intrinsics for loading image from global tensor to columns in
@@ -108,7 +108,7 @@ const Op &tma_load();
108108
* image_offset, ...)
109109
*
110110
*/
111-
const Op &tma_load_im2col();
111+
TVM_DLL const Op &tma_load_im2col();
112112

113113
/*!
114114
* \brief tvm intrinsics for storing data from shared memory to global tensor
@@ -117,119 +117,119 @@ const Op &tma_load_im2col();
117117
* tma_store(descriptor, smem_data, coord_0, coord_1, ...)
118118
*
119119
*/
120-
const Op &tma_store();
120+
TVM_DLL const Op &tma_store();
121121

122122
/*!
123123
* \brief tvm intrinsics for mbarrier wait with parity bit
124124
*
125125
* mbarrier_wait_parity(mbarrier, parity)
126126
*
127127
*/
128-
const Op &mbarrier_wait_parity();
128+
TVM_DLL const Op &mbarrier_wait_parity();
129129

130130
/*!
131131
* \brief tvm intrinsics for mbarrier expect tx
132132
*
133133
* mbarrier_expect_tx(mbarrier, transaction_bytes)
134134
*
135135
*/
136-
const Op &mbarrier_expect_tx();
136+
TVM_DLL const Op &mbarrier_expect_tx();
137137

138138
/*!
139139
* \brief tvm intrinsics for ldmatrix
140140
*
141141
* ptx_ldmatirx(transposed, num, shared_addr, local_addr)
142142
*
143143
*/
144-
const Op &ptx_ldmatirx();
144+
TVM_DLL const Op &ptx_ldmatirx();
145145

146146
/*!
147147
* \brief tvm intrinsics for stmatrix
148148
*
149149
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
150150
*
151151
*/
152-
const Op &ptx_stmatirx();
152+
TVM_DLL const Op &ptx_stmatrix();
153153

154154
/*!
155155
* \brief Pack two b16 value into a b32 value
156156
*
157157
* int32 pack_b16(b16_value, b16_value)
158158
*
159159
*/
160-
const Op &pack_b16();
160+
TVM_DLL const Op &pack_b16();
161161

162162
/*!
163163
* \brief Similar to __syncthreads(), but can be used to sync partial threads
164164
*
165165
* sync_thread_partial(num_partial_threads or mbarrier)
166166
*
167167
*/
168-
const Op &sync_thread_partial();
168+
TVM_DLL const Op &sync_thread_partial();
169169

170170
/*!
171171
* \brief Issue a shared memory fence for async operations
172172
*
173173
* FenceProxyAsync()
174174
*
175175
*/
176-
const Op &fence_proxy_async();
176+
TVM_DLL const Op &fence_proxy_async();
177177

178178
/*!
179179
* \brief Indicate arrival of warp issuing TMA_STORE
180180
*
181181
* tma_store_arrive()
182182
*
183183
*/
184-
const Op &tma_store_arrive();
184+
TVM_DLL const Op &tma_store_arrive();
185185

186186
/*!
187187
* \brief Wait for TMA_STORE to finish
188188
*
189189
* tma_store_wait()
190190
*
191191
*/
192-
const Op &tma_store_wait();
192+
TVM_DLL const Op &tma_store_wait();
193193

194194
/*!
195195
* \brief Set reg hint for warp-specialized branched
196196
*
197197
* SetMaxNRegInc(num_reg, is_inc)
198198
*
199199
*/
200-
const Op &set_max_nreg();
200+
TVM_DLL const Op &set_max_nreg();
201201

202202
/*!
203203
* \brief No set reg hint for warp-specialized branched
204204
*
205205
* no_set_max_nreg()
206206
*
207207
*/
208-
const Op &no_set_max_nreg();
208+
TVM_DLL const Op &no_set_max_nreg();
209209

210210
/*!
211211
* \brief Wait the previous wgmma to finish
212212
*
213213
* wait_wgmma(num_mma)
214214
*
215215
*/
216-
const Op &wait_wgmma();
216+
TVM_DLL const Op &wait_wgmma();
217217

218218
/*!
219219
* \brief Synchronize all threads in a grid
220220
*
221221
* sync_grid()
222222
*
223223
*/
224-
const Op &sync_grid();
224+
TVM_DLL const Op &sync_grid();
225225

226226
/*!
227227
* \brief tvm intrinsic for loop continue
228228
*
229229
* loop_break()
230230
*
231231
*/
232-
const Op &loop_break();
232+
TVM_DLL const Op &loop_break();
233233

234234
/*!
235235
* \brief tvm intrinsic for amd matrix core mfma instructions.

src/op/elem.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
302302
num = 2;
303303

304304
Array<PrimExpr> args;
305-
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx();
305+
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatrix();
306306
args.push_back(static_cast<int>(is_transposed));
307307
args.push_back(num);
308308

0 commit comments

Comments
 (0)