Skip to content

Commit 1c6b0a9

Browse files
Merge pull request #24465 from jakevdp:fix-mypy
PiperOrigin-RevId: 688632024
2 parents 9a2dd19 + 8498502 commit 1c6b0a9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

Diff for: jax/_src/pjit.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1568,21 +1568,21 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
15681568
'Please see the jax.Array migration guide for more information '
15691569
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
15701570
f'Got arg shape: {arg.shape}, arg value: {arg}')
1571-
if not is_unspecified(arg_s):
1571+
if not isinstance(arg_s, UnspecifiedValue):
15721572
# jax.jit does not allow resharding across different memory kinds even
15731573
# if the argument is uncommitted. Use jax.device_put for those cases,
15741574
# either outside or inside jax.jit.
15751575
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
15761576
raise ValueError(
15771577
'Memory kinds passed to jax.jit does not match memory kind on the'
15781578
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
1579-
f'arg memory kind: {arg_s.memory_kind} for ' # pytype: disable=attribute-error
1579+
f'arg memory kind: {arg_s.memory_kind} for '
15801580
f'arg shape: {shaped_abstractify(arg).str_short()}')
15811581
if (committed and
15821582
not isinstance(arg_s, PmapSharding) and
15831583
not op_shardings.are_op_shardings_equal(
15841584
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore
1585-
arg_s._to_xla_hlo_sharding(arg.ndim))): # type: ignore
1585+
arg_s._to_xla_hlo_sharding(arg.ndim))):
15861586
raise ValueError('Sharding passed to pjit does not match the sharding '
15871587
'on the respective arg. '
15881588
f'Got pjit sharding: {pjit_in_s},\n'

0 commit comments

Comments
 (0)