Replies: 1 comment
-
Seems to be a bug key=jax.random.key(0)
base=jax.random.normal(key,(3, 4092, 24, 32))
q,k,v=jnp.split(base,3,axis=0)
for i in range(10):
h0=mha(q,k,v,None,sm_scale=1/math.sqrt(q.shape[-1]))
h1=mha_reference(q,k,v,None,sm_scale=1/math.sqrt(q.shape[-1]))
print(jnp.mean(jnp.abs(h0-h1))) outputs
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Incorrect results of
pallas.ops.gpu.attention.mha
whenseq_len
is not divisible by block_q. Is this expected behavior, or is it a bug?Example code
Beta Was this translation helpful? Give feedback.
All reactions