Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,12 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False,
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST

dtype = jnp.promote_types(_handle_small(a.dtype), _handle_small(b.dtype))
out_dtype = jnp.int32 if jnp.issubdtype(dtype, jnp.integer) else jnp.float32
if jnp.issubdtype(dtype, jnp.integer):
out_dtype = jnp.int32
elif dtype == jnp.float64:
out_dtype = jnp.float64
else:
out_dtype = jnp.float32
return lax.dot_general(
a,
b,
Expand Down
32 changes: 26 additions & 6 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,26 +2305,46 @@ def _dot_general_lowering(
input_precision = None

acc_dtype = out_aval.dtype
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
if acc_dtype not in (jnp.int32, jnp.float16, jnp.float64):
acc_dtype = jnp.float32
else:
raise NotImplementedError(f"Unsupported dot precision: {precision}.")

a_type = ir.RankedTensorType(a.type)
b_type = ir.RankedTensorType(b.type)
if len(a_type.shape) != len(b_type.shape) != 2:
if len(a_type.shape) != 2 or len(b_type.shape) != 2:
raise ValueError("a and b must be 2D, but got:"
f" {a_type.shape} and {b_type.shape}")
if min(*b_type.shape) < 16:
raise ValueError("all dimensions of b must be >= 16 ")

m, k = a_type.shape
_, n = b_type.shape
if a_type.element_type == ir.F64Type.get():
# Triton's MMAv2 fp64 path uses the m8n8k4 PTX instruction but aggregates
# it with NumRegisters={m:2, n:1, k:4}, producing an effective m16n8k16
# per-warp tile. Blocks smaller than these minimums cause repM/repN/repK
# to round to zero, corrupting the ValueTable and segfaulting the compiler.
# M >= 16 (2 × instrM=8)
# N >= 8 (1 × instrN=8)
# K >= 16 (4 × instrK=4)
errors = []
if m < 16:
errors.append(f"M={m} < 16")
if n < 8:
errors.append(f"N={n} < 8")
if k < 16:
errors.append(f"K={k} < 16")
if errors:
raise ValueError(
f"float64 dot requires M>=16, N>=8, K>=16 per warp tile "
f"(Triton MMAv2 m8n8k4 layout); got {', '.join(errors)}"
)

if a_type.element_type != b_type.element_type:
raise ValueError(
"a and b must have the same element type, but got:"
f" {a_type.element_type} and {b_type.element_type}"
)

m, _ = a_type.shape
_, n = b_type.shape
assert acc_dtype is not None
acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)))

Expand Down
78 changes: 76 additions & 2 deletions tests/pallas/triton_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,19 @@ def setUp(self):
if not self.INTERPRET:
self.skipTest("On CPU the test works only in interpret mode")
elif jtu.test_device_matches(["gpu"]):
is_sm80_test = any(
getattr(self, "_testMethodName", "").startswith(prefix)
for prefix in (
"test_dot_f32_small_dimensions",
"test_dot_fp64_valid_dimensions",
"test_dot_fp64_invalid_dimensions",
)
)
min_compute = "8.0" if is_sm80_test else "9.0"
Comment on lines +50 to +58
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This method of checking test names in setUp to determine compute capability requirements is a bit fragile and can make test maintenance harder. If test methods are renamed, this logic will break silently.

A more robust approach would be to move the capability check into the test methods themselves, or use a decorator. For example, you could define a helper method:

def _require_compute_capability(self, min_version_str):
  if jtu.test_device_matches(['cuda']) and not jtu.is_cuda_compute_capability_at_least(min_version_str):
    self.skipTest(f"Requires CUDA compute capability >= {min_version_str}")

And then call it at the beginning of each relevant test:

def test_dot_f32_small_dimensions(self):
  self._require_compute_capability("8.0")
  # ...

This would make the requirements for each test more explicit and avoid coupling the test logic to test names.

Copy link
Copy Markdown
Contributor Author

@mwichro mwichro Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gemini-code-assist While I agree that this would be a better way, to make it properly all the tests in that file should be restructured in this way. This change would go way beyond the scope of this PR. I would rather do it in a separate PR.


if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Only works on GPU with capability >= sm90")
not jtu.is_cuda_compute_capability_at_least(min_compute)):
self.skipTest(f"Only works on GPU with capability >= sm{min_compute.replace('.', '')}")
if plgpu is None:
self.skipTest("plgpu not available on this platform")
else:
Expand Down Expand Up @@ -485,6 +495,70 @@ def dot_kernel(x_ref, y_ref, o_ref):
"Unsigned integer dtype.*not supported"):
dot_kernel(x, y)

def test_dot_f32_small_dimensions(self):
m, k, n = 8, 16, 8
dtype = jnp.float32

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
compiler_params=plgpu.CompilerParams(num_warps=1),
)
def dot_kernel(x_ref, y_ref, o_ref):
o_ref[()] = pl.dot(x_ref[()], y_ref[()])

x = jnp.ones((m, k), dtype=dtype)
y = jnp.ones((k, n), dtype=dtype)
out = dot_kernel(x, y)
np.testing.assert_allclose(out, jnp.full((m, n), k, dtype=dtype))

def test_dot_fp64_valid_dimensions(self):
if not jax.config.jax_enable_x64:
self.skipTest("x64 is disabled")

m, k, n = 16, 16, 8
dtype = jnp.float64

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((m, n), dtype), compiler_params=plgpu.CompilerParams(num_warps=1),
)
def dot_kernel(x_ref, y_ref, o_ref):
o_ref[()] = pl.dot(x_ref[()], y_ref[()])

x = jnp.arange(m * k).reshape(m, k).astype(dtype)
y = jnp.arange(k * n).reshape(k, n).astype(dtype)

out = dot_kernel(x, y)
expected = jnp.dot(x, y, precision=lax.Precision.HIGHEST)
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=1e-5)
Comment thread
mwichro marked this conversation as resolved.

def test_dot_fp64_invalid_dimensions(self):
if not jax.config.jax_enable_x64:
self.skipTest("x64 is disabled")

for m, k, n, err_msg in [
(8, 16, 16, "M=8 < 16"),
(16, 16, 4, "N=4 < 8"),
(16, 8, 16, "K=8 < 16"),
]:
with self.subTest(f"m={m},k={k},n={n}"):
dtype = jnp.float64

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
compiler_params=plgpu.CompilerParams(num_warps=1),
)
def dot_kernel(x_ref, y_ref, o_ref):
o_ref[()] = pl.dot(x_ref[()], y_ref[()])

x = jnp.arange(m * k).reshape(m, k).astype(dtype)
y = jnp.arange(k * n).reshape(k, n).astype(dtype)

with self.assertRaisesRegex(ValueError, err_msg):
dot_kernel(x, y)


@functools.partial(
jax.jit, static_argnames=["bm", "bn", "gm", "bk", "interpret", "debug"]
Expand Down
Loading