diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 2983495b5..1e065ed02 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -102,6 +102,12 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt) TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) .set_num_inputs(-1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 7b071e842..9606176d8 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -143,6 +143,10 @@ TVM_DLL const Op &ieee_frsqrt(); // ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division TVM_DLL const Op &ieee_fdiv(); +// random op +TVM_DLL const Op &rng_init(); +TVM_DLL const Op &rng_rand(); + /*! * \brief tvm intrinsics for TMADescriptor creation for tiled load * diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 011855fbf..c63517ab0 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -297,6 +297,10 @@ std::string CodeGenTileLangCUDA::Finish() { decl_stream << "#include \n"; } + if (need_curand_kernel_h_) { + decl_stream << "#include \n"; + } + decl_stream << "#include \n"; if (enable_sparse_gemm_) { decl_stream << "#include \n"; @@ -2740,6 +2744,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string func_name = math_func(op->dtype, "fdiv", rounding_mode); os << func_name << "(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::rng_init())) { + this->need_curand_kernel_h_ = true; + this->curand_philox_state = name_supply_->FreshName("__philox_state"); + this->PrintIndent(); + this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state + << ";\n"; + this->PrintIndent(); + this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) + << ", &" << this->curand_philox_state << ");\n"; + // Store state_var for later use by rng_rand + } else if (op->op.same_as(tl::rng_rand())) { + this->need_curand_kernel_h_ = true; + os << "curand(&" << this->curand_philox_state << ")"; } else if (op->op.same_as(tl::warp_reduce_sum())) { os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_max())) { diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 45fe5e2a0..9cf460213 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -88,6 +88,8 @@ class CodeGenTileLangCUDA final : public CodeGenC { std::string vid_global_barrier_state_; // Global barrier expected node. std::string vid_global_barrier_expect_; + // Global curand state + std::string curand_philox_state; // whether enable fp16 bool enable_fp16_{false}; @@ -123,6 +125,8 @@ class CodeGenTileLangCUDA final : public CodeGenC { bool need_cast_smem_ptr_to_int_{false}; // whether need cooperative_groups.h bool need_cooperative_groups_{false}; + // whether need curand_kernel.h + bool need_curand_kernel_h_{false}; // Op attribute map OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 337312851..082fec2cb 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -1190,6 +1190,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { }); if ((has_non_local || has_cast_operations) && !has_reducer) { + DLOG(INFO) << "Try to vectorize loop"; for_node = VectorizeLoop(for_node, saved_analyzer.get()); } diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 72b93b78c..7a446731f 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -152,6 +152,10 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } else if (node->op == builtin::call_extern()) { // do not vectorize extern calls vector_size_ = 1; + } else if (node->op.same_as(tl::rng_rand()) || + node->op.same_as(tl::rng_init())) { + // do not vectorize random operation + vector_size_ = 1; } return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } diff --git a/testing/python/language/test_rand.py b/testing/python/language/test_rand.py new file mode 100644 index 000000000..5e25cc3bf --- /dev/null +++ b/testing/python/language/test_rand.py @@ -0,0 +1,35 @@ +import tilelang +import tilelang.language as T # noqa: N812 +import torch +import triton +import triton.language as tl + + +@tilelang.jit +def tilelang_rand_1d(M=1024, seed=42): + blk_M = 128 + num_threads = 128 + + @T.prim_func + def rand_kernel(A: T.Tensor((M,), "uint32")): + with T.Kernel(M // blk_M, threads=num_threads) as bx: + T.rng_init(seed) + for i in T.Parallel(blk_M): + A[bx * blk_M + i] = T.rng_rand() + + return rand_kernel + + +@triton.jit +def triton_rand_1d(X, M, seed): + pid = tl.program_id(0) + offset = pid * M + tl.arange(0, M) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < M) + + +if __name__ == "__main__": + M = 1024 + kernel = tilelang_rand_1d() + x = torch.empty(M, dtype=torch.uint32, device="cuda") + kernel(x) diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 9a6354e96..e3067a23c 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -110,6 +110,11 @@ annotate_l2_hit_ratio, ) +from .random import ( + rng_init, # noqa: F401 + rng_rand, # noqa: F401 +) + def import_source(source: str | None = None): # source is the source code to be imported diff --git a/tilelang/language/random.py b/tilelang/language/random.py new file mode 100644 index 000000000..a76625be2 --- /dev/null +++ b/tilelang/language/random.py @@ -0,0 +1,44 @@ +from tvm import tir +import tilelang.language as T + + +# https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview +def rng_init(seed, seq=None, off=0): + """Initialize CUDA curand random number generator state + + Parameters + ---------- + seed : PrimExpr + Random seed value. + seq : PrimExpr + Sequence number for parallel random number generation. + off : PrimExpr + Offset number for parallel random number generation. + + Returns + ------- + state : PrimExpr + The random number generator state handle. + """ + seed = tir.convert(seed) + if seq is None: + bx = T.get_block_binding() + ex = T.kernel.get_thread_extent() + tx = T.get_thread_binding() + id = tx + bx * ex + seq = tir.convert(id) + else: + seq = tir.convert(seq) + off = tir.convert(off) + return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off) + + +def rng_rand(): + """Generate a 32-bit unsigned random integer + + Returns + ------- + random_value : PrimExpr + A 32-bit unsigned random integer. + """ + return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand"))