Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/tl_templates/cuda/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,10 @@ TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
assert(0);
}
}

// Specialization for msg-only debug print
__device__ void debug_print_msg(const char *msg) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d)\n", msg,
blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z);
Comment on lines +130 to +134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n src/tl_templates/cuda/debug.h | sed -n '120,140p'

Repository: tile-ai/tilelang

Length of output: 647


🏁 Script executed:

rg -n "define TL_DEVICE" src/tl_templates/

Repository: tile-ai/tilelang

Length of output: 464


🏁 Script executed:

cat -n src/tl_templates/cuda/debug.h | head -50

Repository: tile-ai/tilelang

Length of output: 2770


🏁 Script executed:

find . -name "*.h" -o -name "*.hip" | xargs grep -l "debug.h\|debug_print" | head -20

Repository: tile-ai/tilelang

Length of output: 122


🏁 Script executed:

cat -n src/tl_templates/hip/debug.h | sed -n '120,140p'

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

rg -n "debug_print_msg" src/tl_templates/

Repository: tile-ai/tilelang

Length of output: 228


🏁 Script executed:

cat -n src/tl_templates/hip/debug.h | sed -n '95,115p'

Repository: tile-ai/tilelang

Length of output: 716


🏁 Script executed:

cat -n src/tl_templates/hip/debug.h | sed -n '1,20p'

Repository: tile-ai/tilelang

Length of output: 1014


🏁 Script executed:

rg -n "includes.*debug.h" src/

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

git ls-files | xargs grep -l "debug.h" | head -20

Repository: tile-ai/tilelang

Length of output: 299


🏁 Script executed:

cat -n src/target/codegen_cuda.cc | grep -A 5 -B 5 "debug.h"

Repository: tile-ai/tilelang

Length of output: 624


🏁 Script executed:

cat -n src/target/codegen_hip.cc | grep -A 5 -B 5 "debug.h"

Repository: tile-ai/tilelang

Length of output: 613


🏁 Script executed:

cd src/tl_templates && find . -name "*.h" | xargs grep -l "__device__.*{" | head -10

Repository: tile-ai/tilelang

Length of output: 189


🏁 Script executed:

rg -n "static __device__|inline.*__device__|__device__.*inline" src/tl_templates/ | head -20

Repository: tile-ai/tilelang

Length of output: 2313


Use inline or static inline with debug_print_msg for consistency and to avoid potential device function linkage issues.

The function debug_print_msg is defined as a non-inline, non-template __device__ function in a header file. Other device utility functions in the same file (device_assert, device_assert_with_msg) use TL_DEVICE which includes __forceinline__, but debug_print_msg lacks this. This inconsistency and the lack of inline annotation can lead to multiple definition issues when the header is included across multiple compilation units. Apply the same inlining pattern as adjacent functions.

🔧 Proposed fix
-__device__ void debug_print_msg(const char *msg) {
+static __device__ inline void debug_print_msg(const char *msg) {
   printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d)\n", msg,
          blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
          threadIdx.z);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Specialization for msg-only debug print
__device__ void debug_print_msg(const char *msg) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d)\n", msg,
blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z);
// Specialization for msg-only debug print
static __device__ inline void debug_print_msg(const char *msg) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d)\n", msg,
blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z);
}
🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/debug.h` around lines 130 - 134, The header-defined
device function debug_print_msg should be marked inline like the other utilities
to avoid multiple-definition/linkage issues; change its declaration to use the
same inlining macro (e.g., TL_DEVICE or static inline with
__device__/__forceinline__) as device_assert/device_assert_with_msg so the
function is emitted inline across compilation units while keeping the same
signature debug_print_msg(const char *msg).

}
7 changes: 7 additions & 0 deletions src/tl_templates/hip/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,10 @@ __device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
int index, T var) {
PrintTraits<T>::print_buffer(msg, buf_name, index, var);
}

// Specialization for msg-only debug print
__device__ void debug_print_msg(const char *msg) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d)\n", msg,
blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z);
}
10 changes: 7 additions & 3 deletions testing/python/debug/test_tilelang_debug_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,27 @@ def test_debug_print_register_files():
debug_print_register_files(16, 16)


def debug_print_msg(M=16, N=16):
def debug_print_msg(M=16, N=16, msg_only=False):
dtype = T.float16

@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
tid = T.get_thread_binding()
if tid == 0:
T.print(bx + by + bz, msg="hello world")
if msg_only:
T.print(msg="hello world")
Comment on lines 100 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silence unused Q to keep Ruff ARG001 green.

Ruff reports Q unused in this program definition. If lint is enforced, this will fail. Consider a no-op use to preserve buffer naming.

🔧 Proposed fix
     `@T.prim_func`
     def program(Q: T.Tensor((M, N), dtype)):
+        _ = Q
         with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
             tid = T.get_thread_binding()
             if tid == 0:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
tid = T.get_thread_binding()
if tid == 0:
T.print(bx + by + bz, msg="hello world")
if msg_only:
T.print(msg="hello world")
`@T.prim_func`
def program(Q: T.Tensor((M, N), dtype)):
_ = Q
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
tid = T.get_thread_binding()
if tid == 0:
if msg_only:
T.print(msg="hello world")
🧰 Tools
🪛 Ruff (0.14.14)

101-101: Unused function argument: Q

(ARG001)

🤖 Prompt for AI Agents
In `@testing/python/debug/test_tilelang_debug_print.py` around lines 100 - 106,
The parameter Q in the T.prim_func named program is unused and triggers Ruff
ARG001; add a no-op use of Q inside program (e.g., a harmless read or evaluate
of Q or its shape) so the symbol is referenced without changing behavior—place
this no-op before or inside the Kernel block in function program to preserve
buffer naming and silence the lint.

else:
T.print(bx + by + bz, msg="hello world")

jit_kernel = tilelang.compile(program)
profiler = jit_kernel.get_profiler()
profiler.run_once()


def test_debug_print_msg():
debug_print_msg(16, 16)
debug_print_msg(16, 16, msg_only=True)
debug_print_msg(16, 16, msg_only=False)


if __name__ == "__main__":
Expand Down
19 changes: 16 additions & 3 deletions tilelang/language/print_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Bu
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])


@macro
def print_msg(msg: str) -> tir.PrimExpr:
"""
Prints a message string.
"""
assert isinstance(msg, str), "msg must be a string"
assert msg != "", "msg must not be empty"
tir.call_extern("handle", "debug_print_msg", msg)


@macro
def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
"""
Expand Down Expand Up @@ -150,15 +160,15 @@ def device_assert(condition: tir.PrimExpr, msg: str = "", no_stack_info=False):
T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, get_stack_str(msg, stacklevel=2))


def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
def print(obj: Any = None, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
"""
A generic print function that handles both TIR buffers and primitive expressions.

- If the input is a TIR buffer, it prints its values, but only on the first thread (tx=0, ty=0, tz=0).
- If the input is a TIR primitive expression, it prints its value directly.

Parameters:
obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr.
obj (Any): The object to print. It can be either a tir.Buffer, tir.PrimExpr, or None (for msg-only print).
msg (str): An optional message to include in the print statement.
warp_group_id (int): The warp group id to print.
warp_id (int): The warp id to print.
Expand Down Expand Up @@ -227,6 +237,9 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
# Directly print primitive expressions.
return print_var(obj, msg)

elif obj is None:
return print_msg(msg)

else:
# Unsupported object type.
raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.")
raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer, tir.PrimExpr, and None.")
Loading