Skip to content

Causal Mask HLO for jnp.tril(jnp.ones()) can be simplified #19905

Answered by jakevdp
apivovarov asked this question in General
Discussion options

You must be logged in to vote

Rather than jnp.tril(jnp.ones((seq_len, seq_len))).astype('bool'), you might try writing jnp.tri(seq_len, dtype=bool) which is much more direct.

Edit: tri also generates the desired HLO (see openxla/xla#9709 (comment))

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@apivovarov
Comment options

Answer selected by apivovarov
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants