Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make _dot_general_batch_rule handle python builtin numeric types #16826

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2608,7 +2608,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
(lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim),
(np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim),
dimension_numbers)
# TODO Should probably check that any ragged dimensions have corresponding
# sizes, because otherwise the dot product is technically undefined.
Expand All @@ -2619,12 +2619,12 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd)
lhs_shape = batching.bdim_as_shape(lbd, lhs.shape)
else:
lhs_shape = lhs.shape
lhs_shape = np.shape(lhs)
if type(rbd) is RaggedAxis:
rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd)
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = rhs.shape
rhs_shape = np.shape(rhs)
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type)
Expand Down
8 changes: 8 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2664,6 +2664,14 @@ def testDynamicSliceU8Index(self):
np.testing.assert_equal(
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])

def test_dot_general_batching_python_builtin_arg(self):
# https://github.com/google/jax/issues/16805
@jax.remat
def f(x):
return jax.lax.dot_general(x, x, (([], []), ([], [])))

jax.hessian(f)(1.0) # don't crash


class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
Expand Down