Skip to content

Commit

Permalink
Cleanup: fix unused imports & mark exported names
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 17, 2024
1 parent 6a00055 commit de3191f
Show file tree
Hide file tree
Showing 56 changed files with 314 additions and 364 deletions.
6 changes: 3 additions & 3 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@

from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as _deprecated_clear_backends
from jax._src.api import clear_caches as clear_caches
Expand Down Expand Up @@ -122,7 +122,7 @@
from jax._src.xla_bridge import process_index as process_index
from jax._src.xla_bridge import process_indices as process_indices
from jax._src.callback import pure_callback as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.ad_checkpoint import checkpoint_wrapper as remat # noqa: F401
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
Expand Down Expand Up @@ -249,6 +249,6 @@
del _deprecation_getattr
del _typing

import jax.lib # TODO(phawkins): remove this export.
import jax.lib # TODO(phawkins): remove this export. # noqa: F401

# trailer
14 changes: 7 additions & 7 deletions jax/_src/clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .cluster import ClusterEnv
from .cluster import ClusterEnv as ClusterEnv

# Order of declaration of the cluster environments
# will dictate the order in which they will be checked.
# Therefore, if multiple environments are available and
# the user did not explicitly provide the arguments
# to :func:`jax.distributed.initialize`, the first
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster
from .k8s_cluster import K8sCluster
from .ompi_cluster import OmpiCluster as OmpiCluster
from .slurm_cluster import SlurmCluster as SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster as Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster as GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster as GceTpuCluster
from .k8s_cluster import K8sCluster as K8sCluster
2 changes: 1 addition & 1 deletion jax/_src/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .fusion import cudnn_fusion
from .fusion import cudnn_fusion as cudnn_fusion
2 changes: 1 addition & 1 deletion jax/_src/debugger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.debugger.core import breakpoint
from jax._src.debugger.core import breakpoint as breakpoint
from jax._src.debugger import cli_debugger
from jax._src.debugger import colab_debugger
from jax._src.debugger import web_debugger
Expand Down
61 changes: 42 additions & 19 deletions jax/_src/lax/control_flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the control flow primitives."""
from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p,
cummin, cummin_p, cumlogsumexp,
cumlogsumexp_p, cumprod,
cumprod_p, cumsum, cumsum_p,
cumred_reduce_window_impl,
fori_loop, map,
scan, scan_bind, scan_p,
_scan_impl, while_loop, while_p)
from jax._src.lax.control_flow.conditionals import (cond, cond_p, switch,
platform_dependent)
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
_custom_linear_solve_impl,
linear_solve_p)

from jax._src.lax.control_flow.loops import (
associative_scan as associative_scan,
cummax as cummax,
cummax_p as cummax_p,
cummin as cummin,
cummin_p as cummin_p,
cumlogsumexp as cumlogsumexp,
cumlogsumexp_p as cumlogsumexp_p,
cumprod as cumprod,
cumprod_p as cumprod_p,
cumsum as cumsum,
cumsum_p as cumsum_p,
cumred_reduce_window_impl as cumred_reduce_window_impl,
fori_loop as fori_loop,
map as map,
scan as scan,
scan_bind as scan_bind,
scan_p as scan_p,
_scan_impl as _scan_impl,
while_loop as while_loop,
while_p as while_p,
)
from jax._src.lax.control_flow.conditionals import (
cond as cond,
cond_p as cond_p,
switch as switch,
platform_dependent as platform_dependent,
)
from jax._src.lax.control_flow.solves import (
custom_linear_solve as custom_linear_solve,
custom_root as custom_root,
_custom_linear_solve_impl as _custom_linear_solve_impl,
linear_solve_p as linear_solve_p,
)
# Private utilities used elsewhere in JAX
# TODO(sharadmv): lift them into a more common place
from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
_initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts,
_check_tree_and_avals)
from jax._src.lax.control_flow.common import (
_initial_style_open_jaxpr as _initial_style_open_jaxpr,
_initial_style_jaxpr as _initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts,
_check_tree_and_avals as _check_tree_and_avals,

)
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
from jax._src.lax.lax import optimization_barrier_p
from jax._src.lax.lax import optimization_barrier_p as optimization_barrier_p
25 changes: 12 additions & 13 deletions jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os
import pathlib
import re
from typing import Any

try:
import jaxlib as jaxlib
Expand Down Expand Up @@ -84,9 +83,9 @@ def _parse_version(v: str) -> tuple[int, ...]:
import jaxlib.cpu_feature_guard as cpu_feature_guard
cpu_feature_guard.check_cpu_features()

import jaxlib.utils as utils
import jaxlib.utils as utils # noqa: F401
import jaxlib.xla_client as xla_client
import jaxlib.lapack as lapack
import jaxlib.lapack as lapack # noqa: F401

xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
Expand All @@ -99,29 +98,29 @@ def _xla_gc_callback(*args):
gc.callbacks.append(_xla_gc_callback)

try:
import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error
import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error # noqa: F401
except ImportError:
try:
import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error
import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error # noqa: F401
except ImportError:
cuda_versions = None

import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401

# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version
# number that can be used to perform changes without breaking the main
# branch on the Jax github.
xla_extension_version: int = getattr(xla_client, '_version', 0)

import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401

import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error
import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401

# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version
Expand Down
20 changes: 10 additions & 10 deletions jax/_src/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""Module for state."""
from jax._src.state.types import (
AbstractRef,
AccumEffect,
ReadEffect,
RefEffect,
StateEffect,
Transform,
TransformedRef,
WriteEffect,
get_ref_state_effects,
shaped_array_ref,
AbstractRef as AbstractRef,
AccumEffect as AccumEffect,
ReadEffect as ReadEffect,
RefEffect as RefEffect,
StateEffect as StateEffect,
Transform as Transform,
TransformedRef as TransformedRef,
WriteEffect as WriteEffect,
get_ref_state_effects as get_ref_state_effects,
shaped_array_ref as shaped_array_ref,
)
16 changes: 8 additions & 8 deletions jax/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

from jax._src.ad_checkpoint import (
checkpoint,
checkpoint_policies,
checkpoint_name,
print_saved_residuals,
remat,
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
checkpoint_name as checkpoint_name,
print_saved_residuals as print_saved_residuals,
remat as remat,
)
from jax._src.interpreters.partial_eval import (
Recompute,
Saveable,
Offloadable,
Recompute as Recompute,
Saveable as Saveable,
Offloadable as Offloadable,
)
16 changes: 8 additions & 8 deletions jax/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@


from jax._src.api_util import (
argnums_partial,
donation_vector,
flatten_axes,
flatten_fun,
flatten_fun_nokwargs,
rebase_donate_argnums,
safe_map,
shaped_abstractify,
argnums_partial as argnums_partial,
donation_vector as donation_vector,
flatten_axes as flatten_axes,
flatten_fun as flatten_fun,
flatten_fun_nokwargs as flatten_fun_nokwargs,
rebase_donate_argnums as rebase_donate_argnums,
safe_map as safe_map,
shaped_abstractify as shaped_abstractify,
)
2 changes: 1 addition & 1 deletion jax/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from jax._src.cloud_tpu_init import cloud_tpu_init
from jax._src.cloud_tpu_init import cloud_tpu_init as cloud_tpu_init
8 changes: 4 additions & 4 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
Literal as Literal,
MainTrace as MainTrace,
MapPrimitive as MapPrimitive,
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE,
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
OpaqueTraceState as OpaqueTraceState,
NameGatheringSubst as NameGatheringSubst,
OutDBIdx as OutDBIdx,
Expand All @@ -58,9 +58,9 @@
TraceStack as TraceStack,
TraceState as TraceState,
Tracer as Tracer,
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE,
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE,
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE,
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
UnshapedArray as UnshapedArray,
Value as Value,
Var as Var,
Expand Down
6 changes: 3 additions & 3 deletions jax/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570

from jax._src.custom_derivatives import (
_initial_style_jaxpr,
_sum_tangents,
_zeros_like_pytree,
_initial_style_jaxpr, # noqa: F401
_sum_tangents, # noqa: F401
_zeros_like_pytree, # noqa: F401
closure_convert as closure_convert,
custom_gradient as custom_gradient,
custom_jvp as custom_jvp,
Expand Down
6 changes: 3 additions & 3 deletions jax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from jax._src.dtypes import (
bfloat16 as bfloat16,
canonicalize_dtype as canonicalize_dtype,
finfo, # TODO(phawkins): switch callers to jnp.finfo?
finfo, # TODO(phawkins): switch callers to jnp.finfo? # noqa: F401
float0 as float0,
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
iinfo, # TODO(phawkins): switch callers to jnp.iinfo? # noqa: F401
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? # noqa: F401
extended as extended,
prng_key as prng_key,
result_type as result_type,
Expand Down
4 changes: 1 addition & 3 deletions jax/experimental/array_serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@

import jax
from jax._src import array
from jax._src import config
from jax._src import distributed
from jax._src import sharding
from jax._src import sharding_impls
from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src.layout import Layout
from jax._src import typing
from jax._src import util
from jax._src.lib import xla_extension as xe
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/array_serialization/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src import array
from jax._src import xla_bridge as xb
from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.array_serialization import serialization
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/examples/mnist_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import warnings
from absl import flags

import flax
from flax import linen as nn

import jax
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/tests/cross_compilation_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
import numpy.random as npr

import jax # Must import before TF
from jax.experimental import jax2tf # Defines needed flags
from jax._src import test_util # Defines needed flags
from jax.experimental import jax2tf # Defines needed flags # noqa: F401
from jax._src import test_util # Defines needed flags # noqa: F401

jax.config.parse_flags_with_absl()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from absl.testing import absltest

import jax
from jax._src import config
from jax._src import test_util as jtu

Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from jax._src import test_util as jtu
from jax.experimental import jax2tf
from jax.interpreters import mlir
from jax._src.interpreters import xla

import numpy as np
import tensorflow as tf
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from jax._src.export import shape_poly
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
import numpy as np

from jax.experimental.jax2tf.tests import tf_test_util
Expand Down
2 changes: 0 additions & 2 deletions jax/experimental/jax2tf/tests/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import contextlib
from functools import partial
import logging
import math
import os
import re
from typing import Any
import unittest
Expand Down
Loading

0 comments on commit de3191f

Please sign in to comment.