diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 9b20a4b6d..5a4243471 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2830,8 +2830,14 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { int lanes = static_cast(Downcast(op->lanes)->value); - CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef(op) - << " with " << lanes << " lanes is not allowed."; + // TODO(chaofan): Comment the ramp lanes limit for now since we have + // LegalizeVectorizedLoop to automatically legalize vectorized loop whose + // width exceeds the limit. But we should add check here for safety in the + // future. The check should be aligned to certain bit width like 128bits or + // 256bits. + + // CHECK_LE(lanes, 8) << "Translate Ramp Node " << tvm::ffi::GetRef(op) + // << "error: " << lanes << " exceeds max ramp lanes 8."; os << "(make_"; PrintType(op->dtype, os); os << "("; diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index bf2a5100b..7418b236f 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -137,6 +137,49 @@ TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0, return result; } +// Pack four char values. +TL_DEVICE unsigned int make_uint(unsigned char x0, unsigned char x1, + unsigned char x2, unsigned char x3) { + return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0; +} + +// Pack eight char values. +TL_DEVICE uint2 make_uint2(unsigned char x0, unsigned char x1, unsigned char x2, + unsigned char x3, unsigned char y0, unsigned char y1, + unsigned char y2, unsigned char y3) { + uint2 result; + result.x = make_uint(x0, x1, x2, x3); + result.y = make_uint(y0, y1, y2, y3); + return result; +} + +// Pack sixteen char values. +TL_DEVICE uint4 make_uint4(unsigned char x0, unsigned char x1, unsigned char x2, + unsigned char x3, unsigned char y0, unsigned char y1, + unsigned char y2, unsigned char y3, unsigned char z0, + unsigned char z1, unsigned char z2, unsigned char z3, + unsigned char w0, unsigned char w1, unsigned char w2, + unsigned char w3) { + uint4 result; + result.x = make_uint(x0, x1, x2, x3); + result.y = make_uint(y0, y1, y2, y3); + result.z = make_uint(z0, z1, z2, z3); + result.w = make_uint(w0, w1, w2, w3); + return result; +} + +TL_DEVICE uint4 make_uint4(unsigned short x0, unsigned short x1, + unsigned short y0, unsigned short y1, + unsigned short z0, unsigned short z1, + unsigned short w0, unsigned short w1) { + uint4 result; + *((ushort2 *)&result.x) = make_ushort2(x0, x1); + *((ushort2 *)&result.y) = make_ushort2(y0, y1); + *((ushort2 *)&result.z) = make_ushort2(z0, z1); + *((ushort2 *)&result.w) = make_ushort2(w0, w1); + return result; +} + // Pack eight int values. TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, int z1, int w0, int w1) { diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py index 75360bb19..7462aa81b 100644 --- a/testing/python/language/test_tilelang_language_vectorize.py +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -1,6 +1,7 @@ import torch import tilelang.testing import tilelang.language as T +import pytest @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) @@ -112,5 +113,39 @@ def test_vectorize_invariant_index(): run_vectorize_invariant_index(N, M * 7, 14) +@tilelang.jit +def vectorize_test_all_dtypes(dtype, vec_num): + @T.prim_func + def main(A: T.Tensor[(64,), dtype]): + with T.Kernel(1, threads=256): + for i in T.vectorized(vec_num): + A[i] = T.cast(i + 1, dtype) + + return main + + +@pytest.mark.parametrize( + "dtype", + [ + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e8m0fnu, + ], +) +@pytest.mark.parametrize("vec_num", [1, 2, 4, 8]) +def test_vectorize_all_dtypes(dtype, vec_num): + x = torch.empty((64,), dtype=dtype, device="cuda") + kernel = vectorize_test_all_dtypes(dtype, vec_num) + kernel(x) + + if __name__ == "__main__": tilelang.testing.main()