Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Throw an error for unsupported bias shape when cudnn backend is used in nn.dot_product_attention #23740

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
from functools import partial
import operator
import math
import numpy as np
from typing import Any, Literal
import warnings
Expand All @@ -34,6 +35,8 @@
from jax._src.core import AxisName
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.numpy import util as numpy_util
from jax._src.typing import Array, ArrayLike, DType
from jax._src.ops.special import logsumexp as _logsumexp
Expand Down Expand Up @@ -900,6 +903,68 @@ def _reshape_to_grouped(t):
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded

def bias_fwd_rule(a, query_head_num):
return bias_fwd_p.bind(a, query_head_num), a
def bias_bwd_rule(query_head_num, res, g):
a = res
if a.shape[0] > 1 or a.shape[-3] != query_head_num:
raise ValueError("cuDNN only supports bias gradient when the batch size is "
f"1 and the head number matches the query, but got "
f"B={a.shape[0]}, N={a.shape[-3]}.")
return (bias_bwd_p.bind(g, a, query_head_num),)

# This function uses two custom primitives, `bias_fwd` and `bias_bwd`, to work
# around a cuDNN issue where bias gradients are only supported when the batch
# size is 1 and the number of heads matches the query.
# TODO(kaixih@nvidia): Remove this workaround once cuDNN resolves the issue.
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def check_valid_bias_batch(x, query_head_num):
output, _ = bias_fwd_rule(x, query_head_num)
return output
check_valid_bias_batch.defvjp(bias_fwd_rule, bias_bwd_rule)

bias_fwd_p = core.Primitive('bias_fwd')
bias_fwd_p.multiple_results = False
bias_bwd_p = core.Primitive('bias_bwd')
bias_bwd_p.multiple_results = False

def bias_fwd_impl(a, query_head_num):
return a
def bias_bwd_impl(g, a, query_head_num):
return g
bias_fwd_p.def_impl(bias_fwd_impl)
bias_bwd_p.def_impl(bias_bwd_impl)

def bias_fwd_abstract_eval(a, query_head_num):
return core.ShapedArray(a.shape, a.dtype)
def bias_bwd_abstract_eval(g, a, query_head_num):
return core.ShapedArray(g.shape, g.dtype)
bias_fwd_p.def_abstract_eval(bias_fwd_abstract_eval)
bias_bwd_p.def_abstract_eval(bias_bwd_abstract_eval)

def bias_fwd_lowering(ctx, a, query_head_num):
return [a]
def bias_bwd_lowering(ctx, g, a, query_head_num):
return [g]
mlir.register_lowering(bias_fwd_p, bias_fwd_lowering)
mlir.register_lowering(bias_bwd_p, bias_bwd_lowering)

def bias_fwd_batch_rule(batched_args, batch_dims):
x, query_head_num = batched_args
a = batch_dims[0]
output, _ = bias_fwd_rule(x, query_head_num)
return output, a
def bias_bwd_batch_rule(batched_args, batch_dims):
g, x, query_head_num = batched_args
b = batch_dims[0]
*Bs, _, _, _ = x.shape
B = math.prod(Bs)
x = jnp.reshape(x, (B,) + x.shape[-3:])
output, = bias_bwd_rule(query_head_num, x, g)
return output, b
batching.primitive_batchers[bias_fwd_p] = bias_fwd_batch_rule
batching.primitive_batchers[bias_bwd_p] = bias_bwd_batch_rule

def dot_product_attention(
query: ArrayLike,
key: ArrayLike,
Expand Down Expand Up @@ -1032,6 +1097,9 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
local_window_size=local_window_size,
)
case 'cudnn':
if bias is not None:
bias = check_valid_bias_batch(bias, query_arr.shape[-2])
bias = jnp.asarray(bias)
use_padding = (
query_seq_lengths is not None or key_value_seq_lengths is not None
)
Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,9 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"bitcast",
"repeat",
"roll",
# temporary pending cudnn fix, see https://github.com/jax-ml/jax/pull/23740
"bias_fwd",
"bias_bwd",
]

tf_impl[random_internal.random_clone_p] = lambda x: x
Expand Down
59 changes: 59 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _check_cudnn_backend(fn, *args, **kwargs):
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
return '__cudnn$fmha' in hlo

_cudnn_dbias_error = 'cuDNN only supports bias gradient'

@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
Expand Down Expand Up @@ -167,6 +169,63 @@ def testDotProductAttentionMask(self, mask_mode):
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)

@parameterized.product(
batch_size=[1, 16],
use_vmap=[False, True],
)
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
if not _is_required_cudnn_version_satisfied(8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")

dtype = jnp.bfloat16
B, S, N, H = batch_size, 128, 4, 32
keys = random.split(random.PRNGKey(0), 2)
x = random.normal(keys[0], (B, S, N, H), dtype)
bias = random.normal(keys[1], (B, N, S, S), dtype=dtype)
mask = jnp.ones((1, 1, S), dtype=jnp.bool_)

def attention(x, bias, mask, impl):
return jax.nn.dot_product_attention(
query=x,
key=x,
value=x,
bias=bias,
mask=mask,
is_causal=False,
implementation=impl,
)
attn_ref = partial(attention, impl=None)
attn_ans = partial(attention, impl='cudnn')
if use_vmap:
attn_batched_ref = jax.vmap(attn_ref, in_axes=(0, 0, None))
attn_batched_ans = jax.vmap(attn_ans, in_axes=(0, 0, None))
else:
attn_batched_ref = attn_ref
attn_batched_ans = attn_ans

fwd_ref = jax.jit(attn_batched_ref)
fwd_ans = jax.jit(attn_batched_ans)
y_ref = fwd_ref(x, bias, mask)
y_ans = fwd_ans(x, bias, mask)
self.assertAllClose(y_ref, y_ans)

@jax.jit
def bwd_ref(x, bias, mask):
_, f_vjp = jax.vjp(attn_ref, x, bias, mask)
return f_vjp(x)
@jax.jit
def bwd_ans(x, bias, mask):
_, f_vjp = jax.vjp(attn_ans, x, bias, mask)
return f_vjp(x)

if batch_size != 1:
with self.assertRaisesRegex(ValueError, _cudnn_dbias_error):
_, dbias_ans, _ = bwd_ans(x, bias, mask)
else:
_, dbias_ref, _ = bwd_ref(x, bias, mask)
_, dbias_ans, _ = bwd_ans(x, bias, mask)
self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
check_grads(nn.softplus, (1e-8,), order=4,
Expand Down
Loading