Skip to content

Conversation

@int3
Copy link
Contributor

@int3 int3 commented Aug 9, 2024

This makes the autotuner device-agnostic. Instead of having to know about the existence of e.g. do_bench_cudagraph, it can let the callers decide which backend-specific benchmarking function to use.

See discussion in #4417.

@int3
Copy link
Contributor Author

int3 commented Aug 9, 2024

This depends on #4392 landing first, otherwise cudagraph benchmarking will not work

@int3
Copy link
Contributor Author

int3 commented Aug 22, 2024

Bump -- just rebased

This makes the autotuner device-agnostic. Instead of having to know
about the existence of e.g. do_bench_cudagraph, it can let the callers
decide which backend-specific benchmarking function to use.

See discussion in triton-lang#4417.
@int3
Copy link
Contributor Author

int3 commented Sep 5, 2024

Rebased

@Jokeren Jokeren changed the title Make autotuner take do_bench as a parameter [AUTOTUNER] Make autotuner take do_bench as a parameter Oct 7, 2024
@Jokeren Jokeren changed the title [AUTOTUNER] Make autotuner take do_bench as a parameter [AUTOTUNER] Make autotuner take do_bench as a parameter Oct 7, 2024
@Jokeren Jokeren merged commit ab07e54 into triton-lang:main Oct 7, 2024
sfzhu93 pushed a commit to sfzhu93/triton that referenced this pull request Oct 11, 2024
…g#4496)

This makes the autotuner device-agnostic. Instead of having to know
about the existence of e.g. do_bench_cudagraph, it can let the callers
decide which backend-specific benchmarking function to use.

See discussion in triton-lang#4417.

---------

Co-authored-by: Keren Zhou <[email protected]>
Comment on lines -91 to -94
self.num_warmups = warmup
self.num_reps = rep
import torch
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Jokeren, @int3,

Fields self.num_warmups , self.num_reps and self.use_cuda_graph are used by PyTorch to find out what parameters the autotuner was called with:

https://github.com/pytorch/pytorch/blame/5141ade8e30c64e873e14dcc8de233da45d15025/torch/_higher_order_ops/triton_kernel_wrap.py#L829

Can they be left until the corresponding parameters are removed from __init__ signature?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@int3 is driving the effort. It's up to him. I'm OK either way.

minjang added a commit that referenced this pull request Oct 23, 2024
…4974)

This is a quick follow-up for the recent autotuner/testing changes as in
#4496. This PR moves the empty
cache creation into the driver code to make the code more device
independent.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…g#4496)

This makes the autotuner device-agnostic. Instead of having to know
about the existence of e.g. do_bench_cudagraph, it can let the callers
decide which backend-specific benchmarking function to use.

See discussion in triton-lang#4417.

---------

Co-authored-by: Keren Zhou <[email protected]>
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…riton-lang#4974)

This is a quick follow-up for the recent autotuner/testing changes as in
triton-lang#4496. This PR moves the empty
cache creation into the driver code to make the code more device
independent.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
…g#4496)

This makes the autotuner device-agnostic. Instead of having to know
about the existence of e.g. do_bench_cudagraph, it can let the callers
decide which backend-specific benchmarking function to use.

See discussion in triton-lang#4417.

---------

Co-authored-by: Keren Zhou <[email protected]>
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
…riton-lang#4974)

This is a quick follow-up for the recent autotuner/testing changes as in
triton-lang#4496. This PR moves the empty
cache creation into the driver code to make the code more device
independent.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…g#4496)

This makes the autotuner device-agnostic. Instead of having to know
about the existence of e.g. do_bench_cudagraph, it can let the callers
decide which backend-specific benchmarking function to use.

See discussion in triton-lang#4417.

---------

Co-authored-by: Keren Zhou <[email protected]>
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…riton-lang#4974)

This is a quick follow-up for the recent autotuner/testing changes as in
triton-lang#4496. This PR moves the empty
cache creation into the driver code to make the code more device
independent.
ThomasRaoux pushed a commit that referenced this pull request Feb 22, 2025
#5992)

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because `Previous PR has introduced a
test`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)


### Description
Related PR: #4496
In the `autotune` decorator, the `do_bench` parameter was omitted when
passed to the `Autotuner` constructor, causing `do_bench` to fail to be
default. This PR fixes this issue and ensures that the `do_bench`
parameter is passed correctly.

By this way, we can use `do_bench` parameter instead of `use_cuda_graph`
parameters which have been deprecated
loislo pushed a commit to openxla/triton that referenced this pull request Mar 4, 2025
triton-lang#5992)

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because `Previous PR has introduced a
test`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)


### Description
Related PR: triton-lang#4496
In the `autotune` decorator, the `do_bench` parameter was omitted when
passed to the `Autotuner` constructor, causing `do_bench` to fail to be
default. This PR fixes this issue and ensures that the `do_bench`
parameter is passed correctly.

By this way, we can use `do_bench` parameter instead of `use_cuda_graph`
parameters which have been deprecated
loislo pushed a commit to openxla/triton that referenced this pull request Mar 4, 2025
triton-lang#5992)

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because `Previous PR has introduced a
test`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)


### Description
Related PR: triton-lang#4496
In the `autotune` decorator, the `do_bench` parameter was omitted when
passed to the `Autotuner` constructor, causing `do_bench` to fail to be
default. This PR fixes this issue and ensures that the `do_bench`
parameter is passed correctly.

By this way, we can use `do_bench` parameter instead of `use_cuda_graph`
parameters which have been deprecated
loislo pushed a commit to openxla/triton that referenced this pull request Mar 4, 2025
triton-lang#5992)

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because `Previous PR has introduced a
test`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)


### Description
Related PR: triton-lang#4496
In the `autotune` decorator, the `do_bench` parameter was omitted when
passed to the `Autotuner` constructor, causing `do_bench` to fail to be
default. This PR fixes this issue and ensures that the `do_bench`
parameter is passed correctly.

By this way, we can use `do_bench` parameter instead of `use_cuda_graph`
parameters which have been deprecated
Jokeren pushed a commit that referenced this pull request Apr 24, 2025
<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [ ] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

Fixes #6150.

When running on a CPU host, `triton.autotune()` throws an error:
```
RuntimeError: 0 active drivers ([]). There should only be one.
```
This issue was introduced by
#4496,
which forces the caller to specify `do_bench`. But this may not be easy
in a
large codebase.

Default `do_bench` to `triton.testing.do_bench` when there's no GPU. Add
a
unit test.
FindHao pushed a commit to FindHao/triton that referenced this pull request Apr 30, 2025
<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [ ] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

Fixes triton-lang#6150.

When running on a CPU host, `triton.autotune()` throws an error:
```
RuntimeError: 0 active drivers ([]). There should only be one.
```
This issue was introduced by
triton-lang#4496,
which forces the caller to specify `do_bench`. But this may not be easy
in a
large codebase.

Default `do_bench` to `triton.testing.do_bench` when there's no GPU. Add
a
unit test.
zhzhcookie added a commit to flagos-ai/FlagGems that referenced this pull request Jul 1, 2025
…n 3.1 (#726)

Update default parameters of LibTuner to adapt to versions after Triton 3.1
This makes the autotuner device-agnostic. Instead of having to know about the existence of e.g. do_bench_cudagraph, it can let the callers decide which backend-specific benchmarking function to use.
See discussion in triton-lang/triton#4496
liuyunqi20 pushed a commit to flagos-ai/flagtree that referenced this pull request Oct 21, 2025
…4974)

This is a quick follow-up for the recent autotuner/testing changes as in
triton-lang/triton#4496. This PR moves the empty
cache creation into the driver code to make the code more device
independent.
meta-codesync bot pushed a commit to facebookexperimental/triton that referenced this pull request Nov 5, 2025
Summary:
- Let pytest just grab and test all things under a folder directly for dense output
- Skip AMD test if not on AMD GPU

`third_party/tlx/run_all.sh` now skips `third_party/tlx/tutorials/amd-gemm-pipelined.py` on NV GPU as tested locally
```
% third_party/tlx/run_all.sh
Hello! (Facebook-only)
Need to build triton in this script? {y|n}n
Run all LITs? {y|n}n
Run core Triton python unit tests? {y|n}n
Run all TLX unit tests? {y|n}n
Run TLX tutorial kernels (correctness|performance|no)? {c|p|n}
c
Verifying correctness of TLX tutorial kernels
============================================================================================ test session starts ============================================================================================
platform linux -- Python 3.11.13, pytest-8.3.4, pluggy-1.5.0
rootdir: /data/users/pchen7e4/triton
configfile: pyproject.toml
plugins: xdist-3.7.0, forked-1.6.0, typeguard-4.3.0
collected 17 items

third_party/tlx/tutorials/amd-gemm-pipelined.py s                                                                                                                                                     [  5%]
third_party/tlx/tutorials/blackwell-fa-ws-persistent_test.py .                                                                                                                                        [ 11%]
third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py .                                                                                                                              [ 17%]
third_party/tlx/tutorials/blackwell-fa-ws-pipelined_test.py .                                                                                                                                         [ 23%]
third_party/tlx/tutorials/blackwell-fa-ws_test.py .                                                                                                                                                   [ 29%]
third_party/tlx/tutorials/blackwell-gemm-clc.py .                                                                                                                                                     [ 35%]
third_party/tlx/tutorials/blackwell-gemm-pipelined.py .                                                                                                                                               [ 41%]
third_party/tlx/tutorials/blackwell-gemm-ws.py .                                                                                                                                                      [ 47%]
third_party/tlx/tutorials/blackwell-grouped-gemm.py .                                                                                                                                                 [ 52%]
third_party/tlx/tutorials/hopper-fa-ws-pipelined-pingpong_test.py s                                                                                                                                   [ 58%]
third_party/tlx/tutorials/hopper-fa-ws-pipelined_test.py s                                                                                                                                            [ 64%]
third_party/tlx/tutorials/hopper-fa-ws_test.py s                                                                                                                                                      [ 70%]
third_party/tlx/tutorials/hopper-gemm-pipelined_test.py s                                                                                                                                             [ 76%]
third_party/tlx/tutorials/hopper-gemm-ws_test.py s                                                                                                                                                    [ 82%]
third_party/tlx/tutorials/hopper-persistent-gemm-ws-cooperative.py s                                                                                                                                  [ 88%]
third_party/tlx/tutorials/hopper-persistent-gemm-ws-pingpong.py s                                                                                                                                     [ 94%]
third_party/tlx/tutorials/vector-add2.py .                                                                                                                                                            [100%]

============================================================================================= warnings summary ==============================================================================================
python/triton/runtime/autotuner.py:99
python/triton/runtime/autotuner.py:99
python/triton/runtime/autotuner.py:99
  /data/users/pchen7e4/triton/python/triton/runtime/autotuner.py:99: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See triton-lang/triton#4496 for details.
    warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py::test_op[triton-fp16-bwd-128-1024-16-8]
  /data/users/pchen7e4/miniconda3/lib/python3.11/site-packages/torch/autograd/graph.py:824: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:181.)
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================= 9 passed, 8 skipped, 4 warnings in 8.85s =====================
```

Pull Request resolved: #635

Reviewed By: htyu

Differential Revision: D86236535

Pulled By: pchen7e2

fbshipit-source-id: d17e708c39172e01351ec599cb927738236fbf87
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants