Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuDNN attention segmentation fault #23701

Open
sbodenstein opened this issue Sep 17, 2024 · 11 comments
Open

cuDNN attention segmentation fault #23701

sbodenstein opened this issue Sep 17, 2024 · 11 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@sbodenstein
Copy link
Contributor

Description

Run

import jax
import jax.numpy as jnp

dtype = jnp.bfloat16
batch_dim = 16
x_shape = (batch_dim, 512, 16, 48)
bias_shape = (batch_dim, 16, 512, 512)
mask_shape = (1, 1, 512)

x = jnp.ones(x_shape, dtype=dtype)
bias = jnp.ones(bias_shape, dtype=dtype)
mask = jnp.ones(mask_shape, dtype=jnp.bool_)

def attention(x, bias, mask):
  return jax.nn.dot_product_attention(
      query=x,
      key=x,
      value=x,
      bias=bias,
      mask=mask,
      is_causal=False,
      implementation="cudnn",
  )

attention = jax.vmap(attention, in_axes=(0, 0, None))

@jax.jit
def attention_vjp(x, bias, mask):
  _, f_vjp = jax.vjp(attention, x, bias, mask)
  return f_vjp(x)

attention_vjp(x, bias, mask)

produces

*** SIGSEGV (@0x7f5260c00500), see go/stacktraces#s15 received by PID 9235 (TID 9235) on cpu 115; stack trace: ***
PC: @     0x5617bac25898  (unknown)  memcpy_prefetch
    @     0x5617cb88d7b1       1888  FailureSignalHandler()
    @     0x7f638a212e80  1678969448  (unknown)
    @     0x5617bac25898          8  memcpy_prefetch
    @     0x5617aebc89ee        144  libcudnn_engines_runtime_compiled_static_1b8578937b0ce03a7e4b20267fe9d0337678c78c
    @     0x5617aebc95b6        416  libcudnn_engines_runtime_compiled_static_9c37820fea3a2d0a85881f91ee33655956ce9c93
    @     0x5617aee7e795      66928  cudnn::backend::execute()
    @     0x5617aee7fc3f      65632  cudnnBackendExecute
    @     0x5617c65ffe99        448  cudnn_frontend::detail::execute()
    @     0x5617c713e1b8        320  cudnn_frontend::ICudnn::execute_cudnn_plan_with_uid()
    @     0x5617c70926c3        224  cudnn_frontend::graph::Graph::execute()
    @     0x5617c70917fd        496  stream_executor::gpu::CudnnGraph::Execute()
    @     0x5617c83de8bb        192  stream_executor::gpu::GpuCommandBuffer::Trace()
    @     0x5617c191dda7        112  stream_executor::TraceCommandBufferFactory::Create()
    @     0x5617c18f3224        320  xla::gpu::TracedCommandBuffer::GetOrTraceCommandBuffer()
    @     0x5617c18f367e        160  xla::gpu::TracedCommandBufferCmd::AddTracedCommandBuffer()
    @     0x5617c18fc0d8        192  xla::gpu::CuDnnCmd::Record()
    @     0x5617c18f2365        160  xla::gpu::CommandBufferCmdSequence::Record()
    @     0x5617c18eeab2        480  xla::gpu::CommandBufferThunk::Initialize()
    @     0x5617c19242ec         48  xla::gpu::SequentialThunk::Initialize()
    @     0x5617c18e203b       1216  xla::gpu::(anonymous namespace)::ExecuteThunks()
    @     0x5617c18ded4c       2288  xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
    @     0x5617c18ddc4f         80  xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
    @     0x5617c76260db       1056  xla::Executable::ExecuteAsyncOnStreamWrapper()
    @     0x5617c0015714       2640  xla::LocalExecutable::RunAsync()
    @     0x5617c0015cb5        288  xla::LocalExecutable::RunAsync()
    @     0x5617bffd5a7a       2640  xla::PjRtStreamExecutorLoadedExecutable::EnqueueExecution()
    @     0x5617bffd7e08       3344  xla::PjRtStreamExecutorLoadedExecutable::ExecuteHelper()
    @     0x5617bffda79a        512  xla::PjRtStreamExecutorLoadedExecutable::Execute()
    @     0x5617c014ad4c        960  xla::ifrt::PjRtLoadedExecutable::Execute()
    @     0x5617bff8e709        832  xla::(anonymous namespace)::ExecuteShardedOnLocalDevicesInternal<>()
    @     0x5617bff90524         96  xla::PyLoadedExecutable::ExecuteSharded()
    @     0x7f6379b6c205        256  xla::ValueOrThrowWrapper<>::operator()()
    @     0x7f6379b6bf6d        288  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x5617c01e8ee5        368  nanobind::detail::nb_func_vectorcall_complex()
    @     0x5617b93b0fb5         80  PyObject_Vectorcall
    @     0x5617b93a7043        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617cbc0d70e         80  _PyObject_FastCallDictTstate
    @     0x5617cbc0de3b        128  _PyObject_Call_Prepend
    @     0x5617cbc7594f         64  slot_tp_call
    @     0x5617b93b1b96         80  PyObject_Call
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93b1c5f         80  PyObject_Call
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93b0fb5         80  PyObject_Vectorcall
    @     0x5617bcdb753d       1824  jax::(anonymous namespace)::PjitFunction::Call()
    @     0x5617bcdb66b4        224  PjitFunction_tp_vectorcall
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93b1c5f         80  PyObject_Call
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617cbc0fb1e        128  method_vectorcall
    @     0x5617b93b1c5f         80  PyObject_Call
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93a97df        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93b0fb5         80  PyObject_Vectorcall
    @     0x5617bcdb753d       1824  jax::(anonymous namespace)::PjitFunction::Call()
    @     0x5617bcdb66b4        224  PjitFunction_tp_vectorcall
    @     0x5617b93b0fb5         80  PyObject_Vectorcall
    @     0x5617b93a7043        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617b93b0fb5         80  PyObject_Vectorcall
    @     0x5617c01de634         80  nanobind::detail::obj_vectorcall()
    @     0x7f6372066d63         48  benchmark::internal::LambdaBenchmark<>::Run()
    @     0x5617bad98de1         96  benchmark::internal::BenchmarkInstance::Run()
    @     0x5617bada059a        256  benchmark::internal::(anonymous namespace)::RunInThread()
    @     0x5617bada01c8        112  benchmark::internal::BenchmarkRunner::DoNIterations()
    @     0x5617bada1401        672  benchmark::internal::BenchmarkRunner::DoOneRepetition()
    @     0x5617bad935c5       6288  benchmark::RunSpecifiedBenchmarks()
    @     0x5617bad928a6         48  benchmark::RunSpecifiedBenchmarks()
    @     0x7f6372066da9         16  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x5617c01e93d4        192  nanobind::detail::nb_func_vectorcall_simple()
    @     0x5617b93b0fb5         80  PyObject_Vectorcall
    @     0x5617b93a7043        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617cbd19aea        112  PyEval_EvalCode
    @     0x5617cbd16089        112  builtin_exec
    @     0x5617cbc55a59         64  cfunction_vectorcall_FASTCALL_KEYWORDS
    @     0x5617b93b0fb5         80  PyObject_Vectorcall
    @     0x5617b93a7043        400  _PyEval_EvalFrameDefault
    @     0x5617b93a2e4d        160  _PyEval_Vector
    @     0x5617cbc0dff1         96  _PyObject_CallFunctionVa
    @     0x5617cbc0e3ec        240  PyObject_CallMethod
    @     0x5617cb8e9d52        832  devtools::python_launcher::Launcher_Main()
    @     0x7f638a0663d4        192  __libc_start_main
    @     0x5617abcd7d5a  (unknown)  _start
E0917 11:28:52.340728    9235 coredump_hook.cc:280] RAW: Remote crash gathering disabled for unit test. Enable that with --remote_coredump_enabled_for_unit_test flag
Fatal Python error: Segmentation fault

Current thread 0x00007f638a000000 (most recent call first):
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/interpreters/pxla.py", line 1285 in __call__
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/profiler.py", line 333 in wrapper
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/pjit.py", line 1648 in _pjit_call_impl_python
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/pjit.py", line 1694 in call_impl_cache_miss
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/pjit.py", line 1712 in _pjit_call_impl
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/core.py", line 949 in process_primitive
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/core.py", line 443 in bind_with_trace
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/core.py", line 2777 in bind
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/pjit.py", line 189 in _python_pjit_helper
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/pjit.py", line 331 in cache_miss
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/experimental/users/sbodenstein/benchmarks/microbenchmarks/cudnn_attention.py", line 59 in matmul_benchmark
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/benchmark/bindings/python/google_benchmark/__init__.py", line 130 in _run_benchmarks
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/absl/app.py", line 404 in _run_main
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/py/absl/app.py", line 484 in run
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/third_party/benchmark/bindings/python/google_benchmark/__init__.py", line 134 in main
  File "/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/runfiles/google3/experimental/users/sbodenstein/benchmarks/microbenchmarks/cudnn_attention.py", line 68 in <module>
  File "<embedded module '_launcher'>", line 35 in _run_code_in_main
  File "<embedded module '_launcher'>", line 153 in run_filename_as_main

Extension modules: jax.jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, zstandard.backend_c, google3.base.python.clif.cpp_flag, google3.base.python.sysinfo, google3.base.python.user_name, google3.base.python.clif.logging, google3.third_party.pybind11_abseil.ok_status_singleton, google3.base.python.clif.googleinit, google3.base.python.absl_faulthandler, google3.net.util.python._bnsutil, google3.third_party.upb.python._message, google3.monitoring.streamz.exporter._monitoring_streamz_exporter_pywrapmetricops, google3.util.task.python._util_task_python_pywrap__task, google3.base.python.clif.tracecontext_local, google3.perftools.tracing.public.python._tracecontext_util, google3.perftools.tracing.public.python.trace_url_builder, google3.learning.debug.error_log.python.logwriters_registered_check, google3.logs.writer.python._logs_writer_python_pywraplogwriter, google3.tech.env.client.python.envelope_client, google3.learning.deepmind.analytics.lamon.python.lamon, google3.learning.deepmind.python.gpu_monitor.clif.gpu_monitor (total: 34)
E0917 11:28:52.631628    9235 process_state.cc:805] RAW: Raising signal 11 with default behavior
/build/work/80eac4b44c5f7f6a8112b8791876b518bce1/google3/tools/test/remote-runtest.sh: line 263:  9235 Segmentation fault      "$@"
-- 2024-09-17 11:28:52 PDT Forge runner: Test killed by SIGSEGV while running on glcpl1.prod.google.com

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.3
python: 3.11.8 
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1

$ nvidia-smi
Tue Sep 17 11:31:20 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:08:00.0 Off |                    0 |
| N/A   30C    P0             106W / 700W |  61309MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
@sbodenstein sbodenstein added the bug Something isn't working label Sep 17, 2024
@sbodenstein
Copy link
Contributor Author

@kaixih

@kaixih
Copy link
Contributor

kaixih commented Sep 18, 2024

I just created a PR to fix/workaround this long-standing cuDNN limitation with dbias.

cuDNN only supports dbias when the batch size is 1 (here). For batch sizes greater than 1, the cuDNN SDPA API returns an all-zero dbias tensor (which isn't ideal, but that's the current behavior). When using vmap, the API only detects a singleton batch before the vmap is applied, causing it to mistakenly set has_dbias to True, which leads to this SegFault.

This PR resolves the issue by resetting has_dbias to False, and returns an all-zero dbias tensor as in the non-vmap version.

To summarize, the behavior is:

no vmap:
bias (1, ...) => has_dbias=True => OK
bias (B, ...) => has_dbias=False => OK, but dbias is all-zero

vmap:
bias (1, ...) => has_dbias=True => OK
bias (B, ...) => has_dbias=True => Segfault # which is fixed to be
bias (B, ...) => has_dbias=False => OK, but dbias is all-zero

@kaixih
Copy link
Contributor

kaixih commented Sep 18, 2024

Also, @Cjkkkk for comments on the dbias support from cudnn.

@sbodenstein
Copy link
Contributor Author

Having d_bias be zeroes when there is a batch dimension is definitely wrong behaviour: it seems like a silent failure that would be extremely difficult to debug for users when their training curves just don't look right. I think we should fail in this case rather until cuDNN supports this.

@kaixih
Copy link
Contributor

kaixih commented Sep 19, 2024

We actually had some internal discussions earlier. I think the dilemma is this: it seems that some models don't require d_bias but still want to benefit from cuDNN flash attention. If we simply throw an error when the bias has a batch size other than 1, they might complain. Ideally, if we could detect whether d_bias is needed, we could decide whether to error out or proceed. However, it seems that no such mechanism exists in JAX. Instead, we currently have silent all-zero biases when it's not supported (which is very bad, I know....). Do you think issuing a warning would help? Or should we just throw an error whenever cuDNN+bias is used with its batch size larger than 1?

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Sep 19, 2024

I think that this should either work correctly, or it should throw an error. It should definitely not seg fault or give incorrect gradients (even with a warning). The latter is just too dangerous for users, who expect JAX APIs to do what they think they will do, or waste massive compute on runs with a major bug. Would you agree @hawkinsp?

@hawkinsp
Copy link
Collaborator

I do. Wrong outputs aren't ok, because they are the kind of thing that makes people lose trust in a library.

@kaixih
Copy link
Contributor

kaixih commented Sep 19, 2024

Sure, I will essentially move this logic to the public API to throw an error for the cudnn exec path.

@kaixih
Copy link
Contributor

kaixih commented Sep 19, 2024

By the way, do you think we should apply this error-throwing behavior to the public API or the cuDNN API? Perhaps it should only be applied to the public API, allowing power users who are certain they don't have d_bias to use the private cuDNN API.

@hawkinsp
Copy link
Collaborator

Private APIs (jax._src, etc.) can do whatever you like. If a private API breaks, you get to keep both pieces.

@kaixih
Copy link
Contributor

kaixih commented Sep 20, 2024

I have pushed a new change to the PR to check the validity of the bias shape. It will throw an error only when the bias is invalid and bprop is applied. Also, the original segfault issues is fixed in a separate PR. @sbodenstein can you take a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

3 participants