@@ -1568,21 +1568,21 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
1568
1568
'Please see the jax.Array migration guide for more information '
1569
1569
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
1570
1570
f'Got arg shape: { arg .shape } , arg value: { arg } ' )
1571
- if not is_unspecified (arg_s ):
1571
+ if not isinstance (arg_s , UnspecifiedValue ):
1572
1572
# jax.jit does not allow resharding across different memory kinds even
1573
1573
# if the argument is uncommitted. Use jax.device_put for those cases,
1574
1574
# either outside or inside jax.jit.
1575
1575
if pjit_in_s .memory_kind != arg_s .memory_kind : # type: ignore
1576
1576
raise ValueError (
1577
1577
'Memory kinds passed to jax.jit does not match memory kind on the'
1578
1578
f' respective arg. Got pjit memory kind: { pjit_in_s .memory_kind } , ' # type: ignore
1579
- f'arg memory kind: { arg_s .memory_kind } for ' # pytype: disable=attribute-error
1579
+ f'arg memory kind: { arg_s .memory_kind } for '
1580
1580
f'arg shape: { shaped_abstractify (arg ).str_short ()} ' )
1581
1581
if (committed and
1582
1582
not isinstance (arg_s , PmapSharding ) and
1583
1583
not op_shardings .are_op_shardings_equal (
1584
1584
pjit_in_s ._to_xla_hlo_sharding (arg .ndim ), # type: ignore
1585
- arg_s ._to_xla_hlo_sharding (arg .ndim ))): # type: ignore
1585
+ arg_s ._to_xla_hlo_sharding (arg .ndim ))):
1586
1586
raise ValueError ('Sharding passed to pjit does not match the sharding '
1587
1587
'on the respective arg. '
1588
1588
f'Got pjit sharding: { pjit_in_s } ,\n '
0 commit comments