Skip to content

Commit

Permalink
fix dbias in bwd_batcher
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Sep 18, 2024
1 parent d0cb318 commit b87a668
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
10 changes: 9 additions & 1 deletion jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,11 @@ def _dot_product_attention_bwd_batcher(
*_, S, _, _ = key.shape
B = math.prod(Bs)
has_bias, has_dbias = variadic_args
# Reset the has_dbias if the combined batch size is not 1, because cuDNN only
# supports dbias with a single batch. In this case, an all-zero dbias will be
# appended instead.
if B > 1:
variadic_args = (has_bias, False)
original_query_shape = query.shape
original_key_shape = key.shape
original_value_shape = value.shape
Expand Down Expand Up @@ -708,7 +713,10 @@ def _dot_product_attention_bwd_batcher(
grads[2] = jnp.reshape(grads[2], original_value_shape)
if has_dbias:
assert has_bias
grads[3] = jnp.reshape(grads[3], original_bias_shape)
if variadic_args[1]:
grads[3] = jnp.reshape(grads[3], original_bias_shape)
else:
grads.append(jnp.zeros(original_bias_shape, bias.dtype))
out_bdims += (batch_dims[3],)
return grads, out_bdims

Expand Down
44 changes: 44 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,50 @@ def testDotProductAttentionMask(self, mask_mode):
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)

@parameterized.product(
batch_size=[1, 16],
)
def testDotProductAttentionBiasGradient(self, batch_size):
if not _is_required_cudnn_version_satisfied(8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")

dtype = jnp.bfloat16
B, S, N, H = batch_size, 512, 16, 48
keys = random.split(random.PRNGKey(0), 2)
x = random.normal(keys[0], (B, S, N, H), dtype)
bias = random.normal(keys[1], (B, N, S, S), dtype=dtype)
mask = jnp.ones((1, 1, S), dtype=jnp.bool_)

def attention(x, bias, mask):
return jax.nn.dot_product_attention(
query=x,
key=x,
value=x,
bias=bias,
mask=mask,
is_causal=False,
implementation="cudnn",
)
attention_vmap = jax.vmap(attention, in_axes=(0, 0, None))
@jax.jit
def attention_vjp_ref(x, bias, mask):
_, f_vjp = jax.vjp(attention, x, bias, mask)
return f_vjp(x)
@jax.jit
def attention_vjp_ans(x, bias, mask):
_, f_vjp = jax.vjp(attention_vmap, x, bias, mask)
return f_vjp(x)

_, dbias_ref, _ = attention_vjp_ref(x, bias, mask)
_, dbias_ans, _ = attention_vjp_ans(x, bias, mask)
self.assertAllClose(dbias_ans, dbias_ref)
# CuDNN only supports bias gradient (dbias) with a single batch size. For
# larger batch sizes, the dbias is a tensor will an all-zero placeholder.
# TODO(kaixih@nvidia): remove this when cuDNN improves.
if batch_size != 1:
self.assertTrue(jnp.all(dbias_ans == 0.))


@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
check_grads(nn.softplus, (1e-8,), order=4,
Expand Down

0 comments on commit b87a668

Please sign in to comment.