From 554480d15b530c2e4c381480980394e43a753bf5 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 31 Dec 2025 16:27:15 +0800 Subject: [PATCH 1/3] add more curand operations & support vectorization --- src/op/builtin.cc | 22 +++++- src/op/builtin.h | 4 + src/target/codegen_cuda.cc | 73 +++++++++++++++++-- src/target/codegen_cuda.h | 3 +- src/transform/loop_vectorize.cc | 3 +- .../language/test_tilelang_language_rand.py | 48 ++++++++++-- tilelang/language/__init__.py | 4 + tilelang/language/random.py | 52 ++++++++++++- 8 files changed, 190 insertions(+), 19 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index a0ee8acd8..77116397b 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -101,12 +101,32 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt) TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr( +TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(4).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(rng_rand_uniform) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(rng_rand_uniform_double) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(rng_rand_normal) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(rng_rand_normal_double) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) .set_num_inputs(-1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 0e39e9ad4..b79b49211 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -131,6 +131,10 @@ TVM_DLL const Op &ieee_fdiv(); // random op TVM_DLL const Op &rng_init(); TVM_DLL const Op &rng_rand(); +TVM_DLL const Op &rng_rand_uniform(); +TVM_DLL const Op &rng_rand_uniform_double(); +TVM_DLL const Op &rng_rand_normal(); +TVM_DLL const Op &rng_rand_normal_double(); /*! * \brief tvm intrinsics for TMADescriptor creation for tiled load diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 5a4243471..5b22435ba 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2651,18 +2651,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { << PrintExpr(op->args[1]) << ")"; } else if (op->op.same_as(tl::rng_init())) { this->need_curand_kernel_h_ = true; - this->curand_philox_state = name_supply_->FreshName("__philox_state"); + this->curand_random_generator_state = + name_supply_->FreshName("__random_generator_state"); + this->curand_random_generator_state_type = + op->args[3].as()->value; this->PrintIndent(); - this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state - << ";\n"; + this->stream << op->args[3].as()->value << " " + << this->curand_random_generator_state << ";\n"; this->PrintIndent(); this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) - << ", &" << this->curand_philox_state << ");\n"; + << ", &" << this->curand_random_generator_state << ");\n"; // Store state_var for later use by rng_rand } else if (op->op.same_as(tl::rng_rand())) { this->need_curand_kernel_h_ = true; - os << "curand(&" << this->curand_philox_state << ")"; + os << "curand(&" << this->curand_random_generator_state << ")"; + } else if (op->op.same_as(tl::rng_rand_uniform())) { + this->need_curand_kernel_h_ = true; + os << "curand_uniform(&" << this->curand_random_generator_state << ")"; + } else if (op->op.same_as(tl::rng_rand_uniform_double())) { + this->need_curand_kernel_h_ = true; + os << "curand_uniform_double(&" << this->curand_random_generator_state + << ")"; + } else if (op->op.same_as(tl::rng_rand_normal())) { + this->need_curand_kernel_h_ = true; + os << "curand_normal(&" << this->curand_random_generator_state << ")"; + } else if (op->op.same_as(tl::rng_rand_normal_double())) { + this->need_curand_kernel_h_ = true; + os << "curand_normal_double(&" << this->curand_random_generator_state + << ")"; } else if (op->op.same_as(tl::warp_reduce_sum())) { os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_max())) { @@ -3112,6 +3129,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, } } + if (auto call = op->value.as()) { + if (this->curand_random_generator_state_type == + "curandStatePhilox4_32_10_t") { + if (call->op.same_as(tl::rng_rand()) && lanes == 4) { + os << "curand4(&" << this->curand_random_generator_state << ")"; + return; + } + if (call->op.same_as(tl::rng_rand_uniform()) && lanes == 4) { + os << "curand_uniform4(&" << this->curand_random_generator_state << ")"; + return; + } + if (call->op.same_as(tl::rng_rand_uniform_double()) && lanes == 2) { + os << "curand_uniform2_double(&" << this->curand_random_generator_state + << ")"; + return; + } + if (call->op.same_as(tl::rng_rand_normal()) && lanes == 4) { + os << "curand_normal4(&" << this->curand_random_generator_state << ")"; + return; + } + if (call->op.same_as(tl::rng_rand_normal()) && lanes == 2) { + os << "curand_normal2(&" << this->curand_random_generator_state << ")"; + return; + } + if (call->op.same_as(tl::rng_rand_normal_double()) && lanes == 2) { + os << "curand_normal2_double(&" << this->curand_random_generator_state + << ")"; + return; + } + } else if (this->curand_random_generator_state_type == + "curandStateMRG32k3a_t") { + if (call->op.same_as(tl::rng_rand_normal_double()) && lanes == 2) { + os << "curand_normal2_double(&" << this->curand_random_generator_state + << ")"; + return; + } + } else if (this->curand_random_generator_state_type == + "curandStateXORWOW_t") { + if (call->op.same_as(tl::rng_rand_normal_double()) && lanes == 2) { + os << "curand_normal2_double(&" << this->curand_random_generator_state + << ")"; + return; + } + } + } + std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 9cf460213..b46fa1dd0 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -89,7 +89,8 @@ class CodeGenTileLangCUDA final : public CodeGenC { // Global barrier expected node. std::string vid_global_barrier_expect_; // Global curand state - std::string curand_philox_state; + std::string curand_random_generator_state; + std::string curand_random_generator_state_type; // whether enable fp16 bool enable_fp16_{false}; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 7a446731f..b25ed87c8 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -152,8 +152,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } else if (node->op == builtin::call_extern()) { // do not vectorize extern calls vector_size_ = 1; - } else if (node->op.same_as(tl::rng_rand()) || - node->op.same_as(tl::rng_init())) { + } else if (node->op.same_as(tl::rng_init())) { // do not vectorize random operation vector_size_ = 1; } diff --git a/testing/python/language/test_tilelang_language_rand.py b/testing/python/language/test_tilelang_language_rand.py index daf51dbb7..c5e8f1c53 100644 --- a/testing/python/language/test_tilelang_language_rand.py +++ b/testing/python/language/test_tilelang_language_rand.py @@ -6,31 +6,63 @@ @tilelang.jit -def tilelang_rand_1d(M=1024, seed=42): +def tilelang_rand_1d(M=1024, seed=42, generator="curandStatePhilox4_32_10_t"): num_per_thread = 128 threads = 1 blk_M = num_per_thread * threads @T.prim_func - def rand_kernel(A: T.Tensor((M,), "uint32")): + def rand_kernel( + A: T.Tensor((M,), "uint32"), + B: T.Tensor((M,), "float32"), + C: T.Tensor((M,), "float64"), + D: T.Tensor((M,), "float32"), + E: T.Tensor((M,), "float64"), + ): with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx: tx = T.get_thread_binding() - T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread) + T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread, generator=generator) for i, j in T.Parallel(threads, num_per_thread): offsets = (bx * threads + i) * num_per_thread idx = offsets + j if idx < M: A[idx] = T.rng_rand() + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + B[idx] = T.rng_rand_uniform() + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + C[idx] = T.rng_rand_uniform_double() + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + D[idx] = T.rng_rand_normal() + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + E[idx] = T.rng_rand_normal_double() return rand_kernel @tilelang.testing.requires_cuda -@pytest.mark.parametrize("M, seed", [(1024, 42), (512, 123), (128, 0)]) -def test_rand_1d(M, seed): - kernel = tilelang_rand_1d(M, seed) - tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda") - kernel(tilelang_result) +@pytest.mark.parametrize( + "M, seed, generator", [(1024, 42, "curandStateMRG32k3a_t"), (512, 123, "curandStatePhilox4_32_10_t"), (128, 0, "curandStateXORWOW_t")] +) +def test_rand_1d(M, seed, generator): + kernel = tilelang_rand_1d(M, seed, generator) + A = torch.empty(M, dtype=torch.uint32, device="cuda") + B = torch.empty(M, dtype=torch.float32, device="cuda") + C = torch.empty(M, dtype=torch.float64, device="cuda") + D = torch.empty(M, dtype=torch.float32, device="cuda") + E = torch.empty(M, dtype=torch.float64, device="cuda") + kernel(A, B, C, D, E) if __name__ == "__main__": diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index aa92cadd9..d4f1df91a 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -105,6 +105,10 @@ from .random import ( rng_init, # noqa: F401 rng_rand, # noqa: F401 + rng_rand_uniform, # noqa: F401 + rng_rand_uniform_double, # noqa: F401 + rng_rand_normal, # noqa: F401 + rng_rand_normal_double, # noqa: F401 ) diff --git a/tilelang/language/random.py b/tilelang/language/random.py index a76625be2..549c60575 100644 --- a/tilelang/language/random.py +++ b/tilelang/language/random.py @@ -3,7 +3,7 @@ # https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview -def rng_init(seed, seq=None, off=0): +def rng_init(seed, seq=None, off=0, generator="curandStatePhilox4_32_10_t"): """Initialize CUDA curand random number generator state Parameters @@ -14,12 +14,16 @@ def rng_init(seed, seq=None, off=0): Sequence number for parallel random number generation. off : PrimExpr Offset number for parallel random number generation. + generator : StringImm + Set random generator. + See https://docs.nvidia.com/cuda/curand/group__DEVICE.html Returns ------- state : PrimExpr The random number generator state handle. """ + assert generator in ["curandStateMRG32k3a_t", "curandStatePhilox4_32_10_t", "curandStateXORWOW_t"] seed = tir.convert(seed) if seq is None: bx = T.get_block_binding() @@ -30,7 +34,7 @@ def rng_init(seed, seq=None, off=0): else: seq = tir.convert(seq) off = tir.convert(off) - return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off) + return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off, generator) def rng_rand(): @@ -42,3 +46,47 @@ def rng_rand(): A 32-bit unsigned random integer. """ return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand")) + + +def rng_rand_uniform(): + """Generate a uniformly distributed float + + Returns + ------- + random_value : PrimExpr + A 32-bit uniformly distributed float. + """ + return tir.call_intrin("float32", tir.op.Op.get("tl.rng_rand_uniform")) + + +def rng_rand_uniform_double(): + """Generate a uniformly distributed double + + Returns + ------- + random_value : PrimExpr + A 32-bit uniformly distributed double. + """ + return tir.call_intrin("float64", tir.op.Op.get("tl.rng_rand_uniform_double")) + + +def rng_rand_normal(): + """Generate a normally distributed float + + Returns + ------- + random_value : PrimExpr + A 32-bit normally distributed float. + """ + return tir.call_intrin("float32", tir.op.Op.get("tl.rng_rand_normal")) + + +def rng_rand_normal_double(): + """Generate a normally distributed double + + Returns + ------- + random_value : PrimExpr + A 32-bit normally distributed double. + """ + return tir.call_intrin("float64", tir.op.Op.get("tl.rng_rand_normal_double")) From e548feaae16c1e967f9242f1e334e9686a9576b5 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 31 Dec 2025 16:33:23 +0800 Subject: [PATCH 2/3] fix typo about bitwidth --- tilelang/language/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tilelang/language/random.py b/tilelang/language/random.py index 549c60575..586384a04 100644 --- a/tilelang/language/random.py +++ b/tilelang/language/random.py @@ -65,7 +65,7 @@ def rng_rand_uniform_double(): Returns ------- random_value : PrimExpr - A 32-bit uniformly distributed double. + A 64-bit uniformly distributed double. """ return tir.call_intrin("float64", tir.op.Op.get("tl.rng_rand_uniform_double")) @@ -87,6 +87,6 @@ def rng_rand_normal_double(): Returns ------- random_value : PrimExpr - A 32-bit normally distributed double. + A 64-bit normally distributed double. """ return tir.call_intrin("float64", tir.op.Op.get("tl.rng_rand_normal_double")) From 5a97e20132d12d5e31f3f252352e9bf3bf9cc808 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Sun, 4 Jan 2026 11:01:15 +0800 Subject: [PATCH 3/3] merge random float into a single operation --- src/op/builtin.cc | 19 +--- src/op/builtin.h | 5 +- src/target/codegen_cuda.cc | 92 +++++++++---------- .../language/test_tilelang_language_rand.py | 9 +- tilelang/language/__init__.py | 5 +- tilelang/language/random.py | 48 +++------- 6 files changed, 67 insertions(+), 111 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 77116397b..bee1c6f3c 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -107,23 +107,8 @@ TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(4).set_attr( TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(rng_rand_uniform) - .set_num_inputs(0) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(rng_rand_uniform_double) - .set_num_inputs(0) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(rng_rand_normal) - .set_num_inputs(0) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(rng_rand_normal_double) - .set_num_inputs(0) +TIR_DEFINE_TL_BUILTIN(rng_rand_float) + .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/builtin.h b/src/op/builtin.h index b79b49211..2a23db71f 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -131,10 +131,7 @@ TVM_DLL const Op &ieee_fdiv(); // random op TVM_DLL const Op &rng_init(); TVM_DLL const Op &rng_rand(); -TVM_DLL const Op &rng_rand_uniform(); -TVM_DLL const Op &rng_rand_uniform_double(); -TVM_DLL const Op &rng_rand_normal(); -TVM_DLL const Op &rng_rand_normal_double(); +TVM_DLL const Op &rng_rand_float(); /*! * \brief tvm intrinsics for TMADescriptor creation for tiled load diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 5b22435ba..87350acfd 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2666,20 +2666,13 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::rng_rand())) { this->need_curand_kernel_h_ = true; os << "curand(&" << this->curand_random_generator_state << ")"; - } else if (op->op.same_as(tl::rng_rand_uniform())) { + } else if (op->op.same_as(tl::rng_rand_float())) { this->need_curand_kernel_h_ = true; - os << "curand_uniform(&" << this->curand_random_generator_state << ")"; - } else if (op->op.same_as(tl::rng_rand_uniform_double())) { - this->need_curand_kernel_h_ = true; - os << "curand_uniform_double(&" << this->curand_random_generator_state - << ")"; - } else if (op->op.same_as(tl::rng_rand_normal())) { - this->need_curand_kernel_h_ = true; - os << "curand_normal(&" << this->curand_random_generator_state << ")"; - } else if (op->op.same_as(tl::rng_rand_normal_double())) { - this->need_curand_kernel_h_ = true; - os << "curand_normal_double(&" << this->curand_random_generator_state - << ")"; + os << "curand_" << op->args[0].as()->value; + if (op->dtype.bits() == 64) { + os << "_double"; + } + os << "(&" << this->curand_random_generator_state << ")"; } else if (op->op.same_as(tl::warp_reduce_sum())) { os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_max())) { @@ -3136,41 +3129,48 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, os << "curand4(&" << this->curand_random_generator_state << ")"; return; } - if (call->op.same_as(tl::rng_rand_uniform()) && lanes == 4) { - os << "curand_uniform4(&" << this->curand_random_generator_state << ")"; - return; - } - if (call->op.same_as(tl::rng_rand_uniform_double()) && lanes == 2) { - os << "curand_uniform2_double(&" << this->curand_random_generator_state - << ")"; - return; - } - if (call->op.same_as(tl::rng_rand_normal()) && lanes == 4) { - os << "curand_normal4(&" << this->curand_random_generator_state << ")"; - return; - } - if (call->op.same_as(tl::rng_rand_normal()) && lanes == 2) { - os << "curand_normal2(&" << this->curand_random_generator_state << ")"; - return; - } - if (call->op.same_as(tl::rng_rand_normal_double()) && lanes == 2) { - os << "curand_normal2_double(&" << this->curand_random_generator_state - << ")"; - return; - } - } else if (this->curand_random_generator_state_type == - "curandStateMRG32k3a_t") { - if (call->op.same_as(tl::rng_rand_normal_double()) && lanes == 2) { - os << "curand_normal2_double(&" << this->curand_random_generator_state - << ")"; - return; + if (call->op.same_as(tl::rng_rand_float())) { + int bits = call->dtype.bits(); + std::string dist = call->args[0].as()->value; + if (bits == 32) { + if (lanes == 4) { + os << "curand_" << dist << "4(&" + << this->curand_random_generator_state << ")"; + return; + } else if (lanes == 2 && dist == "normal") { + os << "curand_normal2(&" << this->curand_random_generator_state + << ")"; + return; + } + + } else { + if (lanes == 2) { + os << "curand_" << dist << "2_double(&" + << this->curand_random_generator_state << ")"; + return; + } + } } } else if (this->curand_random_generator_state_type == - "curandStateXORWOW_t") { - if (call->op.same_as(tl::rng_rand_normal_double()) && lanes == 2) { - os << "curand_normal2_double(&" << this->curand_random_generator_state - << ")"; - return; + "curandStateMRG32k3a_t" || + this->curand_random_generator_state_type == + "curandStateXORWOW_t") { + if (call->op.same_as(tl::rng_rand_float())) { + int bits = call->dtype.bits(); + std::string dist = call->args[0].as()->value; + if (bits == 32) { + if (lanes == 2 && dist == "normal") { + os << "curand_normal2(&" << this->curand_random_generator_state + << ")"; + return; + } + } else { + if (lanes == 2 && dist == "normal") { + os << "curand_normal2_double(&" + << this->curand_random_generator_state << ")"; + return; + } + } } } } diff --git a/testing/python/language/test_tilelang_language_rand.py b/testing/python/language/test_tilelang_language_rand.py index c5e8f1c53..d179d5478 100644 --- a/testing/python/language/test_tilelang_language_rand.py +++ b/testing/python/language/test_tilelang_language_rand.py @@ -31,22 +31,22 @@ def rand_kernel( offsets = (bx * threads + i) * num_per_thread idx = offsets + j if idx < M: - B[idx] = T.rng_rand_uniform() + B[idx] = T.rng_rand_float() for i, j in T.Parallel(threads, num_per_thread): offsets = (bx * threads + i) * num_per_thread idx = offsets + j if idx < M: - C[idx] = T.rng_rand_uniform_double() + C[idx] = T.rng_rand_float(bit=64) for i, j in T.Parallel(threads, num_per_thread): offsets = (bx * threads + i) * num_per_thread idx = offsets + j if idx < M: - D[idx] = T.rng_rand_normal() + D[idx] = T.rng_rand_float(dist="normal") for i, j in T.Parallel(threads, num_per_thread): offsets = (bx * threads + i) * num_per_thread idx = offsets + j if idx < M: - E[idx] = T.rng_rand_normal_double() + E[idx] = T.rng_rand_float(bit=64, dist="normal") return rand_kernel @@ -67,3 +67,4 @@ def test_rand_1d(M, seed, generator): if __name__ == "__main__": tilelang.testing.main() + # test_rand_1d(1024, 42, "curandStateMRG32k3a_t") diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index d4f1df91a..eb94cf16f 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -105,10 +105,7 @@ from .random import ( rng_init, # noqa: F401 rng_rand, # noqa: F401 - rng_rand_uniform, # noqa: F401 - rng_rand_uniform_double, # noqa: F401 - rng_rand_normal, # noqa: F401 - rng_rand_normal_double, # noqa: F401 + rng_rand_float, # noqa: F401 ) diff --git a/tilelang/language/random.py b/tilelang/language/random.py index 586384a04..d59433891 100644 --- a/tilelang/language/random.py +++ b/tilelang/language/random.py @@ -48,45 +48,21 @@ def rng_rand(): return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand")) -def rng_rand_uniform(): - """Generate a uniformly distributed float +def rng_rand_float(bit=32, dist="uniform"): + """Generate a random float - Returns - ------- - random_value : PrimExpr - A 32-bit uniformly distributed float. - """ - return tir.call_intrin("float32", tir.op.Op.get("tl.rng_rand_uniform")) - - -def rng_rand_uniform_double(): - """Generate a uniformly distributed double - - Returns - ------- - random_value : PrimExpr - A 64-bit uniformly distributed double. - """ - return tir.call_intrin("float64", tir.op.Op.get("tl.rng_rand_uniform_double")) - - -def rng_rand_normal(): - """Generate a normally distributed float - - Returns - ------- - random_value : PrimExpr - A 32-bit normally distributed float. - """ - return tir.call_intrin("float32", tir.op.Op.get("tl.rng_rand_normal")) - - -def rng_rand_normal_double(): - """Generate a normally distributed double + Parameters + ---------- + bit : int = [32, 64] + Bitwidth of random float. + dist : StringImm = ["uniform", "normal"] + Random distribution. Returns ------- random_value : PrimExpr - A 64-bit normally distributed double. + A random float. """ - return tir.call_intrin("float64", tir.op.Op.get("tl.rng_rand_normal_double")) + assert bit in [32, 64] + assert dist in ["uniform", "normal"] + return tir.call_intrin("float" + str(bit), tir.op.Op.get("tl.rng_rand_float"), dist)