Skip to content

Commit

Permalink
Reverts 093b92b
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655114622
  • Loading branch information
gnecula authored and jax authors committed Jul 23, 2024
1 parent f0792b2 commit 459b83c
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 16 deletions.
2 changes: 1 addition & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ def concrete_or_error(force: Any, val: Any, context=""):

def concrete_dim_or_error(val: Any, context=""):
"""Like concrete_or_error(operator.index), allowing symbolic dimensions."""
if is_dim(val):
if is_symbolic_dim(val):
return val
else:
return concrete_or_error(operator.index, val, context=context)
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,10 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]:
if config.enable_checks.value:
v1 = divisor * quotient
v2 = v1 + remainder
assert self == v2, (self, v2, type(self), type(v2))
assert self == divisor * quotient + remainder, (self, divisor, quotient, remainder)
assert self == _ensure_poly(v2, "check", self.scope), (
self, v2, type(self), type(v2))
assert self == _ensure_poly(divisor * quotient + remainder, "test", self.scope), (
self, divisor, quotient, remainder)
return quotient, remainder
except InconclusiveDimensionOperation:
return (_DimExpr._from_operation(_DimFactor.FLOORDIV, self, divisor,
Expand Down
44 changes: 31 additions & 13 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2765,13 +2765,14 @@ def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array:
_check_no_padding(pad_width[i], "wrap")
continue
size = array.shape[i]
repeats, (left_remainder, right_remainder) = np.divmod(pad_width[i], size)
total_repeats = repeats.sum() + 1
left_repeats, left_remainder = divmod(pad_width[i][0], size)
right_repeats, right_remainder = divmod(pad_width[i][1], size)
total_repeats = left_repeats + right_repeats + 1
parts = []
if left_remainder:
if left_remainder > 0:
parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)]
parts += total_repeats * [array]
if right_remainder:
if right_remainder > 0:
parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)]
array = lax.concatenate(parts, dimension=i)
return array
Expand All @@ -2787,32 +2788,49 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
_check_no_padding(pad_width[i], mode)
continue

n = array.shape[i]
offset = 1 if (mode == "reflect" and n > 1) else 0
axis_size = array.shape[i]

def build_padding(array, padding, before):
if before:
edge = lax.slice_in_dim(array, 0, 1, axis=i)
else:
edge = lax.slice_in_dim(array, -1, None, axis=i)

# Try to give nicer error messages for unsupported shape polymorphic uses
shape_poly_error_msg = lambda: (
"Shape polymorphism is supported for jnp.pad with 'reflect' or "
"'symmetric' padding mode only when it is possible to determine "
f"at lowering time that the axis size (= {axis_size}) is larger than 1 "
f"and larger or equal than the padding length (= {padding}). "
f"Error while handling {'left' if before else 'right'} padding on axis {i}.")
try:
# We check that we can determine all comparisions.
offset = 1 if (mode == "reflect" and axis_size > 1) else 0
has_poly_dim = not core.is_constant_shape((axis_size, padding))
# For shape polymorphism, ensure the loop below ends after 1 iteration
if has_poly_dim and not (axis_size > 1 and axis_size - offset >= padding):
raise ValueError(shape_poly_error_msg())
except core.InconclusiveDimensionOperation as e:
raise ValueError(shape_poly_error_msg()) from e

while padding > 0:
curr_pad = min(padding, n - offset)
curr_pad = min(padding, axis_size - offset)
padding -= curr_pad
if has_poly_dim: assert padding == 0

if before:
start = offset
stop = offset + curr_pad
else:
start = -(curr_pad + offset)
stop = None if (mode == "symmetric" or n == 1) else -1
stop = None if (mode == "symmetric" or axis_size == 1) else -1

x = lax.slice_in_dim(array, start, stop, axis=i)
x = flip(x, axis=i)

if reflect_type == 'odd':
x = 2 * edge - x
if n > 1:
if axis_size > 1:
if before:
edge = lax.slice_in_dim(x, 0, 1, axis=i)
else:
Expand Down Expand Up @@ -4308,7 +4326,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]:
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)

Expand Down Expand Up @@ -4337,13 +4355,13 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
bounds_shape.insert(axis, 1)
div = (num - 1) if endpoint else num
if num > 1:
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / div
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / array(div, dtype=computation_dtype)
iota_shape = [1,] * len(bounds_shape)
iota_shape[axis] = div
# This approach recovers the endpoints with float32 arithmetic,
# but can lead to rounding errors for integer outputs.
real_dtype = finfo(computation_dtype).dtype
step = reshape(lax.iota(real_dtype, div), iota_shape) / div
step = reshape(lax.iota(real_dtype, div), iota_shape) / array(div, real_dtype)
step = step.astype(computation_dtype)
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
reshape(broadcast_stop, bounds_shape) * step)
Expand All @@ -4355,7 +4373,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
elif num == 1:
delta = asarray(nan if endpoint else stop - start, dtype=computation_dtype)
out = reshape(broadcast_start, bounds_shape)
else: # num == 0 degenerate case, match numpy behavior
else: # num == 0 degenerate case, match numpy behavior
empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
empty_shape.insert(axis, 0)
delta = asarray(nan, dtype=computation_dtype)
Expand Down
9 changes: 9 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5239,6 +5239,15 @@ def testLinspaceEndpoints(self, dtype):
out = jnp.linspace(*endpoints, 10, dtype=dtype)
self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0)

def testLinspaceArrayNum(self):
"""Regression test for Issue #22405."""
rng = jtu.rand_default(self.rng())
endpoints = rng((2,), np.float32)
# The num parameter is an np.array.
out = jnp.linspace(*endpoints, np.array(10, dtype=np.int32),
dtype=np.float32)
self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0)

@jtu.sample_product(
start_shape=[(), (2,), (2, 2)],
stop_shape=[(), (2,), (2, 2)],
Expand Down
62 changes: 62 additions & 0 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ def test_int_results(self):
(3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b, 0),
(a * a - b * b, a + b, a - b, 0),
(256 * a * b, 32, 8 * a * b, 0),
(0, b, 0, 0),
(a, b, "floordiv(a, b)", "mod(a, b)"),
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, b + a)", "mod(2*a*b + b^2, b + a)"),
Expand Down Expand Up @@ -2532,6 +2533,15 @@ def test_vmap_error(self):
lambda x: x + lax.iota(_f32, x.shape[0]),
arg_descriptors=[RandArg((3,), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("linspace", "",
lambda x: jnp.linspace(0, x.shape[0], 4),
arg_descriptors=[RandArg((30,), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("linspace", "num_poly",
lambda x: jnp.linspace(0, 100, x.shape[0]),
arg_descriptors=[RandArg((30,), _f32)],
polymorphic_shapes=["b, ..."],
symbolic_constraints=["b >= 2"]),
PolyHarness("matmul", "0",
jnp.matmul,
arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)],
Expand Down Expand Up @@ -2613,6 +2623,58 @@ def test_vmap_error(self):
mode="edge"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("jnp.pad", "mode=maximum",
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
mode="maximum"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("jnp.pad", "mode=maximum_stat_length=b",
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
mode="maximum", stat_length=((x.shape[0] // 2, 2), (2, 2))),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."],
symbolic_constraints=["b >= 2"]),
PolyHarness("jnp.pad", "mode=linear_ramp",
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
mode="linear_ramp"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."],
symbolic_constraints=["b >= 2"]),
PolyHarness("jnp.pad", "mode=reflect_odd",
lambda x: jnp.pad(x, [[x.shape[0] - 1, 0], [x.shape[1], 1]],
mode="reflect", reflect_type="odd"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."],
symbolic_constraints=["b >= 2"]),
PolyHarness("jnp.pad", "mode=reflect_odd_error",
lambda x: jnp.pad(x, [[x.shape[0] - 1, 0], [x.shape[1], 1]],
mode="reflect", reflect_type="odd"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."],
expect_error=(ValueError, "Shape polymorphism is supported for jnp.pad")),
PolyHarness("jnp.pad", "mode=reflect_even",
lambda x: jnp.pad(x, [[x.shape[0] - 1, 0], [x.shape[1], 1]],
mode="reflect", reflect_type="even"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."],
symbolic_constraints=["b >= 2"]),
PolyHarness("jnp.pad", "mode=symmetric_odd",
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
mode="symmetric", reflect_type="odd"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."],
symbolic_constraints=["b >= 2"]),
PolyHarness("jnp.pad", "mode=symmetric_even",
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
mode="symmetric", reflect_type="even"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."],
symbolic_constraints=["b >= 2"]),
PolyHarness("jnp.pad", "mode=wrap",
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
mode="wrap"),
arg_descriptors=[RandArg((3, 5), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("percentile", "axis=None",
lambda x: jnp.percentile(x, 50, axis=None),
arg_descriptors=[RandArg((3, 5), _f32)],
Expand Down

0 comments on commit 459b83c

Please sign in to comment.