Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,17 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt)
TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr<TCallEffectKind>(
TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(4).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(rng_rand_float)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
1 change: 1 addition & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ TVM_DLL const Op &ieee_fdiv();
// random op
TVM_DLL const Op &rng_init();
TVM_DLL const Op &rng_rand();
TVM_DLL const Op &rng_rand_float();

/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
Expand Down
73 changes: 68 additions & 5 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2651,18 +2651,28 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
<< PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::rng_init())) {
this->need_curand_kernel_h_ = true;
this->curand_philox_state = name_supply_->FreshName("__philox_state");
this->curand_random_generator_state =
name_supply_->FreshName("__random_generator_state");
this->curand_random_generator_state_type =
op->args[3].as<StringImmNode>()->value;
this->PrintIndent();
this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state
<< ";\n";
this->stream << op->args[3].as<StringImmNode>()->value << " "
<< this->curand_random_generator_state << ";\n";
this->PrintIndent();
this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2])
<< ", &" << this->curand_philox_state << ");\n";
<< ", &" << this->curand_random_generator_state << ");\n";
// Store state_var for later use by rng_rand
} else if (op->op.same_as(tl::rng_rand())) {
this->need_curand_kernel_h_ = true;
os << "curand(&" << this->curand_philox_state << ")";
os << "curand(&" << this->curand_random_generator_state << ")";
} else if (op->op.same_as(tl::rng_rand_float())) {
this->need_curand_kernel_h_ = true;
os << "curand_" << op->args[0].as<StringImmNode>()->value;
if (op->dtype.bits() == 64) {
os << "_double";
}
os << "(&" << this->curand_random_generator_state << ")";
} else if (op->op.same_as(tl::warp_reduce_sum())) {
os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_max())) {
Expand Down Expand Up @@ -3112,6 +3122,59 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
}
}

if (auto call = op->value.as<CallNode>()) {
if (this->curand_random_generator_state_type ==
"curandStatePhilox4_32_10_t") {
if (call->op.same_as(tl::rng_rand()) && lanes == 4) {
os << "curand4(&" << this->curand_random_generator_state << ")";
return;
}
if (call->op.same_as(tl::rng_rand_float())) {
int bits = call->dtype.bits();
std::string dist = call->args[0].as<StringImmNode>()->value;
if (bits == 32) {
if (lanes == 4) {
os << "curand_" << dist << "4(&"
<< this->curand_random_generator_state << ")";
return;
} else if (lanes == 2 && dist == "normal") {
os << "curand_normal2(&" << this->curand_random_generator_state
<< ")";
return;
}

} else {
if (lanes == 2) {
os << "curand_" << dist << "2_double(&"
<< this->curand_random_generator_state << ")";
return;
}
}
}
} else if (this->curand_random_generator_state_type ==
"curandStateMRG32k3a_t" ||
this->curand_random_generator_state_type ==
"curandStateXORWOW_t") {
if (call->op.same_as(tl::rng_rand_float())) {
int bits = call->dtype.bits();
std::string dist = call->args[0].as<StringImmNode>()->value;
if (bits == 32) {
if (lanes == 2 && dist == "normal") {
os << "curand_normal2(&" << this->curand_random_generator_state
<< ")";
return;
}
} else {
if (lanes == 2 && dist == "normal") {
os << "curand_normal2_double(&"
<< this->curand_random_generator_state << ")";
return;
}
}
}
}
}

std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
Expand Down
3 changes: 2 additions & 1 deletion src/target/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Global barrier expected node.
std::string vid_global_barrier_expect_;
// Global curand state
std::string curand_philox_state;
std::string curand_random_generator_state;
std::string curand_random_generator_state_type;

// whether enable fp16
bool enable_fp16_{false};
Expand Down
3 changes: 1 addition & 2 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
} else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls
vector_size_ = 1;
} else if (node->op.same_as(tl::rng_rand()) ||
node->op.same_as(tl::rng_init())) {
} else if (node->op.same_as(tl::rng_init())) {
// do not vectorize random operation
vector_size_ = 1;
}
Expand Down
49 changes: 41 additions & 8 deletions testing/python/language/test_tilelang_language_rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,65 @@


@tilelang.jit
def tilelang_rand_1d(M=1024, seed=42):
def tilelang_rand_1d(M=1024, seed=42, generator="curandStatePhilox4_32_10_t"):
num_per_thread = 128
threads = 1
blk_M = num_per_thread * threads

@T.prim_func
def rand_kernel(A: T.Tensor((M,), "uint32")):
def rand_kernel(
A: T.Tensor((M,), "uint32"),
B: T.Tensor((M,), "float32"),
C: T.Tensor((M,), "float64"),
D: T.Tensor((M,), "float32"),
E: T.Tensor((M,), "float64"),
):
with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx:
tx = T.get_thread_binding()
T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread)
T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread, generator=generator)
for i, j in T.Parallel(threads, num_per_thread):
offsets = (bx * threads + i) * num_per_thread
idx = offsets + j
if idx < M:
A[idx] = T.rng_rand()
for i, j in T.Parallel(threads, num_per_thread):
offsets = (bx * threads + i) * num_per_thread
idx = offsets + j
if idx < M:
B[idx] = T.rng_rand_float()
for i, j in T.Parallel(threads, num_per_thread):
offsets = (bx * threads + i) * num_per_thread
idx = offsets + j
if idx < M:
C[idx] = T.rng_rand_float(bit=64)
for i, j in T.Parallel(threads, num_per_thread):
offsets = (bx * threads + i) * num_per_thread
idx = offsets + j
if idx < M:
D[idx] = T.rng_rand_float(dist="normal")
for i, j in T.Parallel(threads, num_per_thread):
offsets = (bx * threads + i) * num_per_thread
idx = offsets + j
if idx < M:
E[idx] = T.rng_rand_float(bit=64, dist="normal")

return rand_kernel


@tilelang.testing.requires_cuda
@pytest.mark.parametrize("M, seed", [(1024, 42), (512, 123), (128, 0)])
def test_rand_1d(M, seed):
kernel = tilelang_rand_1d(M, seed)
tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda")
kernel(tilelang_result)
@pytest.mark.parametrize(
"M, seed, generator", [(1024, 42, "curandStateMRG32k3a_t"), (512, 123, "curandStatePhilox4_32_10_t"), (128, 0, "curandStateXORWOW_t")]
)
def test_rand_1d(M, seed, generator):
kernel = tilelang_rand_1d(M, seed, generator)
A = torch.empty(M, dtype=torch.uint32, device="cuda")
B = torch.empty(M, dtype=torch.float32, device="cuda")
C = torch.empty(M, dtype=torch.float64, device="cuda")
D = torch.empty(M, dtype=torch.float32, device="cuda")
E = torch.empty(M, dtype=torch.float64, device="cuda")
kernel(A, B, C, D, E)


if __name__ == "__main__":
tilelang.testing.main()
# test_rand_1d(1024, 42, "curandStateMRG32k3a_t")
1 change: 1 addition & 0 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .random import (
rng_init, # noqa: F401
rng_rand, # noqa: F401
rng_rand_float, # noqa: F401
)


Expand Down
28 changes: 26 additions & 2 deletions tilelang/language/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


# https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview
def rng_init(seed, seq=None, off=0):
def rng_init(seed, seq=None, off=0, generator="curandStatePhilox4_32_10_t"):
"""Initialize CUDA curand random number generator state

Parameters
Expand All @@ -14,12 +14,16 @@ def rng_init(seed, seq=None, off=0):
Sequence number for parallel random number generation.
off : PrimExpr
Offset number for parallel random number generation.
generator : StringImm
Set random generator.
See https://docs.nvidia.com/cuda/curand/group__DEVICE.html
Comment on lines +17 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring type annotation is inaccurate.

The parameter type is listed as StringImm, but the function accepts a plain Python str. Consider updating to str for clarity.

🔎 Proposed fix
-    generator : StringImm
-        Set random generator.
-        See https://docs.nvidia.com/cuda/curand/group__DEVICE.html
+    generator : str
+        The CURAND generator type to use.
+        See https://docs.nvidia.com/cuda/curand/group__DEVICE.html
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
generator : StringImm
Set random generator.
See https://docs.nvidia.com/cuda/curand/group__DEVICE.html
generator : str
The CURAND generator type to use.
See https://docs.nvidia.com/cuda/curand/group__DEVICE.html
🤖 Prompt for AI Agents
In tilelang/language/random.py around lines 17 to 19, the docstring lists the
parameter type as `StringImm` but the function actually accepts a plain Python
`str`; update the docstring type annotation to `str` and adjust any surrounding
wording to reflect it accepts a Python string (not an IR/StringImm object),
keeping the link and descriptive text intact.


Returns
-------
state : PrimExpr
The random number generator state handle.
"""
assert generator in ["curandStateMRG32k3a_t", "curandStatePhilox4_32_10_t", "curandStateXORWOW_t"]
seed = tir.convert(seed)
if seq is None:
bx = T.get_block_binding()
Expand All @@ -30,7 +34,7 @@ def rng_init(seed, seq=None, off=0):
else:
seq = tir.convert(seq)
off = tir.convert(off)
return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off)
return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off, generator)


def rng_rand():
Expand All @@ -42,3 +46,23 @@ def rng_rand():
A 32-bit unsigned random integer.
"""
return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand"))


def rng_rand_float(bit=32, dist="uniform"):
"""Generate a random float

Parameters
----------
bit : int = [32, 64]
Bitwidth of random float.
dist : StringImm = ["uniform", "normal"]
Random distribution.

Returns
-------
random_value : PrimExpr
A random float.
"""
assert bit in [32, 64]
assert dist in ["uniform", "normal"]
return tir.call_intrin("float" + str(bit), tir.op.Op.get("tl.rng_rand_float"), dist)
Loading