-
Notifications
You must be signed in to change notification settings - Fork 446
[Language]Adds a random number generation capability through curand_kernel #1461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
6247540
add curand.{curand_init, curand}
silentCoder-dev 9e52af1
run format.sh
silentCoder-dev 7b82270
add default value for curand_init & add test for curand
silentCoder-dev 4624c2d
Update testing/python/language/test_rand.py
silentCoder-dev 568ad20
remove unused library
silentCoder-dev c277a43
enable tilelang cache for testing
silentCoder-dev 5afaff7
run format.sh
silentCoder-dev c435759
Revert "run format.sh"
silentCoder-dev 28116d0
Revert "enable tilelang cache for testing"
silentCoder-dev 64c8554
Revert "remove unused library"
silentCoder-dev 2976ade
run format.sh
silentCoder-dev 8417356
ensure FreshName for __philox_state
silentCoder-dev c0e2141
ensure FreshName for __philox_state
silentCoder-dev 611ef11
change the return type of T.rng_init
silentCoder-dev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import tilelang | ||
| import tilelang.language as T # noqa: N812 | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @tilelang.jit | ||
| def tilelang_rand_1d(M=1024, seed=42): | ||
| blk_M = 128 | ||
| num_threads = 128 | ||
|
|
||
| @T.prim_func | ||
| def rand_kernel(A: T.Tensor((M,), "uint32")): | ||
| with T.Kernel(M // blk_M, threads=num_threads) as bx: | ||
| T.rng_init(seed) | ||
| for i in T.Parallel(blk_M): | ||
| A[bx * blk_M + i] = T.rng_rand() | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return rand_kernel | ||
|
|
||
|
|
||
| @triton.jit | ||
| def triton_rand_1d(X, M, seed): | ||
| pid = tl.program_id(0) | ||
| offset = pid * M + tl.arange(0, M) | ||
| rand = tl.randint(seed, offset) | ||
| tl.store(X + offset, rand, mask=offset < M) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| M = 1024 | ||
| kernel = tilelang_rand_1d() | ||
| x = torch.empty(M, dtype=torch.uint32, device="cuda") | ||
| kernel(x) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| from tvm import tir | ||
| import tilelang.language as T | ||
|
|
||
|
|
||
| # https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview | ||
| def rng_init(seed, seq=None, off=0): | ||
| """Initialize CUDA curand random number generator state | ||
|
|
||
| Parameters | ||
| ---------- | ||
| seed : PrimExpr | ||
| Random seed value. | ||
| seq : PrimExpr | ||
| Sequence number for parallel random number generation. | ||
| off : PrimExpr | ||
| Offset number for parallel random number generation. | ||
|
|
||
| Returns | ||
| ------- | ||
| state : PrimExpr | ||
| The random number generator state handle. | ||
| """ | ||
| seed = tir.convert(seed) | ||
| if seq is None: | ||
| bx = T.get_block_binding() | ||
| ex = T.kernel.get_thread_extent() | ||
| tx = T.get_thread_binding() | ||
| id = tx + bx * ex | ||
| seq = tir.convert(id) | ||
| else: | ||
| seq = tir.convert(seq) | ||
| off = tir.convert(off) | ||
| return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off) | ||
|
|
||
|
|
||
| def rng_rand(): | ||
| """Generate a 32-bit unsigned random integer | ||
|
|
||
| Returns | ||
| ------- | ||
| random_value : PrimExpr | ||
| A 32-bit unsigned random integer. | ||
| """ | ||
| return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand")) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.