diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 72c8e181ef17..9325cbfc1327 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2608,7 +2608,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums( - (lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim), + (np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim), dimension_numbers) # TODO Should probably check that any ragged dimensions have corresponding # sizes, because otherwise the dot product is technically undefined. @@ -2619,12 +2619,12 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd) lhs_shape = batching.bdim_as_shape(lbd, lhs.shape) else: - lhs_shape = lhs.shape + lhs_shape = np.shape(lhs) if type(rbd) is RaggedAxis: rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd) rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) else: - rhs_shape = rhs.shape + rhs_shape = np.shape(rhs) batched_out = dot_general(lhs, rhs, new_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type) diff --git a/tests/lax_test.py b/tests/lax_test.py index 54212693ad5e..661b07c290e1 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2664,6 +2664,14 @@ def testDynamicSliceU8Index(self): np.testing.assert_equal( np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128]) + def test_dot_general_batching_python_builtin_arg(self): + # https://github.com/google/jax/issues/16805 + @jax.remat + def f(x): + return jax.lax.dot_general(x, x, (([], []), ([], []))) + + jax.hessian(f)(1.0) # don't crash + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected):