diff --git a/src/target/codegen_c_host.cc b/src/target/codegen_c_host.cc index 4f1a70cef..2c1873467 100644 --- a/src/target/codegen_c_host.cc +++ b/src/target/codegen_c_host.cc @@ -495,6 +495,10 @@ ::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod, } std::string code = cg.Finish(); + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_c_host_postproc")) { + code = (*f)(code, target).cast(); + } return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 0e94f3061..73500e176 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -599,11 +599,14 @@ void ArgBinder::BindDLTensors( break; } - // The "real" runtime shape value read from DLTensor - PrimExpr shape_val = + // The "real" runtime shape value read from DLTensor. + // Guard the load with `is_null` to avoid dereferencing NULL handles. + PrimExpr raw_shape_val = cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), static_cast(k))})); + PrimExpr shape_val = tvm::if_then_else( + Not(is_null), raw_shape_val, make_const(raw_shape_val.dtype(), 0)); // Check if this dimension is a symbolic variable if (const VarNode *v = buffer->shape[k].as()) { @@ -658,8 +661,8 @@ void ArgBinder::BindDLTensors( } Buffer src_shape_buf = it_buf->second; - // Construct the shape load - PrimExpr src_shape_val = + // Construct the shape load and guard it if the source may be NULL + PrimExpr src_raw_shape_val = cast(buffer->shape[k].dtype(), BufferLoad(src_shape_buf, {IntImm(DataType::Int(32), @@ -671,18 +674,25 @@ void ArgBinder::BindDLTensors( if (is_first_source) { // Base case: use this shape value directly (we know at least // one is non-null from assert) - cascaded_value = src_shape_val; + if (src_is_used) { + cascaded_value = src_raw_shape_val; + } else { + Var src_is_null = is_null_map[src.buf_name]; + cascaded_value = tvm::if_then_else( + Not(src_is_null), src_raw_shape_val, + make_const(src_raw_shape_val.dtype(), 0)); + } is_first_source = false; } else { // if !is_null then use this shape, else use previous cascaded // value But if buffer is used (non-nullable), always use its // shape if (src_is_used) { - cascaded_value = src_shape_val; + cascaded_value = src_raw_shape_val; } else { Var src_is_null = is_null_map[src.buf_name]; cascaded_value = tvm::if_then_else( - Not(src_is_null), src_shape_val, cascaded_value); + Not(src_is_null), src_raw_shape_val, cascaded_value); } } } @@ -694,8 +704,8 @@ void ArgBinder::BindDLTensors( init_nest_.emplace_back( LetStmt(v_arg, cascaded_value, Evaluate(0))); } else { - // Single source or no special handling needed, use the original - // nullable binding + // Single source or no special handling needed, use nullable + // binding. When the only source is NULL, bind m to 0 safely. BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true, is_null); } diff --git a/testing/python/transform/test_nullable_buffer_params.py b/testing/python/transform/test_nullable_buffer_params.py index 5bbde254b..e02c8125a 100644 --- a/testing/python/transform/test_nullable_buffer_params.py +++ b/testing/python/transform/test_nullable_buffer_params.py @@ -69,5 +69,36 @@ def test_kernel( return True +def test_nullable_single_source_shape(): + """Test that a single buffer with a symbolic shape var must be non-null. + + This guards against the previous segfault when binding m from x.shape[0] + with x == None. + """ + + @tilelang.jit + def get_kernel(): + m = T.dynamic("m") + + @T.prim_func + def sample_kernel(x: T.Tensor[(m,), T.int32]): + with T.Kernel(1, threads=1): + tx = T.get_thread_binding() + if tx == 0: + T.print(m) + + return sample_kernel + + m = 16 + kernel = get_kernel() + + # Provide a valid tensor: should run + x = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32) + kernel(x) + + # Passing None should not segfault; m binds to 0 and kernel is a no-op + kernel(None) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 87176b209..b5a436697 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -145,7 +145,7 @@ def _load_tile_lang_lib(): from .autotuner import autotune # noqa: F401 from .transform import PassConfigKey # noqa: F401 -from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401 +from .engine import lower, register_cuda_postproc, register_hip_postproc, register_c_postproc # noqa: F401 from .math import * # noqa: F403 diff --git a/tilelang/engine/__init__.py b/tilelang/engine/__init__.py index 476b40a35..b7cd7eb23 100644 --- a/tilelang/engine/__init__.py +++ b/tilelang/engine/__init__.py @@ -1,3 +1,7 @@ from .lower import lower, is_device_call # noqa: F401 from .param import KernelParam # noqa: F401 -from .callback import register_cuda_postproc, register_hip_postproc # noqa: F401 +from .callback import ( + register_cuda_postproc, # noqa: F401 + register_hip_postproc, # noqa: F401 + register_c_postproc, # noqa: F401 +) diff --git a/tilelang/engine/callback.py b/tilelang/engine/callback.py index 05fafe9db..d65f1eb2b 100644 --- a/tilelang/engine/callback.py +++ b/tilelang/engine/callback.py @@ -26,6 +26,21 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override) +def register_c_postproc(func: Callable[[str, Target], str], override: bool = True): + """Register a post-processing function for C host code generation. + + This callback intercepts C host code emitted by TileLang just before it + is wrapped into a CSourceModule. It should take the generated code string + and the `Target` as inputs, and return the (possibly) modified code. + + Args: + func: A callable that takes generated code (str) and target (Target) as input, + and returns the processed code (str). + override: Whether to override existing registered function. Defaults to True. + """ + tvm_ffi.register_global_func("tilelang_callback_c_host_postproc", f=func, override=override) + + def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True): """Decorator for registering CUDA post-processing callback function. @@ -90,3 +105,36 @@ def _register(fn: Callable[[str, Target], str]): return _register raise TypeError("Invalid decorator usage") + + +def register_c_postproc_callback(func: Callable | bool = None, override: bool = True): + """Decorator for registering C host post-processing callback function. + + Can be used with or without parentheses: + @register_c_postproc_callback + def func(code, target): ... + + @register_c_postproc_callback() + def func(code, target): ... + + @register_c_postproc_callback(override=False) + def func(code, target): ... + + Args: + func: The function to be decorated or a boolean override flag + override: Whether to override existing registered function. Defaults to True. + """ + if callable(func): + register_c_postproc(func, override) + return func + + if func is None or isinstance(func, bool): + _override = func if isinstance(func, bool) else override + + def _register(fn: Callable[[str, Target], str]): + register_c_postproc(fn, _override) + return fn + + return _register + + raise TypeError("Invalid decorator usage")