Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/target/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ ::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod,
}

std::string code = cg.Finish();
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_c_host_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}

Expand Down
28 changes: 19 additions & 9 deletions src/transform/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,11 +599,14 @@ void ArgBinder::BindDLTensors(
break;
}

// The "real" runtime shape value read from DLTensor
PrimExpr shape_val =
// The "real" runtime shape value read from DLTensor.
// Guard the load with `is_null` to avoid dereferencing NULL handles.
PrimExpr raw_shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(buf_shape,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
PrimExpr shape_val = tvm::if_then_else(
Not(is_null), raw_shape_val, make_const(raw_shape_val.dtype(), 0));

// Check if this dimension is a symbolic variable
if (const VarNode *v = buffer->shape[k].as<VarNode>()) {
Expand Down Expand Up @@ -658,8 +661,8 @@ void ArgBinder::BindDLTensors(
}
Buffer src_shape_buf = it_buf->second;

// Construct the shape load
PrimExpr src_shape_val =
// Construct the shape load and guard it if the source may be NULL
PrimExpr src_raw_shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(src_shape_buf,
{IntImm(DataType::Int(32),
Expand All @@ -671,18 +674,25 @@ void ArgBinder::BindDLTensors(
if (is_first_source) {
// Base case: use this shape value directly (we know at least
// one is non-null from assert)
cascaded_value = src_shape_val;
if (src_is_used) {
cascaded_value = src_raw_shape_val;
} else {
Var src_is_null = is_null_map[src.buf_name];
cascaded_value = tvm::if_then_else(
Not(src_is_null), src_raw_shape_val,
make_const(src_raw_shape_val.dtype(), 0));
}
is_first_source = false;
} else {
// if !is_null then use this shape, else use previous cascaded
// value But if buffer is used (non-nullable), always use its
// shape
if (src_is_used) {
cascaded_value = src_shape_val;
cascaded_value = src_raw_shape_val;
} else {
Var src_is_null = is_null_map[src.buf_name];
cascaded_value = tvm::if_then_else(
Not(src_is_null), src_shape_val, cascaded_value);
Not(src_is_null), src_raw_shape_val, cascaded_value);
}
}
}
Expand All @@ -694,8 +704,8 @@ void ArgBinder::BindDLTensors(
init_nest_.emplace_back(
LetStmt(v_arg, cascaded_value, Evaluate(0)));
} else {
// Single source or no special handling needed, use the original
// nullable binding
// Single source or no special handling needed, use nullable
// binding. When the only source is NULL, bind m to 0 safely.
BindNullable(buffer->shape[k], shape_val, shape_element_name(k),
true, is_null);
}
Expand Down
31 changes: 31 additions & 0 deletions testing/python/transform/test_nullable_buffer_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,36 @@ def test_kernel(
return True


def test_nullable_single_source_shape():
"""Test that a single buffer with a symbolic shape var must be non-null.

This guards against the previous segfault when binding m from x.shape[0]
with x == None.
"""

@tilelang.jit
def get_kernel():
m = T.dynamic("m")

@T.prim_func
def sample_kernel(x: T.Tensor[(m,), T.int32]):
with T.Kernel(1, threads=1):
tx = T.get_thread_binding()
if tx == 0:
T.print(m)

return sample_kernel

m = 16
kernel = get_kernel()

# Provide a valid tensor: should run
x = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
kernel(x)

# Passing None should not segfault; m binds to 0 and kernel is a no-op
kernel(None)


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 1 addition & 1 deletion tilelang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _load_tile_lang_lib():
from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401

from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401
from .engine import lower, register_cuda_postproc, register_hip_postproc, register_c_postproc # noqa: F401

from .math import * # noqa: F403

Expand Down
6 changes: 5 additions & 1 deletion tilelang/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .lower import lower, is_device_call # noqa: F401
from .param import KernelParam # noqa: F401
from .callback import register_cuda_postproc, register_hip_postproc # noqa: F401
from .callback import (
register_cuda_postproc, # noqa: F401
register_hip_postproc, # noqa: F401
register_c_postproc, # noqa: F401
)
48 changes: 48 additions & 0 deletions tilelang/engine/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override)


def register_c_postproc(func: Callable[[str, Target], str], override: bool = True):
"""Register a post-processing function for C host code generation.

This callback intercepts C host code emitted by TileLang just before it
is wrapped into a CSourceModule. It should take the generated code string
and the `Target` as inputs, and return the (possibly) modified code.

Args:
func: A callable that takes generated code (str) and target (Target) as input,
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
tvm_ffi.register_global_func("tilelang_callback_c_host_postproc", f=func, override=override)


def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering CUDA post-processing callback function.

Expand Down Expand Up @@ -90,3 +105,36 @@ def _register(fn: Callable[[str, Target], str]):
return _register

raise TypeError("Invalid decorator usage")


def register_c_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering C host post-processing callback function.

Can be used with or without parentheses:
@register_c_postproc_callback
def func(code, target): ...

@register_c_postproc_callback()
def func(code, target): ...

@register_c_postproc_callback(override=False)
def func(code, target): ...

Args:
func: The function to be decorated or a boolean override flag
override: Whether to override existing registered function. Defaults to True.
"""
if callable(func):
register_c_postproc(func, override)
return func

if func is None or isinstance(func, bool):
_override = func if isinstance(func, bool) else override

def _register(fn: Callable[[str, Target], str]):
register_c_postproc(fn, _override)
return fn

return _register

raise TypeError("Invalid decorator usage")
Loading