Skip to content

[Host] Provide post process to customize host code and enhance nullable check#1562

Merged
LeiWang1999 merged 2 commits intotile-ai:mainfrom
LeiWang1999:host_1229
Dec 29, 2025
Merged

[Host] Provide post process to customize host code and enhance nullable check#1562
LeiWang1999 merged 2 commits intotile-ai:mainfrom
LeiWang1999:host_1229

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 29, 2025

This pull request introduces enhanced support for post-processing C host code generation in TileLang, improves the safety of buffer shape binding for nullable tensors, and adds corresponding tests. The most important changes are grouped below:

C Host Code Generation Post-Processing

  • Added a new callback registration function, register_c_postproc, and its decorator variant, register_c_postproc_callback, to allow users to intercept and modify C host code emitted by TileLang before it is wrapped into a CSourceModule. This includes updates to tilelang/engine/callback.py, tilelang/engine/__init__.py, and tilelang/__init__.py to expose the new API. [1] [2] [3] [4]
  • Integrated the C host post-processing callback into the TVM codegen pipeline: after generating code in BuildTileLangCHost, the callback is invoked if registered, allowing custom modification of the emitted code.

Buffer Shape Binding Safety

  • Improved the logic in ArgBinder::BindDLTensors to safely handle nullable buffers: shape loads are now guarded to avoid dereferencing null handles, and symbolic shape variables bind to zero if the source buffer is null, preventing segfaults. [1] [2] [3] [4]

Testing

  • Added a new test, test_nullable_single_source_shape, to ensure that binding a symbolic shape variable from a nullable buffer does not cause segfaults and safely binds to zero when the buffer is None.

Summary by CodeRabbit

  • New Features

    • Introduced C code post-processing callback registration system, enabling custom transformation of generated C code
    • Enhanced NULL safety handling for nullable buffer parameters with dynamic shape dimensions
  • Tests

    • Added test coverage for NULL inputs with symbolic shape variables to prevent runtime errors

✏️ Tip: You can customize this high-level summary in your review settings.

…e buffer handling

- Introduced a new callback for C host code generation that allows for post-processing of generated code before it is wrapped into a CSourceModule.
- Enhanced the ArgBinder to log shape variable sources and ensure safe binding of symbolic shape variables, preventing potential segmentation faults when dealing with nullable buffers.
- Added a regression test to verify that a single buffer with a symbolic shape variable must be non-null, ensuring robustness against null inputs.
… generation and ArgBinder

- Adjusted indentation in the C host code generation to enhance clarity.
- Reformatted comments and code structure in ArgBinder for better readability, ensuring consistent style throughout the file.
- Minor whitespace adjustments to maintain code consistency.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 29, 2025

📝 Walkthrough

Walkthrough

The PR introduces a C host post-processing callback mechanism and adds runtime null-safety guards for shape variable binding. It exposes a new public API for registering C post-processing functions and includes a test validating null-safe shape binding for single-buffer nullable scenarios.

Changes

Cohort / File(s) Summary
C Post-Processing Callback Registration
tilelang/engine/callback.py, tilelang/engine/__init__.py, tilelang/__init__.py
Added register_c_postproc() and register_c_postproc_callback() decorator to register C host post-processing functions via TVM global functions. Exported new symbols through module hierarchy.
C Code Generation Enhancement
src/target/codegen_c_host.cc
Introduced conditional post-processing hook in BuildTileLangCHost that looks up and invokes a registered callback (tilelang_callback_c_host_postproc) after code generation, allowing external modification of generated C code.
Shape Binding Runtime Guards
src/transform/arg_binder.cc
Added null-safety guards for shape-related loads using conditional expressions (Not(is_null)) to safely fallback to zero when source buffers are null, applied across multiple binding paths (single-source, multi-source, cascaded).
Shape Binding Tests
testing/python/transform/test_nullable_buffer_params.py
Added test_nullable_single_source_shape() to verify null-safe behavior when binding symbolic shapes from nullable single-buffer inputs.

Sequence Diagram(s)

sequenceDiagram
    participant CodeGen as BuildTileLangCHost
    participant Registry as TVM Global<br/>Function Registry
    participant Callback as User Callback<br/>Function
    
    CodeGen->>CodeGen: Generate C code
    CodeGen->>Registry: Lookup tilelang_callback_c_host_postproc
    alt Callback registered
        Registry-->>CodeGen: Return callback function
        CodeGen->>Callback: Invoke(generated_code, target)
        Callback-->>CodeGen: Return modified code
        CodeGen->>CodeGen: Use post-processed code
    else Callback not found
        Registry-->>CodeGen: Return null/not found
        CodeGen->>CodeGen: Use original code (fallback)
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Possibly related PRs

Poem

🐰 With guards of null and callbacks bright,
The shapes now bind both safe and right,
C code hopping through post-process lanes,
No segfaults scattered 'cross the plains! 🐿️

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the two main objectives: adding post-processing customization for host code and enhancing nullable buffer checks.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tilelang/engine/callback.py (1)

110-140: Decorator implementation is functionally correct.

The decorator properly handles all three invocation patterns and mirrors the existing CUDA/HIP decorator implementations, ensuring API consistency.

📝 Optional type hint improvement

For stricter PEP 484 compliance, consider using Optional explicitly:

-def register_c_postproc_callback(func: Callable | bool = None, override: bool = True):
+def register_c_postproc_callback(func: Optional[Callable] | bool = None, override: bool = True):

Add the import at the top:

 from __future__ import annotations
-from typing import Callable
+from typing import Callable, Optional

This makes the nullable parameter more explicit, though the current implementation works correctly.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 193eff1 and 902ed76.

📒 Files selected for processing (6)
  • src/target/codegen_c_host.cc
  • src/transform/arg_binder.cc
  • testing/python/transform/test_nullable_buffer_params.py
  • tilelang/__init__.py
  • tilelang/engine/__init__.py
  • tilelang/engine/callback.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • src/target/codegen_c_host.cc
  • tilelang/engine/__init__.py
  • tilelang/__init__.py
  • tilelang/engine/callback.py
🧬 Code graph analysis (3)
testing/python/transform/test_nullable_buffer_params.py (1)
tilelang/language/symbolics.py (1)
  • dynamic (10-21)
tilelang/engine/__init__.py (1)
tilelang/engine/callback.py (3)
  • register_cuda_postproc (7-15)
  • register_hip_postproc (18-26)
  • register_c_postproc (29-41)
tilelang/__init__.py (1)
tilelang/engine/callback.py (3)
  • register_cuda_postproc (7-15)
  • register_hip_postproc (18-26)
  • register_c_postproc (29-41)
🪛 Ruff (0.14.10)
testing/python/transform/test_nullable_buffer_params.py

84-84: Unused function argument: x

(ARG001)

tilelang/engine/__init__.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


5-5: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/__init__.py

148-148: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/engine/callback.py

110-110: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


140-140: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (8)
tilelang/__init__.py (1)

148-148: LGTM! Consistent API extension for C host post-processing.

The addition of register_c_postproc follows the established pattern for CUDA and HIP post-processing hooks, providing a uniform interface for code generation customization across different backends.

Note: The static analysis warning about the noqa directive is a false positive—F401 suppression is necessary for re-exported symbols.

tilelang/engine/__init__.py (1)

3-7: LGTM! Proper module-level API exposure.

The multi-line import correctly groups all post-processing registration functions together and exposes register_c_postproc alongside the existing CUDA and HIP variants.

Note: The static analysis warnings about noqa directives are false positives—these symbols are intentionally re-exported.

src/transform/arg_binder.cc (3)

602-609: Excellent null-safety improvement for shape loading.

The guarded load prevents segmentation faults by checking is_null before dereferencing the buffer handle. When the buffer is null, the shape dimension safely defaults to zero, allowing downstream code to handle the null case gracefully.


664-696: Well-designed cascaded binding for multi-source symbolic shapes.

The logic correctly builds a cascaded if_then_else expression in reverse order, binding the symbolic variable to the shape from the first non-null buffer. The handling distinguishes between:

  • Used buffers (non-nullable): directly use their raw shape value
  • Nullable buffers: guard with is_null checks and cascade fallback

This ensures correctness when multiple buffers share a symbolic dimension, relying on the earlier assertion (lines 644-645) that at least one source is non-null.


707-708: Clear documentation of single-source nullable behavior.

The updated comment accurately reflects the safer binding semantics: when the only source buffer is null, the symbolic shape variable binds to zero, preventing segfaults while allowing the kernel to execute (likely as a no-op for zero-sized dimensions).

tilelang/engine/callback.py (1)

29-41: LGTM! Consistent implementation of C host post-processing registration.

The function follows the established pattern for CUDA and HIP post-processing, registering the global function tilelang_callback_c_host_postproc that will be invoked by the code generator.

src/target/codegen_c_host.cc (1)

498-501: Perfect integration point for post-processing hook.

The hook is invoked at the optimal stage—after code generation but before wrapping into CSourceModule. The optional pattern (only invoking if registered) ensures backward compatibility, and the implementation correctly passes both the generated code and target to the callback.

testing/python/transform/test_nullable_buffer_params.py (1)

72-101: Excellent regression test for nullable single-source shape binding.

The test correctly validates the null-safety improvements in arg_binder.cc:

  1. First invocation with a valid tensor confirms normal operation
  2. Second invocation with None ensures no segfault occurs and that the symbolic variable m safely binds to 0 (as implemented in the guarded shape loading)

Note: The static analysis warning about unused parameter x at line 84 is expected—the parameter must be declared to test the nullable buffer path, even though the kernel body doesn't use it.

@LeiWang1999 LeiWang1999 merged commit 27db71f into tile-ai:main Dec 29, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant