Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for setting a dot product "algorithm" for lax.dot_general. #23574

Closed
wants to merge 0 commits into from

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Sep 11, 2024

Add support for setting a dot product "algorithm" for lax.dot_general.

The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.

This parameter can be a member of the SupportedDotAlgorithm Enum to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the DotAlgorithm data structure which exposes the full generality of the StableHLO spec.

Transposition is supported using the transpose_algorithm argument.

@copybara-service copybara-service bot force-pushed the test_670551032 branch 6 times, most recently from 8f5dd81 to e7f6a6c Compare September 12, 2024 19:36
@copybara-service copybara-service bot changed the title Add support for setting the dot product "algorithm" for lax.dot_general. Add support for setting a dot product "algorithm" for lax.dot_general. Sep 12, 2024
@copybara-service copybara-service bot force-pushed the test_670551032 branch 4 times, most recently from 90198fc to d6bea09 Compare September 16, 2024 16:21
dfm added a commit to dfm/jax that referenced this pull request Sep 16, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
dfm added a commit to dfm/jax that referenced this pull request Sep 16, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
dfm added a commit to dfm/jax that referenced this pull request Sep 16, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
dfm added a commit to dfm/jax that referenced this pull request Sep 16, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
dfm added a commit to dfm/jax that referenced this pull request Sep 17, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
dfm added a commit to dfm/jax that referenced this pull request Sep 18, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
dfm added a commit to dfm/jax that referenced this pull request Sep 18, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
keshavb96 added a commit to keshavb96/jax that referenced this pull request Sep 18, 2024
commit 093c6e9
Merge: e1a77ee d0cb318
Author: Keshav <[email protected]>
Date:   Wed Sep 18 13:37:44 2024 -0700

    Merge remote-tracking branch 'upstream/main' into disable_remat_pass

commit e1a77ee
Author: Keshav <[email protected]>
Date:   Wed Sep 18 13:35:37 2024 -0700

    minor changes

commit d0cb318
Merge: b51c653 bef36c4
Author: jax authors <[email protected]>
Date:   Wed Sep 18 13:34:11 2024 -0700

    Merge pull request jax-ml#23736 from hawkinsp:changelog

    PiperOrigin-RevId: 676111400

commit b51c653
Merge: dbc03cf 57a4b76
Author: jax authors <[email protected]>
Date:   Wed Sep 18 13:33:05 2024 -0700

    Merge pull request jax-ml#23737 from jakevdp:digitize-doc

    PiperOrigin-RevId: 676111220

commit dbc03cf
Author: Dan Foreman-Mackey <[email protected]>
Date:   Wed Sep 18 12:39:58 2024 -0700

    Re-land jax-ml#23261 with appropriate compatibility checks.

    PiperOrigin-RevId: 676092618

commit b164d67
Merge: cd04d0f 541b3a3
Author: jax authors <[email protected]>
Date:   Wed Sep 18 12:05:03 2024 -0700

    Merge pull request jax-ml#23247 from kaixih:sliding_window_attn

    PiperOrigin-RevId: 676079831

commit 57a4b76
Author: Jake VanderPlas <[email protected]>
Date:   Wed Sep 18 11:59:00 2024 -0700

    Improve documentation for jnp.digitize

commit bef36c4
Author: Peter Hawkins <[email protected]>
Date:   Wed Sep 18 18:57:03 2024 +0000

    Add Python 3.13 wheels to changelog.

commit cd04d0f
Merge: 016c499 c756d9b
Author: jax authors <[email protected]>
Date:   Wed Sep 18 10:00:03 2024 -0700

    Merge pull request jax-ml#23726 from hawkinsp:debug

    PiperOrigin-RevId: 676030839

commit 016c499
Author: Sergei Lebedev <[email protected]>
Date:   Wed Sep 18 09:56:44 2024 -0700

    Removed leftover usages of GPUGridSpec from Pallas Mosaic GPU tests

    PiperOrigin-RevId: 676029854

commit 9dd363d
Author: Luke Baumann <[email protected]>
Date:   Wed Sep 18 09:28:25 2024 -0700

    Export `jax.lib.xla_extension.ifrt_programs`.

    PiperOrigin-RevId: 676020419

commit e27f1e9
Author: jax authors <[email protected]>
Date:   Wed Sep 18 09:03:55 2024 -0700

    Change Python version 3.13.0rc2 to 3.13.0-rc.2.

    The value is taken from [the versions manifest](https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json).

    PiperOrigin-RevId: 676012255

commit 442e863
Author: Sergei Lebedev <[email protected]>
Date:   Wed Sep 18 08:56:49 2024 -0700

    Added a missing branch to `mgpu.FragmentedArray.astype`

    Previously, an unsupported cast produced a `NameError` instead.

    PiperOrigin-RevId: 676010161

commit 6236b8f
Merge: 826843a 1cc9661
Author: jax authors <[email protected]>
Date:   Wed Sep 18 08:57:38 2024 -0700

    Merge pull request jax-ml#23667 from dfm:always-lower-jnp-dot-to-dot-general

    PiperOrigin-RevId: 676010154

commit c756d9b
Author: Peter Hawkins <[email protected]>
Date:   Wed Sep 18 15:44:45 2024 +0000

    Fix error in debugger tests that is showing up in CI.

    I'm unsure why this started happening now, but sometimes we get an
    invalid offset for a frame. Be tolerant of that case.

commit 826843a
Merge: c191bbc 922e652
Author: jax authors <[email protected]>
Date:   Wed Sep 18 08:42:39 2024 -0700

    Merge pull request jax-ml#23723 from hawkinsp:setuptools

    PiperOrigin-RevId: 676005613

commit c191bbc
Author: Yash Katariya <[email protected]>
Date:   Wed Sep 18 08:40:30 2024 -0700

    Make `debug.print` work with static args. Fixes: jax-ml#23600

    PiperOrigin-RevId: 676005582

commit 1cc9661
Author: Dan Foreman-Mackey <[email protected]>
Date:   Mon Sep 16 14:18:29 2024 -0400

    Unconditionally lower jnp.dot to lax.dot_general.

    jax-ml#16721 added a condition to lower
    calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
    `lax.dot_general`. AFAICT, jax-ml#16826
    fixed the issue that this was solving, so this condition should no
    longer be necessary. Removing this condition simplifies the addition of
    new arguments to `dot` and `dot_general`, including the `algorithm`
    parameter that I am currently working on in
    jax-ml#23574, so now seemed like a good time
    to remove it!

commit 922e652
Author: Peter Hawkins <[email protected]>
Date:   Wed Sep 18 15:17:49 2024 +0000

    Replace plat-name with plat_name.

    The former seems to elicit a deprecation warning from setuptools
    recently.

commit 69ba060
Author: Dan Foreman-Mackey <[email protected]>
Date:   Wed Sep 18 07:40:58 2024 -0700

    Reverts e15ec1e

    PiperOrigin-RevId: 675987338

commit 44a7f04
Merge: 0a29696 2834c13
Author: jax authors <[email protected]>
Date:   Wed Sep 18 07:31:00 2024 -0700

    Merge pull request jax-ml#23708 from jakevdp:sort-complex

    PiperOrigin-RevId: 675983957

commit 0a29696
Merge: e15ec1e 73c38cb
Author: jax authors <[email protected]>
Date:   Wed Sep 18 07:08:24 2024 -0700

    Merge pull request jax-ml#23698 from dfm:dev-clang-warning

    PiperOrigin-RevId: 675977448

commit 2834c13
Author: Jake VanderPlas <[email protected]>
Date:   Tue Sep 17 15:32:25 2024 -0700

    jnp.sort_complex: fix output for N-dimensional inputs

commit 73c38cb
Author: Dan Foreman-Mackey <[email protected]>
Date:   Tue Sep 17 14:00:21 2024 -0400

    Add a note to the developer docs making it clear that clang is the only
    toolchain that is actively supported for source compilation.

    As discussed in jax-ml#23687

commit e15ec1e
Merge: 48d8fce 3f2bc9b
Author: jax authors <[email protected]>
Date:   Wed Sep 18 06:56:28 2024 -0700

    Merge pull request jax-ml#23261 from joaospinto:stablehlo.tan

    PiperOrigin-RevId: 675973798

commit 48d8fce
Merge: 4e6f690 2714469
Author: jax authors <[email protected]>
Date:   Wed Sep 18 06:54:28 2024 -0700

    Merge pull request jax-ml#23563 from rajasekharporeddy:testbranch1

    PiperOrigin-RevId: 675973225

commit 4e6f690
Merge: b7c91e9 611ad63
Author: jax authors <[email protected]>
Date:   Wed Sep 18 06:35:15 2024 -0700

    Merge pull request jax-ml#23653 from apaszke:torchsaic

    PiperOrigin-RevId: 675967844

commit b7c91e9
Author: Sergei Lebedev <[email protected]>
Date:   Wed Sep 18 06:22:14 2024 -0700

    Lookup `shape` and `dtype` directly on `state.AbstractRef` instead of going through `inner_aval`

    This is just a cleanup. No behavior changes are expected.

    PiperOrigin-RevId: 675964703

commit 611ad63
Author: Adam Paszke <[email protected]>
Date:   Fri Sep 6 16:09:58 2024 +0000

    Add basic PyTorch integration for Mosaic GPU

    We have already had most of the relevant pieces and we only needed
    to connect them together. The most sensitive change is perhaps that
    I needed to expose one more symbol from the XLA GPU plugin, but I don't
    think it should be a problem.

commit e903369
Author: Sergei Lebedev <[email protected]>
Date:   Wed Sep 18 05:25:37 2024 -0700

    Pulled `scratch_shapes` into `GridSpec`

    It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton.

    PiperOrigin-RevId: 675950199

commit 2714469
Author: rajasekharporeddy <[email protected]>
Date:   Wed Sep 18 17:06:28 2024 +0530

    Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros

commit b904599
Author: Sergei Lebedev <[email protected]>
Date:   Wed Sep 18 04:23:25 2024 -0700

    `pl.debug_print` no longer restricts values to be scalars

    This allows printing arrays on Triton and soon on Mosaic GPU.

    PiperOrigin-RevId: 675935666

commit 988ed2b
Author: jax authors <[email protected]>
Date:   Tue Sep 17 21:09:26 2024 -0700

    Add support for SMEM windows in Pallas custom pipeline.

    PiperOrigin-RevId: 675822640

commit f79d85b
Merge: 1b74cfd cc28d63
Author: Keshav <[email protected]>
Date:   Tue Sep 17 18:58:33 2024 -0700

    Merge remote-tracking branch 'upstream/main' into disable_remat_pass

commit cc28d63
Merge: 8bcdb12 9d3762b
Author: jax authors <[email protected]>
Date:   Tue Sep 17 17:36:36 2024 -0700

    Merge pull request jax-ml#23682 from sharadmv:pallas-async-docs

    PiperOrigin-RevId: 675770723

commit 1b74cfd
Author: Keshav <[email protected]>
Date:   Tue Sep 17 17:23:30 2024 -0700

    disable remat hlo pass by default

commit 8bcdb12
Author: jax authors <[email protected]>
Date:   Tue Sep 17 16:50:55 2024 -0700

    Add CI jobs for python 3.13.0rc2.

    PiperOrigin-RevId: 675758096

commit 8b5b717
Author: Yash Katariya <[email protected]>
Date:   Tue Sep 17 16:39:55 2024 -0700

    Fix jaxpr equation context propagation in jaxpr equations when `inline=True`.

    PiperOrigin-RevId: 675754808

commit 86fe463
Author: Parker Schuh <[email protected]>
Date:   Tue Sep 17 16:10:41 2024 -0700

    [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.

    This allows us to get more cache hits globally. For example:

    Before:

    jax.jit(f, out_shardings=s)(arr)
    jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
    After:

    jax.jit(f, out_shardings=s)(arr)
    jax.jit(f, out_shardings=s)(arr)  # cpp cache hit

    Reverts b615266

    PiperOrigin-RevId: 675746175

commit e92a599
Author: Christos Perivolaropoulos <[email protected]>
Date:   Tue Sep 17 15:26:42 2024 -0700

    [mosaic_gpu] Better error message for misaligned tma_transpose with dtype.

    PiperOrigin-RevId: 675731295

commit 7864648
Merge: 3f2c58b 83a7555
Author: jax authors <[email protected]>
Date:   Tue Sep 17 15:12:50 2024 -0700

    Merge pull request jax-ml#23679 from selamw1:docstring_sort_complex

    PiperOrigin-RevId: 675726527

commit 83a7555
Author: selamw1 <[email protected]>
Date:   Mon Sep 16 16:47:52 2024 -0700

    docstring_sort_complex_added

    input_array_modified

commit 9d3762b
Author: Sharad Vikram <[email protected]>
Date:   Mon Sep 16 19:18:22 2024 -0700

    [Pallas] Add design note for async ops on TPU

commit 3f2bc9b
Author: Joao Sousa-Pinto <[email protected]>
Date:   Mon Aug 26 17:25:16 2024 -0700

    Lower tan to StableHLO instead of CHLO.

    Fixes jax-ml#23259

commit 541b3a3
Author: kaixih <[email protected]>
Date:   Mon Aug 26 17:32:38 2024 +0000

    New feature
@copybara-service copybara-service bot force-pushed the test_670551032 branch 2 times, most recently from 90b56a1 to a552387 Compare September 20, 2024 15:38
rajasekharporeddy pushed a commit to rajasekharporeddy/jax that referenced this pull request Sep 20, 2024
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
@dfm
Copy link
Collaborator

dfm commented Sep 23, 2024

The doctests are failing for the new docstring because this new feature requires jaxlib > 0.4.33, which hasn't been released yet. @jakevdp do you have any suggestions for a reasonable approach for including these examples, but "xfailing" when the jaxlib version is too old?

@copybara-service copybara-service bot force-pushed the test_670551032 branch 2 times, most recently from 6aac0fe to 6a68403 Compare September 25, 2024 13:04
@copybara-service copybara-service bot closed this Sep 25, 2024
@copybara-service copybara-service bot deleted the test_670551032 branch September 25, 2024 13:17
copybara-service bot pushed a commit to google-deepmind/kfac-jax that referenced this pull request Sep 25, 2024
…tuning parameters for dot_general that will be included in the next JAX release.

More information about the change to JAX can be found at jax-ml/jax#23574 and jax-ml/jax#23797.

PiperOrigin-RevId: 674292750
copybara-service bot pushed a commit to google-deepmind/kfac-jax that referenced this pull request Sep 25, 2024
…tuning parameters for dot_general that will be included in the next JAX release.

More information about the change to JAX can be found at jax-ml/jax#23574 and jax-ml/jax#23797.

PiperOrigin-RevId: 678674241
copybara-service bot pushed a commit that referenced this pull request Oct 2, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 2, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 3, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 7, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 7, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 7, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 7, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 681543555
copybara-service bot pushed a commit that referenced this pull request Oct 7, 2024
In #23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for the new "dot algorithm" spec for more explicit control of dot product numerics
1 participant