diff --git a/.gitignore b/.gitignore index 752f6cb76..730398dfc 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,6 @@ cmake-build-*/ # pre-commit cache .pre-commit-cache/* + +# host checks logs +maint/host_checks/logs/* diff --git a/docs/compiler_internals/tensor_checks.md b/docs/compiler_internals/tensor_checks.md new file mode 100644 index 000000000..b4d2a0b3c --- /dev/null +++ b/docs/compiler_internals/tensor_checks.md @@ -0,0 +1,387 @@ +# Tensor Checks (Host-Side Auto-Validation) + +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. + +## Why Host-Side Checks +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. +- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. + +## How To Inspect Host Source +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: + +```python +print(matmul_relu_kernel.get_host_source()) +``` + +--- + +## What The Host Checks + +### 1) Argument count and pointer kind +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. +- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. + +### 2) Tensor checks (per tensor, after nullability decision) +- Nullability + - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. +- Rank (`ndim`) + - Runtime `ndim` must equal the compile-time rank. +- Data type (`dtype`) + - Match the triple `(code, bits, lanes)` with tolerance: + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. +- Shape + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. + - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. +- Strides + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). +- `byte_offset` + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. +- Device info + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. + - When multiple tensors participate, assert that `device_id` matches across them. +- Data pointer + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. + +### 3) Scalar checks +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. + +--- + +## Shapes and Symbolic Equations: Linear Solving +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: + +```python +@T.prim_func +def main( + A: T.Tensor((m,), dtype), + B: T.Tensor((m + n,), dtype), + C: T.Tensor((n * k,), dtype), +): + ... +``` + +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. + +--- + +## Nullability Rules and Examples +Which tensors may be NULL? + +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. +- Examples: + +1) Must be non-NULL (used) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + A[0] = 1 +``` +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. + +2) Still must be non-NULL (constant-true branch) +```python +some_cond: bool = True +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +3) Nullable (constant-false branch, statically unreachable) +```python +some_cond: bool = False +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +4) Must be non-NULL (runtime condition) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): + if some_cond: + A[0] = 1 +``` +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. + +--- + +## Device Type Codes (DLPack) +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. + +--- + +## Common Error Examples (What you’ll see) +- Argument count mismatch (num_args) + - Trigger: missing/extra argument + - Error: `: num_args should be N; expected: , got: N` + +- Pointer-typed argument expected + - Trigger: scalar passed where a tensor is expected + - Error: `: Expect arg[i] to be pointer` + +- Rank (ndim) mismatch + - Trigger: runtime rank differs from compile-time rank + - Error: `..ndim is expected to equal R, but got mismatched ndim` + +- Dtype mismatch + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set + - Error: `..dtype is expected to be , but got incompatible dtype` + +- Shape constraint violation + - Trigger: a dimension doesn’t match a constant/symbol binding + - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` + +- Strides check failed (e.g., non-contiguous layout) + - Trigger: transposed/sliced tensors that violate expected strides + - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` + +- Device type mismatch + - Trigger: calling a CUDA kernel with CPU tensors, etc. + - Error: `..device_type mismatch [expected: ()] ...` + +- Device id mismatch + - Trigger: mixing tensors from different GPUs + - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` + +- NULL data pointer + - Trigger: tensor required to be non-null has a NULL data pointer + - Error: `. is expected to have non-NULL data pointer, but got NULL` + +- Scalar type mismatch + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` + - Error: `: Expect arg[i] to be int/boolean` + +--- + +## Troubleshooting Tips +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. +- Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). + +--- + +## FAQ +- Can I disable the checks? + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. +- Is the overhead noticeable? + - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. + +--- + +## Reference Example (Matmul + ReLU) + +```python +@T.prim_func +def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), +): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + +# For debugging, print the host source +print(matmul_relu_kernel.get_host_source()) +``` + +The host will insert all checks described above for this example. + +--- + +## Quick Error Reference (Short List) +- Argument count + - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. +- Pointer kind + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. +- Rank (ndim) + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. +- Dtype + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. +- Shape + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. +- Strides + - Trigger: layout mismatch; Error: `strides[j] ... == `. +- Device type + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. +- Device id + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. +- Data pointer + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. +- Scalar types + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. + +--- + +## Host Error Troubleshooting (Minimal Repros) + +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: + +```python +# Convention: +# A: float16 [M, K] +# B: float16 [K, N] +# C: float16 [M, N] +# Target: CUDA (device_type=2) +fn = matmul_relu_kernel # your compiled function +M = N = K = 1024 +``` + +Adjust dtype/device if your kernel differs. + +### 0. Tip: print the host source +```python +print(fn.get_host_source()) +``` + +### 1. num_args mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +# Missing C +fn(A, B) +``` +Expected: `: num_args should be 3; expected: , got: 3`. + +Fix: pass all arguments per the signature. + +### 2. Expect pointer (tensor) but got scalar +```python +import torch + +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(1, B, C) +``` +Expected: `: Expect arg[0] to be pointer`. + +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). + +### 3. ndim mismatch +```python +import torch + +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. + +Fix: ensure runtime rank equals compiled rank. + +### 4. dtype mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. + +Fix: `A = A.to(torch.float16)` or create with the correct dtype. + +### 5. Shape constant/symbol mismatch +```python +import torch + +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. + +Fix: satisfy linear constraints and constants across tensors. + +### 6. Strides check failure (non-contiguous) +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +A_nc = A.t() # transpose -> non-contiguous +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A_nc, B, C) +``` +Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. + +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. + +### 7. device_type mismatch +```python +import torch + +A = torch.empty((M, K), device='cpu', dtype=torch.float16) +B = torch.empty((K, N), device='cpu', dtype=torch.float16) +C = torch.empty((M, N), device='cpu', dtype=torch.float16) +fn(A, B, C) # CUDA-targeted kernel +``` +Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. + +Fix: move tensors to the CUDA device. + +### 8. device_id mismatch (multi-GPU) +```python +import torch + +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. + +Fix: place all tensors on the same GPU (e.g., `cuda:0`). + +### 9. NULL data pointer (advanced) +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. + +Expected: `. is expected to have non-NULL data pointer, but got NULL`. + +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. + +### 10. Scalar type mismatch (int / bool) +```python +import tilelang.language as T + +@T.prim_func +def scalar_check(x: T.int32, flag: T.bool()): + T.evaluate(0) + +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean +``` + +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. + +--- + +## Closing Notes +- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. + diff --git a/docs/index.md b/docs/index.md index 5d9a158f8..9f7947766 100644 --- a/docs/index.md +++ b/docs/index.md @@ -42,6 +42,7 @@ deeplearning_operators/deepseek_mla compiler_internals/letstmt_inline compiler_internals/inject_fence_proxy +compiler_internals/tensor_checks ::: :::{toctree} diff --git a/examples/quickstart.py b/examples/quickstart.py index 46a39e0d9..39ad348b5 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -77,7 +77,7 @@ def matmul_relu_kernel( print("Kernel output matches PyTorch reference.") # 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() +# cuda_source = matmul_relu_kernel.get_kernel_source() # print("Generated CUDA kernel:\n", cuda_source) # 5.Profile latency with kernel diff --git a/maint/host_checks/01_num_args_mismatch.py b/maint/host_checks/01_num_args_mismatch.py new file mode 100644 index 000000000..8ba366463 --- /dev/null +++ b/maint/host_checks/01_num_args_mismatch.py @@ -0,0 +1,21 @@ +"""Reproduce: Argument count mismatch. + +Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. +Calling with the wrong number of inputs raises a ValueError before host entry. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + # Missing b + # Expected: ValueError with message about expected vs. actual inputs + fn(a) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/02_pointer_type_error.py b/maint/host_checks/02_pointer_type_error.py new file mode 100644 index 000000000..fd3585405 --- /dev/null +++ b/maint/host_checks/02_pointer_type_error.py @@ -0,0 +1,22 @@ +"""Reproduce: Pointer-type argument expected but scalar provided. + +We pass an integer for A; wrapper forwards it to the host where a pointer is expected. +Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # Wrong type for A (int instead of tensor) + a = 1 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/03_ndim_mismatch.py b/maint/host_checks/03_ndim_mismatch.py new file mode 100644 index 000000000..994ce23e8 --- /dev/null +++ b/maint/host_checks/03_ndim_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: ndim (rank) mismatch for A. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A has rank 3 instead of 2 + a = torch.empty((M, K, 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/04_dtype_mismatch.py b/maint/host_checks/04_dtype_mismatch.py new file mode 100644 index 000000000..6e6a0503e --- /dev/null +++ b/maint/host_checks/04_dtype_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: dtype mismatch for A (float32 vs expected float16). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + print(fn.get_host_source()) + + a = torch.empty((M, K), device="cuda", dtype=torch.float32) # should be float16 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/05_shape_mismatch.py b/maint/host_checks/05_shape_mismatch.py new file mode 100644 index 000000000..8b41ae36a --- /dev/null +++ b/maint/host_checks/05_shape_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: shape constant/symbol mismatch on A. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A's second dimension is wrong (K+1 instead of K) + a = torch.empty((M, K + 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/06_strides_mismatch.py b/maint/host_checks/06_strides_mismatch.py new file mode 100644 index 000000000..477d200bc --- /dev/null +++ b/maint/host_checks/06_strides_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: strides check failure (non-contiguous A via transpose). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + a_nc = a.t() # non-contiguous after transpose + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a_nc, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/07_device_type_mismatch.py b/maint/host_checks/07_device_type_mismatch.py new file mode 100644 index 000000000..67cb7718c --- /dev/null +++ b/maint/host_checks/07_device_type_mismatch.py @@ -0,0 +1,18 @@ +"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cpu", dtype=torch.float16) + b = torch.empty((K, N), device="cpu", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/08_device_id_mismatch.py b/maint/host_checks/08_device_id_mismatch.py new file mode 100644 index 000000000..649109661 --- /dev/null +++ b/maint/host_checks/08_device_id_mismatch.py @@ -0,0 +1,25 @@ +"""Reproduce: device_id mismatch (requires >=2 CUDA devices). +""" +import torch +from common import build_matmul_kernel + + +def main(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + if torch.cuda.device_count() < 2: + print("[SKIP] Need at least 2 CUDA devices to reproduce device_id mismatch.") + return + + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda:0", dtype=torch.float16) + b = torch.empty((K, N), device="cuda:1", dtype=torch.float16) + # Output device is derived by the adapter; mismatch occurs in host checks + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/09_null_data_pointer.py b/maint/host_checks/09_null_data_pointer.py new file mode 100644 index 000000000..00bac67dd --- /dev/null +++ b/maint/host_checks/09_null_data_pointer.py @@ -0,0 +1,25 @@ +"""Reproduce: NULL data pointer (advanced). + +Passing None for a tensor argument will be forwarded through the adapter. Depending on +FFI handling, this commonly triggers a pointer-type assertion (e.g., "Expect buffer to be pointer or tensor") +or a host-side non-NULL pointer check. + +Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script +demonstrates passing None, which still reproduces the intended class of failure. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = None # attempt to pass a null-like pointer + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/10_scalar_type_mismatch.py b/maint/host_checks/10_scalar_type_mismatch.py new file mode 100644 index 000000000..f1fcba274 --- /dev/null +++ b/maint/host_checks/10_scalar_type_mismatch.py @@ -0,0 +1,15 @@ +"""Reproduce: scalar parameter type mismatch (int/bool). +""" +from common import build_scalar_check_kernel + + +def main(): + fn = build_scalar_check_kernel(target="cuda") + + # Wrong types + fn(1.0, True) # x should be int -> Expect arg[0] to be int + fn(1, 2.5) # flag should be bool -> Expect arg[1] to be boolean + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/README.md b/maint/host_checks/README.md new file mode 100644 index 000000000..ac23d6fd2 --- /dev/null +++ b/maint/host_checks/README.md @@ -0,0 +1,21 @@ +# Host-Side Check Repro Scripts + +This folder contains standalone scripts that deliberately trigger host-side (and adapter-side) validation errors described in `docs/compiler_internals/tensor_checks.md`. Each script can be run directly and will reproduce the corresponding error with a minimal example. + +Prerequisites +- CUDA-capable environment (most scripts compile a CUDA-targeted kernel) +- Python packages: torch, tilelang + +Usage +- Run any script, e.g.: + - `python 01_num_args_mismatch.py` + - `python 02_pointer_type_error.py` + - ... up to `10_scalar_type_mismatch.py` + +- Or run all at once with a summary: + - `python run_all.py` + - Logs per test are saved under `logs/` as `