-
Notifications
You must be signed in to change notification settings - Fork 450
[BugFix] Complete vectorized loading for common dtypes #1536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -137,6 +137,49 @@ TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0, | |
| return result; | ||
| } | ||
|
|
||
| // Pack four char values. | ||
| TL_DEVICE unsigned int make_uint(unsigned char x0, unsigned char x1, | ||
| unsigned char x2, unsigned char x3) { | ||
| return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0; | ||
|
Comment on lines
+141
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The shifts in Useful? React with 👍 / 👎. |
||
| } | ||
|
|
||
| // Pack eight char values. | ||
| TL_DEVICE uint2 make_uint2(unsigned char x0, unsigned char x1, unsigned char x2, | ||
| unsigned char x3, unsigned char y0, unsigned char y1, | ||
| unsigned char y2, unsigned char y3) { | ||
| uint2 result; | ||
| result.x = make_uint(x0, x1, x2, x3); | ||
| result.y = make_uint(y0, y1, y2, y3); | ||
| return result; | ||
| } | ||
|
|
||
| // Pack sixteen char values. | ||
| TL_DEVICE uint4 make_uint4(unsigned char x0, unsigned char x1, unsigned char x2, | ||
| unsigned char x3, unsigned char y0, unsigned char y1, | ||
| unsigned char y2, unsigned char y3, unsigned char z0, | ||
| unsigned char z1, unsigned char z2, unsigned char z3, | ||
| unsigned char w0, unsigned char w1, unsigned char w2, | ||
| unsigned char w3) { | ||
| uint4 result; | ||
| result.x = make_uint(x0, x1, x2, x3); | ||
| result.y = make_uint(y0, y1, y2, y3); | ||
| result.z = make_uint(z0, z1, z2, z3); | ||
| result.w = make_uint(w0, w1, w2, w3); | ||
| return result; | ||
| } | ||
|
|
||
| TL_DEVICE uint4 make_uint4(unsigned short x0, unsigned short x1, | ||
| unsigned short y0, unsigned short y1, | ||
| unsigned short z0, unsigned short z1, | ||
| unsigned short w0, unsigned short w1) { | ||
| uint4 result; | ||
| *((ushort2 *)&result.x) = make_ushort2(x0, x1); | ||
| *((ushort2 *)&result.y) = make_ushort2(y0, y1); | ||
| *((ushort2 *)&result.z) = make_ushort2(z0, z1); | ||
| *((ushort2 *)&result.w) = make_ushort2(w0, w1); | ||
| return result; | ||
| } | ||
|
|
||
| // Pack eight int values. | ||
| TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, | ||
| int z1, int w0, int w1) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import torch | ||
| import tilelang.testing | ||
| import tilelang.language as T | ||
| import pytest | ||
|
|
||
|
|
||
| @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) | ||
|
|
@@ -112,5 +113,39 @@ def test_vectorize_invariant_index(): | |
| run_vectorize_invariant_index(N, M * 7, 14) | ||
|
|
||
|
|
||
| @tilelang.jit | ||
| def vectorize_test_all_dtypes(dtype, vec_num): | ||
| @T.prim_func | ||
| def main(A: T.Tensor[(64,), dtype]): | ||
| with T.Kernel(1, threads=256): | ||
| for i in T.vectorized(vec_num): | ||
| A[i] = T.cast(i + 1, dtype) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "dtype", | ||
| [ | ||
| torch.uint8, | ||
| torch.uint16, | ||
| torch.uint32, | ||
| torch.uint64, | ||
| torch.int8, | ||
| torch.int16, | ||
| torch.int32, | ||
| torch.int64, | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
| torch.float8_e8m0fnu, | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("vec_num", [1, 2, 4, 8]) | ||
| def test_vectorize_all_dtypes(dtype, vec_num): | ||
| x = torch.empty((64,), dtype=dtype, device="cuda") | ||
| kernel = vectorize_test_all_dtypes(dtype, vec_num) | ||
| kernel(x) | ||
|
Comment on lines
+144
to
+147
|
||
|
|
||
|
Comment on lines
+116
to
+148
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add output verification to the vectorization test. The test currently only verifies that the kernel executes without crashing, but doesn't validate the correctness of the vectorized memory operations. This significantly reduces the test's effectiveness. 🔎 Proposed enhancement to verify output correctness @pytest.mark.parametrize("vec_num", [1, 2, 4, 8])
def test_vectorize_all_dtypes(dtype, vec_num):
x = torch.empty((64,), dtype=dtype, device="cuda")
kernel = vectorize_test_all_dtypes(dtype, vec_num)
kernel(x)
+
+ # Verify the kernel wrote the expected values
+ expected = torch.arange(1, vec_num + 1, dtype=dtype, device="cuda")
+ torch.testing.assert_close(x[:vec_num], expected, atol=0, rtol=0)This ensures the vectorized writes are actually working correctly for all data types and vectorization widths. 🤖 Prompt for AI Agents |
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO comment indicates that the commented-out CHECK is temporarily removed, which creates a safety issue. While the LegalizeVectorizedLoop pass may handle oversized vectorized loops, removing runtime checks without proper validation could allow invalid code to pass through. The TODO should either be addressed in this PR or tracked in a separate issue with specific bit-width constraints (128/256 bits) documented.