Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Support vmap usage of jax.nn.dot_product_attention #22830

Merged
merged 2 commits into from
Aug 7, 2024

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Aug 1, 2024

To address this request: #22760, this PR support no-batch inputs.

@@ -853,9 +853,9 @@ def dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you can only specify the batch dims for the positional args in jax.vmap. For the keyward arguments, vmap will always use leading dim for the batch.

@@ -618,11 +618,12 @@ def _dot_product_attention_fwd_batcher(
*_, S, _, _ = key.shape
B = math.prod(Bs)
has_bias, _ = variadic_args
original_shape = query.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about output_shape since you only use it for output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if t is None:
return t
t = jnp.asarray(t)
return t[None, ...] if t.ndim == 3 else t
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to assert that t.ndim is 4 if it's not 3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -912,19 +912,25 @@ def dot_product_attention(
Returns:
An array of the attention output with the same shape as :code:`query`.
"""
original_shape = jnp.asarray(query).shape
def _preprocess_array(t):
if t is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would personally move this out, since only bias and mask can be None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -912,19 +912,25 @@ def dot_product_attention(
Returns:
An array of the attention output with the same shape as :code:`query`.
"""
original_shape = jnp.asarray(query).shape
def _preprocess_array(t):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: how about _ensure_4d?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

query, key, value, bias, mask, is_causal=is_causal, scale=scale_val,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")

return jnp.reshape(out, original_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you really just squeeze here, so you can do that instead of reshaping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use reshape instead of squeeze, because I want to make sure if users want to do the batch aware call and pass (1, T, N, H) inputs and then they can still get the results of the same shape rather than squeezed shape of (T, N, H).

fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m)
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your current implementation will fail if vmapped more than once, since it requires either a 3D or a 4D array.

Wdyt about handling the N>4 case by collapsing any extra leading dimensions into B? @sbodenstein does this make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this doesn't work? I have tried this to mimic the 5D tensor and it works fine. Or do I miss your point?

    Q = random.normal(keys[0], (B, B, T, N, H), dtype)                          
    K = random.normal(keys[1], (B, B, S, N // G, H), dtype)                     
    V = random.normal(keys[2], (B, B, S, N // G, H), dtype)                     
    if use_bias:                                                                
      bias = random.normal(keys[3], (1, N, T, S), dtype)                        
    else:                                                                       
      bias = None                                                               
                                                                                
    is_causal = causal_mode == 'is_causal'                                      
    causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None  
                                                                                
    sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)          
    sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)                                
                                                                                
    if use_vmap:                                                                
      sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)  
      sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)  
    K_ref = (jnp.repeat(K, G, axis=2) if G != 1 else K).reshape(B*B, S, N, H)   
    V_ref = (jnp.repeat(V, G, axis=2) if G != 1 else V).reshape(B*B, S, N, H)   
    Q_ref = Q.reshape(B*B, T, N, H)                                             
    out_ref = sdpa_ref(Q_ref, K_ref, V_ref, bias, causal_mask).reshape(B,B,T,N,H)
                                                                                
    out_ans = sdpa_ans(Q, K, V, bias, causal_mask)                              
    self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)   

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we assert that ndim is 3 or 4 now? I would expect this to fail given a 5D input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will fail if we directly pass in the 5D tensor. The behavior now is to support (1) 4D tensor for those who want to use the batch-aware API (2) 3D tensor for those wants to use the API in the context of vmap, meaning if users have 5D tensor, they need to use vmap as shown above.

@kaixih
Copy link
Contributor Author

kaixih commented Aug 7, 2024

Gentle ping @superbobry

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Aug 7, 2024
@copybara-service copybara-service bot merged commit cce7250 into jax-ml:main Aug 7, 2024
8 checks passed
@jakevdp jakevdp mentioned this pull request Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants