-
Notifications
You must be signed in to change notification settings - Fork 330
[Feature] Tilelang JITv2: Low Overhead and Syntax Sugars #1003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
dba18e8
tilelang jit v2
kurisu6912 209490c
Merge branch 'main' into jit-v2
kurisu6912 0b56bd6
fix lint error
kurisu6912 e7f4cd9
fix typos
kurisu6912 51c415a
add torch.dtype, add var naming
kurisu6912 fb8dd3d
add macro, add ptr and make_tensor
kurisu6912 26b6e65
many update
kurisu6912 032c978
add support for augassign
kurisu6912 dee2b75
fix lint error
kurisu6912 df8c21e
Merge branch 'main' into jit-v2
kurisu6912 407b694
fix compile error reports
kurisu6912 f7a4e0d
update
kurisu6912 5832618
update
kurisu6912 dee6d7c
fix lint error
kurisu6912 f7acf8a
update
kurisu6912 1dc0d0f
fix lint error
kurisu6912 0c90b0e
add jitv2 example
kurisu6912 25c7d1c
[Lint]: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| from tilelang import tvm | ||
| from tvm import ir | ||
| import torch | ||
| import ctypes | ||
| from typing import Any | ||
|
|
||
|
|
||
| class VoidPtr: | ||
| ... | ||
|
|
||
|
|
||
| AnyDType = ir.Type | str | type | torch.dtype | tvm.DataType | ||
|
|
||
| _dtype_torch2tvm = { | ||
| # special types should placed in the first | ||
| float: "float32", | ||
| int: "int32", | ||
| torch.long: "int64", | ||
| torch.half: "half", | ||
| # other dtypes | ||
| torch.bool: "bool", | ||
| torch.int8: "int8", | ||
| torch.int16: "int16", | ||
| torch.int32: "int32", | ||
| torch.int64: "int64", | ||
| torch.uint8: "uint8", | ||
| torch.uint16: "uint16", | ||
| torch.uint32: "uint32", | ||
| torch.uint64: "uint64", | ||
| torch.bfloat16: "bfloat16", | ||
| torch.float16: "float16", | ||
| torch.float32: "float32", | ||
| torch.float64: "float64", | ||
| torch.float8_e4m3fn: "float8_e4m3fn", | ||
| torch.float8_e4m3fnuz: "float8_e4m3fnuz", | ||
| torch.float8_e5m2: "float8_e5m2", | ||
| torch.float8_e5m2fnuz: "float8_e5m2fnuz", | ||
| torch.float8_e8m0fnu: "float8_e8m0fnu", | ||
| } | ||
|
|
||
| _dtype_tvm2torch = {tvm.DataType(v): k for k, v in _dtype_torch2tvm.items()} | ||
|
|
||
| _dtype_tvm2ctype = { | ||
| tvm.DataType("bool"): | ||
| ctypes.c_bool, | ||
| tvm.DataType("int8"): | ||
| ctypes.c_int8, | ||
| tvm.DataType("int16"): | ||
| ctypes.c_int16, | ||
| tvm.DataType("int32"): | ||
| ctypes.c_int32, | ||
| tvm.DataType("int64"): | ||
| ctypes.c_int64, | ||
| tvm.DataType("uint8"): | ||
| ctypes.c_uint8, | ||
| tvm.DataType("uint16"): | ||
| ctypes.c_uint16, | ||
| tvm.DataType("uint32"): | ||
| ctypes.c_uint32, | ||
| tvm.DataType("uint64"): | ||
| ctypes.c_uint64, | ||
| # tvm.DataType("float16"): ctypes.c_uint16, | ||
| # tvm.DataType("bfloat16"): ctypes.c_uint16, | ||
| tvm.DataType("float32"): | ||
| ctypes.c_float, | ||
| tvm.DataType("float64"): | ||
| ctypes.c_double, | ||
| # tvm.DataType("float8_e4m3fn"): ctypes.c_uint8, | ||
| # tvm.DataType("float8_e4m3fnuz"): ctypes.c_uint8, | ||
| # tvm.DataType("float8_e5m2"): ctypes.c_uint8, | ||
| # tvm.DataType("float8_e5m2fnuz"): ctypes.c_uint8, | ||
| # tvm.DataType("float8_e8m0fnu"): ctypes.c_uint8, | ||
| tvm.DataType("handle"): | ||
| ctypes.c_void_p, | ||
| } | ||
|
|
||
| _dtype_tvm2cffi = { | ||
| tvm.DataType("bool"): | ||
| "bool", | ||
| tvm.DataType("int8"): | ||
| "char", | ||
| tvm.DataType("int16"): | ||
| "short", | ||
| tvm.DataType("int32"): | ||
| "int", | ||
| tvm.DataType("int64"): | ||
| "long long", | ||
| tvm.DataType("uint8"): | ||
| "unsigned char", | ||
| tvm.DataType("uint16"): | ||
| "unsigned short", | ||
| tvm.DataType("uint32"): | ||
| "unsigned int", | ||
| tvm.DataType("uint64"): | ||
| "unsigned long long", | ||
| tvm.DataType("float32"): | ||
| "float", | ||
| tvm.DataType("float64"): | ||
| "double", | ||
| # tvm.DataType("float16"): 'uint16_t', | ||
| # tvm.DataType("bfloat16"): 'uint16_t', | ||
| # tvm.DataType("float8_e4m3fn"): 'uint8_t', | ||
| # tvm.DataType("float8_e4m3fnuz"): 'uint8_t', | ||
| # tvm.DataType("float8_e5m2"): ctypes.c_uint8, | ||
| # tvm.DataType("float8_e5m2fnuz"): ctypes.c_uint8, | ||
| # tvm.DataType("float8_e8m0fnu"): ctypes.c_uint8, | ||
| tvm.DataType("handle"): | ||
| "long", | ||
| } | ||
|
|
||
|
|
||
| def get_tvm_dtype(ty: AnyDType) -> tvm.DataType: | ||
| if ty is None: | ||
| return ty | ||
| if ty == VoidPtr: | ||
| return get_tvm_ptr_type() | ||
| if isinstance(ty, (ir.Type, tvm.DataType)): | ||
| return ty | ||
| if isinstance(ty, str): | ||
| return tvm.DataType(ty) | ||
| return tvm.DataType(_dtype_torch2tvm[ty]) | ||
|
|
||
|
|
||
| def get_tvm_dtype_str(ty: AnyDType) -> str: | ||
| if isinstance(ty, str): | ||
| return ty | ||
| return _dtype_torch2tvm[ty] | ||
|
|
||
|
|
||
| def get_torch_dtype(ty: AnyDType) -> torch.dtype: | ||
| if isinstance(ty, torch.dtype): | ||
| return ty | ||
| if isinstance(ty, str): | ||
| ty = tvm.DataType(ty) | ||
| return _dtype_tvm2torch[ty] | ||
|
|
||
|
|
||
| def get_ctypes_dtype(ty: AnyDType) -> Any: | ||
| ty = get_tvm_dtype(ty) | ||
| return _dtype_tvm2ctype[ty] | ||
|
|
||
|
|
||
| def get_cffi_dtype(ty: AnyDType) -> str: | ||
| ty = get_tvm_dtype(ty) | ||
| return _dtype_tvm2cffi[ty] | ||
|
|
||
|
|
||
| def get_tvm_ptr_type(ty: ir.Type | str | type | torch.dtype = "void", | ||
| scope: str = "global") -> ir.PointerType: | ||
| ty = get_tvm_dtype(ty) | ||
| return ir.PointerType(ir.PrimType(ty), scope) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle
tvm.DataType(and other AnyDType cases) before mapping.get_tvm_dtype_strcurrently only returns early forstrinputs. For every other allowedAnyDType(e.g., thetvm.DataTypeobjects produced byget_tvm_dtype,ir.PrimType, or the newVoidPtrsentinel), this code immediately indexes_dtype_torch2tvm. Passing any of those values now raisesKeyError, which breaks call sites such ascall_intrin(get_tvm_dtype(torch.float32), ...)and any of the intrinsic wrappers updated in this PR. Please normalize non-string inputs first—for example, accepttvm.DataType/ir.PrimTypeby returningstr(ty)(orty.dtype), detectVoidPtr, and only fall back to_dtype_torch2tvmfor Python/torch dtypes.A minimal patch might look like:
(Feel free to tailor the exact pointer handling, but the key is to cover all members of
AnyDTypewithout throwing.)📝 Committable suggestion
🤖 Prompt for AI Agents