Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Causal mask suboptimal HLO simplification #9867

Closed
wants to merge 1 commit into from

Conversation

apivovarov
Copy link
Contributor

@apivovarov apivovarov commented Feb 23, 2024

Below is a common JAX code issue encountered when generating a Causal mask. The user either neglected to specify dtype=bool in ones() or mistakenly applied .astype(bool) to the result of tril() instead of to ones(). Consequently, the mask will be converted from f32 to bool, resulting in suboptimal HLO.

mask = jnp.tril(jnp.ones((seq_len, seq_len)))
res = jnp.where(mask, x, -jnp.inf)

# it will be lowered to the following suboptimal HLO
%cmp0 = pred compare(s32, s32, direction=GE)
%sel0 = f32 select(%cmp0, ones, zeros)
%cmp1 = pred compare(%sel0, zeros, direction=NE)

# which can be simplified to just
%cmp0 = pred compare(s32, s32, direction=GE)

Simplification:
Ne(select(Ge(a, b), ones, zeros), zeros) -> Ge(a, b)

Discussion: #9709

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 23, 2024
@kokoro-team kokoro-team removed kokoro:force-run Forces CI to rerun labels Feb 23, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Feb 23, 2024
Imported from GitHub PR openxla/xla#9867

Below is a common JAX code issue encountered when generating a Causal mask. The user either neglected to specify `dtype=bool` in `ones()` or mistakenly applied `.astype(bool)` to the result of `tril()` instead of to `ones()`. Consequently, the mask will be converted from f32 to bool, resulting in suboptimal HLO.

```python
mask = jnp.tril(jnp.ones((seq_len, seq_len)))
res = jnp.where(mask, x, -jnp.inf)

# it will be lowered to the following suboptimal HLO
%cmp0 = pred compare(s32, s32, direction=GE)
%sel0 = f32 select(%cmp0, ones, zeros)
%cmp1 = pred compare(%sel0, zeros, direction=NE)

# which can be simplified to just
%cmp0 = pred compare(s32, s32, direction=GE)

```
Simplification:
`Ne(select(Ge(a, b), ones, zeros), zeros) -> Ge(a, b)`

Discussion: openxla/xla#9709
Copybara import of the project:

--
fcd42cd9822429beeba8664c1f67af7246f74b04 by Alexander Pivovarov <[email protected]>:

Causal mask suboptimal HLO simplification

Merging this change closes #9867

PiperOrigin-RevId: 609676233
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants