diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index 40d364bc9..c832a5e5e 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -126,3 +126,10 @@ TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) { assert(0); } } + +// Specialization for msg-only debug print +__device__ void debug_print_msg(const char *msg) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d)\n", msg, + blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z); +} diff --git a/src/tl_templates/hip/debug.h b/src/tl_templates/hip/debug.h index 7eb3736c2..309b8fd99 100644 --- a/src/tl_templates/hip/debug.h +++ b/src/tl_templates/hip/debug.h @@ -99,3 +99,10 @@ __device__ void debug_print_buffer_value(const char *msg, const char *buf_name, int index, T var) { PrintTraits::print_buffer(msg, buf_name, index, var); } + +// Specialization for msg-only debug print +__device__ void debug_print_msg(const char *msg) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d)\n", msg, + blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z); +} diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index 23c0f4d92..3025ddfb9 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -94,7 +94,7 @@ def test_debug_print_register_files(): debug_print_register_files(16, 16) -def debug_print_msg(M=16, N=16): +def debug_print_msg(M=16, N=16, msg_only=False): dtype = T.float16 @T.prim_func @@ -102,7 +102,10 @@ def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): tid = T.get_thread_binding() if tid == 0: - T.print(bx + by + bz, msg="hello world") + if msg_only: + T.print(msg="hello world") + else: + T.print(bx + by + bz, msg="hello world") jit_kernel = tilelang.compile(program) profiler = jit_kernel.get_profiler() @@ -110,7 +113,8 @@ def program(Q: T.Tensor((M, N), dtype)): def test_debug_print_msg(): - debug_print_msg(16, 16) + debug_print_msg(16, 16, msg_only=True) + debug_print_msg(16, 16, msg_only=False) if __name__ == "__main__": diff --git a/tilelang/language/print_op.py b/tilelang/language/print_op.py index c8c7be81f..f1149c1e6 100644 --- a/tilelang/language/print_op.py +++ b/tilelang/language/print_op.py @@ -98,6 +98,16 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Bu tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords]) +@macro +def print_msg(msg: str) -> tir.PrimExpr: + """ + Prints a message string. + """ + assert isinstance(msg, str), "msg must be a string" + assert msg != "", "msg must not be empty" + tir.call_extern("handle", "debug_print_msg", msg) + + @macro def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ @@ -150,7 +160,7 @@ def device_assert(condition: tir.PrimExpr, msg: str = "", no_stack_info=False): T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, get_stack_str(msg, stacklevel=2)) -def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr: +def print(obj: Any = None, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr: """ A generic print function that handles both TIR buffers and primitive expressions. @@ -158,7 +168,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> - If the input is a TIR primitive expression, it prints its value directly. Parameters: - obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr. + obj (Any): The object to print. It can be either a tir.Buffer, tir.PrimExpr, or None (for msg-only print). msg (str): An optional message to include in the print statement. warp_group_id (int): The warp group id to print. warp_id (int): The warp id to print. @@ -227,6 +237,9 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> # Directly print primitive expressions. return print_var(obj, msg) + elif obj is None: + return print_msg(msg) + else: # Unsupported object type. - raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.") + raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer, tir.PrimExpr, and None.")