Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Support vmap usage of jax.nn.dot_product_attention #22830

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,12 @@ def _dot_product_attention_fwd_batcher(
*_, S, _, _ = key.shape
B = math.prod(Bs)
has_bias, _ = variadic_args
original_shape = query.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about output_shape since you only use it for output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# reshape to 4D shape
query = jnp.reshape(query, (B,) + query.shape[-3:])
key = jnp.reshape(key, (B,) + key.shape[-3:])
value = jnp.reshape(value, (B,) + key.shape[-3:])
if has_bias:
if has_bias and batch_dims[3] is not None:
bias = jnp.reshape(bias, (B, N, T, S))
if has_padding(mask_type):
q_seqlen = jnp.reshape(q_seqlen, (B, ))
Expand All @@ -635,7 +636,7 @@ def _dot_product_attention_fwd_batcher(

# reshape to original shape
output = outputs[0]
output = jnp.reshape(output, query.shape)
output = jnp.reshape(output, original_shape)
if is_training:
activation = outputs[1]
activation = jnp.reshape(activation, (*Bs, N, T))
Expand All @@ -660,11 +661,15 @@ def _dot_product_attention_bwd_batcher(
*_, S, _, _ = key.shape
B = math.prod(Bs)
has_bias, has_dbias = variadic_args
original_query_shape = query.shape
original_key_shape = key.shape
original_value_shape = value.shape
original_bias_shape = bias.shape if has_bias else None
# reshape to 4D shape
query = jnp.reshape(query, (B,) + query.shape[-3:])
key = jnp.reshape(key, (B,) + key.shape[-3:])
value = jnp.reshape(value, (B,) + key.shape[-3:])
if has_bias:
if has_bias and batch_dims[3] is not None:
bias = jnp.reshape(bias, (B, N, T, S))
if has_padding(mask_type):
q_seqlen = jnp.reshape(q_seqlen, (B, ))
Expand All @@ -681,15 +686,14 @@ def _dot_product_attention_bwd_batcher(
mask_type=mask_type, layout=layout,
)

grad_query, grad_key, grad_value = grads[:3]
# reshape to original shape
grad_query = jnp.reshape(grad_query, query.shape)
grad_key = jnp.reshape(grad_key, key.shape)
grad_value = jnp.reshape(grad_value, value.shape)
grads[0] = jnp.reshape(grads[0], original_query_shape)
grads[1] = jnp.reshape(grads[1], original_key_shape)
grads[2] = jnp.reshape(grads[2], original_value_shape)
if has_dbias:
grad_bias = grads[3]
grad_bias = jnp.reshape(grad_bias, bias.shape)
return grads + (grad_bias,), out_bdims + (query_bdim,)
assert has_bias
grads[3] = jnp.reshape(grads[3], original_bias_shape)
out_bdims += (batch_dims[3],)
return grads, out_bdims

# custom partitioning
Expand Down
48 changes: 29 additions & 19 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,9 +853,9 @@ def dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you can only specify the batch dims for the positional args in jax.vmap. For the keyward arguments, vmap will always use leading dim for the batch.

bias: ArrayLike | None = None,
mask: ArrayLike | None = None,
*,
scale: float | None = None,
is_causal: bool = False,
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
Expand All @@ -882,20 +882,20 @@ def dot_product_attention(
G = number of groups, which equals to N // K

Args:
query: query array; shape :code:`(BTNH)`
key: key array: shape :code:`(BSKH)`. When `K` equals `N`, multi-headed
attention (MHA: https://arxiv.org/abs/1706.03762) is performed. Otherwise,
grouped query attention (GQA: https://arxiv.org/abs/2305.13245) is performed
if `N` is a multiple of `K`, and multi-query attention (MQA:
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case
of GQA).
query: query array; shape :code:`(BTNH|TNH)`
key: key array: shape :code:`(BSKH|SKH)`. When `K` equals `N`, multi-headed
attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise,
grouped query attention (GQA https://arxiv.org/abs/2305.13245) is
performed if `N` is a multiple of `K`, and multi-query attention (MQA
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case
of GQA).
value: value array, should have the same shape as the `key` array.
bias: optional, bias array to be added to logits; The shape must be 4D and
be broadcastable to :code:`(BNTS)`.
be broadcastable to :code:`(BNTS|NTS)`.
mask: optional, mask array used to filter out logits. It is a boolean mask
where `True` indicates the element should take part in attention. For an
additive mask, users should pass it to `bias`. The shape must be 4D and be
broadcastable to :code:`(BNTS)`.
broadcastable to :code:`(BNTS|NTS)`.
scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query's head dimension (i.e. H).
is_causal: If true, causal attention will be applied. Note, some
Expand All @@ -912,19 +912,27 @@ def dot_product_attention(
Returns:
An array of the attention output with the same shape as :code:`query`.
"""
output_shape = jnp.asarray(query).shape
def _ensure_4d(t):
t = jnp.asarray(t)
dims_to_add = 4 - t.ndim
if dims_to_add > 0:
return jnp.expand_dims(t, axis=tuple(range(dims_to_add)))
return t

query = _ensure_4d(query)
key = _ensure_4d(key)
value = _ensure_4d(value)
bias = _ensure_4d(bias) if bias is not None else None
mask = _ensure_4d(mask) if mask is not None else None

def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
if t.ndim != len(shape):
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}")
for i in range(t.ndim):
if shape[i] != -1 and t.shape[i] != shape[i]:
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")

query = jnp.asarray(query)
key = jnp.asarray(key)
value = jnp.asarray(value)
bias = bias if bias is None else jnp.asarray(bias)
mask = mask if mask is None else jnp.asarray(mask)

B, S, K, H = key.shape
_check_has_shape(value, [B, S, K, H], 'value')
_check_has_shape(query, [B, -1, -1, H], 'query')
Expand All @@ -944,19 +952,21 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:

match implementation:
case 'xla':
return _dot_product_attention_xla(
out = _dot_product_attention_xla(
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val,
)
case 'cudnn':
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK
return cudnn_dot_product_attention(
out = cudnn_dot_product_attention(
query, key, value, bias, mask, scale=scale_val, mask_type=mask_type
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
# best backend.
return _dot_product_attention_xla(
out = _dot_product_attention_xla(
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")

return jnp.reshape(out, output_shape)
22 changes: 13 additions & 9 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ class NNFunctionsTest(jtu.JaxTestCase):
use_bias=[False, True],
causal_mode=[None, 'is_causal', 'is_mask'],
group_num=[1, 2, 4],
use_vmap=[False, True],
impl=['xla', 'cudnn'],
)
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
group_num, impl):
group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
Expand All @@ -84,26 +85,29 @@ def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)

if impl == 'cudnn':
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias=bias, mask=causal_mask)
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn('__cudnn$fmha', hlo)

if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
out_ref = sdpa_ref(Q, K_ref, V_ref, bias=bias, mask=causal_mask)
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask)

out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask)
out_ans = sdpa_ans(Q, K, V, bias, causal_mask)
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)

@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'is_causal', 'is_mask'],
group_num=[1, 2, 4],
use_vmap=[False, True],
impl=['xla', 'cudnn'],
)
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
group_num, impl):
group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
Expand All @@ -127,16 +131,16 @@ def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m)
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask)
_, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, causal_mask)
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad)
if G != 1:
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)

sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m)
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your current implementation will fail if vmapped more than once, since it requires either a 3D or a 4D array.

Wdyt about handling the N>4 case by collapsing any extra leading dimensions into B? @sbodenstein does this make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this doesn't work? I have tried this to mimic the 5D tensor and it works fine. Or do I miss your point?

    Q = random.normal(keys[0], (B, B, T, N, H), dtype)                          
    K = random.normal(keys[1], (B, B, S, N // G, H), dtype)                     
    V = random.normal(keys[2], (B, B, S, N // G, H), dtype)                     
    if use_bias:                                                                
      bias = random.normal(keys[3], (1, N, T, S), dtype)                        
    else:                                                                       
      bias = None                                                               
                                                                                
    is_causal = causal_mode == 'is_causal'                                      
    causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None  
                                                                                
    sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)          
    sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)                                
                                                                                
    if use_vmap:                                                                
      sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)  
      sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)  
    K_ref = (jnp.repeat(K, G, axis=2) if G != 1 else K).reshape(B*B, S, N, H)   
    V_ref = (jnp.repeat(V, G, axis=2) if G != 1 else V).reshape(B*B, S, N, H)   
    Q_ref = Q.reshape(B*B, T, N, H)                                             
    out_ref = sdpa_ref(Q_ref, K_ref, V_ref, bias, causal_mask).reshape(B,B,T,N,H)
                                                                                
    out_ans = sdpa_ans(Q, K, V, bias, causal_mask)                              
    self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)   

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we assert that ndim is 3 or 4 now? I would expect this to fail given a 5D input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will fail if we directly pass in the 5D tensor. The behavior now is to support (1) 4D tensor for those who want to use the batch-aware API (2) 3D tensor for those wants to use the API in the context of vmap, meaning if users have 5D tensor, they need to use vmap as shown above.

_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask)
dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad)

if impl == 'cudnn':
Expand Down