-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support vmap usage of jax.nn.dot_product_attention
#22830
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -853,9 +853,9 @@ def dot_product_attention( | |
query: ArrayLike, | ||
key: ArrayLike, | ||
value: ArrayLike, | ||
*, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems you can only specify the batch dims for the positional args in jax.vmap. For the keyward arguments, vmap will always use leading dim for the batch. |
||
bias: ArrayLike | None = None, | ||
mask: ArrayLike | None = None, | ||
*, | ||
scale: float | None = None, | ||
is_causal: bool = False, | ||
implementation: Literal['xla', 'cudnn'] | None = None) -> Array: | ||
|
@@ -882,20 +882,20 @@ def dot_product_attention( | |
G = number of groups, which equals to N // K | ||
|
||
Args: | ||
query: query array; shape :code:`(BTNH)` | ||
key: key array: shape :code:`(BSKH)`. When `K` equals `N`, multi-headed | ||
attention (MHA: https://arxiv.org/abs/1706.03762) is performed. Otherwise, | ||
grouped query attention (GQA: https://arxiv.org/abs/2305.13245) is performed | ||
if `N` is a multiple of `K`, and multi-query attention (MQA: | ||
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case | ||
of GQA). | ||
query: query array; shape :code:`(BTNH|TNH)` | ||
key: key array: shape :code:`(BSKH|SKH)`. When `K` equals `N`, multi-headed | ||
attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise, | ||
grouped query attention (GQA https://arxiv.org/abs/2305.13245) is | ||
performed if `N` is a multiple of `K`, and multi-query attention (MQA | ||
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case | ||
of GQA). | ||
value: value array, should have the same shape as the `key` array. | ||
bias: optional, bias array to be added to logits; The shape must be 4D and | ||
be broadcastable to :code:`(BNTS)`. | ||
be broadcastable to :code:`(BNTS|NTS)`. | ||
mask: optional, mask array used to filter out logits. It is a boolean mask | ||
where `True` indicates the element should take part in attention. For an | ||
additive mask, users should pass it to `bias`. The shape must be 4D and be | ||
broadcastable to :code:`(BNTS)`. | ||
broadcastable to :code:`(BNTS|NTS)`. | ||
scale: scale for the logits. If None, the scale will be set to 1 divided by | ||
the square root of query's head dimension (i.e. H). | ||
is_causal: If true, causal attention will be applied. Note, some | ||
|
@@ -912,19 +912,27 @@ def dot_product_attention( | |
Returns: | ||
An array of the attention output with the same shape as :code:`query`. | ||
""" | ||
output_shape = jnp.asarray(query).shape | ||
def _ensure_4d(t): | ||
t = jnp.asarray(t) | ||
dims_to_add = 4 - t.ndim | ||
if dims_to_add > 0: | ||
return jnp.expand_dims(t, axis=tuple(range(dims_to_add))) | ||
return t | ||
|
||
query = _ensure_4d(query) | ||
key = _ensure_4d(key) | ||
value = _ensure_4d(value) | ||
bias = _ensure_4d(bias) if bias is not None else None | ||
mask = _ensure_4d(mask) if mask is not None else None | ||
|
||
def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: | ||
if t.ndim != len(shape): | ||
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}") | ||
for i in range(t.ndim): | ||
if shape[i] != -1 and t.shape[i] != shape[i]: | ||
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}") | ||
|
||
query = jnp.asarray(query) | ||
key = jnp.asarray(key) | ||
value = jnp.asarray(value) | ||
bias = bias if bias is None else jnp.asarray(bias) | ||
mask = mask if mask is None else jnp.asarray(mask) | ||
|
||
B, S, K, H = key.shape | ||
_check_has_shape(value, [B, S, K, H], 'value') | ||
_check_has_shape(query, [B, -1, -1, H], 'query') | ||
|
@@ -944,19 +952,21 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: | |
|
||
match implementation: | ||
case 'xla': | ||
return _dot_product_attention_xla( | ||
out = _dot_product_attention_xla( | ||
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, | ||
) | ||
case 'cudnn': | ||
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK | ||
return cudnn_dot_product_attention( | ||
out = cudnn_dot_product_attention( | ||
query, key, value, bias, mask, scale=scale_val, mask_type=mask_type | ||
) | ||
case None: | ||
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select | ||
# best backend. | ||
return _dot_product_attention_xla( | ||
out = _dot_product_attention_xla( | ||
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, | ||
) | ||
case _: | ||
raise ValueError(f"Unsupported implementation option: {implementation}") | ||
|
||
return jnp.reshape(out, output_shape) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,10 +57,11 @@ class NNFunctionsTest(jtu.JaxTestCase): | |
use_bias=[False, True], | ||
causal_mode=[None, 'is_causal', 'is_mask'], | ||
group_num=[1, 2, 4], | ||
use_vmap=[False, True], | ||
impl=['xla', 'cudnn'], | ||
) | ||
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, | ||
group_num, impl): | ||
group_num, use_vmap, impl): | ||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): | ||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") | ||
if impl == 'cudnn' and dtype == jnp.float32: | ||
|
@@ -84,26 +85,29 @@ def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, | |
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) | ||
|
||
if impl == 'cudnn': | ||
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias=bias, mask=causal_mask) | ||
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask) | ||
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) | ||
self.assertIn('__cudnn$fmha', hlo) | ||
|
||
if use_vmap: | ||
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) | ||
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K | ||
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V | ||
out_ref = sdpa_ref(Q, K_ref, V_ref, bias=bias, mask=causal_mask) | ||
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask) | ||
|
||
out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask) | ||
out_ans = sdpa_ans(Q, K, V, bias, causal_mask) | ||
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) | ||
|
||
@parameterized.product( | ||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16], | ||
use_bias=[False, True], | ||
causal_mode=[None, 'is_causal', 'is_mask'], | ||
group_num=[1, 2, 4], | ||
use_vmap=[False, True], | ||
impl=['xla', 'cudnn'], | ||
) | ||
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, | ||
group_num, impl): | ||
group_num, use_vmap, impl): | ||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): | ||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") | ||
if impl == 'cudnn' and dtype == jnp.float32: | ||
|
@@ -127,16 +131,16 @@ def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, | |
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K | ||
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V | ||
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) | ||
fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m) | ||
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask) | ||
_, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, causal_mask) | ||
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad) | ||
if G != 1: | ||
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) | ||
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) | ||
|
||
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) | ||
fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m) | ||
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask) | ||
if use_vmap: | ||
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think your current implementation will fail if vmapped more than once, since it requires either a 3D or a 4D array. Wdyt about handling the N>4 case by collapsing any extra leading dimensions into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this doesn't work? I have tried this to mimic the 5D tensor and it works fine. Or do I miss your point?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we assert that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it will fail if we directly pass in the 5D tensor. The behavior now is to support (1) 4D tensor for those who want to use the batch-aware API (2) 3D tensor for those wants to use the API in the context of vmap, meaning if users have 5D tensor, they need to use vmap as shown above. |
||
_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) | ||
dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad) | ||
|
||
if impl == 'cudnn': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
output_shape
since you only use it foroutput
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.