-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuDNN attention segmentation fault #23701
Comments
I just created a PR to fix/workaround this long-standing cuDNN limitation with dbias. cuDNN only supports dbias when the batch size is 1 (here). For batch sizes greater than 1, the cuDNN SDPA API returns an all-zero dbias tensor (which isn't ideal, but that's the current behavior). When using vmap, the API only detects a singleton batch before the vmap is applied, causing it to mistakenly set has_dbias to This PR resolves the issue by resetting has_dbias to False, and returns an all-zero dbias tensor as in the non-vmap version. To summarize, the behavior is:
|
Also, @Cjkkkk for comments on the dbias support from cudnn. |
Having d_bias be zeroes when there is a batch dimension is definitely wrong behaviour: it seems like a silent failure that would be extremely difficult to debug for users when their training curves just don't look right. I think we should fail in this case rather until cuDNN supports this. |
We actually had some internal discussions earlier. I think the dilemma is this: it seems that some models don't require |
I think that this should either work correctly, or it should throw an error. It should definitely not seg fault or give incorrect gradients (even with a warning). The latter is just too dangerous for users, who expect JAX APIs to do what they think they will do, or waste massive compute on runs with a major bug. Would you agree @hawkinsp? |
I do. Wrong outputs aren't ok, because they are the kind of thing that makes people lose trust in a library. |
Sure, I will essentially move this logic to the public API to throw an error for the cudnn exec path. |
By the way, do you think we should apply this error-throwing behavior to the public API or the cuDNN API? Perhaps it should only be applied to the public API, allowing power users who are certain they don't have d_bias to use the private cuDNN API. |
Private APIs ( |
I have pushed a new change to the PR to check the validity of the bias shape. It will throw an error only when the bias is invalid and bprop is applied. Also, the original segfault issues is fixed in a separate PR. @sbodenstein can you take a look? |
Description
Run
produces
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: