Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt)
TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
4 changes: 4 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ TVM_DLL const Op &ieee_frsqrt();
// ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division
TVM_DLL const Op &ieee_fdiv();

// random op
TVM_DLL const Op &rng_init();
TVM_DLL const Op &rng_rand();

/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
*
Expand Down
18 changes: 18 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <cooperative_groups.h>\n";
}

if (need_curand_kernel_h_) {
decl_stream << "#include <curand_kernel.h>\n";
}

decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
if (enable_sparse_gemm_) {
decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
Expand Down Expand Up @@ -2740,6 +2744,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string func_name = math_func(op->dtype, "fdiv", rounding_mode);
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::rng_init())) {
this->need_curand_kernel_h_ = true;
this->curand_philox_state = name_supply_->FreshName("__philox_state");
this->PrintIndent();
this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state
<< ";\n";
this->PrintIndent();
this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2])
<< ", &" << this->curand_philox_state << ");\n";
// Store state_var for later use by rng_rand
} else if (op->op.same_as(tl::rng_rand())) {
this->need_curand_kernel_h_ = true;
os << "curand(&" << this->curand_philox_state << ")";
} else if (op->op.same_as(tl::warp_reduce_sum())) {
os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_max())) {
Expand Down
4 changes: 4 additions & 0 deletions src/target/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class CodeGenTileLangCUDA final : public CodeGenC {
std::string vid_global_barrier_state_;
// Global barrier expected node.
std::string vid_global_barrier_expect_;
// Global curand state
std::string curand_philox_state;

// whether enable fp16
bool enable_fp16_{false};
Expand Down Expand Up @@ -123,6 +125,8 @@ class CodeGenTileLangCUDA final : public CodeGenC {
bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h
bool need_cooperative_groups_{false};
// whether need curand_kernel.h
bool need_curand_kernel_h_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
Expand Down
1 change: 1 addition & 0 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
});

if ((has_non_local || has_cast_operations) && !has_reducer) {
DLOG(INFO) << "Try to vectorize loop";
for_node = VectorizeLoop(for_node, saved_analyzer.get());
}

Expand Down
4 changes: 4 additions & 0 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
} else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls
vector_size_ = 1;
} else if (node->op.same_as(tl::rng_rand()) ||
node->op.same_as(tl::rng_init())) {
// do not vectorize random operation
vector_size_ = 1;
}
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
Expand Down
35 changes: 35 additions & 0 deletions testing/python/language/test_rand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import tilelang
import tilelang.language as T # noqa: N812
import torch
import triton
import triton.language as tl


@tilelang.jit
def tilelang_rand_1d(M=1024, seed=42):
blk_M = 128
num_threads = 128

@T.prim_func
def rand_kernel(A: T.Tensor((M,), "uint32")):
with T.Kernel(M // blk_M, threads=num_threads) as bx:
T.rng_init(seed)
for i in T.Parallel(blk_M):
A[bx * blk_M + i] = T.rng_rand()

return rand_kernel


@triton.jit
def triton_rand_1d(X, M, seed):
pid = tl.program_id(0)
offset = pid * M + tl.arange(0, M)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < M)


if __name__ == "__main__":
M = 1024
kernel = tilelang_rand_1d()
x = torch.empty(M, dtype=torch.uint32, device="cuda")
kernel(x)
5 changes: 5 additions & 0 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@
annotate_l2_hit_ratio,
)

from .random import (
rng_init, # noqa: F401
rng_rand, # noqa: F401
)


def import_source(source: str | None = None):
# source is the source code to be imported
Expand Down
44 changes: 44 additions & 0 deletions tilelang/language/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from tvm import tir
import tilelang.language as T


# https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview
def rng_init(seed, seq=None, off=0):
"""Initialize CUDA curand random number generator state

Parameters
----------
seed : PrimExpr
Random seed value.
seq : PrimExpr
Sequence number for parallel random number generation.
off : PrimExpr
Offset number for parallel random number generation.

Returns
-------
state : PrimExpr
The random number generator state handle.
"""
seed = tir.convert(seed)
if seq is None:
bx = T.get_block_binding()
ex = T.kernel.get_thread_extent()
tx = T.get_thread_binding()
id = tx + bx * ex
seq = tir.convert(id)
else:
seq = tir.convert(seq)
off = tir.convert(off)
return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off)


def rng_rand():
"""Generate a 32-bit unsigned random integer

Returns
-------
random_value : PrimExpr
A 32-bit unsigned random integer.
"""
return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand"))
Loading