diff --git a/src/op/builtin.cc b/src/op/builtin.cc index a0ee8acd8..bee1c6f3c 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -101,12 +101,17 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt) TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr( +TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(4).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(rng_rand_float) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) .set_num_inputs(-1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 0e39e9ad4..2a23db71f 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -131,6 +131,7 @@ TVM_DLL const Op &ieee_fdiv(); // random op TVM_DLL const Op &rng_init(); TVM_DLL const Op &rng_rand(); +TVM_DLL const Op &rng_rand_float(); /*! * \brief tvm intrinsics for TMADescriptor creation for tiled load diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 5a4243471..87350acfd 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2651,18 +2651,28 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { << PrintExpr(op->args[1]) << ")"; } else if (op->op.same_as(tl::rng_init())) { this->need_curand_kernel_h_ = true; - this->curand_philox_state = name_supply_->FreshName("__philox_state"); + this->curand_random_generator_state = + name_supply_->FreshName("__random_generator_state"); + this->curand_random_generator_state_type = + op->args[3].as()->value; this->PrintIndent(); - this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state - << ";\n"; + this->stream << op->args[3].as()->value << " " + << this->curand_random_generator_state << ";\n"; this->PrintIndent(); this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) - << ", &" << this->curand_philox_state << ");\n"; + << ", &" << this->curand_random_generator_state << ");\n"; // Store state_var for later use by rng_rand } else if (op->op.same_as(tl::rng_rand())) { this->need_curand_kernel_h_ = true; - os << "curand(&" << this->curand_philox_state << ")"; + os << "curand(&" << this->curand_random_generator_state << ")"; + } else if (op->op.same_as(tl::rng_rand_float())) { + this->need_curand_kernel_h_ = true; + os << "curand_" << op->args[0].as()->value; + if (op->dtype.bits() == 64) { + os << "_double"; + } + os << "(&" << this->curand_random_generator_state << ")"; } else if (op->op.same_as(tl::warp_reduce_sum())) { os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_max())) { @@ -3112,6 +3122,59 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, } } + if (auto call = op->value.as()) { + if (this->curand_random_generator_state_type == + "curandStatePhilox4_32_10_t") { + if (call->op.same_as(tl::rng_rand()) && lanes == 4) { + os << "curand4(&" << this->curand_random_generator_state << ")"; + return; + } + if (call->op.same_as(tl::rng_rand_float())) { + int bits = call->dtype.bits(); + std::string dist = call->args[0].as()->value; + if (bits == 32) { + if (lanes == 4) { + os << "curand_" << dist << "4(&" + << this->curand_random_generator_state << ")"; + return; + } else if (lanes == 2 && dist == "normal") { + os << "curand_normal2(&" << this->curand_random_generator_state + << ")"; + return; + } + + } else { + if (lanes == 2) { + os << "curand_" << dist << "2_double(&" + << this->curand_random_generator_state << ")"; + return; + } + } + } + } else if (this->curand_random_generator_state_type == + "curandStateMRG32k3a_t" || + this->curand_random_generator_state_type == + "curandStateXORWOW_t") { + if (call->op.same_as(tl::rng_rand_float())) { + int bits = call->dtype.bits(); + std::string dist = call->args[0].as()->value; + if (bits == 32) { + if (lanes == 2 && dist == "normal") { + os << "curand_normal2(&" << this->curand_random_generator_state + << ")"; + return; + } + } else { + if (lanes == 2 && dist == "normal") { + os << "curand_normal2_double(&" + << this->curand_random_generator_state << ")"; + return; + } + } + } + } + } + std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 9cf460213..b46fa1dd0 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -89,7 +89,8 @@ class CodeGenTileLangCUDA final : public CodeGenC { // Global barrier expected node. std::string vid_global_barrier_expect_; // Global curand state - std::string curand_philox_state; + std::string curand_random_generator_state; + std::string curand_random_generator_state_type; // whether enable fp16 bool enable_fp16_{false}; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 7a446731f..b25ed87c8 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -152,8 +152,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } else if (node->op == builtin::call_extern()) { // do not vectorize extern calls vector_size_ = 1; - } else if (node->op.same_as(tl::rng_rand()) || - node->op.same_as(tl::rng_init())) { + } else if (node->op.same_as(tl::rng_init())) { // do not vectorize random operation vector_size_ = 1; } diff --git a/testing/python/language/test_tilelang_language_rand.py b/testing/python/language/test_tilelang_language_rand.py index daf51dbb7..d179d5478 100644 --- a/testing/python/language/test_tilelang_language_rand.py +++ b/testing/python/language/test_tilelang_language_rand.py @@ -6,32 +6,65 @@ @tilelang.jit -def tilelang_rand_1d(M=1024, seed=42): +def tilelang_rand_1d(M=1024, seed=42, generator="curandStatePhilox4_32_10_t"): num_per_thread = 128 threads = 1 blk_M = num_per_thread * threads @T.prim_func - def rand_kernel(A: T.Tensor((M,), "uint32")): + def rand_kernel( + A: T.Tensor((M,), "uint32"), + B: T.Tensor((M,), "float32"), + C: T.Tensor((M,), "float64"), + D: T.Tensor((M,), "float32"), + E: T.Tensor((M,), "float64"), + ): with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx: tx = T.get_thread_binding() - T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread) + T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread, generator=generator) for i, j in T.Parallel(threads, num_per_thread): offsets = (bx * threads + i) * num_per_thread idx = offsets + j if idx < M: A[idx] = T.rng_rand() + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + B[idx] = T.rng_rand_float() + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + C[idx] = T.rng_rand_float(bit=64) + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + D[idx] = T.rng_rand_float(dist="normal") + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + E[idx] = T.rng_rand_float(bit=64, dist="normal") return rand_kernel @tilelang.testing.requires_cuda -@pytest.mark.parametrize("M, seed", [(1024, 42), (512, 123), (128, 0)]) -def test_rand_1d(M, seed): - kernel = tilelang_rand_1d(M, seed) - tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda") - kernel(tilelang_result) +@pytest.mark.parametrize( + "M, seed, generator", [(1024, 42, "curandStateMRG32k3a_t"), (512, 123, "curandStatePhilox4_32_10_t"), (128, 0, "curandStateXORWOW_t")] +) +def test_rand_1d(M, seed, generator): + kernel = tilelang_rand_1d(M, seed, generator) + A = torch.empty(M, dtype=torch.uint32, device="cuda") + B = torch.empty(M, dtype=torch.float32, device="cuda") + C = torch.empty(M, dtype=torch.float64, device="cuda") + D = torch.empty(M, dtype=torch.float32, device="cuda") + E = torch.empty(M, dtype=torch.float64, device="cuda") + kernel(A, B, C, D, E) if __name__ == "__main__": tilelang.testing.main() + # test_rand_1d(1024, 42, "curandStateMRG32k3a_t") diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index aa92cadd9..eb94cf16f 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -105,6 +105,7 @@ from .random import ( rng_init, # noqa: F401 rng_rand, # noqa: F401 + rng_rand_float, # noqa: F401 ) diff --git a/tilelang/language/random.py b/tilelang/language/random.py index a76625be2..d59433891 100644 --- a/tilelang/language/random.py +++ b/tilelang/language/random.py @@ -3,7 +3,7 @@ # https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview -def rng_init(seed, seq=None, off=0): +def rng_init(seed, seq=None, off=0, generator="curandStatePhilox4_32_10_t"): """Initialize CUDA curand random number generator state Parameters @@ -14,12 +14,16 @@ def rng_init(seed, seq=None, off=0): Sequence number for parallel random number generation. off : PrimExpr Offset number for parallel random number generation. + generator : StringImm + Set random generator. + See https://docs.nvidia.com/cuda/curand/group__DEVICE.html Returns ------- state : PrimExpr The random number generator state handle. """ + assert generator in ["curandStateMRG32k3a_t", "curandStatePhilox4_32_10_t", "curandStateXORWOW_t"] seed = tir.convert(seed) if seq is None: bx = T.get_block_binding() @@ -30,7 +34,7 @@ def rng_init(seed, seq=None, off=0): else: seq = tir.convert(seq) off = tir.convert(off) - return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off) + return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off, generator) def rng_rand(): @@ -42,3 +46,23 @@ def rng_rand(): A 32-bit unsigned random integer. """ return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand")) + + +def rng_rand_float(bit=32, dist="uniform"): + """Generate a random float + + Parameters + ---------- + bit : int = [32, 64] + Bitwidth of random float. + dist : StringImm = ["uniform", "normal"] + Random distribution. + + Returns + ------- + random_value : PrimExpr + A random float. + """ + assert bit in [32, 64] + assert dist in ["uniform", "normal"] + return tir.call_intrin("float" + str(bit), tir.op.Op.get("tl.rng_rand_float"), dist)