Skip to content

Commit

Permalink
Fix overflow error in GPU batched linear algebra kernels.
Browse files Browse the repository at this point in the history
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
  • Loading branch information
dfm authored and Google-ML-Automation committed Nov 12, 2024
1 parent 9bb6366 commit 21e98b5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
11 changes: 7 additions & 4 deletions jaxlib/gpu/make_batch_pointers.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "jaxlib/gpu/make_batch_pointers.h"

#include <algorithm>
#include <cstdint>

#include "jaxlib/gpu/vendor.h"

Expand All @@ -24,17 +25,19 @@ namespace JAX_GPU_NAMESPACE {

namespace {
__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out,
int batch, int batch_elem_size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
int64_t batch,
int64_t batch_elem_size) {
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
idx += blockDim.x * gridDim.x) {
buffer_out[idx] = buffer_in + idx * batch_elem_size;
}
}
} // namespace

void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size) {
const int block_dim = 128;
void* buffer_out, int64_t batch,
int64_t batch_elem_size) {
const std::size_t block_dim = 128;
const std::size_t grid_dim =
std::min<std::size_t>(1024, (batch + block_dim - 1) / block_dim);
MakeBatchPointersAsyncKernel<<<grid_dim, block_dim, 0, stream>>>(
Expand Down
5 changes: 4 additions & 1 deletion jaxlib/gpu/make_batch_pointers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ limitations under the License.
#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_

#include <cstdint>

#include "jaxlib/gpu/vendor.h"

namespace jax {
namespace JAX_GPU_NAMESPACE {

void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size);
void* buffer_out, int64_t batch,
int64_t batch_elem_size);

} // namespace JAX_GPU_NAMESPACE
} // namespace jax
Expand Down
8 changes: 8 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,14 @@ def testLuBatching(self, shape, dtype):
self.assertAllClose(ls, actual_ls, rtol=5e-6)
self.assertAllClose(us, actual_us)

@jtu.skip_on_devices("cpu", "tpu")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testBatchedLuOverflow(self):
# see https://github.com/jax-ml/jax/issues/24843
x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32)
lu, _, _ = lax.linalg.lu(x)
self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9))

@jtu.skip_on_devices("cpu", "tpu")
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
Expand Down

0 comments on commit 21e98b5

Please sign in to comment.