forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 41
[CPU] Support flexible active driver + update vector-add tutorial #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
minjang
merged 3 commits into
triton-lang:main
from
minjang:update-tutorial-with-flexible-driver
May 31, 2024
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ | |
| import triton | ||
| import triton.language as tl | ||
|
|
||
| BLOCK_SIZE = 1024 | ||
|
|
||
|
|
||
| @triton.jit | ||
| def add_kernel(x_ptr, # *Pointer* to first input vector. | ||
|
|
@@ -57,10 +59,10 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. | |
| # and (2) enqueue the above kernel with appropriate grid/block sizes: | ||
|
|
||
|
|
||
| def add(x: torch.Tensor, y: torch.Tensor): | ||
| def add(x: torch.Tensor, y: torch.Tensor, is_cpu): | ||
| # We need to preallocate the output. | ||
| output = torch.empty_like(x) | ||
| assert x.is_cuda and y.is_cuda and output.is_cuda | ||
| assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu | ||
| n_elements = output.numel() | ||
| # The SPMD launch grid denotes the number of kernel instances that run in parallel. | ||
| # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. | ||
|
|
@@ -78,17 +80,37 @@ def add(x: torch.Tensor, y: torch.Tensor): | |
|
|
||
| # %% | ||
| # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: | ||
|
|
||
| torch.manual_seed(0) | ||
| size = 98432 | ||
| x = torch.rand(size, device='cuda') | ||
| y = torch.rand(size, device='cuda') | ||
| output_torch = x + y | ||
| output_triton = add(x, y) | ||
| print(output_torch) | ||
| print(output_triton) | ||
| print(f'The maximum difference between torch and triton is ' | ||
| f'{torch.max(torch.abs(output_torch - output_triton))}') | ||
|
|
||
| triton.runtime.driver.set_active_to_cpu() | ||
| x = torch.rand(size, device='cpu') | ||
| y = torch.rand(size, device='cpu') | ||
| output_torch_cpu = x + y | ||
| output_triton_cpu = add(x, y, is_cpu=True) | ||
| print(output_torch_cpu) | ||
| print(output_triton_cpu) | ||
| print(f'The maximum difference between torch-cpu and triton-cpu is ' | ||
| f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') | ||
|
|
||
| LINE_VALS = ['triton-cpu', 'torch-cpu'] | ||
| LINE_NAMES = ['TritonCPU', 'TorchCPU'] | ||
| LINE_STYLES = [('blue', '-'), ('green', '-')] | ||
|
|
||
| if triton.runtime.driver.get_active_gpus(): | ||
| triton.runtime.driver.set_active_to_gpu() | ||
| x = x.to('cuda') | ||
| y = y.to('cuda') | ||
| output_torch_gpu = x + y | ||
| output_triton_gpu = add(x, y, is_cpu=False) | ||
| print(output_torch_gpu) | ||
| print(output_triton_gpu) | ||
| print(f'The maximum difference between torch-gpu and triton-gpu is ' | ||
| f'{torch.max(torch.abs(output_torch_gpu - output_triton_gpu))}') | ||
|
|
||
| LINE_VALS += ['triton-gpu', 'torch-gpu'] | ||
| LINE_NAMES += ['TritonGPU', 'TorchGPU'] | ||
| LINE_STYLES += [('yellow', '-'), ('red', '-')] | ||
|
|
||
| # %% | ||
| # Seems like we're good to go! | ||
|
|
@@ -108,21 +130,34 @@ def add(x: torch.Tensor, y: torch.Tensor): | |
| x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. | ||
| x_log=True, # x axis is logarithmic. | ||
| line_arg='provider', # Argument name whose value corresponds to a different line in the plot. | ||
| line_vals=['triton', 'torch'], # Possible values for `line_arg`. | ||
| line_names=['Triton', 'Torch'], # Label name for the lines. | ||
| styles=[('blue', '-'), ('green', '-')], # Line styles. | ||
| line_vals=LINE_VALS, # Possible values for `line_arg`. | ||
| line_names=LINE_NAMES, # Label name for the lines. | ||
| styles=LINE_STYLES, # Line styles. | ||
| ylabel='GB/s', # Label name for the y-axis. | ||
| plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. | ||
| plot_name= | ||
| # Name for the plot. Used also as a file name for saving the plot. | ||
| f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})', | ||
| args={}, # Values for function arguments not in `x_names` and `y_name`. | ||
| )) | ||
| def benchmark(size, provider): | ||
| x = torch.rand(size, device='cuda', dtype=torch.float32) | ||
| y = torch.rand(size, device='cuda', dtype=torch.float32) | ||
| device = 'cpu' if 'cpu' in provider else 'cuda' | ||
| x = torch.rand(size, device=device, dtype=torch.float32) | ||
| y = torch.rand(size, device=device, dtype=torch.float32) | ||
|
|
||
| if device == 'cpu': | ||
| triton.runtime.driver.set_active_to_cpu() | ||
| else: | ||
| triton.runtime.driver.set_active_to_gpu() | ||
|
|
||
| quantiles = [0.5, 0.2, 0.8] | ||
| if provider == 'torch': | ||
| if provider == 'torch-gpu': | ||
| ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tested on a computer without GPU. And it turned out |
||
| if provider == 'triton': | ||
| ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) | ||
| elif provider == 'triton-gpu': | ||
| ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles) | ||
| elif provider == 'torch-cpu': | ||
| ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) | ||
| elif provider == 'triton-cpu': | ||
| ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) | ||
| gbps = lambda ms: 12 * size / ms * 1e-6 | ||
| return gbps(ms), gbps(max_ms), gbps(min_ms) | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was from my scratch. Unused. Removing.