From d225d313ab322046a27856c4ff1716bbe631dbda Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Sep 2025 12:29:31 +0530 Subject: [PATCH 1/9] LINT: Enable ruff imports for rllib/algorithms Signed-off-by: Gagandeep Singh --- pyproject.toml | 17 +++++- rllib/algorithms/__init__.py | 3 +- rllib/algorithms/algorithm.py | 55 +++++++++---------- rllib/algorithms/algorithm_config.py | 17 +++--- rllib/algorithms/appo/appo.py | 6 +- rllib/algorithms/appo/appo_tf_policy.py | 22 ++++---- rllib/algorithms/appo/appo_torch_policy.py | 17 +++--- .../algorithms/appo/default_appo_rl_module.py | 5 +- rllib/algorithms/appo/tests/test_appo.py | 2 +- .../appo/tests/test_appo_learner.py | 6 +- .../appo/torch/appo_torch_learner.py | 4 +- rllib/algorithms/appo/utils.py | 3 +- rllib/algorithms/bc/__init__.py | 2 +- rllib/algorithms/bc/bc_catalog.py | 2 +- rllib/algorithms/bc/tests/test_bc.py | 7 ++- rllib/algorithms/callbacks.py | 1 - rllib/algorithms/cql/cql.py | 14 ++--- rllib/algorithms/cql/cql_tf_policy.py | 21 +++---- rllib/algorithms/cql/cql_torch_policy.py | 23 ++++---- .../cql/tests/test_cql_old_api_stack.py | 5 +- .../algorithms/cql/torch/cql_torch_learner.py | 8 +-- .../cql/torch/default_cql_torch_rl_module.py | 5 +- rllib/algorithms/dqn/default_dqn_rl_module.py | 5 +- .../dqn/distributional_q_tf_model.py | 1 + rllib/algorithms/dqn/dqn.py | 25 +++++---- rllib/algorithms/dqn/dqn_catalog.py | 6 +- rllib/algorithms/dqn/dqn_learner.py | 3 +- rllib/algorithms/dqn/dqn_torch_model.py | 2 + rllib/algorithms/dqn/dqn_torch_policy.py | 7 ++- rllib/algorithms/dqn/tests/test_dqn.py | 3 +- .../dqn/torch/default_dqn_torch_rl_module.py | 11 ++-- .../algorithms/dqn/torch/dqn_torch_learner.py | 11 ++-- rllib/algorithms/dreamerv3/dreamerv3.py | 5 +- .../algorithms/dreamerv3/dreamerv3_catalog.py | 4 +- .../algorithms/dreamerv3/dreamerv3_learner.py | 2 +- .../dreamerv3/dreamerv3_rl_module.py | 9 ++- .../dreamerv3/tests/test_dreamerv3.py | 5 +- .../torch/dreamerv3_torch_learner.py | 2 +- .../dreamerv3/torch/models/actor_network.py | 1 - .../models/components/dynamics_predictor.py | 2 +- .../models/components/reward_predictor.py | 3 +- .../torch/models/components/sequence_model.py | 2 +- .../dreamerv3/torch/models/critic_network.py | 4 +- .../dreamerv3/torch/models/world_model.py | 6 +- rllib/algorithms/dreamerv3/utils/debugging.py | 3 +- rllib/algorithms/impala/__init__.py | 2 +- rllib/algorithms/impala/impala.py | 7 +-- rllib/algorithms/impala/impala_learner.py | 4 +- rllib/algorithms/impala/impala_tf_policy.py | 13 +++-- .../algorithms/impala/impala_torch_policy.py | 7 ++- rllib/algorithms/impala/tests/test_impala.py | 3 +- .../impala/tests/test_vtrace_old_api_stack.py | 8 ++- .../algorithms/impala/tests/test_vtrace_v2.py | 15 ++--- .../impala/torch/impala_torch_learner.py | 2 +- .../impala/torch/vtrace_torch_v2.py | 1 + rllib/algorithms/iql/default_iql_rl_module.py | 2 +- rllib/algorithms/iql/iql_learner.py | 2 +- .../iql/torch/default_iql_torch_rl_module.py | 5 +- .../algorithms/iql/torch/iql_torch_learner.py | 8 +-- rllib/algorithms/marwil/marwil.py | 4 +- rllib/algorithms/marwil/marwil_learner.py | 2 +- rllib/algorithms/marwil/marwil_tf_policy.py | 4 +- rllib/algorithms/marwil/tests/test_marwil.py | 10 ++-- .../marwil/tests/test_marwil_rl_module.py | 7 ++- rllib/algorithms/mock.py | 2 +- rllib/algorithms/ppo/__init__.py | 2 +- rllib/algorithms/ppo/default_ppo_rl_module.py | 2 +- rllib/algorithms/ppo/ppo.py | 10 ++-- rllib/algorithms/ppo/ppo_catalog.py | 4 +- rllib/algorithms/ppo/ppo_learner.py | 2 +- rllib/algorithms/ppo/tests/test_ppo.py | 3 +- .../algorithms/ppo/tests/test_ppo_learner.py | 6 +- .../ppo/tests/test_ppo_rl_module.py | 4 +- .../algorithms/ppo/torch/ppo_torch_learner.py | 4 +- rllib/algorithms/sac/default_sac_rl_module.py | 2 +- rllib/algorithms/sac/sac.py | 2 +- rllib/algorithms/sac/sac_catalog.py | 17 +++--- rllib/algorithms/sac/sac_learner.py | 4 +- rllib/algorithms/sac/sac_tf_model.py | 7 ++- rllib/algorithms/sac/sac_tf_policy.py | 11 ++-- rllib/algorithms/sac/sac_torch_model.py | 7 ++- rllib/algorithms/sac/sac_torch_policy.py | 19 ++++--- rllib/algorithms/sac/tests/test_sac.py | 10 ++-- .../sac/torch/default_sac_torch_rl_module.py | 13 +++-- .../algorithms/sac/torch/sac_torch_learner.py | 15 ++--- rllib/algorithms/tests/test_algorithm.py | 14 +++-- .../algorithms/tests/test_algorithm_config.py | 10 ++-- .../tests/test_algorithm_export_checkpoint.py | 6 +- .../tests/test_algorithm_imports.py | 3 +- .../tests/test_algorithm_rl_module_restore.py | 17 +++--- ...gorithm_save_load_checkpoint_connectors.py | 3 +- ..._algorithm_save_load_checkpoint_learner.py | 2 +- .../tests/test_env_runner_failures.py | 7 ++- rllib/algorithms/tests/test_node_failures.py | 4 +- rllib/algorithms/tests/test_registry.py | 7 ++- 95 files changed, 370 insertions(+), 322 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c8c2219e451f..4a654d7f98f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,22 @@ afterray = ["psutil", "setproctitle"] "python/ray/__init__.py" = ["I"] "python/ray/dag/__init__.py" = ["I"] "python/ray/air/__init__.py" = ["I"] -"rllib/*" = ["I"] +"rllib/__init__.py" = ["I"] +"rllib/benchmarks/*" = ["I"] +"rllib/connectors/*" = ["I"] +"rllib/evaluation/*" = ["I"] +"rllib/models/*" = ["I"] +"rllib/utils/*" = ["I"] +# "rllib/algorithms/*" = ["I"] +"rllib/core/*" = ["I"] +"rllib/examples/*" = ["I"] +"rllib/offline/*" = ["I"] +"rllib/tests/*" = ["I"] +"rllib/callbacks/*" = ["I"] +"rllib/env/*" = ["I"] +"rllib/execution/*" = ["I"] +"rllib/policy/*" = ["I"] +"rllib/tuned_examples/*" = ["I"] "release/*" = ["I"] # TODO(matthewdeng): Remove this line diff --git a/rllib/algorithms/__init__.py b/rllib/algorithms/__init__.py index fdc21775e119..f7e0696a0d32 100644 --- a/rllib/algorithms/__init__.py +++ b/rllib/algorithms/__init__.py @@ -6,15 +6,14 @@ from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig from ray.rllib.algorithms.impala.impala import ( IMPALA, - IMPALAConfig, Impala, + IMPALAConfig, ImpalaConfig, ) from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig from ray.rllib.algorithms.sac.sac import SAC, SACConfig - __all__ = [ "Algorithm", "AlgorithmConfig", diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 5b3f56c8d96d..d599e0ebea60 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -1,22 +1,19 @@ -from collections import defaultdict import concurrent import copy -from datetime import datetime import functools -import gymnasium as gym import importlib import importlib.metadata import json import logging -import numpy as np import os -from packaging import version import pathlib -import pyarrow.fs import re import tempfile import time +from collections import defaultdict +from datetime import datetime from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -27,26 +24,32 @@ Set, Tuple, Type, - TYPE_CHECKING, Union, ) +import gymnasium as gym +import numpy as np +import pyarrow.fs import tree # pip install dm_tree +from packaging import version import ray -from ray.tune.result import TRAINING_ITERATION +import ray.cloudpickle as pickle +from ray._common.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag from ray.actor import ActorHandle -from ray.tune import Checkpoint -import ray.cloudpickle as pickle from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.registry import ALGORITHMS_CLASS_TO_NAME as ALL_ALGORITHMS from ray.rllib.algorithms.utils import ( AggregatorActor, _get_env_runner_bundles, - _get_offline_eval_runner_bundles, _get_learner_bundles, _get_main_process_bundle, + _get_offline_eval_runner_bundles, ) from ray.rllib.callbacks.utils import make_callback from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector @@ -63,11 +66,11 @@ DEFAULT_MODULE_ID, ) from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module import validate_module_id from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, ) -from ray.rllib.core.rl_module import validate_module_id from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.env.env_context import EnvContext @@ -81,39 +84,34 @@ from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.offline import get_dataset_and_shards from ray.rllib.offline.estimators import ( - OffPolicyEstimator, - ImportanceSampling, - WeightedImportanceSampling, DirectMethod, DoublyRobust, + ImportanceSampling, + OffPolicyEstimator, + WeightedImportanceSampling, ) from ray.rllib.offline.offline_evaluator import OfflineEvaluator from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch -from ray.rllib.utils import deep_update, FilterManager, force_list +from ray.rllib.utils import FilterManager, deep_update, force_list from ray.rllib.utils.actor_manager import FaultTolerantActorManager, RemoteCallResults from ray.rllib.utils.annotations import ( DeveloperAPI, ExperimentalAPI, OldAPIStack, - override, OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, PublicAPI, + override, ) from ray.rllib.utils.checkpoints import ( - Checkpointable, CHECKPOINT_VERSION, CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER, + Checkpointable, get_checkpoint_info, try_import_msgpack, ) from ray.rllib.utils.debug import update_global_seed_if_necessary -from ray._common.deprecation import ( - DEPRECATED_VALUE, - Deprecated, - deprecation_warning, -) from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.from_config import from_config @@ -136,9 +134,9 @@ NUM_AGENT_STEPS_TRAINED, NUM_AGENT_STEPS_TRAINED_LIFETIME, NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER, NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_SAMPLED_THIS_ITER, - NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER, NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_EPISODES, @@ -149,19 +147,19 @@ RESTORE_ENV_RUNNERS_TIMER, RESTORE_EVAL_ENV_RUNNERS_TIMER, RESTORE_OFFLINE_EVAL_RUNNERS_TIMER, + STEPS_TRAINED_THIS_ITER_COUNTER, SYNCH_ENV_CONNECTOR_STATES_TIMER, SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER, SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, TRAINING_ITERATION_TIMER, TRAINING_STEP_TIMER, - STEPS_TRAINED_THIS_ITER_COUNTER, ) from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.metrics.metrics_logger import MetricsLogger from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer from ray.rllib.utils.runners.runner_group import RunnerGroup -from ray.rllib.utils.serialization import deserialize_type, NOT_SERIALIZABLE +from ray.rllib.utils.serialization import NOT_SERIALIZABLE, deserialize_type from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import ( AgentConnectorDataType, @@ -184,15 +182,16 @@ TensorType, ) from ray.train.constants import DEFAULT_STORAGE_PATH +from ray.tune import Checkpoint from ray.tune.execution.placement_groups import PlacementGroupFactory from ray.tune.experiment.trial import ExportFormat from ray.tune.logger import Logger, UnifiedLogger -from ray.tune.registry import ENV_CREATOR, _global_registry +from ray.tune.registry import ENV_CREATOR, _global_registry, get_trainable_cls from ray.tune.resources import Resources +from ray.tune.result import TRAINING_ITERATION from ray.tune.trainable import Trainable from ray.util import log_once from ray.util.timer import _Timer -from ray.tune.registry import get_trainable_cls if TYPE_CHECKING: from ray.rllib.core.learner.learner_group import LearnerGroup diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index f58470da92ed..132e231563c1 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -1,10 +1,11 @@ import copy import dataclasses -from enum import Enum import logging import math import sys +from enum import Enum from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -13,7 +14,6 @@ Optional, Tuple, Type, - TYPE_CHECKING, Union, ) @@ -22,6 +22,11 @@ from packaging import version import ray +from ray._common.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core import DEFAULT_MODULE_ID @@ -33,7 +38,7 @@ from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.env import INPUT_ENV_SPACES, INPUT_ENV_SINGLE_SPACES +from ray.rllib.env import INPUT_ENV_SINGLE_SPACES, INPUT_ENV_SPACES from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.wrappers.atari_wrappers import is_atari from ray.rllib.evaluation.collectors.sample_collector import SampleCollector @@ -48,11 +53,6 @@ OldAPIStack, OverrideToImplementCustomLogic_CallToSuperRecommended, ) -from ray._common.deprecation import ( - DEPRECATED_VALUE, - Deprecated, - deprecation_warning, -) from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import NotProvided, from_config from ray.rllib.utils.schedules.scheduler import Scheduler @@ -83,7 +83,6 @@ from ray.util import log_once from ray.util.placement_group import PlacementGroup - if TYPE_CHECKING: from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.core.learner import Learner diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index c3bc4c0031bb..75dd6b000e3c 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -10,22 +10,22 @@ https://arxiv.org/pdf/1912.00167 """ -from typing import Optional, Type import logging +from typing import Optional, Type +from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.impala.impala import IMPALA, IMPALAConfig from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override -from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, + LEARNER_STATS_KEY, NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, NUM_TARGET_UPDATES, ) -from ray.rllib.utils.metrics import LEARNER_STATS_KEY logger = logging.getLogger(__name__) diff --git a/rllib/algorithms/appo/appo_tf_policy.py b/rllib/algorithms/appo/appo_tf_policy.py index 4af36f099df9..eab4bfefeb2e 100644 --- a/rllib/algorithms/appo/appo_tf_policy.py +++ b/rllib/algorithms/appo/appo_tf_policy.py @@ -5,37 +5,37 @@ Keep in sync with changes to VTraceTFPolicy. """ -import numpy as np import logging -import gymnasium as gym from typing import Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + from ray.rllib.algorithms.appo.utils import make_appo_models from ray.rllib.algorithms.impala import vtrace_tf as vtrace from ray.rllib.algorithms.impala.impala_tf_policy import ( - _make_time_major, VTraceClipGradients, VTraceOptimizer, + _make_time_major, ) from ray.rllib.evaluation.postprocessing import ( + Postprocessing, compute_bootstrap_value, compute_gae_for_sample_batch, - Postprocessing, ) -from ray.rllib.models.tf.tf_action_dist import Categorical -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import Categorical, TFActionDistribution from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_mixins import ( EntropyCoeffSchedule, - LearningRateSchedule, - KLCoeffMixin, - ValueNetworkMixin, GradStatsMixin, + KLCoeffMixin, + LearningRateSchedule, TargetNetworkMixin, + ValueNetworkMixin, ) -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.utils.annotations import ( override, ) diff --git a/rllib/algorithms/appo/appo_torch_policy.py b/rllib/algorithms/appo/appo_torch_policy.py index 1d28138c8c25..f150c6761cac 100644 --- a/rllib/algorithms/appo/appo_torch_policy.py +++ b/rllib/algorithms/appo/appo_torch_policy.py @@ -5,37 +5,38 @@ Keep in sync with changes to VTraceTFPolicy. """ -import gymnasium as gym -import numpy as np import logging from typing import Any, Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + import ray -from ray.rllib.algorithms.appo.utils import make_appo_models import ray.rllib.algorithms.impala.vtrace_torch as vtrace +from ray.rllib.algorithms.appo.utils import make_appo_models from ray.rllib.algorithms.impala.impala_torch_policy import ( - make_time_major, VTraceOptimizer, + make_time_major, ) from ray.rllib.evaluation.postprocessing import ( + Postprocessing, compute_bootstrap_value, compute_gae_for_sample_batch, - Postprocessing, ) from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( - TorchDistributionWrapper, TorchCategorical, + TorchDistributionWrapper, ) from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_mixins import ( EntropyCoeffSchedule, - LearningRateSchedule, KLCoeffMixin, - ValueNetworkMixin, + LearningRateSchedule, TargetNetworkMixin, + ValueNetworkMixin, ) from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override diff --git a/rllib/algorithms/appo/default_appo_rl_module.py b/rllib/algorithms/appo/default_appo_rl_module.py index 9152ac43d9d0..e6eb13d23bf1 100644 --- a/rllib/algorithms/appo/default_appo_rl_module.py +++ b/rllib/algorithms/appo/default_appo_rl_module.py @@ -8,12 +8,11 @@ TARGET_NETWORK_ACTION_DIST_INPUTS, TargetNetworkAPI, ) -from ray.rllib.utils.typing import NetworkType - from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) +from ray.rllib.utils.typing import NetworkType from ray.util.annotations import DeveloperAPI diff --git a/rllib/algorithms/appo/tests/test_appo.py b/rllib/algorithms/appo/tests/test_appo.py index 6986eb1d2146..d6271f575104 100644 --- a/rllib/algorithms/appo/tests/test_appo.py +++ b/rllib/algorithms/appo/tests/test_appo.py @@ -11,9 +11,9 @@ NUM_ENV_STEPS_SAMPLED_LIFETIME, ) from ray.rllib.utils.test_utils import ( + check_compute_single_action, check_train_results, check_train_results_new_api_stack, - check_compute_single_action, ) diff --git a/rllib/algorithms/appo/tests/test_appo_learner.py b/rllib/algorithms/appo/tests/test_appo_learner.py index bd8cbffc10eb..92f1df9f8608 100644 --- a/rllib/algorithms/appo/tests/test_appo_learner.py +++ b/rllib/algorithms/appo/tests/test_appo_learner.py @@ -1,6 +1,6 @@ import unittest -import numpy as np +import numpy as np import tree # pip install dm_tree import ray @@ -13,7 +13,6 @@ from ray.rllib.utils.metrics import LEARNER_RESULTS from ray.rllib.utils.torch_utils import convert_to_torch_tensor - frag_length = 50 FAKE_BATCH = { @@ -119,7 +118,8 @@ def test_kl_coeff_changes(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/appo/torch/appo_torch_learner.py b/rllib/algorithms/appo/torch/appo_torch_learner.py index 62a4198952ec..9e3bbfca3b92 100644 --- a/rllib/algorithms/appo/torch/appo_torch_learner.py +++ b/rllib/algorithms/appo/torch/appo_torch_learner.py @@ -12,9 +12,9 @@ from typing import Dict from ray.rllib.algorithms.appo.appo import ( - APPOConfig, LEARNER_RESULTS_CURR_KL_COEFF_KEY, LEARNER_RESULTS_KL_KEY, + APPOConfig, ) from ray.rllib.algorithms.appo.appo_learner import APPOLearner from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner @@ -23,7 +23,7 @@ vtrace_torch, ) from ray.rllib.core.columns import Columns -from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.learner.learner import ENTROPY_KEY, POLICY_LOSS_KEY, VF_LOSS_KEY from ray.rllib.core.rl_module.apis import ( TARGET_NETWORK_ACTION_DIST_INPUTS, TargetNetworkAPI, diff --git a/rllib/algorithms/appo/utils.py b/rllib/algorithms/appo/utils.py index 8c66f080c165..052115630b39 100644 --- a/rllib/algorithms/appo/utils.py +++ b/rllib/algorithms/appo/utils.py @@ -3,9 +3,9 @@ Luo et al. 2020 https://arxiv.org/pdf/1912.00167 """ -from collections import deque import threading import time +from collections import deque import numpy as np @@ -13,7 +13,6 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import OldAPIStack - POLICY_SCOPE = "func" TARGET_POLICY_SCOPE = "target_func" diff --git a/rllib/algorithms/bc/__init__.py b/rllib/algorithms/bc/__init__.py index 0bf454356c60..ac3749f7a57f 100644 --- a/rllib/algorithms/bc/__init__.py +++ b/rllib/algorithms/bc/__init__.py @@ -1,4 +1,4 @@ -from ray.rllib.algorithms.bc.bc import BCConfig, BC +from ray.rllib.algorithms.bc.bc import BC, BCConfig __all__ = [ "BC", diff --git a/rllib/algorithms/bc/bc_catalog.py b/rllib/algorithms/bc/bc_catalog.py index 1ac0e935266b..54a01ddd649c 100644 --- a/rllib/algorithms/bc/bc_catalog.py +++ b/rllib/algorithms/bc/bc_catalog.py @@ -2,9 +2,9 @@ import gymnasium as gym from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian +from ray.rllib.core.models.base import Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import FreeLogStdMLPHeadConfig, MLPHeadConfig -from ray.rllib.core.models.base import Model from ray.rllib.utils.annotations import OverrideToImplementCustomLogic diff --git a/rllib/algorithms/bc/tests/test_bc.py b/rllib/algorithms/bc/tests/test_bc.py index d3bbf371dad2..edec3c3422ed 100644 --- a/rllib/algorithms/bc/tests/test_bc.py +++ b/rllib/algorithms/bc/tests/test_bc.py @@ -1,7 +1,7 @@ -from pathlib import Path import unittest -import ray +from pathlib import Path +import ray from ray.rllib.algorithms.bc import BCConfig from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, @@ -88,7 +88,8 @@ def test_bc_compilation_and_learning_from_offline_file(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/callbacks.py b/rllib/algorithms/callbacks.py index 49e59d0c6a3e..9330e66335d7 100644 --- a/rllib/algorithms/callbacks.py +++ b/rllib/algorithms/callbacks.py @@ -2,7 +2,6 @@ from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.callbacks.utils import _make_multi_callbacks - # Backward compatibility DefaultCallbacks = RLlibCallback make_multi_callbacks = _make_multi_callbacks diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 6d3b95cad746..987afc5bbff5 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -1,6 +1,10 @@ import logging from typing import Optional, Type, Union +from ray._common.deprecation import ( + DEPRECATED_VALUE, + deprecation_warning, +) from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy @@ -25,24 +29,20 @@ ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import ( - DEPRECATED_VALUE, - deprecation_warning, -) from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.metrics import ( + LAST_TARGET_UPDATE_TS, LEARNER_RESULTS, LEARNER_UPDATE_TIMER, - LAST_TARGET_UPDATE_TS, NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED, NUM_TARGET_UPDATES, OFFLINE_SAMPLING_TIMER, - TARGET_NET_UPDATE_TIMER, - SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, + TARGET_NET_UPDATE_TIMER, TIMERS, ) from ray.rllib.utils.typing import ResultDict, RLModuleSpecType diff --git a/rllib/algorithms/cql/cql_tf_policy.py b/rllib/algorithms/cql/cql_tf_policy.py index 0bfc871f328d..ae6c4f8d4fef 100644 --- a/rllib/algorithms/cql/cql_tf_policy.py +++ b/rllib/algorithms/cql/cql_tf_policy.py @@ -1,40 +1,41 @@ """ TensorFlow policy class used for CQL. """ +import logging from functools import partial -import numpy as np +from typing import Dict, List, Type, Union + import gymnasium as gym -import logging +import numpy as np import tree -from typing import Dict, List, Type, Union import ray from ray.rllib.algorithms.sac.sac_tf_policy import ( + ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, + ComputeTDErrorMixin, + _get_dist_class, apply_gradients as sac_apply_gradients, + build_sac_model, compute_and_clip_gradients as sac_compute_and_clip_gradients, get_distribution_inputs_and_class, - _get_dist_class, - build_sac_model, postprocess_trajectory, setup_late_mixins, stats, validate_spaces, - ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, - ComputeTDErrorMixin, ) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution -from ray.rllib.policy.tf_mixins import TargetNetworkMixin -from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import TargetNetworkMixin +from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils.exploration.random import Random from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_tfp from ray.rllib.utils.typing import ( + AlgorithmConfigDict, LocalOptimizer, ModelGradients, TensorType, - AlgorithmConfigDict, ) tf1, tf, tfv = try_import_tf() diff --git a/rllib/algorithms/cql/cql_torch_policy.py b/rllib/algorithms/cql/cql_torch_policy.py index 2f67c8d642bb..a7fab43bda61 100644 --- a/rllib/algorithms/cql/cql_torch_policy.py +++ b/rllib/algorithms/cql/cql_torch_policy.py @@ -1,40 +1,41 @@ """ PyTorch policy class used for CQL. """ -import numpy as np -import gymnasium as gym import logging -import tree from typing import Dict, List, Tuple, Type, Union +import gymnasium as gym +import numpy as np +import tree + import ray from ray.rllib.algorithms.sac.sac_tf_policy import ( postprocess_trajectory, validate_spaces, ) from ray.rllib.algorithms.sac.sac_torch_policy import ( + ComputeTDErrorMixin, _get_dist_class, - stats, + action_distribution_fn, build_sac_model_and_action_dist, optimizer_fn, - ComputeTDErrorMixin, setup_late_mixins, - action_distribution_fn, + stats, ) -from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy -from ray.rllib.policy.torch_mixins import TargetNetworkMixin +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import TargetNetworkMixin from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY -from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict from ray.rllib.utils.torch_utils import ( apply_grad_clipping, - convert_to_torch_tensor, concat_multi_gpu_td_errors, + convert_to_torch_tensor, ) +from ray.rllib.utils.typing import AlgorithmConfigDict, LocalOptimizer, TensorType torch, nn = try_import_torch() F = nn.functional diff --git a/rllib/algorithms/cql/tests/test_cql_old_api_stack.py b/rllib/algorithms/cql/tests/test_cql_old_api_stack.py index 1321741253a8..c2d3686da71c 100644 --- a/rllib/algorithms/cql/tests/test_cql_old_api_stack.py +++ b/rllib/algorithms/cql/tests/test_cql_old_api_stack.py @@ -1,6 +1,6 @@ -from pathlib import Path import os import unittest +from pathlib import Path import ray from ray.rllib.algorithms import cql @@ -121,7 +121,8 @@ def test_cql_compilation(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index e9f6897d3c83..4c04fb5de873 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -1,27 +1,27 @@ from typing import Dict -from ray.tune.result import TRAINING_ITERATION +from ray.rllib.algorithms.cql.cql import CQLConfig from ray.rllib.algorithms.sac.sac_learner import ( LOGPS_KEY, QF_LOSS_KEY, - QF_MEAN_KEY, QF_MAX_KEY, + QF_MEAN_KEY, QF_MIN_KEY, QF_PREDS, QF_TWIN_LOSS_KEY, QF_TWIN_PREDS, TD_ERROR_MEAN_KEY, ) -from ray.rllib.algorithms.cql.cql import CQLConfig from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ( POLICY_LOSS_KEY, ) from ray.rllib.utils.annotations import override -from ray.rllib.utils.metrics import ALL_MODULES from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ALL_MODULES from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType +from ray.tune.result import TRAINING_ITERATION torch, nn = try_import_torch() diff --git a/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py b/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py index 32e90815710e..1c2e7a7a2301 100644 --- a/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py +++ b/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py @@ -1,11 +1,12 @@ -import tree from typing import Any, Dict, Optional +import tree + +from ray.rllib.algorithms.sac.sac_catalog import SACCatalog from ray.rllib.algorithms.sac.sac_learner import ( QF_PREDS, QF_TWIN_PREDS, ) -from ray.rllib.algorithms.sac.sac_catalog import SACCatalog from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import ( DefaultSACTorchRLModule, ) diff --git a/rllib/algorithms/dqn/default_dqn_rl_module.py b/rllib/algorithms/dqn/default_dqn_rl_module.py index 78f9fe2c2e60..b4062ead7adf 100644 --- a/rllib/algorithms/dqn/default_dqn_rl_module.py +++ b/rllib/algorithms/dqn/default_dqn_rl_module.py @@ -3,17 +3,16 @@ from ray.rllib.core.learner.utils import make_target_network from ray.rllib.core.models.base import Encoder, Model -from ray.rllib.core.rl_module.apis import QNetAPI, InferenceOnlyAPI, TargetNetworkAPI +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, + override, ) from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import NetworkType, TensorType from ray.util.annotations import DeveloperAPI - QF_PREDS = "qf_preds" ATOMS = "atoms" QF_LOGITS = "qf_logits" diff --git a/rllib/algorithms/dqn/distributional_q_tf_model.py b/rllib/algorithms/dqn/distributional_q_tf_model.py index a4dd63f587b7..421f5716d2b7 100644 --- a/rllib/algorithms/dqn/distributional_q_tf_model.py +++ b/rllib/algorithms/dqn/distributional_q_tf_model.py @@ -3,6 +3,7 @@ from typing import List import gymnasium as gym + from ray.rllib.models.tf.layers import NoisyLayer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.annotations import OldAPIStack diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 6bc65698a56b..cf0bd69c0714 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -9,11 +9,13 @@ https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn """ # noqa: E501 -from collections import defaultdict import logging +from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + import numpy as np +from ray._common.deprecation import DEPRECATED_VALUE from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy @@ -23,21 +25,14 @@ from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, ) -from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.execution.train_ops import ( - train_one_step, multi_gpu_train_one_step, + train_one_step, ) from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override -from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.replay_buffers.utils import ( - update_priorities_in_episode_replay_buffer, - update_priorities_in_replay_buffer, - validate_buffer_config, -) -from ray.rllib.utils.typing import ResultDict from ray.rllib.utils.metrics import ( ALL_MODULES, ENV_RUNNER_RESULTS, @@ -59,10 +54,16 @@ TD_ERROR_KEY, TIMERS, ) -from ray._common.deprecation import DEPRECATED_VALUE -from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.replay_buffers.utils import ( + sample_min_n_steps_from_buffer, + update_priorities_in_episode_replay_buffer, + update_priorities_in_replay_buffer, + validate_buffer_config, +) from ray.rllib.utils.typing import ( LearningRateOrSchedule, + ResultDict, RLModuleSpecType, SampleBatchType, ) diff --git a/rllib/algorithms/dqn/dqn_catalog.py b/rllib/algorithms/dqn/dqn_catalog.py index f98dc5429c3a..32c7cf1c063f 100644 --- a/rllib/algorithms/dqn/dqn_catalog.py +++ b/rllib/algorithms/dqn/dqn_catalog.py @@ -1,13 +1,13 @@ import gymnasium as gym -from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.distribution.torch.torch_distribution import TorchCategorical from ray.rllib.core.models.base import Model +from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import MLPHeadConfig -from ray.rllib.core.distribution.torch.torch_distribution import TorchCategorical from ray.rllib.utils.annotations import ( ExperimentalAPI, - override, OverrideToImplementCustomLogic, + override, ) diff --git a/rllib/algorithms/dqn/dqn_learner.py b/rllib/algorithms/dqn/dqn_learner.py index b55385eaf939..64bc51969a75 100644 --- a/rllib/algorithms/dqn/dqn_learner.py +++ b/rllib/algorithms/dqn/dqn_learner.py @@ -12,8 +12,8 @@ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, @@ -22,7 +22,6 @@ ) from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn - # Now, this is double defined: In `SACRLModule` and here. I would keep it here # or push it into the `Learner` as these are recurring keys in RL. ATOMS = "atoms" diff --git a/rllib/algorithms/dqn/dqn_torch_model.py b/rllib/algorithms/dqn/dqn_torch_model.py index 03c109878f73..4cb93bb63967 100644 --- a/rllib/algorithms/dqn/dqn_torch_model.py +++ b/rllib/algorithms/dqn/dqn_torch_model.py @@ -1,7 +1,9 @@ """PyTorch model for DQN""" from typing import Sequence + import gymnasium as gym + from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 diff --git a/rllib/algorithms/dqn/dqn_torch_policy.py b/rllib/algorithms/dqn/dqn_torch_policy.py index 3229e379c730..fead64a5bc11 100644 --- a/rllib/algorithms/dqn/dqn_torch_policy.py +++ b/rllib/algorithms/dqn/dqn_torch_policy.py @@ -3,6 +3,7 @@ from typing import Dict, List, Tuple import gymnasium as gym + import ray from ray.rllib.algorithms.dqn.dqn_tf_policy import ( PRIO_WEIGHTS, @@ -14,8 +15,8 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( - get_torch_categorical_class_with_temperature, TorchDistributionWrapper, + get_torch_categorical_class_with_temperature, ) from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class @@ -29,15 +30,15 @@ from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_utils import ( + FLOAT_MIN, apply_grad_clipping, concat_multi_gpu_td_errors, - FLOAT_MIN, huber_loss, l2_loss, reduce_mean_ignore_inf, softmax_cross_entropy_with_logits, ) -from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict +from ray.rllib.utils.typing import AlgorithmConfigDict, TensorType torch, nn = try_import_torch() F = None diff --git a/rllib/algorithms/dqn/tests/test_dqn.py b/rllib/algorithms/dqn/tests/test_dqn.py index 238daefdb2f5..9805ce181d04 100644 --- a/rllib/algorithms/dqn/tests/test_dqn.py +++ b/rllib/algorithms/dqn/tests/test_dqn.py @@ -47,7 +47,8 @@ def test_dqn_compilation(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py b/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py index 968ebe2da68d..b1a07226d5c7 100644 --- a/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py +++ b/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py @@ -1,8 +1,8 @@ -import tree from typing import Dict, Union +import tree + from ray.rllib.algorithms.dqn.default_dqn_rl_module import ( - DefaultDQNRLModule, ATOMS, QF_LOGITS, QF_NEXT_PREDS, @@ -10,16 +10,17 @@ QF_PROBS, QF_TARGET_NEXT_PREDS, QF_TARGET_NEXT_PROBS, + DefaultDQNRLModule, ) from ray.rllib.algorithms.dqn.dqn_catalog import DQNCatalog from ray.rllib.core.columns import Columns -from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model +from ray.rllib.core.models.base import ENCODER_OUT, Encoder, Model from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI -from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import TensorType, TensorStructType +from ray.rllib.utils.typing import TensorStructType, TensorType from ray.util.annotations import DeveloperAPI torch, nn = try_import_torch() diff --git a/rllib/algorithms/dqn/torch/dqn_torch_learner.py b/rllib/algorithms/dqn/torch/dqn_torch_learner.py index 4289bcc7cdcf..3e77529bc130 100644 --- a/rllib/algorithms/dqn/torch/dqn_torch_learner.py +++ b/rllib/algorithms/dqn/torch/dqn_torch_learner.py @@ -3,18 +3,18 @@ from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.dqn.dqn_learner import ( ATOMS, - DQNLearner, - QF_LOSS_KEY, QF_LOGITS, - QF_MEAN_KEY, + QF_LOSS_KEY, QF_MAX_KEY, + QF_MEAN_KEY, QF_MIN_KEY, QF_NEXT_PREDS, - QF_TARGET_NEXT_PREDS, - QF_TARGET_NEXT_PROBS, QF_PREDS, QF_PROBS, + QF_TARGET_NEXT_PREDS, + QF_TARGET_NEXT_PROBS, TD_ERROR_MEAN_KEY, + DQNLearner, ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.torch.torch_learner import TorchLearner @@ -23,7 +23,6 @@ from ray.rllib.utils.metrics import TD_ERROR_KEY from ray.rllib.utils.typing import ModuleID, TensorType - torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/dreamerv3.py b/rllib/algorithms/dreamerv3/dreamerv3.py index 63784d3e09a7..44a677d8ac9f 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3.py +++ b/rllib/algorithms/dreamerv3/dreamerv3.py @@ -32,8 +32,7 @@ from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import deep_update -from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.numpy import one_hot +from ray.rllib.utils.annotations import PublicAPI, override from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, LEARN_ON_BATCH_TIMER, @@ -47,10 +46,10 @@ SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, ) +from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.typing import LearningRateOrSchedule - logger = logging.getLogger(__name__) diff --git a/rllib/algorithms/dreamerv3/dreamerv3_catalog.py b/rllib/algorithms/dreamerv3/dreamerv3_catalog.py index a2cca266ec64..ce16b747ec4d 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_catalog.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_catalog.py @@ -4,11 +4,11 @@ from ray.rllib.algorithms.dreamerv3.utils import ( do_symlog_obs, get_gru_units, - get_num_z_classes, get_num_z_categoricals, + get_num_z_classes, ) -from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.core.models.catalog import Catalog from ray.rllib.utils import override diff --git a/rllib/algorithms/dreamerv3/dreamerv3_learner.py b/rllib/algorithms/dreamerv3/dreamerv3_learner.py index 2bd634ca76e8..b2c0cf27cb22 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_learner.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_learner.py @@ -9,8 +9,8 @@ """ from ray.rllib.core.learner.learner import Learner from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) diff --git a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py index 20e0b8140b16..5cf8f4884a97 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py @@ -5,21 +5,20 @@ import abc from typing import Dict +from ray.rllib.algorithms.dreamerv3.torch.models.actor_network import ActorNetwork +from ray.rllib.algorithms.dreamerv3.torch.models.critic_network import CriticNetwork +from ray.rllib.algorithms.dreamerv3.torch.models.dreamer_model import DreamerModel +from ray.rllib.algorithms.dreamerv3.torch.models.world_model import WorldModel from ray.rllib.algorithms.dreamerv3.utils import ( do_symlog_obs, get_gru_units, get_num_z_categoricals, get_num_z_classes, ) -from ray.rllib.algorithms.dreamerv3.torch.models.actor_network import ActorNetwork -from ray.rllib.algorithms.dreamerv3.torch.models.critic_network import CriticNetwork -from ray.rllib.algorithms.dreamerv3.torch.models.dreamer_model import DreamerModel -from ray.rllib.algorithms.dreamerv3.torch.models.world_model import WorldModel from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import override from ray.util.annotations import DeveloperAPI - ACTIONS_ONE_HOT = "actions_one_hot" diff --git a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py index a3936253a80b..096fdf7d7fa6 100644 --- a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py +++ b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py @@ -19,6 +19,7 @@ import tree # pip install dm_tree import ray +from ray import tune from ray.rllib.algorithms.dreamerv3 import dreamerv3 from ray.rllib.connectors.env_to_module import FlattenObservations from ray.rllib.core import DEFAULT_MODULE_ID @@ -27,7 +28,6 @@ from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.test_utils import check -from ray import tune torch, nn = try_import_torch() @@ -317,7 +317,8 @@ def test_dreamerv3_dreamer_model_sizes(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py b/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py index f33b9d2deb57..8d9c0ec4ea04 100644 --- a/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py +++ b/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py @@ -19,7 +19,7 @@ from ray.rllib.core.learner.torch.torch_learner import TorchLearner from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_utils import symlog, two_hot, clip_gradients +from ray.rllib.utils.torch_utils import clip_gradients, symlog, two_hot from ray.rllib.utils.typing import ModuleID, TensorType torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/actor_network.py b/rllib/algorithms/dreamerv3/torch/models/actor_network.py index 8a02a41bd9bf..8dc90f4bdf9d 100644 --- a/rllib/algorithms/dreamerv3/torch/models/actor_network.py +++ b/rllib/algorithms/dreamerv3/torch/models/actor_network.py @@ -9,7 +9,6 @@ from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.utils.framework import try_import_torch - torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py b/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py index 14e8a39c829a..64df56079bda 100644 --- a/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py +++ b/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py @@ -5,10 +5,10 @@ """ from typing import Optional -from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.torch.models.components import ( representation_layer, ) +from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units from ray.rllib.utils.framework import try_import_torch diff --git a/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py b/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py index 2733dd2cc132..98f5920f5890 100644 --- a/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py +++ b/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py @@ -3,12 +3,11 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.torch.models.components import ( reward_predictor_layer, ) +from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units - from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py b/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py index 1fbd695d54cc..38934a016aa6 100644 --- a/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py +++ b/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py @@ -11,7 +11,7 @@ dreamerv3_normal_initializer, ) from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP -from ray.rllib.algorithms.dreamerv3.utils import get_gru_units, get_dense_hidden_units +from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units, get_gru_units from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/critic_network.py b/rllib/algorithms/dreamerv3/torch/models/critic_network.py index d4b5798eb55a..f4b4fb956778 100644 --- a/rllib/algorithms/dreamerv3/torch/models/critic_network.py +++ b/rllib/algorithms/dreamerv3/torch/models/critic_network.py @@ -3,11 +3,11 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units -from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.torch.models.components import ( reward_predictor_layer, ) +from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP +from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/world_model.py b/rllib/algorithms/dreamerv3/torch/models/world_model.py index c8851ea8dd71..5e2b4de597f3 100644 --- a/rllib/algorithms/dreamerv3/torch/models/world_model.py +++ b/rllib/algorithms/dreamerv3/torch/models/world_model.py @@ -9,6 +9,9 @@ import numpy as np import tree # pip install dm_tree +from ray.rllib.algorithms.dreamerv3.torch.models.components import ( + representation_layer, +) from ray.rllib.algorithms.dreamerv3.torch.models.components.continue_predictor import ( ContinuePredictor, ) @@ -16,9 +19,6 @@ DynamicsPredictor, ) from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP -from ray.rllib.algorithms.dreamerv3.torch.models.components import ( - representation_layer, -) from ray.rllib.algorithms.dreamerv3.torch.models.components.reward_predictor import ( RewardPredictor, ) diff --git a/rllib/algorithms/dreamerv3/utils/debugging.py b/rllib/algorithms/dreamerv3/utils/debugging.py index d69281713a38..a99d2923d4ad 100644 --- a/rllib/algorithms/dreamerv3/utils/debugging.py +++ b/rllib/algorithms/dreamerv3/utils/debugging.py @@ -1,8 +1,7 @@ import gymnasium as gym import numpy as np -from PIL import Image, ImageDraw - from gymnasium.envs.classic_control.cartpole import CartPoleEnv +from PIL import Image, ImageDraw from ray.rllib.utils.framework import try_import_torch diff --git a/rllib/algorithms/impala/__init__.py b/rllib/algorithms/impala/__init__.py index 913c1b77198e..f81a5666eb0d 100644 --- a/rllib/algorithms/impala/__init__.py +++ b/rllib/algorithms/impala/__init__.py @@ -1,7 +1,7 @@ from ray.rllib.algorithms.impala.impala import ( IMPALA, - IMPALAConfig, Impala, + IMPALAConfig, ImpalaConfig, ) from ray.rllib.algorithms.impala.impala_tf_policy import ( diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index c183b9e2f653..47af681125fd 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -5,6 +5,7 @@ import ray from ray import ObjectRef +from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib import SampleBatch from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided @@ -21,7 +22,6 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import concat_samples from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.metrics import ( AGGREGATOR_ACTOR_RESULTS, ALL_MODULES, @@ -30,8 +30,8 @@ LEARNER_RESULTS, LEARNER_UPDATE_TIMER, MEAN_NUM_EPISODE_LISTS_RECEIVED, - MEAN_NUM_LEARNER_RESULTS_RECEIVED, MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED, + MEAN_NUM_LEARNER_RESULTS_RECEIVED, NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_SAMPLED, @@ -40,8 +40,8 @@ NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_SYNCH_WORKER_WEIGHTS, NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS, - SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, ) from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder @@ -55,7 +55,6 @@ SampleBatchType, ) - logger = logging.getLogger(__name__) diff --git a/rllib/algorithms/impala/impala_learner.py b/rllib/algorithms/impala/impala_learner.py index 95ef5a947623..6ce2768cfdd7 100644 --- a/rllib/algorithms/impala/impala_learner.py +++ b/rllib/algorithms/impala/impala_learner.py @@ -1,7 +1,7 @@ -from collections import deque import queue import threading import time +from collections import deque from typing import Any, Dict, Union import ray @@ -12,8 +12,8 @@ from ray.rllib.core.learner.training_data import TrainingData from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict diff --git a/rllib/algorithms/impala/impala_tf_policy.py b/rllib/algorithms/impala/impala_tf_policy.py index a1f74f48f8ce..94ee60e20260 100644 --- a/rllib/algorithms/impala/impala_tf_policy.py +++ b/rllib/algorithms/impala/impala_tf_policy.py @@ -2,11 +2,12 @@ Keep in sync with changes to A3CTFPolicy and VtraceSurrogatePolicy.""" -import numpy as np import logging -import gymnasium as gym from typing import Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + from ray.rllib.algorithms.impala import vtrace_tf as vtrace from ray.rllib.evaluation.postprocessing import compute_bootstrap_value from ray.rllib.models.modelv2 import ModelV2 @@ -14,12 +15,16 @@ from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import LearningRateSchedule, EntropyCoeffSchedule +from ray.rllib.policy.tf_mixins import ( + EntropyCoeffSchedule, + GradStatsMixin, + LearningRateSchedule, + ValueNetworkMixin, +) from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance -from ray.rllib.policy.tf_mixins import GradStatsMixin, ValueNetworkMixin from ray.rllib.utils.typing import ( LocalOptimizer, ModelGradients, diff --git a/rllib/algorithms/impala/impala_torch_policy.py b/rllib/algorithms/impala/impala_torch_policy.py index c174149f7c60..ee58654cab7b 100644 --- a/rllib/algorithms/impala/impala_torch_policy.py +++ b/rllib/algorithms/impala/impala_torch_policy.py @@ -1,12 +1,13 @@ -import gymnasium as gym import logging -import numpy as np from typing import Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + import ray from ray.rllib.evaluation.postprocessing import compute_bootstrap_value -from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_mixins import ( diff --git a/rllib/algorithms/impala/tests/test_impala.py b/rllib/algorithms/impala/tests/test_impala.py index 868062f019ea..be5ee0eccfb9 100644 --- a/rllib/algorithms/impala/tests/test_impala.py +++ b/rllib/algorithms/impala/tests/test_impala.py @@ -64,7 +64,8 @@ def get_lr(result): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py b/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py index 303797f2f947..d538a032eecb 100644 --- a/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py +++ b/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py @@ -20,10 +20,11 @@ by Espeholt, Soyer, Munos et al. """ -from gymnasium.spaces import Box -import numpy as np import unittest +import numpy as np +from gymnasium.spaces import Box + from ray.rllib.algorithms.impala import vtrace_torch as vtrace_torch from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import softmax @@ -282,7 +283,8 @@ def test_inconsistent_rank_inputs_for_importance_weights(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/impala/tests/test_vtrace_v2.py b/rllib/algorithms/impala/tests/test_vtrace_v2.py index 79387104d968..fda785d3df90 100644 --- a/rllib/algorithms/impala/tests/test_vtrace_v2.py +++ b/rllib/algorithms/impala/tests/test_vtrace_v2.py @@ -1,19 +1,19 @@ import unittest -import numpy as np +import numpy as np from gymnasium.spaces import Box, Discrete -from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( - vtrace_torch, - make_time_major, -) from ray.rllib.algorithms.impala.tests.test_vtrace_old_api_stack import ( _ground_truth_vtrace_calculation, ) -from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( + make_time_major, + vtrace_torch, +) from ray.rllib.core.distribution.torch.torch_distribution import TorchCategorical from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.test_utils import check +from ray.rllib.utils.torch_utils import convert_to_torch_tensor torch, _ = try_import_torch() @@ -147,7 +147,8 @@ def test_vtrace_torch(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index 256e3b48fb79..f52250d16a1c 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -3,8 +3,8 @@ from ray.rllib.algorithms.impala.impala import IMPALAConfig from ray.rllib.algorithms.impala.impala_learner import IMPALALearner from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( - vtrace_torch, make_time_major, + vtrace_torch, ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ENTROPY_KEY diff --git a/rllib/algorithms/impala/torch/vtrace_torch_v2.py b/rllib/algorithms/impala/torch/vtrace_torch_v2.py index 48231be9d7d5..bf4c4fa99373 100644 --- a/rllib/algorithms/impala/torch/vtrace_torch_v2.py +++ b/rllib/algorithms/impala/torch/vtrace_torch_v2.py @@ -1,4 +1,5 @@ from typing import List, Union + from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/iql/default_iql_rl_module.py b/rllib/algorithms/iql/default_iql_rl_module.py index e6e3b2279ac5..95596bd8b91d 100644 --- a/rllib/algorithms/iql/default_iql_rl_module.py +++ b/rllib/algorithms/iql/default_iql_rl_module.py @@ -2,8 +2,8 @@ from ray.rllib.core.models.configs import MLPHeadConfig from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) diff --git a/rllib/algorithms/iql/iql_learner.py b/rllib/algorithms/iql/iql_learner.py index 5821f2ccb5e0..ef2e2e83e15b 100644 --- a/rllib/algorithms/iql/iql_learner.py +++ b/rllib/algorithms/iql/iql_learner.py @@ -2,8 +2,8 @@ from ray.rllib.algorithms.dqn.dqn_learner import DQNLearner from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.typing import ModuleID, TensorType diff --git a/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py b/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py index 00d7fc821e49..318dcc207533 100644 --- a/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py +++ b/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py @@ -1,8 +1,9 @@ -import gymnasium as gym from typing import Any, Dict, Optional +import gymnasium as gym + from ray.rllib.algorithms.iql.default_iql_rl_module import DefaultIQLRLModule -from ray.rllib.algorithms.iql.iql_learner import VF_PREDS_NEXT, QF_TARGET_PREDS +from ray.rllib.algorithms.iql.iql_learner import QF_TARGET_PREDS, VF_PREDS_NEXT from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import ( DefaultSACTorchRLModule, ) diff --git a/rllib/algorithms/iql/torch/iql_torch_learner.py b/rllib/algorithms/iql/torch/iql_torch_learner.py index 85dc68e86fb2..54a4fd263caa 100644 --- a/rllib/algorithms/iql/torch/iql_torch_learner.py +++ b/rllib/algorithms/iql/torch/iql_torch_learner.py @@ -1,14 +1,14 @@ from typing import Dict from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.algorithms.dqn.dqn_learner import QF_PREDS, QF_LOSS_KEY +from ray.rllib.algorithms.dqn.dqn_learner import QF_LOSS_KEY, QF_PREDS from ray.rllib.algorithms.iql.iql_learner import ( - IQLLearner, QF_TARGET_PREDS, - VF_PREDS_NEXT, VF_LOSS, + VF_PREDS_NEXT, + IQLLearner, ) -from ray.rllib.algorithms.sac.sac_learner import QF_TWIN_PREDS, QF_TWIN_LOSS_KEY +from ray.rllib.algorithms.sac.sac_learner import QF_TWIN_LOSS_KEY, QF_TWIN_PREDS from ray.rllib.core import ALL_MODULES from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ( diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index 7dfb1e1dcde6..5ba2cb6dd69f 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -1,11 +1,12 @@ from typing import Callable, Optional, Type, Union +from ray._common.deprecation import deprecation_warning from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.connectors.learner import ( + AddNextObservationsFromEpisodesToTrainBatch, AddObservationsFromEpisodesToBatch, AddOneTsToEpisodesAndTruncate, - AddNextObservationsFromEpisodesToTrainBatch, GeneralAdvantageEstimation, ) from ray.rllib.core.learner.learner import Learner @@ -20,7 +21,6 @@ ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import deprecation_warning from ray.rllib.utils.metrics import ( LEARNER_RESULTS, LEARNER_UPDATE_TIMER, diff --git a/rllib/algorithms/marwil/marwil_learner.py b/rllib/algorithms/marwil/marwil_learner.py index 363e6a84a309..b98b0d090f66 100644 --- a/rllib/algorithms/marwil/marwil_learner.py +++ b/rllib/algorithms/marwil/marwil_learner.py @@ -1,7 +1,7 @@ from typing import Dict, Optional -from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.utils.annotations import override from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn, TensorType diff --git a/rllib/algorithms/marwil/marwil_tf_policy.py b/rllib/algorithms/marwil/marwil_tf_policy.py index 5f75a8424c76..2dbdb6a0efd2 100644 --- a/rllib/algorithms/marwil/marwil_tf_policy.py +++ b/rllib/algorithms/marwil/marwil_tf_policy.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict, List, Optional, Type, Union -from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing +from ray.rllib.evaluation.postprocessing import Postprocessing, compute_advantages from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution @@ -14,7 +14,7 @@ compute_gradients, ) from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_tf, get_variable +from ray.rllib.utils.framework import get_variable, try_import_tf from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.typing import ( LocalOptimizer, diff --git a/rllib/algorithms/marwil/tests/test_marwil.py b/rllib/algorithms/marwil/tests/test_marwil.py index 3bfc8ff30231..5f39cf9752c0 100644 --- a/rllib/algorithms/marwil/tests/test_marwil.py +++ b/rllib/algorithms/marwil/tests/test_marwil.py @@ -1,11 +1,12 @@ +import unittest +from pathlib import Path + import gymnasium as gym import numpy as np -from pathlib import Path -import unittest import ray import ray.rllib.algorithms.marwil as marwil -from ray.rllib.core import DEFAULT_MODULE_ID, COMPONENT_RL_MODULE +from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY from ray.rllib.env import INPUT_ENV_SPACES @@ -232,7 +233,8 @@ def possibly_masked_mean(data_): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/marwil/tests/test_marwil_rl_module.py b/rllib/algorithms/marwil/tests/test_marwil_rl_module.py index 683180d0609a..8ea50e5be7f3 100644 --- a/rllib/algorithms/marwil/tests/test_marwil_rl_module.py +++ b/rllib/algorithms/marwil/tests/test_marwil_rl_module.py @@ -1,9 +1,9 @@ import itertools import unittest -import ray - from pathlib import Path +import ray + class TestMARWIL(unittest.TestCase): @classmethod @@ -31,7 +31,8 @@ def test_rollouts(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/mock.py b/rllib/algorithms/mock.py index 25707cf1677b..ba2ac262af21 100644 --- a/rllib/algorithms/mock.py +++ b/rllib/algorithms/mock.py @@ -4,9 +4,9 @@ import numpy as np -from ray.tune import result as tune_result from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig from ray.rllib.utils.annotations import override +from ray.tune import result as tune_result class _MockTrainer(Algorithm): diff --git a/rllib/algorithms/ppo/__init__.py b/rllib/algorithms/ppo/__init__.py index a02982e64a53..9ed907f5dd1e 100644 --- a/rllib/algorithms/ppo/__init__.py +++ b/rllib/algorithms/ppo/__init__.py @@ -1,4 +1,4 @@ -from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO +from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy diff --git a/rllib/algorithms/ppo/default_ppo_rl_module.py b/rllib/algorithms/ppo/default_ppo_rl_module.py index 1216eeef0d75..5ac176452f36 100644 --- a/rllib/algorithms/ppo/default_ppo_rl_module.py +++ b/rllib/algorithms/ppo/default_ppo_rl_module.py @@ -5,8 +5,8 @@ from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.util.annotations import DeveloperAPI diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 7ffa74477928..439243b86dda 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -10,8 +10,9 @@ """ import logging -from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union +from ray._common.deprecation import DEPRECATED_VALUE from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.core.rl_module.rl_module import RLModuleSpec @@ -20,13 +21,13 @@ synchronous_parallel_sample, ) from ray.rllib.execution.train_ops import ( - train_one_step, multi_gpu_train_one_step, + train_one_step, ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import DEPRECATED_VALUE from ray.rllib.utils.metrics import ( + ALL_MODULES, ENV_RUNNER_RESULTS, ENV_RUNNER_SAMPLING_TIMER, LEARNER_RESULTS, @@ -34,10 +35,9 @@ NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED_LIFETIME, - SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, - ALL_MODULES, ) from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules.scheduler import Scheduler diff --git a/rllib/algorithms/ppo/ppo_catalog.py b/rllib/algorithms/ppo/ppo_catalog.py index fb11efea17ba..e88b761427a2 100644 --- a/rllib/algorithms/ppo/ppo_catalog.py +++ b/rllib/algorithms/ppo/ppo_catalog.py @@ -1,13 +1,13 @@ # __sphinx_doc_begin__ import gymnasium as gym +from ray.rllib.core.models.base import ActorCriticEncoder, Encoder, Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( ActorCriticEncoderConfig, - MLPHeadConfig, FreeLogStdMLPHeadConfig, + MLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model from ray.rllib.utils import override from ray.rllib.utils.annotations import OverrideToImplementCustomLogic diff --git a/rllib/algorithms/ppo/ppo_learner.py b/rllib/algorithms/ppo/ppo_learner.py index b6d3953a8a45..ef16f71c98bb 100644 --- a/rllib/algorithms/ppo/ppo_learner.py +++ b/rllib/algorithms/ppo/ppo_learner.py @@ -13,8 +13,8 @@ from ray.rllib.core.learner.learner import Learner from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.metrics import ( diff --git a/rllib/algorithms/ppo/tests/test_ppo.py b/rllib/algorithms/ppo/tests/test_ppo.py index a0544a566944..6531d2d8f5cf 100644 --- a/rllib/algorithms/ppo/tests/test_ppo.py +++ b/rllib/algorithms/ppo/tests/test_ppo.py @@ -161,7 +161,8 @@ def get_value(log_std_var=log_std_var): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/tests/test_ppo_learner.py b/rllib/algorithms/ppo/tests/test_ppo_learner.py index 1d5f83639bb9..825b1411b948 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_learner.py +++ b/rllib/algorithms/ppo/tests/test_ppo_learner.py @@ -1,5 +1,5 @@ -import unittest import tempfile +import unittest import gymnasium as gym import numpy as np @@ -13,7 +13,6 @@ from ray.rllib.utils.test_utils import check from ray.tune.registry import register_env - # Fake CartPole episode of n time steps. FAKE_BATCH = { Columns.OBS: np.array( @@ -136,7 +135,8 @@ def test_kl_coeff_changes(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index c7d786d3d1a5..a8d5999a586d 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -17,7 +17,6 @@ from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor - torch, nn = try_import_torch() @@ -186,7 +185,8 @@ def test_forward_train(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index 190ecbf106c1..4e7a806f98ab 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -4,15 +4,15 @@ import numpy as np from ray.rllib.algorithms.ppo.ppo import ( - LEARNER_RESULTS_KL_KEY, LEARNER_RESULTS_CURR_KL_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, PPOConfig, ) from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner from ray.rllib.core.columns import Columns -from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.learner.learner import ENTROPY_KEY, POLICY_LOSS_KEY, VF_LOSS_KEY from ray.rllib.core.learner.torch.torch_learner import TorchLearner from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.utils.annotations import override diff --git a/rllib/algorithms/sac/default_sac_rl_module.py b/rllib/algorithms/sac/default_sac_rl_module.py index 3d01e5ed5ccc..76f02a1e4c7f 100644 --- a/rllib/algorithms/sac/default_sac_rl_module.py +++ b/rllib/algorithms/sac/default_sac_rl_module.py @@ -6,8 +6,8 @@ from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, + override, ) from ray.rllib.utils.typing import NetworkType from ray.util.annotations import DeveloperAPI diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index 6a0c2375153a..75751a360bad 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -1,6 +1,7 @@ import logging from typing import Any, Dict, Optional, Tuple, Type, Union +from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.dqn.dqn import DQN from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy @@ -15,7 +16,6 @@ from ray.rllib.policy.policy import Policy from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override -from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType diff --git a/rllib/algorithms/sac/sac_catalog.py b/rllib/algorithms/sac/sac_catalog.py index c60bfd77992f..2ebe4470af18 100644 --- a/rllib/algorithms/sac/sac_catalog.py +++ b/rllib/algorithms/sac/sac_catalog.py @@ -1,23 +1,24 @@ +from typing import Callable + import gymnasium as gym import numpy as np -from typing import Callable # TODO (simon): Store this function somewhere more central as many # algorithms will use it. from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian +from ray.rllib.core.distribution.distribution import Distribution +from ray.rllib.core.distribution.torch.torch_distribution import ( + TorchCategorical, + TorchSquashedGaussian, +) +from ray.rllib.core.models.base import Encoder, Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( FreeLogStdMLPHeadConfig, MLPEncoderConfig, MLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, Model -from ray.rllib.core.distribution.torch.torch_distribution import ( - TorchSquashedGaussian, - TorchCategorical, -) -from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic -from ray.rllib.core.distribution.distribution import Distribution +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic, override # TODO (simon): Check, if we can directly derive from DQNCatalog. diff --git a/rllib/algorithms/sac/sac_learner.py b/rllib/algorithms/sac/sac_learner.py index f4108943ad04..abbf082b1ca1 100644 --- a/rllib/algorithms/sac/sac_learner.py +++ b/rllib/algorithms/sac/sac_learner.py @@ -1,7 +1,7 @@ -import numpy as np - from typing import Dict +import numpy as np + from ray.rllib.algorithms.dqn.dqn_learner import DQNLearner from ray.rllib.core.learner.learner import Learner from ray.rllib.utils.annotations import override diff --git a/rllib/algorithms/sac/sac_tf_model.py b/rllib/algorithms/sac/sac_tf_model.py index 7302a25fcccf..e3b3479ff684 100644 --- a/rllib/algorithms/sac/sac_tf_model.py +++ b/rllib/algorithms/sac/sac_tf_model.py @@ -1,15 +1,16 @@ +from typing import Dict, List, Optional + import gymnasium as gym -from gymnasium.spaces import Box, Discrete import numpy as np import tree # pip install dm_tree -from typing import Dict, List, Optional +from gymnasium.spaces import Box, Discrete from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.typing import ModelConfigDict, TensorType, TensorStructType +from ray.rllib.utils.typing import ModelConfigDict, TensorStructType, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/algorithms/sac/sac_tf_policy.py b/rllib/algorithms/sac/sac_tf_policy.py index 2ce3184c70d9..e4322518e46a 100644 --- a/rllib/algorithms/sac/sac_tf_policy.py +++ b/rllib/algorithms/sac/sac_tf_policy.py @@ -3,16 +3,17 @@ """ import copy -import gymnasium as gym -from gymnasium.spaces import Box, Discrete -from functools import partial import logging +from functools import partial from typing import Dict, List, Optional, Tuple, Type, Union +import gymnasium as gym +from gymnasium.spaces import Box, Discrete + import ray from ray.rllib.algorithms.dqn.dqn_tf_policy import ( - postprocess_nstep_and_prio, PRIO_WEIGHTS, + postprocess_nstep_and_prio, ) from ray.rllib.algorithms.sac.sac_tf_model import SACTFModel from ray.rllib.algorithms.sac.sac_torch_model import SACTorchModel @@ -36,10 +37,10 @@ from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable from ray.rllib.utils.typing import ( AgentID, + AlgorithmConfigDict, LocalOptimizer, ModelGradients, TensorType, - AlgorithmConfigDict, ) tf1, tf, tfv = try_import_tf() diff --git a/rllib/algorithms/sac/sac_torch_model.py b/rllib/algorithms/sac/sac_torch_model.py index 00219fd95b8a..8c2fcd5b530c 100644 --- a/rllib/algorithms/sac/sac_torch_model.py +++ b/rllib/algorithms/sac/sac_torch_model.py @@ -1,15 +1,16 @@ +from typing import Dict, List, Optional + import gymnasium as gym -from gymnasium.spaces import Box, Discrete import numpy as np import tree # pip install dm_tree -from typing import Dict, List, Optional +from gymnasium.spaces import Box, Discrete from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.typing import ModelConfigDict, TensorType, TensorStructType +from ray.rllib.utils.typing import ModelConfigDict, TensorStructType, TensorType torch, nn = try_import_torch() diff --git a/rllib/algorithms/sac/sac_torch_policy.py b/rllib/algorithms/sac/sac_torch_policy.py index cef30f465f5d..b105b856ed0b 100644 --- a/rllib/algorithms/sac/sac_torch_policy.py +++ b/rllib/algorithms/sac/sac_torch_policy.py @@ -2,45 +2,46 @@ PyTorch policy class used for SAC. """ -import gymnasium as gym -from gymnasium.spaces import Box, Discrete import logging -import tree # pip install dm_tree from typing import Dict, List, Optional, Tuple, Type, Union +import gymnasium as gym +import tree # pip install dm_tree +from gymnasium.spaces import Box, Discrete + import ray +from ray.rllib.algorithms.dqn.dqn_tf_policy import PRIO_WEIGHTS from ray.rllib.algorithms.sac.sac_tf_policy import ( build_sac_model, postprocess_trajectory, validate_spaces, ) -from ray.rllib.algorithms.dqn.dqn_tf_policy import PRIO_WEIGHTS from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( + TorchBeta, TorchCategorical, - TorchDistributionWrapper, + TorchDiagGaussian, TorchDirichlet, + TorchDistributionWrapper, TorchSquashedGaussian, - TorchDiagGaussian, - TorchBeta, ) from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import TargetNetworkMixin from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.policy.torch_mixins import TargetNetworkMixin from ray.rllib.utils.torch_utils import ( apply_grad_clipping, concat_multi_gpu_td_errors, huber_loss, ) from ray.rllib.utils.typing import ( + AlgorithmConfigDict, LocalOptimizer, ModelInputDict, TensorType, - AlgorithmConfigDict, ) torch, nn = try_import_torch() diff --git a/rllib/algorithms/sac/tests/test_sac.py b/rllib/algorithms/sac/tests/test_sac.py index be49d2fd5f81..4d03ba92db63 100644 --- a/rllib/algorithms/sac/tests/test_sac.py +++ b/rllib/algorithms/sac/tests/test_sac.py @@ -1,16 +1,17 @@ +import unittest + import gymnasium as gym -from gymnasium.spaces import Box, Dict, Discrete, Tuple import numpy as np -import unittest +from gymnasium.spaces import Box, Dict, Discrete, Tuple import ray +from ray import tune from ray.rllib.algorithms import sac from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations from ray.rllib.examples.envs.classes.random_env import RandomEnv from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.test_utils import check_train_results_new_api_stack -from ray import tune tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -179,7 +180,8 @@ def step(self, action): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py b/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py index 0612dce7c391..09a0f6091ab1 100644 --- a/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py +++ b/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py @@ -1,17 +1,18 @@ -import gymnasium as gym from typing import Any, Dict +import gymnasium as gym + from ray.rllib.algorithms.sac.default_sac_rl_module import DefaultSACRLModule from ray.rllib.algorithms.sac.sac_catalog import SACCatalog from ray.rllib.algorithms.sac.sac_learner import ( ACTION_DIST_INPUTS_NEXT, - QF_PREDS, - QF_TWIN_PREDS, - QF_TARGET_NEXT, + ACTION_LOG_PROBS, ACTION_LOG_PROBS_NEXT, - ACTION_PROBS_NEXT, ACTION_PROBS, - ACTION_LOG_PROBS, + ACTION_PROBS_NEXT, + QF_PREDS, + QF_TARGET_NEXT, + QF_TWIN_PREDS, ) from ray.rllib.core.columns import Columns from ray.rllib.core.models.base import ENCODER_OUT, Encoder, Model diff --git a/rllib/algorithms/sac/torch/sac_torch_learner.py b/rllib/algorithms/sac/torch/sac_torch_learner.py index 8d96f22f4730..478970795d85 100644 --- a/rllib/algorithms/sac/torch/sac_torch_learner.py +++ b/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -1,24 +1,25 @@ -import gymnasium as gym from typing import Any, Dict +import gymnasium as gym + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.dqn.torch.dqn_torch_learner import DQNTorchLearner from ray.rllib.algorithms.sac.sac import SACConfig from ray.rllib.algorithms.sac.sac_learner import ( + ACTION_LOG_PROBS, + ACTION_LOG_PROBS_NEXT, + ACTION_PROBS, + ACTION_PROBS_NEXT, LOGPS_KEY, QF_LOSS_KEY, - QF_MEAN_KEY, QF_MAX_KEY, + QF_MEAN_KEY, QF_MIN_KEY, QF_PREDS, + QF_TARGET_NEXT, QF_TWIN_LOSS_KEY, QF_TWIN_PREDS, TD_ERROR_MEAN_KEY, - ACTION_LOG_PROBS, - ACTION_LOG_PROBS_NEXT, - ACTION_PROBS, - ACTION_PROBS_NEXT, - QF_TARGET_NEXT, SACLearner, ) from ray.rllib.core.columns import Columns diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 39f583f5f722..e31ff3999271 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -1,15 +1,16 @@ -import gymnasium as gym -import numpy as np import os +import unittest from pathlib import Path from random import choice -import unittest + +import gymnasium as gym +import numpy as np import ray -from ray.rllib.algorithms.algorithm import Algorithm import ray.rllib.algorithms.dqn as dqn -from ray.rllib.algorithms.bc import BCConfig import ray.rllib.algorithms.ppo as ppo +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.bc import BCConfig from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.core.rl_module.rl_module import RLModuleSpec @@ -615,7 +616,8 @@ def _assert_modules_added( if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index afe48b1117c5..6b6c381f6e56 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -1,17 +1,18 @@ -import gymnasium as gym -from typing import Type import unittest +from typing import Type + +import gymnasium as gym import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.ppo import PPO, PPOConfig from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModule from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, ) +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.utils.test_utils import check @@ -432,7 +433,8 @@ def get_default_rl_module_spec(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_export_checkpoint.py b/rllib/algorithms/tests/test_algorithm_export_checkpoint.py index 116b68399aee..e978dc961b55 100644 --- a/rllib/algorithms/tests/test_algorithm_export_checkpoint.py +++ b/rllib/algorithms/tests/test_algorithm_export_checkpoint.py @@ -1,8 +1,9 @@ -import numpy as np import os import shutil import unittest +import numpy as np + import ray import ray._common from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole @@ -96,7 +97,8 @@ def test_save_appo_multi_agent(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_imports.py b/rllib/algorithms/tests/test_algorithm_imports.py index f528f082e19c..352dd41d9880 100644 --- a/rllib/algorithms/tests/test_algorithm_imports.py +++ b/rllib/algorithms/tests/test_algorithm_imports.py @@ -17,7 +17,8 @@ def test_algo_import(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py index bc13d04567f5..7cbb1ec39269 100644 --- a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py +++ b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py @@ -1,25 +1,25 @@ -import gymnasium as gym -import numpy as np import shutil import tempfile -import tree import unittest +import gymnasium as gym +import numpy as np +import tree + import ray from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.rl_module.multi_rl_module import ( - MultiRLModuleSpec, MultiRLModule, + MultiRLModuleSpec, ) +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole -from ray.rllib.utils.test_utils import check from ray.rllib.utils.numpy import convert_to_numpy - +from ray.rllib.utils.test_utils import check NUM_AGENTS = 2 @@ -329,7 +329,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py index ec58a5376faa..3ede13e17215 100644 --- a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py +++ b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py @@ -1,7 +1,6 @@ import tempfile import unittest - import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.ppo import PPOConfig @@ -10,7 +9,6 @@ from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.test_utils import check - algorithms_and_configs = { "PPO": (PPOConfig().training(train_batch_size=2, minibatch_size=2)) } @@ -228,6 +226,7 @@ def _assert_running_stats_consistency( if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py index 78840a3fe4be..1d0134cc2769 100644 --- a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py +++ b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py @@ -7,7 +7,6 @@ from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.utils.metrics import LEARNER_RESULTS - algorithms_and_configs = { "PPO": (PPOConfig().training(train_batch_size=2, minibatch_size=2)) } @@ -126,6 +125,7 @@ def test_save_and_restore(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_env_runner_failures.py b/rllib/algorithms/tests/test_env_runner_failures.py index 72808d517522..abc3c7d1e3f9 100644 --- a/rllib/algorithms/tests/test_env_runner_failures.py +++ b/rllib/algorithms/tests/test_env_runner_failures.py @@ -1,14 +1,15 @@ +import time +import unittest from collections import defaultdict + import gymnasium as gym import numpy as np -import time -import unittest import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.impala import IMPALAConfig -from ray.rllib.algorithms.sac.sac import SACConfig from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.sac.sac import SACConfig from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.env.multi_agent_env import make_multi_agent diff --git a/rllib/algorithms/tests/test_node_failures.py b/rllib/algorithms/tests/test_node_failures.py index 34e6560a8aae..7e1350024740 100644 --- a/rllib/algorithms/tests/test_node_failures.py +++ b/rllib/algorithms/tests/test_node_failures.py @@ -14,7 +14,6 @@ MODULE_TRAIN_BATCH_SIZE_MEAN, ) - object_store_memory = 10**8 num_nodes = 3 @@ -193,7 +192,8 @@ def _train(self, *, config, iters, min_reward, preempt_freq): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_registry.py b/rllib/algorithms/tests/test_registry.py index 85e8029691ba..534f79327ede 100644 --- a/rllib/algorithms/tests/test_registry.py +++ b/rllib/algorithms/tests/test_registry.py @@ -1,11 +1,11 @@ import unittest from ray.rllib.algorithms.registry import ( + ALGORITHMS, + ALGORITHMS_CLASS_TO_NAME, POLICIES, get_policy_class, get_policy_class_name, - ALGORITHMS_CLASS_TO_NAME, - ALGORITHMS, ) @@ -31,7 +31,8 @@ def test_registered_algorithm_names(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) From 32f085748eecb93b4192aeaffe1160154e48b30e Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Sep 2025 14:48:47 +0530 Subject: [PATCH 2/9] LINT: Enable ruff imports for rllib/core Signed-off-by: Gagandeep Singh --- pyproject.toml | 17 +++++++++++++- rllib/core/__init__.py | 1 - rllib/core/distribution/distribution.py | 6 ++--- .../distribution/torch/torch_distribution.py | 10 ++++---- rllib/core/learner/__init__.py | 1 - rllib/core/learner/differentiable_learner.py | 9 ++++---- .../learner/differentiable_learner_config.py | 6 ++--- rllib/core/learner/learner.py | 23 ++++++++++--------- rllib/core/learner/learner_group.py | 6 ++--- rllib/core/learner/tests/test_learner.py | 13 ++++++----- .../core/learner/tests/test_learner_group.py | 10 ++++---- .../torch/tests/test_torch_learner_compile.py | 3 ++- .../torch/torch_differentiable_learner.py | 1 - rllib/core/learner/torch/torch_learner.py | 10 ++++---- .../core/learner/torch/torch_meta_learner.py | 5 ++-- rllib/core/learner/training_data.py | 2 +- rllib/core/learner/utils.py | 1 - rllib/core/models/base.py | 1 - rllib/core/models/catalog.py | 15 ++++++------ rllib/core/models/configs.py | 6 ++--- rllib/core/models/tests/test_base_models.py | 9 ++++---- rllib/core/models/tests/test_catalog.py | 23 ++++++++++--------- rllib/core/models/tests/test_cnn_encoders.py | 3 ++- .../models/tests/test_cnn_transpose_heads.py | 3 ++- rllib/core/models/tests/test_mlp_encoders.py | 5 ++-- rllib/core/models/tests/test_mlp_heads.py | 5 ++-- .../models/tests/test_recurrent_encoders.py | 3 ++- rllib/core/models/torch/encoder.py | 9 ++++---- rllib/core/models/torch/primitives.py | 2 +- rllib/core/rl_module/__init__.py | 2 +- rllib/core/rl_module/apis/__init__.py | 3 +-- .../apis/self_supervised_loss_api.py | 2 +- .../core/rl_module/apis/target_network_api.py | 1 - rllib/core/rl_module/multi_rl_module.py | 14 +++++------ rllib/core/rl_module/rl_module.py | 22 +++++++++--------- .../rl_module/tests/test_multi_rl_module.py | 7 +++--- .../rl_module/tests/test_rl_module_specs.py | 9 ++++---- .../torch/tests/test_torch_rl_module.py | 5 ++-- rllib/core/rl_module/torch/torch_rl_module.py | 12 +++++----- rllib/core/testing/bc_algorithm.py | 6 ++--- rllib/core/testing/torch/bc_learner.py | 3 ++- rllib/core/testing/torch/bc_module.py | 2 +- 42 files changed, 159 insertions(+), 137 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c8c2219e451f..71247f839036 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,22 @@ afterray = ["psutil", "setproctitle"] "python/ray/__init__.py" = ["I"] "python/ray/dag/__init__.py" = ["I"] "python/ray/air/__init__.py" = ["I"] -"rllib/*" = ["I"] +"rllib/__init__.py" = ["I"] +"rllib/benchmarks/*" = ["I"] +"rllib/connectors/*" = ["I"] +"rllib/evaluation/*" = ["I"] +"rllib/models/*" = ["I"] +"rllib/utils/*" = ["I"] +"rllib/algorithms/*" = ["I"] +# "rllib/core/*" = ["I"] +"rllib/examples/*" = ["I"] +"rllib/offline/*" = ["I"] +"rllib/tests/*" = ["I"] +"rllib/callbacks/*" = ["I"] +"rllib/env/*" = ["I"] +"rllib/execution/*" = ["I"] +"rllib/policy/*" = ["I"] +"rllib/tuned_examples/*" = ["I"] "release/*" = ["I"] # TODO(matthewdeng): Remove this line diff --git a/rllib/core/__init__.py b/rllib/core/__init__.py index bff33528c9af..42404d51e5d3 100644 --- a/rllib/core/__init__.py +++ b/rllib/core/__init__.py @@ -1,6 +1,5 @@ from ray.rllib.core.columns import Columns - DEFAULT_AGENT_ID = "default_agent" DEFAULT_POLICY_ID = "default_policy" # TODO (sven): Change this to "default_module" diff --git a/rllib/core/distribution/distribution.py b/rllib/core/distribution/distribution.py index f9dbf137917c..a5812058713e 100644 --- a/rllib/core/distribution/distribution.py +++ b/rllib/core/distribution/distribution.py @@ -1,11 +1,11 @@ """This is the next version of action distribution base class.""" +import abc from typing import Tuple + import gymnasium as gym -import abc -from ray.rllib.utils.annotations import ExperimentalAPI +from ray.rllib.utils.annotations import ExperimentalAPI, override from ray.rllib.utils.typing import TensorType, Union -from ray.rllib.utils.annotations import override @ExperimentalAPI diff --git a/rllib/core/distribution/torch/torch_distribution.py b/rllib/core/distribution/torch/torch_distribution.py index e234f6f10c08..d0b94828a9c0 100644 --- a/rllib/core/distribution/torch/torch_distribution.py +++ b/rllib/core/distribution/torch/torch_distribution.py @@ -3,18 +3,18 @@ the code. This matches the design pattern of torch distribution which developers may already be familiar with. """ +import abc +from typing import Dict, Iterable, List, Optional + import gymnasium as gym import numpy as np -from typing import Dict, Iterable, List, Optional import tree -import abc - from ray.rllib.core.distribution.distribution import Distribution -from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import MAX_LOG_NN_OUTPUT, MIN_LOG_NN_OUTPUT, SMALL_NUMBER -from ray.rllib.utils.typing import TensorType, Union, Tuple +from ray.rllib.utils.typing import TensorType, Tuple, Union torch, nn = try_import_torch() diff --git a/rllib/core/learner/__init__.py b/rllib/core/learner/__init__.py index 1265532aa05f..8fc450012cde 100644 --- a/rllib/core/learner/__init__.py +++ b/rllib/core/learner/__init__.py @@ -1,7 +1,6 @@ from ray.rllib.core.learner.learner import Learner from ray.rllib.core.learner.learner_group import LearnerGroup - __all__ = [ "Learner", "LearnerGroup", diff --git a/rllib/core/learner/differentiable_learner.py b/rllib/core/learner/differentiable_learner.py index 8d118aacfae8..2d019aaba255 100644 --- a/rllib/core/learner/differentiable_learner.py +++ b/rllib/core/learner/differentiable_learner.py @@ -1,17 +1,18 @@ import abc import logging -import numpy from typing import ( + TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Tuple, - TYPE_CHECKING, Union, ) +import numpy + from ray.rllib.connectors.learner.learner_connector_pipeline import ( LearnerConnectorPipeline, ) @@ -22,19 +23,19 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils import unflatten_dict from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.checkpoints import Checkpointable from ray.rllib.utils.metrics import ( DATASET_NUM_ITERS_TRAINED, DATASET_NUM_ITERS_TRAINED_LIFETIME, + MODULE_TRAIN_BATCH_SIZE_MEAN, NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_MODULE_STEPS_TRAINED, NUM_MODULE_STEPS_TRAINED_LIFETIME, - MODULE_TRAIN_BATCH_SIZE_MEAN, WEIGHTS_SEQ_NO, ) from ray.rllib.utils.metrics.metrics_logger import MetricsLogger diff --git a/rllib/core/learner/differentiable_learner_config.py b/rllib/core/learner/differentiable_learner_config.py index d8b5a134d4aa..a9629cefe715 100644 --- a/rllib/core/learner/differentiable_learner_config.py +++ b/rllib/core/learner/differentiable_learner_config.py @@ -1,12 +1,12 @@ -import gymnasium as gym from dataclasses import dataclass, fields - from typing import Callable, List, Optional, Union +import gymnasium as gym + from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core.learner.differentiable_learner import DifferentiableLearner -from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.typing import DeviceType, ModuleID diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index e2c6578c9173..4cecde67838a 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -1,26 +1,28 @@ import abc -from collections import defaultdict import copy import logging -import numpy import platform -import tree +from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, Callable, Collection, Dict, - List, Hashable, Iterable, + List, Optional, Sequence, Tuple, - TYPE_CHECKING, Union, ) +import numpy +import tree + import ray +from ray._common.deprecation import Deprecated from ray.rllib.connectors.learner.learner_connector_pipeline import ( LearnerConnectorPipeline, ) @@ -31,8 +33,8 @@ DEFAULT_MODULE_ID, ) from ray.rllib.core.learner.training_data import TrainingData -from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, @@ -41,30 +43,29 @@ from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.checkpoints import Checkpointable from ray.rllib.utils.debug import update_global_seed_if_necessary -from ray._common.deprecation import Deprecated from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics import ( ALL_MODULES, DATASET_NUM_ITERS_TRAINED, DATASET_NUM_ITERS_TRAINED_LIFETIME, + MODULE_TRAIN_BATCH_SIZE_MEAN, NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_MODULE_STEPS_TRAINED, NUM_MODULE_STEPS_TRAINED_LIFETIME, - MODULE_TRAIN_BATCH_SIZE_MEAN, WEIGHTS_SEQ_NO, ) from ray.rllib.utils.metrics.metrics_logger import MetricsLogger from ray.rllib.utils.minibatch_utils import ( - MiniBatchDummyIterator, MiniBatchCyclicIterator, + MiniBatchDummyIterator, MiniBatchRayDataIterator, ) from ray.rllib.utils.schedules.scheduler import Scheduler @@ -74,8 +75,8 @@ ModuleID, Optimizer, Param, - ParamRef, ParamDict, + ParamRef, ResultDict, ShouldModuleBeUpdatedFn, StateDict, diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index b1010950e746..7513cbea2431 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -1,8 +1,9 @@ import copy -from functools import partial import itertools import pathlib +from functools import partial from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -11,11 +12,11 @@ Optional, Set, Type, - TYPE_CHECKING, Union, ) import ray +from ray._common.deprecation import Deprecated from ray.rllib.core import ( COMPONENT_LEARNER, COMPONENT_RL_MODULE, @@ -34,7 +35,6 @@ ) from ray.rllib.utils.annotations import override from ray.rllib.utils.checkpoints import Checkpointable -from ray._common.deprecation import Deprecated from ray.rllib.utils.typing import ( EpisodeType, ModuleID, diff --git a/rllib/core/learner/tests/test_learner.py b/rllib/core/learner/tests/test_learner.py index 5c2db2c0c119..0f3574affa08 100644 --- a/rllib/core/learner/tests/test_learner.py +++ b/rllib/core/learner/tests/test_learner.py @@ -1,17 +1,17 @@ -import gymnasium as gym -import numpy as np import tempfile import unittest +import gymnasium as gym +import numpy as np + import ray from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.learner.learner import Learner from ray.rllib.core.testing.testing_learner import BaseTestingAlgorithmConfig - -from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.test_utils import check, get_cartpole_dataset_reader from ray.rllib.utils.metrics import ALL_MODULES +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.test_utils import check, get_cartpole_dataset_reader torch, _ = try_import_torch() @@ -241,7 +241,8 @@ def _check_learner_states(self, framework, learner1, learner2): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/learner/tests/test_learner_group.py b/rllib/core/learner/tests/test_learner_group.py index b79d0453050a..20b6ecd97d32 100644 --- a/rllib/core/learner/tests/test_learner_group.py +++ b/rllib/core/learner/tests/test_learner_group.py @@ -1,24 +1,25 @@ -import gymnasium as gym -import numpy as np import tempfile import unittest + +import gymnasium as gym +import numpy as np import pytest import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.bc import BCConfig from ray.rllib.core import ( - Columns, COMPONENT_LEARNER, COMPONENT_RL_MODULE, DEFAULT_MODULE_ID, + Columns, ) from ray.rllib.core.learner.learner import Learner from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.testing.testing_learner import BaseTestingAlgorithmConfig from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule -from ray.rllib.core.testing.testing_learner import BaseTestingAlgorithmConfig from ray.rllib.env.multi_agent_episode import MultiAgentEpisode from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole @@ -27,7 +28,6 @@ from ray.rllib.utils.test_utils import check from ray.util.timer import _Timer - REMOTE_CONFIGS = { "remote-cpu": AlgorithmConfig.overrides(num_learners=1), "remote-gpu": AlgorithmConfig.overrides(num_learners=1, num_gpus_per_learner=1), diff --git a/rllib/core/learner/torch/tests/test_torch_learner_compile.py b/rllib/core/learner/torch/tests/test_torch_learner_compile.py index 397fc26dbc10..9a6d775297c7 100644 --- a/rllib/core/learner/torch/tests/test_torch_learner_compile.py +++ b/rllib/core/learner/torch/tests/test_torch_learner_compile.py @@ -118,7 +118,8 @@ def test_torch_compile_no_breaks(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/learner/torch/torch_differentiable_learner.py b/rllib/core/learner/torch/torch_differentiable_learner.py index 5ceeaf924cec..baa6a6bc18d3 100644 --- a/rllib/core/learner/torch/torch_differentiable_learner.py +++ b/rllib/core/learner/torch/torch_differentiable_learner.py @@ -1,6 +1,5 @@ import contextlib import logging - from typing import Any, Dict, Optional, Tuple from ray.rllib.algorithms.algorithm_config import ( diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index e49ce42d5168..b1ad8fdbf9eb 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -1,7 +1,8 @@ -from collections import defaultdict import contextlib import logging +from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -9,7 +10,6 @@ Optional, Sequence, Tuple, - TYPE_CHECKING, ) from ray.rllib.algorithms.algorithm_config import ( @@ -17,7 +17,7 @@ TorchCompileWhatToCompile, ) from ray.rllib.core.columns import Columns -from ray.rllib.core.learner.learner import Learner, LR_KEY +from ray.rllib.core.learner.learner import LR_KEY, Learner from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, @@ -33,16 +33,16 @@ ) from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.framework import get_device, try_import_torch from ray.rllib.utils.metrics import ( ALL_MODULES, DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, - NUM_TRAINABLE_PARAMETERS, NUM_NON_TRAINABLE_PARAMETERS, + NUM_TRAINABLE_PARAMETERS, WEIGHTS_SEQ_NO, ) from ray.rllib.utils.numpy import convert_to_numpy diff --git a/rllib/core/learner/torch/torch_meta_learner.py b/rllib/core/learner/torch/torch_meta_learner.py index 284a37ee98dc..0187942af02c 100644 --- a/rllib/core/learner/torch/torch_meta_learner.py +++ b/rllib/core/learner/torch/torch_meta_learner.py @@ -1,10 +1,9 @@ import contextlib import logging -import ray - from itertools import cycle from typing import Any, Dict, List, Optional, Tuple +import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core import ALL_MODULES from ray.rllib.core.learner.learner import Learner @@ -16,9 +15,9 @@ from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics import ( diff --git a/rllib/core/learner/training_data.py b/rllib/core/learner/training_data.py index 5561d075d73e..dc2bb1bd3271 100644 --- a/rllib/core/learner/training_data.py +++ b/rllib/core/learner/training_data.py @@ -1,5 +1,5 @@ -from collections import defaultdict import dataclasses +from collections import defaultdict from typing import List, Optional import tree # pip install dm_tree diff --git a/rllib/core/learner/utils.py b/rllib/core/learner/utils.py index 7682725cf9a2..a511dd71c337 100644 --- a/rllib/core/learner/utils.py +++ b/rllib/core/learner/utils.py @@ -4,7 +4,6 @@ from ray.rllib.utils.typing import NetworkType from ray.util import PublicAPI - torch, _ = try_import_torch() diff --git a/rllib/core/models/base.py b/rllib/core/models/base.py index d3ad4dd18963..337b35421263 100644 --- a/rllib/core/models/base.py +++ b/rllib/core/models/base.py @@ -1,7 +1,6 @@ import abc from typing import List, Optional, Tuple, Union - from ray.rllib.core.columns import Columns from ray.rllib.core.models.configs import ModelConfig from ray.rllib.core.models.specs.specs_base import Spec diff --git a/rllib/core/models/catalog.py b/rllib/core/models/catalog.py index e4f9abe53b88..ce3b42592c0b 100644 --- a/rllib/core/models/catalog.py +++ b/rllib/core/models/catalog.py @@ -8,26 +8,25 @@ import tree from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple +from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.core.distribution.distribution import Distribution from ray.rllib.core.models.base import Encoder from ray.rllib.core.models.configs import ( CNNEncoderConfig, MLPEncoderConfig, + ModelConfig, RecurrentEncoderConfig, ) -from ray.rllib.core.models.configs import ModelConfig from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.core.distribution.distribution import Distribution -from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor +from ray.rllib.models.preprocessors import Preprocessor, get_preprocessor from ray.rllib.models.utils import get_filter_config -from ray._common.deprecation import deprecation_warning, DEPRECATED_VALUE -from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.spaces.space_utils import flatten_space -from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space from ray.rllib.utils.annotations import ( OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, ) +from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.utils.spaces.space_utils import flatten_space, get_base_struct_from_space class Catalog: diff --git a/rllib/core/models/configs.py b/rllib/core/models/configs.py index 00acf8ef4132..db2f7cdbe681 100644 --- a/rllib/core/models/configs.py +++ b/rllib/core/models/configs.py @@ -1,7 +1,7 @@ import abc -from dataclasses import dataclass, field import functools -from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -14,7 +14,7 @@ from ray.rllib.utils.annotations import ExperimentalAPI if TYPE_CHECKING: - from ray.rllib.core.models.base import Model, Encoder + from ray.rllib.core.models.base import Encoder, Model @ExperimentalAPI diff --git a/rllib/core/models/tests/test_base_models.py b/rllib/core/models/tests/test_base_models.py index d10390c8f50d..5490d8b23634 100644 --- a/rllib/core/models/tests/test_base_models.py +++ b/rllib/core/models/tests/test_base_models.py @@ -3,13 +3,13 @@ import gymnasium as gym +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core.models.configs import ModelConfig from ray.rllib.core.models.torch.base import TorchModel from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_utils import _dynamo_is_available torch, nn = try_import_torch() @@ -102,7 +102,8 @@ def test_torch_compile_forwards(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/tests/test_catalog.py b/rllib/core/models/tests/test_catalog.py index e3e14f93aea5..d6057d1a55fe 100644 --- a/rllib/core/models/tests/test_catalog.py +++ b/rllib/core/models/tests/test_catalog.py @@ -1,21 +1,27 @@ import dataclasses -from collections import namedtuple import functools import itertools import unittest +from collections import namedtuple import gymnasium as gym -from gymnasium.spaces import Box, Discrete, Dict, Tuple, MultiDiscrete import numpy as np import tree +from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core.columns import Columns +from ray.rllib.core.distribution.torch.torch_distribution import ( + TorchCategorical, + TorchDiagGaussian, + TorchMultiCategorical, + TorchMultiDistribution, +) from ray.rllib.core.models.base import ( - Encoder, ENCODER_OUT, + Encoder, ) from ray.rllib.core.models.catalog import ( Catalog, @@ -23,20 +29,14 @@ _multi_categorical_dist_partial_helper, ) from ray.rllib.core.models.configs import ( + CNNEncoderConfig, MLPEncoderConfig, ModelConfig, - CNNEncoderConfig, ) from ray.rllib.core.models.torch.base import TorchModel from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.models import MODEL_DEFAULTS -from ray.rllib.core.distribution.torch.torch_distribution import ( - TorchCategorical, - TorchDiagGaussian, - TorchMultiCategorical, - TorchMultiDistribution, -) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space @@ -397,7 +397,8 @@ def _determine_components(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/tests/test_cnn_encoders.py b/rllib/core/models/tests/test_cnn_encoders.py index d7b344aba375..f73a7f1b3016 100644 --- a/rllib/core/models/tests/test_cnn_encoders.py +++ b/rllib/core/models/tests/test_cnn_encoders.py @@ -105,7 +105,8 @@ def test_cnn_encoders_valid_padding(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/tests/test_cnn_transpose_heads.py b/rllib/core/models/tests/test_cnn_transpose_heads.py index 3248ce17b24e..057073403566 100644 --- a/rllib/core/models/tests/test_cnn_transpose_heads.py +++ b/rllib/core/models/tests/test_cnn_transpose_heads.py @@ -100,7 +100,8 @@ def test_cnn_transpose_heads(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/tests/test_mlp_encoders.py b/rllib/core/models/tests/test_mlp_encoders.py index 96b5fc45dbe3..f69fc85552fa 100644 --- a/rllib/core/models/tests/test_mlp_encoders.py +++ b/rllib/core/models/tests/test_mlp_encoders.py @@ -1,8 +1,8 @@ import itertools import unittest -from ray.rllib.core.models.configs import MLPEncoderConfig from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.models.configs import MLPEncoderConfig from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.test_utils import ModelChecker @@ -80,7 +80,8 @@ def test_mlp_encoders(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/tests/test_mlp_heads.py b/rllib/core/models/tests/test_mlp_heads.py index d40f9880a5af..d0af874e46b0 100644 --- a/rllib/core/models/tests/test_mlp_heads.py +++ b/rllib/core/models/tests/test_mlp_heads.py @@ -1,7 +1,7 @@ import itertools import unittest -from ray.rllib.core.models.configs import MLPHeadConfig, FreeLogStdMLPHeadConfig +from ray.rllib.core.models.configs import FreeLogStdMLPHeadConfig, MLPHeadConfig from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.test_utils import ModelChecker @@ -85,7 +85,8 @@ def test_mlp_heads(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/tests/test_recurrent_encoders.py b/rllib/core/models/tests/test_recurrent_encoders.py index 3ac411bc0945..0b87e8a3ed27 100644 --- a/rllib/core/models/tests/test_recurrent_encoders.py +++ b/rllib/core/models/tests/test_recurrent_encoders.py @@ -147,7 +147,8 @@ def test_lstm_encoders(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/torch/encoder.py b/rllib/core/models/torch/encoder.py index 82812e43fc61..ea444f65d829 100644 --- a/rllib/core/models/torch/encoder.py +++ b/rllib/core/models/torch/encoder.py @@ -2,12 +2,13 @@ from ray.rllib.core.columns import Columns from ray.rllib.core.models.base import ( - Encoder, + ENCODER_OUT, ActorCriticEncoder, + Encoder, + Model, StatefulActorCriticEncoder, - ENCODER_OUT, + tokenize, ) -from ray.rllib.core.models.base import Model, tokenize from ray.rllib.core.models.configs import ( ActorCriticEncoderConfig, CNNEncoderConfig, @@ -15,7 +16,7 @@ RecurrentEncoderConfig, ) from ray.rllib.core.models.torch.base import TorchModel -from ray.rllib.core.models.torch.primitives import TorchMLP, TorchCNN +from ray.rllib.core.models.torch.primitives import TorchCNN, TorchMLP from ray.rllib.models.utils import get_initializer_fn from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch diff --git a/rllib/core/models/torch/primitives.py b/rllib/core/models/torch/primitives.py index 9c4e55743510..11ee43167b5d 100644 --- a/rllib/core/models/torch/primitives.py +++ b/rllib/core/models/torch/primitives.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Union, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union from ray.rllib.core.models.torch.utils import Stride2D from ray.rllib.models.torch.misc import ( diff --git a/rllib/core/rl_module/__init__.py b/rllib/core/rl_module/__init__.py index 490cd7942947..4508636d4c2a 100644 --- a/rllib/core/rl_module/__init__.py +++ b/rllib/core/rl_module/__init__.py @@ -1,11 +1,11 @@ import logging import re -from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, ) +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.util import log_once from ray.util.annotations import DeveloperAPI diff --git a/rllib/core/rl_module/apis/__init__.py b/rllib/core/rl_module/apis/__init__.py index 4e51e91a1b11..d9d89bd07a23 100644 --- a/rllib/core/rl_module/apis/__init__.py +++ b/rllib/core/rl_module/apis/__init__.py @@ -2,12 +2,11 @@ from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI from ray.rllib.core.rl_module.apis.self_supervised_loss_api import SelfSupervisedLossAPI from ray.rllib.core.rl_module.apis.target_network_api import ( - TargetNetworkAPI, TARGET_NETWORK_ACTION_DIST_INPUTS, + TargetNetworkAPI, ) from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI - __all__ = [ "InferenceOnlyAPI", "QNetAPI", diff --git a/rllib/core/rl_module/apis/self_supervised_loss_api.py b/rllib/core/rl_module/apis/self_supervised_loss_api.py index 6f1785a426c9..c907a896b01a 100644 --- a/rllib/core/rl_module/apis/self_supervised_loss_api.py +++ b/rllib/core/rl_module/apis/self_supervised_loss_api.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict from ray.rllib.utils.typing import ModuleID, TensorType from ray.util.annotations import PublicAPI diff --git a/rllib/core/rl_module/apis/target_network_api.py b/rllib/core/rl_module/apis/target_network_api.py index d1615edff1e0..a368d4f7c6d3 100644 --- a/rllib/core/rl_module/apis/target_network_api.py +++ b/rllib/core/rl_module/apis/target_network_api.py @@ -4,7 +4,6 @@ from ray.rllib.utils.typing import NetworkType from ray.util.annotations import PublicAPI - TARGET_NETWORK_ACTION_DIST_INPUTS = "target_network_action_dist_inputs" diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index 5ba41d10931c..d9d04b58866a 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -20,23 +20,23 @@ import gymnasium as gym +from ray._common.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.utils import force_list from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, + override, ) from ray.rllib.utils.checkpoints import Checkpointable -from ray._common.deprecation import ( - Deprecated, - DEPRECATED_VALUE, - deprecation_warning, -) from ray.rllib.utils.serialization import ( + deserialize_type, gym_space_from_dict, gym_space_to_dict, serialize_type, - deserialize_type, ) from ray.rllib.utils.typing import ModuleID, StateDict, T from ray.util.annotations import PublicAPI diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index eeb75a1cd680..7e63c9cd9662 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -1,40 +1,40 @@ import abc import dataclasses -from dataclasses import dataclass, field import logging -from typing import Any, Collection, Dict, Optional, Type, TYPE_CHECKING, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Collection, Dict, Optional, Type, Union import gymnasium as gym +from ray._common.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) from ray.rllib.core import DEFAULT_MODULE_ID -from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.core.distribution.distribution import Distribution +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, + override, ) from ray.rllib.utils.checkpoints import Checkpointable -from ray._common.deprecation import ( - Deprecated, - DEPRECATED_VALUE, - deprecation_warning, -) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.serialization import ( + deserialize_type, gym_space_from_dict, gym_space_to_dict, serialize_type, - deserialize_type, ) from ray.rllib.utils.typing import StateDict from ray.util.annotations import PublicAPI if TYPE_CHECKING: + from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, ) - from ray.rllib.core.models.catalog import Catalog logger = logging.getLogger("ray.rllib") torch, _ = try_import_torch() diff --git a/rllib/core/rl_module/tests/test_multi_rl_module.py b/rllib/core/rl_module/tests/test_multi_rl_module.py index 898556faf206..98dc6b3978b2 100644 --- a/rllib/core/rl_module/tests/test_multi_rl_module.py +++ b/rllib/core/rl_module/tests/test_multi_rl_module.py @@ -2,10 +2,10 @@ import unittest from ray.rllib.core import DEFAULT_MODULE_ID -from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule -from ray.rllib.examples.rl_modules.classes.vpg_torch_rlm import VPGTorchRLModule +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.multi_agent_env import make_multi_agent +from ray.rllib.examples.rl_modules.classes.vpg_torch_rlm import VPGTorchRLModule from ray.rllib.utils.test_utils import check @@ -203,7 +203,8 @@ def test_save_to_path_and_from_checkpoint(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/rl_module/tests/test_rl_module_specs.py b/rllib/core/rl_module/tests/test_rl_module_specs.py index 5a7904e4d10e..ede405d2bd6b 100644 --- a/rllib/core/rl_module/tests/test_rl_module_specs.py +++ b/rllib/core/rl_module/tests/test_rl_module_specs.py @@ -3,18 +3,18 @@ import gymnasium as gym import numpy as np -from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, ) +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.examples.rl_modules.classes.vpg_torch_rlm import VPGTorchRLModule from ray.rllib.examples.rl_modules.classes.vpg_using_shared_encoder_rlm import ( SHARED_ENCODER_ID, SharedEncoder, - VPGPolicyAfterSharedEncoder, VPGMultiRLModuleWithSharedEncoder, + VPGPolicyAfterSharedEncoder, ) -from ray.rllib.examples.rl_modules.classes.vpg_torch_rlm import VPGTorchRLModule class TestRLModuleSpecs(unittest.TestCase): @@ -235,7 +235,8 @@ def test_update_specs_multi_agent(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/rl_module/torch/tests/test_torch_rl_module.py b/rllib/core/rl_module/torch/tests/test_torch_rl_module.py index b1104f21f036..b4123b83028c 100644 --- a/rllib/core/rl_module/torch/tests/test_torch_rl_module.py +++ b/rllib/core/rl_module/torch/tests/test_torch_rl_module.py @@ -1,6 +1,6 @@ +import gc import tempfile import unittest -import gc import gymnasium as gym import torch @@ -155,9 +155,10 @@ def get_memory_usage_cuda(): if __name__ == "__main__": - import pytest import sys + import pytest + # One can specify the specific TestCase class to run. # None for all unittest.TestCase classes in this file. class_ = sys.argv[1] if len(sys.argv) > 1 else None diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 236e35b4f48b..b35c9a6572fa 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -1,22 +1,22 @@ -from typing import Any, Collection, Dict, Optional, Union, Type +from typing import Any, Collection, Dict, Optional, Type, Union import gymnasium as gym from packaging import version -from ray.rllib.core.rl_module.apis import InferenceOnlyAPI -from ray.rllib.core.rl_module.rl_module import RLModule -from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig from ray.rllib.core.distribution.torch.torch_distribution import ( TorchCategorical, TorchDiagGaussian, TorchDistribution, ) -from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic, override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import ( - convert_to_torch_tensor, TORCH_COMPILE_REQUIRED_VERSION, + convert_to_torch_tensor, ) from ray.rllib.utils.typing import StateDict diff --git a/rllib/core/testing/bc_algorithm.py b/rllib/core/testing/bc_algorithm.py index 950f2aec87ee..b5768e309f69 100644 --- a/rllib/core/testing/bc_algorithm.py +++ b/rllib/core/testing/bc_algorithm.py @@ -5,10 +5,10 @@ """ from ray.rllib.algorithms import Algorithm, AlgorithmConfig -from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 -from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule -from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner +from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import ResultDict diff --git a/rllib/core/testing/torch/bc_learner.py b/rllib/core/testing/torch/bc_learner.py index 1c12aee7a1ee..6e1963a93038 100644 --- a/rllib/core/testing/torch/bc_learner.py +++ b/rllib/core/testing/torch/bc_learner.py @@ -1,5 +1,6 @@ +from typing import TYPE_CHECKING, Any, Dict + import torch -from typing import Any, Dict, TYPE_CHECKING from ray.rllib.core.columns import Columns from ray.rllib.core.learner.torch.torch_learner import TorchLearner diff --git a/rllib/core/testing/torch/bc_module.py b/rllib/core/testing/torch/bc_module.py index b3fc35556db6..afaedd7ae581 100644 --- a/rllib/core/testing/torch/bc_module.py +++ b/rllib/core/testing/torch/bc_module.py @@ -1,9 +1,9 @@ from typing import Any, Dict from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.distribution.torch.torch_distribution import TorchCategorical from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule +from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch From 0215f361a600f8d6957fe5625da3d691f7145b25 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Sep 2025 14:56:14 +0530 Subject: [PATCH 3/9] LINT: Enable ruff imports for offline, tests, callbacks and env in rllib Signed-off-by: Gagandeep Singh --- pyproject.toml | 17 +++++++++++++- rllib/callbacks/callbacks.py | 2 +- .../tests/test_callbacks_old_api_stack.py | 5 +++-- .../tests/test_callbacks_on_algorithm.py | 5 +++-- .../tests/test_callbacks_on_env_runner.py | 7 +++--- rllib/callbacks/tests/test_multicallback.py | 4 +++- rllib/env/__init__.py | 6 ++--- rllib/env/base_env.py | 5 +++-- rllib/env/env_runner.py | 4 ++-- rllib/env/env_runner_group.py | 19 ++++++++-------- rllib/env/external/__init__.py | 3 +-- ...nv_runner_server_for_external_inference.py | 8 +++---- rllib/env/external/rllink.py | 2 +- rllib/env/external_env.py | 9 ++++---- rllib/env/external_multi_agent_env.py | 5 +++-- rllib/env/multi_agent_env.py | 9 ++++---- rllib/env/multi_agent_env_runner.py | 16 +++++++------- rllib/env/multi_agent_episode.py | 6 ++--- rllib/env/policy_client.py | 12 +++++----- rllib/env/remote_base_env.py | 11 +++++----- rllib/env/single_agent_env_runner.py | 12 +++++----- rllib/env/single_agent_episode.py | 8 +++---- rllib/env/tests/test_env_runner_group.py | 3 ++- .../tests/test_infinite_lookback_buffer.py | 3 ++- rllib/env/tests/test_multi_agent_env.py | 12 +++++----- .../env/tests/test_multi_agent_env_runner.py | 4 ++-- rllib/env/tests/test_multi_agent_episode.py | 9 ++++---- .../env/tests/test_single_agent_env_runner.py | 3 ++- rllib/env/tests/test_single_agent_episode.py | 7 +++--- rllib/env/utils/__init__.py | 1 - rllib/env/utils/infinite_lookback_buffer.py | 2 +- rllib/env/vector/registration.py | 4 ++-- .../env/vector/sync_vector_multi_agent_env.py | 6 ++--- rllib/env/vector/vector_multi_agent_env.py | 3 +-- rllib/env/vector_env.py | 7 +++--- rllib/env/wrappers/atari_wrappers.py | 7 +++--- rllib/env/wrappers/dm_env_wrapper.py | 3 +-- rllib/env/wrappers/group_agents_wrapper.py | 3 ++- rllib/env/wrappers/open_spiel.py | 2 +- .../tests/test_group_agents_wrapper.py | 3 ++- rllib/env/wrappers/unity3d_env.py | 7 +++--- rllib/offline/__init__.py | 9 ++++---- rllib/offline/d4rl_reader.py | 5 +++-- rllib/offline/dataset_reader.py | 11 +++++----- rllib/offline/dataset_writer.py | 4 ++-- rllib/offline/estimators/__init__.py | 6 ++--- rllib/offline/estimators/direct_method.py | 11 +++++----- rllib/offline/estimators/doubly_robust.py | 22 +++++++++---------- rllib/offline/estimators/fqe_torch_model.py | 12 +++++----- .../offline/estimators/importance_sampling.py | 11 +++++----- .../estimators/off_policy_estimator.py | 16 +++++++------- .../estimators/tests/test_dm_learning.py | 3 ++- .../estimators/tests/test_dr_learning.py | 3 ++- rllib/offline/estimators/tests/test_ope.py | 13 ++++++----- .../offline/estimators/tests/test_ope_math.py | 13 ++++++----- rllib/offline/estimators/tests/utils.py | 5 +++-- .../weighted_importance_sampling.py | 14 ++++++------ rllib/offline/feature_importance.py | 8 +++---- rllib/offline/input_reader.py | 9 ++++---- rllib/offline/io_context.py | 4 ++-- rllib/offline/is_estimator.py | 2 +- rllib/offline/json_reader.py | 14 ++++++------ rllib/offline/json_writer.py | 16 ++++++++------ rllib/offline/mixed_input.py | 5 +++-- rllib/offline/off_policy_estimator.py | 2 +- rllib/offline/offline_data.py | 13 +++++------ rllib/offline/offline_env_runner.py | 9 ++++---- rllib/offline/offline_evaluation_runner.py | 5 ++--- .../offline_evaluation_runner_group.py | 4 ++-- rllib/offline/offline_evaluation_utils.py | 7 +++--- rllib/offline/offline_evaluator.py | 5 ++--- .../offline_policy_evaluation_runner.py | 16 +++++++------- rllib/offline/offline_prelearner.py | 10 ++++----- rllib/offline/output_writer.py | 2 +- rllib/offline/resource.py | 3 ++- rllib/offline/shuffled_input.py | 2 +- rllib/offline/tests/test_dataset_reader.py | 8 +++---- .../offline/tests/test_feature_importance.py | 5 +++-- rllib/offline/tests/test_offline_data.py | 7 +++--- .../offline/tests/test_offline_env_runner.py | 6 +++-- .../tests/test_offline_evaluation_runner.py | 9 ++++---- .../test_offline_evaluation_runner_group.py | 9 ++++---- .../offline/tests/test_offline_prelearner.py | 9 ++++---- rllib/offline/wis_estimator.py | 2 +- rllib/tests/conftest.py | 7 +++--- rllib/tests/run_regression_tests.py | 7 +++--- rllib/tests/test_catalog.py | 10 +++++---- rllib/tests/test_dependency_torch.py | 1 - rllib/tests/test_local.py | 3 ++- rllib/tests/test_lstm.py | 6 +++-- .../tests/test_nn_framework_import_errors.py | 1 + rllib/tests/test_pettingzoo_env.py | 7 +++--- rllib/tests/test_placement_groups.py | 5 +++-- rllib/tests/test_ray_client.py | 2 +- rllib/tests/test_telemetry.py | 3 +-- rllib/tests/test_timesteps.py | 6 +++-- 96 files changed, 361 insertions(+), 310 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c8c2219e451f..4cf595ed73cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,22 @@ afterray = ["psutil", "setproctitle"] "python/ray/__init__.py" = ["I"] "python/ray/dag/__init__.py" = ["I"] "python/ray/air/__init__.py" = ["I"] -"rllib/*" = ["I"] +"rllib/__init__.py" = ["I"] +"rllib/benchmarks/*" = ["I"] +"rllib/connectors/*" = ["I"] +"rllib/evaluation/*" = ["I"] +"rllib/models/*" = ["I"] +"rllib/utils/*" = ["I"] +"rllib/algorithms/*" = ["I"] +"rllib/core/*" = ["I"] +"rllib/examples/*" = ["I"] +# "rllib/offline/*" = ["I"] +# "rllib/tests/*" = ["I"] +# "rllib/callbacks/*" = ["I"] +# "rllib/env/*" = ["I"] +"rllib/execution/*" = ["I"] +"rllib/policy/*" = ["I"] +"rllib/tuned_examples/*" = ["I"] "release/*" = ["I"] # TODO(matthewdeng): Remove this line diff --git a/rllib/callbacks/callbacks.py b/rllib/callbacks/callbacks.py index 208684f780fc..fb4107872ad7 100644 --- a/rllib/callbacks/callbacks.py +++ b/rllib/callbacks/callbacks.py @@ -17,9 +17,9 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import ( OldAPIStack, - override, OverrideToImplementCustomLogic, PublicAPI, + override, ) from ray.rllib.utils.metrics.metrics_logger import MetricsLogger from ray.rllib.utils.typing import AgentID, EnvType, EpisodeType, PolicyID diff --git a/rllib/callbacks/tests/test_callbacks_old_api_stack.py b/rllib/callbacks/tests/test_callbacks_old_api_stack.py index d836360d1741..ccf05cdf425e 100644 --- a/rllib/callbacks/tests/test_callbacks_old_api_stack.py +++ b/rllib/callbacks/tests/test_callbacks_old_api_stack.py @@ -1,5 +1,5 @@ -from collections import Counter import unittest +from collections import Counter import ray from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks @@ -209,7 +209,8 @@ def test_on_episode_created(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/callbacks/tests/test_callbacks_on_algorithm.py b/rllib/callbacks/tests/test_callbacks_on_algorithm.py index 9cfbb1f5658c..db7646c3090c 100644 --- a/rllib/callbacks/tests/test_callbacks_on_algorithm.py +++ b/rllib/callbacks/tests/test_callbacks_on_algorithm.py @@ -4,8 +4,8 @@ import ray from ray import tune -from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.examples.envs.classes.cartpole_crashing import CartPoleCrashing from ray.rllib.utils.test_utils import check @@ -108,7 +108,8 @@ def test_on_init_and_checkpoint_loaded(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/callbacks/tests/test_callbacks_on_env_runner.py b/rllib/callbacks/tests/test_callbacks_on_env_runner.py index 577a02dcb0a4..6bc3a82ff247 100644 --- a/rllib/callbacks/tests/test_callbacks_on_env_runner.py +++ b/rllib/callbacks/tests/test_callbacks_on_env_runner.py @@ -1,12 +1,12 @@ -from collections import Counter import unittest +from collections import Counter import gymnasium as gym import ray from ray import tune -from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.env.env_runner import EnvRunner from ray.rllib.env.vector.vector_multi_agent_env import VectorMultiAgentEnv from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole @@ -237,7 +237,8 @@ def test_tune_trial_id_visible_in_callbacks(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/callbacks/tests/test_multicallback.py b/rllib/callbacks/tests/test_multicallback.py index 2cd56ba33c7a..208b0cfa688d 100644 --- a/rllib/callbacks/tests/test_multicallback.py +++ b/rllib/callbacks/tests/test_multicallback.py @@ -1,4 +1,5 @@ import unittest + import ray from ray.rllib.algorithms import PPOConfig from ray.rllib.callbacks.callbacks import RLlibCallback @@ -141,7 +142,8 @@ def test_single_callback_validation_error(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/__init__.py b/rllib/env/__init__.py index ca9de7949565..247096892b30 100644 --- a/rllib/env/__init__.py +++ b/rllib/env/__init__.py @@ -6,12 +6,10 @@ from ray.rllib.env.policy_client import PolicyClient from ray.rllib.env.remote_base_env import RemoteBaseEnv from ray.rllib.env.vector_env import VectorEnv - -from ray.rllib.env.wrappers.dm_env_wrapper import DMEnv from ray.rllib.env.wrappers.dm_control_wrapper import DMCEnv +from ray.rllib.env.wrappers.dm_env_wrapper import DMEnv from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper -from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv -from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv +from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv, PettingZooEnv from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv INPUT_ENV_SPACES = "__env__" diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index c67c642e4763..1467472eead8 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -1,7 +1,8 @@ import logging -from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING, Union, Set +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union import gymnasium as gym + import ray from ray.rllib.utils.annotations import OldAPIStack from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict @@ -368,9 +369,9 @@ def convert_to_base_env( The resulting BaseEnv object. """ - from ray.rllib.env.remote_base_env import RemoteBaseEnv from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv + from ray.rllib.env.remote_base_env import RemoteBaseEnv from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper if remote_envs and num_envs == 1: diff --git a/rllib/env/env_runner.py b/rllib/env/env_runner.py index 04191cf28900..de176bf582ff 100644 --- a/rllib/env/env_runner.py +++ b/rllib/env/env_runner.py @@ -1,6 +1,6 @@ import abc import logging -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import gymnasium as gym import tree # pip install dm_tree @@ -15,7 +15,7 @@ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger from ray.rllib.utils.torch_utils import convert_to_torch_tensor from ray.rllib.utils.typing import StateDict, TensorType -from ray.util.annotations import PublicAPI, DeveloperAPI +from ray.util.annotations import DeveloperAPI, PublicAPI if TYPE_CHECKING: from ray.rllib.algorithms.algorithm_config import AlgorithmConfig diff --git a/rllib/env/env_runner_group.py b/rllib/env/env_runner_group.py index 6974c1d30187..bd96cbe0ee90 100644 --- a/rllib/env/env_runner_group.py +++ b/rllib/env/env_runner_group.py @@ -1,9 +1,9 @@ import functools -import gymnasium as gym -import logging import importlib.util +import logging import os from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -12,12 +12,18 @@ Optional, Tuple, Type, - TYPE_CHECKING, TypeVar, Union, ) +import gymnasium as gym + import ray +from ray._common.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) from ray.actor import ActorHandle from ray.exceptions import RayActorError from ray.rllib.core import ( @@ -29,19 +35,14 @@ from ray.rllib.core.learner import LearnerGroup from ray.rllib.core.rl_module import validate_module_id from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.offline import get_dataset_and_shards from ray.rllib.policy.policy import Policy, PolicyState from ray.rllib.utils.actor_manager import FaultTolerantActorManager from ray.rllib.utils.annotations import OldAPIStack -from ray._common.deprecation import ( - Deprecated, - deprecation_warning, - DEPRECATED_VALUE, -) from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO from ray.rllib.utils.typing import ( diff --git a/rllib/env/external/__init__.py b/rllib/env/external/__init__.py index 343adb18b3c5..9264f89c978e 100644 --- a/rllib/env/external/__init__.py +++ b/rllib/env/external/__init__.py @@ -1,10 +1,9 @@ from ray.rllib.env.external.rllink import ( + RLlink, get_rllink_message, send_rllink_message, - RLlink, ) - __all__ = [ "get_rllink_message", "send_rllink_message", diff --git a/rllib/env/external/env_runner_server_for_external_inference.py b/rllib/env/external/env_runner_server_for_external_inference.py index 36bb2723c27b..59b55a590fcb 100644 --- a/rllib/env/external/env_runner_server_for_external_inference.py +++ b/rllib/env/external/env_runner_server_for_external_inference.py @@ -1,8 +1,8 @@ -from collections import defaultdict import pickle import socket import threading import time +from collections import defaultdict from typing import Collection, DefaultDict, List, Optional, Union from ray.rllib.core import ( @@ -12,13 +12,13 @@ ) from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.env.env_runner import EnvRunner -from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner -from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.env.external.rllink import ( + RLlink, get_rllink_message, send_rllink_message, - RLlink, ) +from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.annotations import override from ray.rllib.utils.checkpoints import Checkpointable from ray.rllib.utils.framework import try_import_torch diff --git a/rllib/env/external/rllink.py b/rllib/env/external/rllink.py index dfb72bda97b6..12ec0c3c5f04 100644 --- a/rllib/env/external/rllink.py +++ b/rllib/env/external/rllink.py @@ -1,10 +1,10 @@ from enum import Enum + from packaging.version import Version from ray.rllib.utils.checkpoints import try_import_msgpack from ray.util.annotations import DeveloperAPI - msgpack = None diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index 783ae256cb99..c9aae38f1852 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -1,11 +1,13 @@ -import gymnasium as gym import queue import threading import uuid -from typing import Callable, Tuple, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Optional, Tuple + +import gymnasium as gym +from ray._common.deprecation import deprecation_warning from ray.rllib.env.base_env import BaseEnv -from ray.rllib.utils.annotations import override, OldAPIStack +from ray.rllib.utils.annotations import OldAPIStack, override from ray.rllib.utils.typing import ( EnvActionType, EnvInfoDict, @@ -13,7 +15,6 @@ EnvType, MultiEnvDict, ) -from ray._common.deprecation import deprecation_warning if TYPE_CHECKING: from ray.rllib.models.preprocessors import Preprocessor diff --git a/rllib/env/external_multi_agent_env.py b/rllib/env/external_multi_agent_env.py index 1350d5c7c356..6c7e333700f6 100644 --- a/rllib/env/external_multi_agent_env.py +++ b/rllib/env/external_multi_agent_env.py @@ -1,9 +1,10 @@ import uuid -import gymnasium as gym from typing import Optional -from ray.rllib.utils.annotations import override, OldAPIStack +import gymnasium as gym + from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode +from ray.rllib.utils.annotations import OldAPIStack, override from ray.rllib.utils.typing import MultiAgentDict diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 327abe50779e..5e2a2de96435 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -1,14 +1,14 @@ import copy -import gymnasium as gym import logging -from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union +import gymnasium as gym import numpy as np +from ray._common.deprecation import Deprecated from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import Deprecated from ray.rllib.utils.typing import ( AgentID, EnvCreator, @@ -250,8 +250,7 @@ class MyMultiAgentEnv(MultiAgentEnv): """ - from ray.rllib.env.wrappers.group_agents_wrapper import \ - GroupAgentsWrapper + from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper return GroupAgentsWrapper(self, groups, obs_space, act_space) # __grouping_doc_end__ diff --git a/rllib/env/multi_agent_env_runner.py b/rllib/env/multi_agent_env_runner.py index f1c5922eab8c..3b2b0ec9e939 100644 --- a/rllib/env/multi_agent_env_runner.py +++ b/rllib/env/multi_agent_env_runner.py @@ -1,13 +1,14 @@ -from collections import defaultdict -from functools import partial -import math import logging +import math import time +from collections import defaultdict +from functools import partial from typing import Collection, DefaultDict, Dict, List, Optional, Union import gymnasium as gym import ray +from ray._common.deprecation import Deprecated from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.callbacks.utils import make_callback from ray.rllib.core import ( @@ -17,18 +18,17 @@ ) from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec -from ray.rllib.env import INPUT_ENV_SPACES, INPUT_ENV_SINGLE_SPACES +from ray.rllib.env import INPUT_ENV_SINGLE_SPACES, INPUT_ENV_SPACES from ray.rllib.env.env_context import EnvContext -from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE +from ray.rllib.env.env_runner import ENV_STEP_FAILURE, EnvRunner from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_episode import MultiAgentEpisode -from ray.rllib.env.vector.vector_multi_agent_env import VectorMultiAgentEnv -from ray.rllib.env.vector.registration import make_vec from ray.rllib.env.utils import _gym_env_creator +from ray.rllib.env.vector.registration import make_vec +from ray.rllib.env.vector.vector_multi_agent_env import VectorMultiAgentEnv from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.checkpoints import Checkpointable -from ray._common.deprecation import Deprecated from ray.rllib.utils.framework import get_device, try_import_torch from ray.rllib.utils.metrics import ( ENV_TO_MODULE_CONNECTOR, diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 3e21bac0cb4e..c2b75efec5d3 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -1,6 +1,7 @@ -from collections import defaultdict import copy import time +import uuid +from collections import defaultdict from typing import ( Any, Callable, @@ -12,15 +13,14 @@ Set, Union, ) -import uuid import gymnasium as gym +from ray._common.deprecation import Deprecated from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import force_list -from ray._common.deprecation import Deprecated from ray.rllib.utils.error import MultiAgentEnvError from ray.rllib.utils.spaces.space_utils import batch from ray.rllib.utils.typing import AgentID, ModuleID, MultiAgentDict diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index e4e2e9ad8a62..9ae02d5714b7 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -1,24 +1,24 @@ import logging import threading import time -from typing import Union, Optional +from typing import Optional, Union import ray.cloudpickle as pickle + +# Backward compatibility. +from ray.rllib.env.external.rllink import RLlink as Commands from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import OldAPIStack from ray.rllib.utils.typing import ( - MultiAgentDict, + EnvActionType, EnvInfoDict, EnvObsType, - EnvActionType, + MultiAgentDict, ) -# Backward compatibility. -from ray.rllib.env.external.rllink import RLlink as Commands - logger = logging.getLogger(__name__) try: diff --git a/rllib/env/remote_base_env.py b/rllib/env/remote_base_env.py index 9ff6537a9d32..5d48de3098ec 100644 --- a/rllib/env/remote_base_env.py +++ b/rllib/env/remote_base_env.py @@ -1,12 +1,13 @@ -import gymnasium as gym import logging -from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple + +import gymnasium as gym import ray -from ray.util import log_once -from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN -from ray.rllib.utils.annotations import override, OldAPIStack +from ray.rllib.env.base_env import _DUMMY_AGENT_ID, ASYNC_RESET_RETURN, BaseEnv +from ray.rllib.utils.annotations import OldAPIStack, override from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict +from ray.util import log_once if TYPE_CHECKING: from ray.rllib.evaluation.rollout_worker import RolloutWorker diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index ec30195bec44..bdb4a72d6570 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -1,14 +1,15 @@ -from collections import defaultdict -from functools import partial import logging import math import time +from collections import defaultdict +from functools import partial from typing import Collection, DefaultDict, List, Optional, Union import gymnasium as gym -import ray from gymnasium.wrappers.vector import DictInfoToList +import ray +from ray._common.deprecation import Deprecated from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.callbacks.utils import make_callback @@ -21,15 +22,14 @@ ) from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec -from ray.rllib.env import INPUT_ENV_SPACES, INPUT_ENV_SINGLE_SPACES +from ray.rllib.env import INPUT_ENV_SINGLE_SPACES, INPUT_ENV_SPACES from ray.rllib.env.env_context import EnvContext -from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE +from ray.rllib.env.env_runner import ENV_STEP_FAILURE, EnvRunner from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.env.utils import _gym_env_creator from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.checkpoints import Checkpointable -from ray._common.deprecation import Deprecated from ray.rllib.utils.framework import get_device from ray.rllib.utils.metrics import ( ENV_TO_MODULE_CONNECTOR, diff --git a/rllib/env/single_agent_episode.py b/rllib/env/single_agent_episode.py index 03906ff3d692..1087c4a7d58d 100644 --- a/rllib/env/single_agent_episode.py +++ b/rllib/env/single_agent_episode.py @@ -1,19 +1,19 @@ -from collections import defaultdict import copy import functools -import numpy as np import time import uuid +from collections import defaultdict +from typing import Any, Dict, List, Optional, SupportsFloat, Union import gymnasium as gym +import numpy as np from gymnasium.core import ActType, ObsType -from typing import Any, Dict, List, Optional, SupportsFloat, Union +from ray._common.deprecation import Deprecated from ray.rllib.core.columns import Columns from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict -from ray._common.deprecation import Deprecated from ray.rllib.utils.typing import AgentID, ModuleID from ray.util.annotations import PublicAPI diff --git a/rllib/env/tests/test_env_runner_group.py b/rllib/env/tests/test_env_runner_group.py index f615bfc835b8..2b813ab57abe 100644 --- a/rllib/env/tests/test_env_runner_group.py +++ b/rllib/env/tests/test_env_runner_group.py @@ -93,7 +93,8 @@ def test_foreach_env_runner_async(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/tests/test_infinite_lookback_buffer.py b/rllib/env/tests/test_infinite_lookback_buffer.py index 7e114c1086ee..7f1db6b98a5c 100644 --- a/rllib/env/tests/test_infinite_lookback_buffer.py +++ b/rllib/env/tests/test_infinite_lookback_buffer.py @@ -599,7 +599,8 @@ def test_set_with_complex_space(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py index 31d4c9ea13cc..a9ded7b21da6 100644 --- a/rllib/env/tests/test_multi_agent_env.py +++ b/rllib/env/tests/test_multi_agent_env.py @@ -1,11 +1,11 @@ +import random +import unittest + import gymnasium as gym import numpy as np -import random import tree # pip install dm-tree -import unittest import ray -from ray.tune.registry import register_env from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.env.multi_agent_env import ( @@ -20,12 +20,13 @@ convert_ma_batch_to_sample_batch, ) from ray.rllib.utils.metrics import ( - NUM_ENV_STEPS_SAMPLED_LIFETIME, ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, + NUM_ENV_STEPS_SAMPLED_LIFETIME, ) from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.test_utils import check +from ray.tune.registry import register_env class BasicMultiAgent(MultiAgentEnv): @@ -820,7 +821,8 @@ def is_recurrent(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/tests/test_multi_agent_env_runner.py b/rllib/env/tests/test_multi_agent_env_runner.py index 13fc1021f0fb..5de70bd57f99 100644 --- a/rllib/env/tests/test_multi_agent_env_runner.py +++ b/rllib/env/tests/test_multi_agent_env_runner.py @@ -1,7 +1,6 @@ import unittest import ray - from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.env.multi_agent_env_runner import MultiAgentEnvRunner from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole @@ -131,7 +130,8 @@ def _build_config(self, num_agents=2, num_policies=2): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index c97f934cf3b7..6b8db36b2382 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -1,9 +1,9 @@ -import gymnasium as gym -import numpy as np import unittest - from typing import Optional, Tuple +import gymnasium as gym +import numpy as np + import ray from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_episode import MultiAgentEpisode @@ -3573,7 +3573,8 @@ def _mock_multi_agent_records(): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/tests/test_single_agent_env_runner.py b/rllib/env/tests/test_single_agent_env_runner.py index 2f3df1864bf4..2abf9a79bc72 100644 --- a/rllib/env/tests/test_single_agent_env_runner.py +++ b/rllib/env/tests/test_single_agent_env_runner.py @@ -197,7 +197,8 @@ def test_vector_env(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/tests/test_single_agent_episode.py b/rllib/env/tests/test_single_agent_episode.py index 8411017bf08f..d4b6e41f768b 100644 --- a/rllib/env/tests/test_single_agent_episode.py +++ b/rllib/env/tests/test_single_agent_episode.py @@ -1,10 +1,10 @@ +import unittest from collections import defaultdict from typing import Any, Dict, Optional, SupportsFloat, Tuple -import unittest import gymnasium as gym -from gymnasium.core import ActType, ObsType import numpy as np +from gymnasium.core import ActType, ObsType from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.test_utils import check @@ -709,7 +709,8 @@ def _create_episode(self, num_data, t_started=None, len_lookback_buffer=0): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/utils/__init__.py b/rllib/env/utils/__init__.py index d0186ab489c1..1ca7be4f333a 100644 --- a/rllib/env/utils/__init__.py +++ b/rllib/env/utils/__init__.py @@ -10,7 +10,6 @@ ) from ray.util.annotations import PublicAPI - logger = logging.getLogger(__name__) diff --git a/rllib/env/utils/infinite_lookback_buffer.py b/rllib/env/utils/infinite_lookback_buffer.py index 76004f0200fa..ed5e6f63d3c5 100644 --- a/rllib/env/utils/infinite_lookback_buffer.py +++ b/rllib/env/utils/infinite_lookback_buffer.py @@ -9,8 +9,8 @@ from ray.rllib.utils.spaces.space_utils import ( batch, from_jsonable_if_needed, - get_dummy_batch_for_space, get_base_struct_from_space, + get_dummy_batch_for_space, to_jsonable_if_needed, ) from ray.util.annotations import DeveloperAPI diff --git a/rllib/env/vector/registration.py b/rllib/env/vector/registration.py index d9d4a4f59886..ec0a8b43e633 100644 --- a/rllib/env/vector/registration.py +++ b/rllib/env/vector/registration.py @@ -1,9 +1,9 @@ import copy -import gymnasium as gym import logging +from typing import Any, Dict, Optional +import gymnasium as gym from gymnasium.envs.registration import VectorizeMode -from typing import Any, Dict, Optional from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector.sync_vector_multi_agent_env import SyncVectorMultiAgentEnv diff --git a/rllib/env/vector/sync_vector_multi_agent_env.py b/rllib/env/vector/sync_vector_multi_agent_env.py index d1133ebf0d94..e5c375526f76 100644 --- a/rllib/env/vector/sync_vector_multi_agent_env.py +++ b/rllib/env/vector/sync_vector_multi_agent_env.py @@ -1,9 +1,9 @@ +from copy import deepcopy +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union + import gymnasium as gym import numpy as np - -from copy import deepcopy from gymnasium.core import ActType, RenderFrame -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector.vector_multi_agent_env import ArrayType, VectorMultiAgentEnv diff --git a/rllib/env/vector/vector_multi_agent_env.py b/rllib/env/vector/vector_multi_agent_env.py index a5b22d12dafc..a4e712643c18 100644 --- a/rllib/env/vector/vector_multi_agent_env.py +++ b/rllib/env/vector/vector_multi_agent_env.py @@ -1,11 +1,10 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar import gymnasium as gym +import numpy as np from gymnasium.core import RenderFrame from gymnasium.envs.registration import EnvSpec from gymnasium.utils import seeding -import numpy as np - ArrayType = TypeVar("ArrayType") diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index b1da92dd0cad..ed1e2dfdb70c 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -1,18 +1,19 @@ import logging +from typing import Callable, List, Optional, Set, Tuple, Union + import gymnasium as gym import numpy as np -from typing import Callable, List, Optional, Tuple, Union, Set -from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID +from ray.rllib.env.base_env import _DUMMY_AGENT_ID, BaseEnv from ray.rllib.utils.annotations import Deprecated, OldAPIStack, override from ray.rllib.utils.typing import ( + AgentID, EnvActionType, EnvID, EnvInfoDict, EnvObsType, EnvType, MultiEnvDict, - AgentID, ) from ray.util import log_once diff --git a/rllib/env/wrappers/atari_wrappers.py b/rllib/env/wrappers/atari_wrappers.py index 018b93564165..3b20ea23221e 100644 --- a/rllib/env/wrappers/atari_wrappers.py +++ b/rllib/env/wrappers/atari_wrappers.py @@ -1,11 +1,12 @@ from collections import deque +from typing import Optional, Union + import gymnasium as gym -from gymnasium import spaces import numpy as np -from typing import Optional, Union +from gymnasium import spaces from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.images import rgb2gray, resize +from ray.rllib.utils.images import resize, rgb2gray @PublicAPI diff --git a/rllib/env/wrappers/dm_env_wrapper.py b/rllib/env/wrappers/dm_env_wrapper.py index 435251df216b..7aef65848ee2 100644 --- a/rllib/env/wrappers/dm_env_wrapper.py +++ b/rllib/env/wrappers/dm_env_wrapper.py @@ -1,7 +1,6 @@ import gymnasium as gym -from gymnasium import spaces - import numpy as np +from gymnasium import spaces try: from dm_env import specs diff --git a/rllib/env/wrappers/group_agents_wrapper.py b/rllib/env/wrappers/group_agents_wrapper.py index c9bb592a79d0..bb242709a136 100644 --- a/rllib/env/wrappers/group_agents_wrapper.py +++ b/rllib/env/wrappers/group_agents_wrapper.py @@ -1,7 +1,8 @@ from collections import OrderedDict -import gymnasium as gym from typing import Dict, List, Optional +import gymnasium as gym + from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.typing import AgentID diff --git a/rllib/env/wrappers/open_spiel.py b/rllib/env/wrappers/open_spiel.py index abc051c65770..3823230fb2d4 100644 --- a/rllib/env/wrappers/open_spiel.py +++ b/rllib/env/wrappers/open_spiel.py @@ -1,7 +1,7 @@ from typing import Optional -import numpy as np import gymnasium as gym +import numpy as np from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.utils import try_import_pyspiel diff --git a/rllib/env/wrappers/tests/test_group_agents_wrapper.py b/rllib/env/wrappers/tests/test_group_agents_wrapper.py index 8f295513984c..c755dd869ebf 100644 --- a/rllib/env/wrappers/tests/test_group_agents_wrapper.py +++ b/rllib/env/wrappers/tests/test_group_agents_wrapper.py @@ -1,7 +1,7 @@ import unittest -from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper from ray.rllib.env.multi_agent_env import make_multi_agent +from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper class TestGroupAgentsWrapper(unittest.TestCase): @@ -20,6 +20,7 @@ def test_group_agents_wrapper(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 82babd666741..f22498e8bd96 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -1,14 +1,15 @@ -from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace import logging -import numpy as np import random import time from typing import Callable, Optional, Tuple +import numpy as np +from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace + from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.annotations import OldAPIStack -from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID +from ray.rllib.utils.typing import AgentID, MultiAgentDict, PolicyID logger = logging.getLogger(__name__) diff --git a/rllib/offline/__init__.py b/rllib/offline/__init__.py index cc4a0d9bb05d..c58b423c77ba 100644 --- a/rllib/offline/__init__.py +++ b/rllib/offline/__init__.py @@ -1,16 +1,15 @@ from ray.rllib.offline.d4rl_reader import D4RLReader from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards from ray.rllib.offline.dataset_writer import DatasetWriter -from ray.rllib.offline.io_context import IOContext +from ray.rllib.offline.feature_importance import FeatureImportance from ray.rllib.offline.input_reader import InputReader -from ray.rllib.offline.mixed_input import MixedInput +from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.json_reader import JsonReader from ray.rllib.offline.json_writer import JsonWriter -from ray.rllib.offline.output_writer import OutputWriter, NoopOutput +from ray.rllib.offline.mixed_input import MixedInput +from ray.rllib.offline.output_writer import NoopOutput, OutputWriter from ray.rllib.offline.resource import get_offline_io_resource_bundles from ray.rllib.offline.shuffled_input import ShuffledInput -from ray.rllib.offline.feature_importance import FeatureImportance - __all__ = [ "IOContext", diff --git a/rllib/offline/d4rl_reader.py b/rllib/offline/d4rl_reader.py index b9f18634b3d1..2800bf08c04e 100644 --- a/rllib/offline/d4rl_reader.py +++ b/rllib/offline/d4rl_reader.py @@ -1,12 +1,13 @@ import logging +from typing import Dict + import gymnasium as gym from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.io_context import IOContext from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.annotations import PublicAPI, override from ray.rllib.utils.typing import SampleBatchType -from typing import Dict logger = logging.getLogger(__name__) diff --git a/rllib/offline/dataset_reader.py b/rllib/offline/dataset_reader.py index 1172aa7f5d0d..cf5abffc9c40 100644 --- a/rllib/offline/dataset_reader.py +++ b/rllib/offline/dataset_reader.py @@ -1,17 +1,18 @@ import logging import math -from pathlib import Path import re -import numpy as np -from typing import List, Tuple, TYPE_CHECKING, Optional import zipfile +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Tuple + +import numpy as np import ray.data from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.json_reader import from_json_data, postprocess_actions -from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID -from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples +from ray.rllib.utils.annotations import PublicAPI, override from ray.rllib.utils.typing import SampleBatchType if TYPE_CHECKING: diff --git a/rllib/offline/dataset_writer.py b/rllib/offline/dataset_writer.py index b517933ce985..6b3ac7a15373 100644 --- a/rllib/offline/dataset_writer.py +++ b/rllib/offline/dataset_writer.py @@ -1,14 +1,14 @@ import logging import os import time +from typing import Dict, List from ray import data from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.json_writer import _to_json_dict from ray.rllib.offline.output_writer import OutputWriter -from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.annotations import PublicAPI, override from ray.rllib.utils.typing import SampleBatchType -from typing import Dict, List logger = logging.getLogger(__name__) diff --git a/rllib/offline/estimators/__init__.py b/rllib/offline/estimators/__init__.py index 74131faf3eb6..f2561b648776 100644 --- a/rllib/offline/estimators/__init__.py +++ b/rllib/offline/estimators/__init__.py @@ -1,10 +1,10 @@ +from ray.rllib.offline.estimators.direct_method import DirectMethod +from ray.rllib.offline.estimators.doubly_robust import DoublyRobust from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling +from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.offline.estimators.weighted_importance_sampling import ( WeightedImportanceSampling, ) -from ray.rllib.offline.estimators.direct_method import DirectMethod -from ray.rllib.offline.estimators.doubly_robust import DoublyRobust -from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator __all__ = [ "OffPolicyEstimator", diff --git a/rllib/offline/estimators/direct_method.py b/rllib/offline/estimators/direct_method.py index c735b93a5e1b..99e116deae05 100644 --- a/rllib/offline/estimators/direct_method.py +++ b/rllib/offline/estimators/direct_method.py @@ -1,20 +1,19 @@ import logging -from typing import Dict, Any, Optional, List import math +from typing import Any, Dict, List, Optional + import numpy as np from ray.data import Dataset - +from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.offline.offline_evaluation_utils import compute_q_and_v_values from ray.rllib.offline.offline_evaluator import OfflineEvaluator -from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel from ray.rllib.policy import Policy -from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch from ray.rllib.utils.annotations import DeveloperAPI, override -from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger() diff --git a/rllib/offline/estimators/doubly_robust.py b/rllib/offline/estimators/doubly_robust.py index 4341055789b1..3d17ea6c22f1 100644 --- a/rllib/offline/estimators/doubly_robust.py +++ b/rllib/offline/estimators/doubly_robust.py @@ -1,25 +1,23 @@ import logging -import numpy as np import math -import pandas as pd +from typing import Any, Dict, List, Optional -from typing import Dict, Any, Optional, List +import numpy as np +import pandas as pd from ray.data import Dataset - -from ray.rllib.policy import Policy -from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch -from ray.rllib.utils.annotations import DeveloperAPI, override -from ray.rllib.utils.typing import SampleBatchType -from ray.rllib.utils.numpy import convert_to_numpy - -from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel -from ray.rllib.offline.offline_evaluator import OfflineEvaluator +from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.offline.offline_evaluation_utils import ( compute_is_weights, compute_q_and_v_values, ) +from ray.rllib.offline.offline_evaluator import OfflineEvaluator +from ray.rllib.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch +from ray.rllib.utils.annotations import DeveloperAPI, override +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger() diff --git a/rllib/offline/estimators/fqe_torch_model.py b/rllib/offline/estimators/fqe_torch_model.py index f071640a9afd..b95417e3169b 100644 --- a/rllib/offline/estimators/fqe_torch_model.py +++ b/rllib/offline/estimators/fqe_torch_model.py @@ -1,15 +1,15 @@ -from typing import Dict, Any -from ray.rllib.models.utils import get_initializer -from ray.rllib.policy import Policy +from typing import Any, Dict + +from gymnasium.spaces import Discrete from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.models.utils import get_initializer +from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI, is_overridden from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.annotations import is_overridden from ray.rllib.utils.typing import ModelConfigDict, TensorType -from gymnasium.spaces import Discrete torch, nn = try_import_torch() diff --git a/rllib/offline/estimators/importance_sampling.py b/rllib/offline/estimators/importance_sampling.py index 630859820948..0d62a902b9f4 100644 --- a/rllib/offline/estimators/importance_sampling.py +++ b/rllib/offline/estimators/importance_sampling.py @@ -1,16 +1,15 @@ -from typing import Dict, List, Any import math +from typing import Any, Dict, List from ray.data import Dataset - -from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.offline.offline_evaluator import OfflineEvaluator +from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.offline.offline_evaluation_utils import ( - remove_time_dim, compute_is_weights, + remove_time_dim, ) -from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator +from ray.rllib.offline.offline_evaluator import OfflineEvaluator from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI, override @DeveloperAPI diff --git a/rllib/offline/estimators/off_policy_estimator.py b/rllib/offline/estimators/off_policy_estimator.py index 9abee46c1a12..cc674b303bb4 100644 --- a/rllib/offline/estimators/off_policy_estimator.py +++ b/rllib/offline/estimators/off_policy_estimator.py @@ -1,22 +1,22 @@ +import logging +from typing import Any, Dict, List + import gymnasium as gym import numpy as np import tree -from typing import Dict, Any, List -import logging -from ray.rllib.policy.sample_batch import SampleBatch +from ray._common.deprecation import Deprecated +from ray.rllib.offline.offline_evaluator import OfflineEvaluator from ray.rllib.policy import Policy -from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch -from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict +from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch from ray.rllib.utils.annotations import ( DeveloperAPI, ExperimentalAPI, OverrideToImplementCustomLogic, ) -from ray._common.deprecation import Deprecated from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.typing import TensorType, SampleBatchType -from ray.rllib.offline.offline_evaluator import OfflineEvaluator +from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict +from ray.rllib.utils.typing import SampleBatchType, TensorType logger = logging.getLogger(__name__) diff --git a/rllib/offline/estimators/tests/test_dm_learning.py b/rllib/offline/estimators/tests/test_dm_learning.py index a193e84e89a4..f6760c717fd7 100644 --- a/rllib/offline/estimators/tests/test_dm_learning.py +++ b/rllib/offline/estimators/tests/test_dm_learning.py @@ -3,8 +3,8 @@ import ray from ray.rllib.offline.estimators import DirectMethod from ray.rllib.offline.estimators.tests.utils import ( - get_cliff_walking_wall_policy_and_data, check_estimate, + get_cliff_walking_wall_policy_and_data, ) SEED = 0 @@ -197,6 +197,7 @@ def test_dm_expert_policy_expert_data(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/estimators/tests/test_dr_learning.py b/rllib/offline/estimators/tests/test_dr_learning.py index da79ccdcefa7..8c78b9195a33 100644 --- a/rllib/offline/estimators/tests/test_dr_learning.py +++ b/rllib/offline/estimators/tests/test_dr_learning.py @@ -3,8 +3,8 @@ import ray from ray.rllib.offline.estimators import DoublyRobust from ray.rllib.offline.estimators.tests.utils import ( - get_cliff_walking_wall_policy_and_data, check_estimate, + get_cliff_walking_wall_policy_and_data, ) SEED = 0 @@ -197,6 +197,7 @@ def test_dr_expert_policy_expert_data(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/estimators/tests/test_ope.py b/rllib/offline/estimators/tests/test_ope.py index 51dc6619e881..cbf2a69499c8 100644 --- a/rllib/offline/estimators/tests/test_ope.py +++ b/rllib/offline/estimators/tests/test_ope.py @@ -1,20 +1,20 @@ +import copy +import os +import unittest +from pathlib import Path from typing import TYPE_CHECKING, Tuple -import copy import gymnasium as gym import numpy as np -import os import pandas as pd -from pathlib import Path -import unittest import ray from ray.data import read_json from ray.rllib.algorithms.dqn import DQNConfig -from ray.rllib.examples.envs.classes.cliff_walking_wall_env import CliffWalkingWallEnv from ray.rllib.examples._old_api_stack.policy.cliff_walking_wall_policy import ( CliffWalkingWallPolicy, ) +from ray.rllib.examples.envs.classes.cliff_walking_wall_env import CliffWalkingWallEnv from ray.rllib.offline.dataset_reader import DatasetReader from ray.rllib.offline.estimators import ( DirectMethod, @@ -327,7 +327,8 @@ def test_fqe_optimal_convergence(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/estimators/tests/test_ope_math.py b/rllib/offline/estimators/tests/test_ope_math.py index 759857e50a69..9b3ec78b9857 100644 --- a/rllib/offline/estimators/tests/test_ope_math.py +++ b/rllib/offline/estimators/tests/test_ope_math.py @@ -1,22 +1,22 @@ -import unittest import time +import unittest + import gymnasium as gym +import numpy as np import torch -import numpy as np +import ray +from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.offline.estimators import ( DirectMethod, DoublyRobust, ImportanceSampling, WeightedImportanceSampling, ) -from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.test_utils import check -import ray - class FakePolicy(TorchPolicyV2): """A fake policy used in test ope math to emulate a target policy that is better @@ -215,7 +215,8 @@ def test_dm_dr_math(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/estimators/tests/utils.py b/rllib/offline/estimators/tests/utils.py index b7366e8609a3..30d443c7c68f 100644 --- a/rllib/offline/estimators/tests/utils.py +++ b/rllib/offline/estimators/tests/utils.py @@ -1,12 +1,13 @@ -from typing import Type, Union, Dict, Tuple +from typing import Dict, Tuple, Type, Union import numpy as np + from ray.rllib.algorithms import AlgorithmConfig from ray.rllib.env.env_runner_group import EnvRunnerGroup -from ray.rllib.examples.envs.classes.cliff_walking_wall_env import CliffWalkingWallEnv from ray.rllib.examples._old_api_stack.policy.cliff_walking_wall_policy import ( CliffWalkingWallPolicy, ) +from ray.rllib.examples.envs.classes.cliff_walking_wall_env import CliffWalkingWallEnv from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.offline.estimators import ( DirectMethod, diff --git a/rllib/offline/estimators/weighted_importance_sampling.py b/rllib/offline/estimators/weighted_importance_sampling.py index cfca393a0212..67d14682a996 100644 --- a/rllib/offline/estimators/weighted_importance_sampling.py +++ b/rllib/offline/estimators/weighted_importance_sampling.py @@ -1,18 +1,18 @@ -from typing import Dict, Any, List -import numpy as np import math +from typing import Any, Dict, List -from ray.data import Dataset +import numpy as np -from ray.rllib.offline.offline_evaluator import OfflineEvaluator +from ray.data import Dataset from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.offline.offline_evaluation_utils import ( - remove_time_dim, compute_is_weights, + remove_time_dim, ) -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.offline.offline_evaluator import OfflineEvaluator from ray.rllib.policy import Policy -from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI, override @DeveloperAPI diff --git a/rllib/offline/feature_importance.py b/rllib/offline/feature_importance.py index 2efe17790a79..bc520f8e5cbe 100644 --- a/rllib/offline/feature_importance.py +++ b/rllib/offline/feature_importance.py @@ -1,16 +1,16 @@ import copy +from typing import Any, Callable, Dict + import numpy as np import pandas as pd -from typing import Callable, Dict, Any import ray from ray.data import Dataset - +from ray.rllib.offline.offline_evaluator import OfflineEvaluator from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch -from ray.rllib.utils.annotations import override, DeveloperAPI, ExperimentalAPI +from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, override from ray.rllib.utils.typing import SampleBatchType -from ray.rllib.offline.offline_evaluator import OfflineEvaluator @DeveloperAPI diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index 042e3783c39d..18f40176072e 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -1,13 +1,14 @@ -from abc import ABCMeta, abstractmethod import logging -import numpy as np import threading +from abc import ABCMeta, abstractmethod +from typing import Dict, List + +import numpy as np from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.framework import try_import_tf -from typing import Dict, List -from ray.rllib.utils.typing import TensorType, SampleBatchType +from ray.rllib.utils.typing import SampleBatchType, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/offline/io_context.py b/rllib/offline/io_context.py index 1d0ec1683b93..72576d3730f8 100644 --- a/rllib/offline/io_context.py +++ b/rllib/offline/io_context.py @@ -1,12 +1,12 @@ import os -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from ray.rllib.utils.annotations import PublicAPI if TYPE_CHECKING: from ray.rllib.algorithms.algorithm_config import AlgorithmConfig - from ray.rllib.evaluation.sampler import SamplerInput from ray.rllib.evaluation.rollout_worker import RolloutWorker + from ray.rllib.evaluation.sampler import SamplerInput @PublicAPI diff --git a/rllib/offline/is_estimator.py b/rllib/offline/is_estimator.py index d395e3f9a356..871120ad676b 100644 --- a/rllib/offline/is_estimator.py +++ b/rllib/offline/is_estimator.py @@ -1,5 +1,5 @@ -from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling from ray._common.deprecation import Deprecated +from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling @Deprecated( diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index 30562b515aac..076791716d82 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -2,16 +2,16 @@ import json import logging import math - -import numpy as np import os -from pathlib import Path import random import re -import tree # pip install dm_tree -from typing import List, Optional, TYPE_CHECKING, Union -from urllib.parse import urlparse import zipfile +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Union +from urllib.parse import urlparse + +import numpy as np +import tree # pip install dm_tree try: from smart_open import smart_open @@ -28,7 +28,7 @@ concat_samples, convert_ma_batch_to_sample_batch, ) -from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI, override from ray.rllib.utils.compression import unpack_if_needed from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action from ray.rllib.utils.typing import Any, FileType, SampleBatchType diff --git a/rllib/offline/json_writer.py b/rllib/offline/json_writer.py index 4e15bfb2e550..da7b49e5b17b 100644 --- a/rllib/offline/json_writer.py +++ b/rllib/offline/json_writer.py @@ -1,24 +1,26 @@ -from datetime import datetime import json import logging -import numpy as np import os -from urllib.parse import urlparse import time +from datetime import datetime +from urllib.parse import urlparse + +import numpy as np try: from smart_open import smart_open except ImportError: smart_open = None +from typing import Any, Dict, List + from ray.air._internal.json import SafeFallbackEncoder -from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.output_writer import OutputWriter -from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.compression import pack, compression_supported +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.annotations import PublicAPI, override +from ray.rllib.utils.compression import compression_supported, pack from ray.rllib.utils.typing import FileType, SampleBatchType -from typing import Any, Dict, List logger = logging.getLogger(__name__) diff --git a/rllib/offline/mixed_input.py b/rllib/offline/mixed_input.py index 8c8ad60b06f9..171a7a2c85c4 100644 --- a/rllib/offline/mixed_input.py +++ b/rllib/offline/mixed_input.py @@ -2,12 +2,13 @@ from typing import Dict import numpy as np + from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.json_reader import JsonReader -from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.typing import SampleBatchType -from ray.tune.registry import registry_get_input, registry_contains_input +from ray.tune.registry import registry_contains_input, registry_get_input @DeveloperAPI diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 9d2b90195a57..71d3a80f0148 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -1,7 +1,7 @@ +from ray._common.deprecation import deprecation_warning from ray.rllib.offline.estimators.off_policy_estimator import ( # noqa: F401 OffPolicyEstimator, ) -from ray._common.deprecation import deprecation_warning deprecation_warning( old="ray.rllib.offline.off_policy_estimator", diff --git a/rllib/offline/offline_data.py b/rllib/offline/offline_data.py index f48346ed0cc7..64c55a6c33d9 100644 --- a/rllib/offline/offline_data.py +++ b/rllib/offline/offline_data.py @@ -1,19 +1,18 @@ import logging -from pathlib import Path -import pyarrow.fs -import numpy as np -import ray import time import types +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict -from typing import Any, Dict, TYPE_CHECKING +import numpy as np +import pyarrow.fs +import ray from ray.rllib.core import COMPONENT_RL_MODULE from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.offline.offline_prelearner import OfflinePreLearner -from ray.rllib.utils import unflatten_dict from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch -from ray.rllib.utils import force_list +from ray.rllib.utils import force_list, unflatten_dict from ray.rllib.utils.annotations import ( OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, diff --git a/rllib/offline/offline_env_runner.py b/rllib/offline/offline_env_runner.py index 704ed175be53..e58f2ef99b51 100644 --- a/rllib/offline/offline_env_runner.py +++ b/rllib/offline/offline_env_runner.py @@ -1,24 +1,23 @@ import logging -import ray - from pathlib import Path from typing import List +import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core.columns import Columns from ray.rllib.env.env_runner import EnvRunner from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.annotations import ( - override, - OverrideToImplementCustomLogic_CallToSuperRecommended, OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.compression import pack_if_needed from ray.rllib.utils.spaces.space_utils import to_jsonable_if_needed from ray.rllib.utils.typing import EpisodeType -from ray.util.debug import log_once from ray.util.annotations import PublicAPI +from ray.util.debug import log_once logger = logging.Logger(__file__) diff --git a/rllib/offline/offline_evaluation_runner.py b/rllib/offline/offline_evaluation_runner.py index 8578256c560d..c978b6388ff0 100644 --- a/rllib/offline/offline_evaluation_runner.py +++ b/rllib/offline/offline_evaluation_runner.py @@ -1,8 +1,7 @@ -import ray import types +from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union -from typing import Any, Collection, Dict, Iterable, Optional, TYPE_CHECKING, Union - +import ray from ray.data.iterator import DataIterator from ray.rllib.core import ( ALL_MODULES, diff --git a/rllib/offline/offline_evaluation_runner_group.py b/rllib/offline/offline_evaluation_runner_group.py index 2f92e2d1aabf..470d762211c1 100644 --- a/rllib/offline/offline_evaluation_runner_group.py +++ b/rllib/offline/offline_evaluation_runner_group.py @@ -1,6 +1,6 @@ -import ray -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +import ray from ray.data.iterator import DataIterator from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec diff --git a/rllib/offline/offline_evaluation_utils.py b/rllib/offline/offline_evaluation_utils.py index de39f149f695..ac7eb8c3f728 100644 --- a/rllib/offline/offline_evaluation_utils.py +++ b/rllib/offline/offline_evaluation_utils.py @@ -1,11 +1,12 @@ +from typing import TYPE_CHECKING, Any, Dict, Type + import numpy as np import pandas as pd -from typing import Any, Dict, Type, TYPE_CHECKING -from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy import Policy -from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.numpy import convert_to_numpy if TYPE_CHECKING: from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel diff --git a/rllib/offline/offline_evaluator.py b/rllib/offline/offline_evaluator.py index 60b87ff1296d..277c514e4582 100644 --- a/rllib/offline/offline_evaluator.py +++ b/rllib/offline/offline_evaluator.py @@ -1,10 +1,9 @@ import abc -import os import logging -from typing import Dict, Any +import os +from typing import Any, Dict from ray.data import Dataset - from ray.rllib.policy import Policy from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI from ray.rllib.utils.typing import SampleBatchType diff --git a/rllib/offline/offline_policy_evaluation_runner.py b/rllib/offline/offline_policy_evaluation_runner.py index 2e2f63e9d7ff..ea6dda004bae 100644 --- a/rllib/offline/offline_policy_evaluation_runner.py +++ b/rllib/offline/offline_policy_evaluation_runner.py @@ -1,32 +1,32 @@ -import gymnasium as gym import math -import numpy -import ray - from enum import Enum from typing import ( + TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, - TYPE_CHECKING, Union, ) +import gymnasium as gym +import numpy + +import ray from ray.data.iterator import DataIterator from ray.rllib.connectors.env_to_module import EnvToModulePipeline from ray.rllib.core import ( ALL_MODULES, - DEFAULT_AGENT_ID, - DEFAULT_MODULE_ID, COMPONENT_ENV_TO_MODULE_CONNECTOR, COMPONENT_RL_MODULE, + DEFAULT_AGENT_ID, + DEFAULT_MODULE_ID, ) from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.offline.offline_prelearner import OfflinePreLearner, SCHEMA +from ray.rllib.offline.offline_prelearner import SCHEMA, OfflinePreLearner from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.checkpoints import Checkpointable diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index 782f5c0f0fdc..0801c68b04f5 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -1,14 +1,14 @@ import copy -import gymnasium as gym import logging -import numpy as np -import tree import uuid +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union -from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING +import gymnasium as gym +import numpy as np +import tree from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec, MultiRLModule +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils import flatten_dict from ray.rllib.utils.annotations import ( diff --git a/rllib/offline/output_writer.py b/rllib/offline/output_writer.py index ca26c5a538fa..b278ca90ae91 100644 --- a/rllib/offline/output_writer.py +++ b/rllib/offline/output_writer.py @@ -1,4 +1,4 @@ -from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.annotations import PublicAPI, override from ray.rllib.utils.typing import SampleBatchType diff --git a/rllib/offline/resource.py b/rllib/offline/resource.py index e658b9b682bc..ff01b2fb7b89 100644 --- a/rllib/offline/resource.py +++ b/rllib/offline/resource.py @@ -1,4 +1,5 @@ -from typing import Dict, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List + from ray.rllib.utils.annotations import PublicAPI if TYPE_CHECKING: diff --git a/rllib/offline/shuffled_input.py b/rllib/offline/shuffled_input.py index a7c261018594..a6633fa76e5c 100644 --- a/rllib/offline/shuffled_input.py +++ b/rllib/offline/shuffled_input.py @@ -2,7 +2,7 @@ import random from ray.rllib.offline.input_reader import InputReader -from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) diff --git a/rllib/offline/tests/test_dataset_reader.py b/rllib/offline/tests/test_dataset_reader.py index b8825c49a307..0128b3b67e98 100644 --- a/rllib/offline/tests/test_dataset_reader.py +++ b/rllib/offline/tests/test_dataset_reader.py @@ -1,17 +1,17 @@ -import tempfile import os -from pathlib import Path +import tempfile import unittest -import pytest +from pathlib import Path +import pytest import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.offline import IOContext from ray.rllib.offline.dataset_reader import ( DatasetReader, - get_dataset_and_shards, _unzip_if_needed, + get_dataset_and_shards, ) diff --git a/rllib/offline/tests/test_feature_importance.py b/rllib/offline/tests/test_feature_importance.py index c19953aa4403..af626bc88d60 100644 --- a/rllib/offline/tests/test_feature_importance.py +++ b/rllib/offline/tests/test_feature_importance.py @@ -1,6 +1,6 @@ import unittest -import ray +import ray from ray.rllib.algorithms.marwil import MARWILConfig from ray.rllib.execution import synchronous_parallel_sample from ray.rllib.offline.feature_importance import FeatureImportance @@ -41,7 +41,8 @@ def test_feat_importance_estimate_on_dataset(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/tests/test_offline_data.py b/rllib/offline/tests/test_offline_data.py index f872ffeeef88..087f5bb1b132 100644 --- a/rllib/offline/tests/test_offline_data.py +++ b/rllib/offline/tests/test_offline_data.py @@ -1,10 +1,10 @@ -import gymnasium as gym -import ray import shutil import unittest - from pathlib import Path +import gymnasium as gym + +import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.bc import BCConfig from ray.rllib.core.columns import Columns @@ -260,6 +260,7 @@ def __init__(self, config: AlgorithmConfig): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/tests/test_offline_env_runner.py b/rllib/offline/tests/test_offline_env_runner.py index 6ec7a2b60b57..fedcc1ae5a39 100644 --- a/rllib/offline/tests/test_offline_env_runner.py +++ b/rllib/offline/tests/test_offline_env_runner.py @@ -1,9 +1,10 @@ -import msgpack -import msgpack_numpy as m import pathlib import shutil import unittest +import msgpack +import msgpack_numpy as m + import ray from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.core.columns import Columns @@ -205,6 +206,7 @@ def test_offline_env_runner_compress_columns(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/tests/test_offline_evaluation_runner.py b/rllib/offline/tests/test_offline_evaluation_runner.py index e5397b9eda09..58f6e4e9b292 100644 --- a/rllib/offline/tests/test_offline_evaluation_runner.py +++ b/rllib/offline/tests/test_offline_evaluation_runner.py @@ -1,15 +1,15 @@ import unittest -import gymnasium as gym - from pathlib import Path -from typing import Any, Dict, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict + +import gymnasium as gym from ray.rllib.algorithms.bc.bc import BCConfig from ray.rllib.core import ALL_MODULES, DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.offline.offline_evaluation_runner import ( - OfflineEvaluationRunner, TOTAL_EVAL_LOSS_KEY, + OfflineEvaluationRunner, ) from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED from ray.rllib.utils.typing import ModuleID, ResultDict, TensorType @@ -198,6 +198,7 @@ def _compute_loss_for_module( if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/tests/test_offline_evaluation_runner_group.py b/rllib/offline/tests/test_offline_evaluation_runner_group.py index fe402261a40b..8a1772069cee 100644 --- a/rllib/offline/tests/test_offline_evaluation_runner_group.py +++ b/rllib/offline/tests/test_offline_evaluation_runner_group.py @@ -1,10 +1,10 @@ -import gymnasium as gym -import ray import sys import unittest - from pathlib import Path +import gymnasium as gym + +import ray from ray.rllib.algorithms.bc.bc import BCConfig from ray.rllib.offline.offline_evaluation_runner_group import ( OfflineEvaluationRunnerGroup, @@ -122,7 +122,7 @@ def test_offline_evaluation_runner_group_run(self): self.assertIsInstance(metrics, list) self.assertEqual(len(metrics), offline_runner_group.num_runners) # Ensure that the `eval_total_loss_key` is part of the runner metrics. - from ray.rllib.core import DEFAULT_MODULE_ID, ALL_MODULES + from ray.rllib.core import ALL_MODULES, DEFAULT_MODULE_ID from ray.rllib.offline.offline_evaluation_runner import TOTAL_EVAL_LOSS_KEY from ray.rllib.utils.metrics import ( NUM_ENV_STEPS_SAMPLED, @@ -187,6 +187,7 @@ def test_offline_evaluation_runner_group_with_local_runner(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/tests/test_offline_prelearner.py b/rllib/offline/tests/test_offline_prelearner.py index 1123a4ee9d74..1365625d286a 100644 --- a/rllib/offline/tests/test_offline_prelearner.py +++ b/rllib/offline/tests/test_offline_prelearner.py @@ -1,14 +1,14 @@ import functools -import gymnasium as gym -import ray import shutil import unittest - from pathlib import Path +import gymnasium as gym + +import ray from ray.rllib.algorithms.bc import BCConfig from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.core import Columns, COMPONENT_RL_MODULE +from ray.rllib.core import COMPONENT_RL_MODULE, Columns from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.offline.offline_prelearner import OfflinePreLearner @@ -313,6 +313,7 @@ def test_offline_prelearner_sample_from_episode_data(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/wis_estimator.py b/rllib/offline/wis_estimator.py index 95c7e3bcec09..d207d7e90428 100644 --- a/rllib/offline/wis_estimator.py +++ b/rllib/offline/wis_estimator.py @@ -1,7 +1,7 @@ +from ray._common.deprecation import Deprecated from ray.rllib.offline.estimators.weighted_importance_sampling import ( WeightedImportanceSampling, ) -from ray._common.deprecation import Deprecated @Deprecated( diff --git a/rllib/tests/conftest.py b/rllib/tests/conftest.py index 6cf458bcb91e..4f40638d36fd 100644 --- a/rllib/tests/conftest.py +++ b/rllib/tests/conftest.py @@ -1,4 +1,5 @@ -from ray.tests.conftest import ray_start_regular_shared # noqa: F401 - # Trigger pytest hook to automatically zip test cluster logs to archive dir on failure -from ray.tests.conftest import pytest_runtest_makereport # noqa +from ray.tests.conftest import ( + pytest_runtest_makereport, # noqa + ray_start_regular_shared, # noqa: F401 +) diff --git a/rllib/tests/run_regression_tests.py b/rllib/tests/run_regression_tests.py index 8fc62da78c23..49a61942ad06 100644 --- a/rllib/tests/run_regression_tests.py +++ b/rllib/tests/run_regression_tests.py @@ -6,17 +6,18 @@ import importlib import json import os -from pathlib import Path -import sys import re +import sys import uuid +from pathlib import Path + import yaml import ray from ray import air +from ray._common.deprecation import deprecation_warning from ray.air.integrations.wandb import WandbLoggerCallback from ray.rllib import _register_all -from ray._common.deprecation import deprecation_warning from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index 119d1d9614e3..b7f3855707e3 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -1,10 +1,11 @@ +import unittest from functools import partial -from gymnasium.spaces import Box, Dict, Discrete, Tuple + import numpy as np -import unittest +from gymnasium.spaces import Box, Dict, Discrete, Tuple import ray -from ray.rllib.models import ActionDistribution, ModelCatalog, MODEL_DEFAULTS +from ray.rllib.models import MODEL_DEFAULTS, ActionDistribution, ModelCatalog from ray.rllib.models.preprocessors import ( Preprocessor, TupleFlatteningPreprocessor, @@ -259,7 +260,8 @@ class Model: if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_dependency_torch.py b/rllib/tests/test_dependency_torch.py index 235ee833d701..bcd720a6c7aa 100755 --- a/rllib/tests/test_dependency_torch.py +++ b/rllib/tests/test_dependency_torch.py @@ -3,7 +3,6 @@ import os import sys - if __name__ == "__main__": # Do not import torch for testing purposes. os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1" diff --git a/rllib/tests/test_local.py b/rllib/tests/test_local.py index 3ace9d11080a..42606d3d3745 100644 --- a/rllib/tests/test_local.py +++ b/rllib/tests/test_local.py @@ -29,7 +29,8 @@ def test_local(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_lstm.py b/rllib/tests/test_lstm.py index 66ceda0b4f2b..c25b85ced65b 100644 --- a/rllib/tests/test_lstm.py +++ b/rllib/tests/test_lstm.py @@ -1,6 +1,7 @@ -import numpy as np import unittest +import numpy as np + from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.rllib.utils.test_utils import check @@ -159,7 +160,8 @@ def test_dynamic_max_len(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_nn_framework_import_errors.py b/rllib/tests/test_nn_framework_import_errors.py index d117bf0f385d..85fcb93a55a9 100644 --- a/rllib/tests/test_nn_framework_import_errors.py +++ b/rllib/tests/test_nn_framework_import_errors.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import os + import pytest import ray.rllib.algorithms.ppo as ppo diff --git a/rllib/tests/test_pettingzoo_env.py b/rllib/tests/test_pettingzoo_env.py index e42d18b77f5c..556383ac2c64 100644 --- a/rllib/tests/test_pettingzoo_env.py +++ b/rllib/tests/test_pettingzoo_env.py @@ -1,3 +1,5 @@ +import unittest + from numpy import float32 from pettingzoo.butterfly import pistonball_v6 from pettingzoo.mpe import simple_spread_v3 @@ -10,8 +12,6 @@ ) from supersuit.utils.convert_box import convert_box -import unittest - import ray from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.env import PettingZooEnv @@ -110,7 +110,8 @@ def test_pettingzoo_env(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_placement_groups.py b/rllib/tests/test_placement_groups.py index 789e606f5eff..52dbe2d697df 100644 --- a/rllib/tests/test_placement_groups.py +++ b/rllib/tests/test_placement_groups.py @@ -5,8 +5,8 @@ from ray import tune from ray.rllib.algorithms.ppo import PPO, PPOConfig from ray.tune import Callback -from ray.tune.experiment import Trial from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.experiment import Trial from ray.tune.result import TRAINING_ITERATION trial_executor = None @@ -126,7 +126,8 @@ def test_default_resource_request_plus_manual_leads_to_error(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_ray_client.py b/rllib/tests/test_ray_client.py index f6c83165ff7d..0b9337daf69f 100644 --- a/rllib/tests/test_ray_client.py +++ b/rllib/tests/test_ray_client.py @@ -2,9 +2,9 @@ import pytest +from ray._private.client_mode_hook import client_mode_should_convert, enable_client_mode from ray.rllib.algorithms import dqn from ray.util.client.ray_client_helpers import ray_start_client_server -from ray._private.client_mode_hook import enable_client_mode, client_mode_should_convert def test_basic_dqn(): diff --git a/rllib/tests/test_telemetry.py b/rllib/tests/test_telemetry.py index 14712dca1a94..bcd2e48461bf 100644 --- a/rllib/tests/test_telemetry.py +++ b/rllib/tests/test_telemetry.py @@ -4,8 +4,7 @@ import ray import ray._common.usage.usage_lib as ray_usage_lib - -from ray._common.test_utils import check_library_usage_telemetry, TelemetryCallsite +from ray._common.test_utils import TelemetryCallsite, check_library_usage_telemetry @pytest.fixture diff --git a/rllib/tests/test_timesteps.py b/rllib/tests/test_timesteps.py index f0a081c57246..07ea0ed8d0f3 100644 --- a/rllib/tests/test_timesteps.py +++ b/rllib/tests/test_timesteps.py @@ -1,6 +1,7 @@ -import numpy as np import unittest +import numpy as np + import ray import ray.rllib.algorithms.ppo as ppo from ray.rllib.examples.envs.classes.random_env import RandomEnv @@ -61,7 +62,8 @@ def test_timesteps(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) From a6af7a4be6e8c80e57801cbaecef3c985c126cba Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Wed, 1 Oct 2025 11:16:52 +0530 Subject: [PATCH 4/9] Apply ruff_imports Signed-off-by: Gagandeep Singh --- rllib/env/env_runner_group.py | 7 ------- rllib/env/tests/test_env_runner_group.py | 2 +- rllib/env/tests/test_single_agent_episode.py | 4 ++-- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/rllib/env/env_runner_group.py b/rllib/env/env_runner_group.py index 63e5a5070474..b817bfcdf54e 100644 --- a/rllib/env/env_runner_group.py +++ b/rllib/env/env_runner_group.py @@ -1,5 +1,3 @@ -import gymnasium as gym -import logging import importlib.util import logging import os @@ -22,7 +20,6 @@ import ray from ray._common.deprecation import ( DEPRECATED_VALUE, - Deprecated, deprecation_warning, ) from ray.actor import ActorHandle @@ -44,10 +41,6 @@ from ray.rllib.policy.policy import Policy, PolicyState from ray.rllib.utils.actor_manager import FaultTolerantActorManager from ray.rllib.utils.annotations import OldAPIStack -from ray._common.deprecation import ( - deprecation_warning, - DEPRECATED_VALUE, -) from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO from ray.rllib.utils.typing import ( diff --git a/rllib/env/tests/test_env_runner_group.py b/rllib/env/tests/test_env_runner_group.py index d33ad05269ae..b79cde6a3dba 100644 --- a/rllib/env/tests/test_env_runner_group.py +++ b/rllib/env/tests/test_env_runner_group.py @@ -1,7 +1,7 @@ +import time import unittest import ray -import time from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.env.env_runner_group import EnvRunnerGroup diff --git a/rllib/env/tests/test_single_agent_episode.py b/rllib/env/tests/test_single_agent_episode.py index 04f379c1b858..8c4ec890c0d3 100644 --- a/rllib/env/tests/test_single_agent_episode.py +++ b/rllib/env/tests/test_single_agent_episode.py @@ -1,8 +1,8 @@ +import copy import unittest from collections import defaultdict from typing import Any, Dict, Optional, SupportsFloat, Tuple -import unittest -import copy + import gymnasium as gym import numpy as np from gymnasium.core import ActType, ObsType From 9e3e32fa0045ad52059a1cf20ee1aaed2229c469 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Fri, 31 Oct 2025 10:26:00 +0000 Subject: [PATCH 5/9] run pre-commit Signed-off-by: Mark Towers --- rllib/core/learner/learner_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index 9bebc0409fd1..5a06f79071f3 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -3,6 +3,7 @@ import pathlib from functools import partial from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -11,7 +12,6 @@ Optional, Set, Type, - TYPE_CHECKING, Union, ) @@ -36,8 +36,8 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.checkpoints import Checkpointable from ray.rllib.utils.metrics.ray_metrics import ( - TimerAndPrometheusLogger, DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS, + TimerAndPrometheusLogger, ) from ray.rllib.utils.typing import ( EpisodeType, From 541bc3b08886d2eac7d6eb6e4262c80be867d8a1 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Fri, 31 Oct 2025 10:29:59 +0000 Subject: [PATCH 6/9] run pre-commit Signed-off-by: Mark Towers --- rllib/env/env_runner.py | 4 ++-- rllib/env/single_agent_episode.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/rllib/env/env_runner.py b/rllib/env/env_runner.py index 840c25a77e84..c407f02c42cb 100644 --- a/rllib/env/env_runner.py +++ b/rllib/env/env_runner.py @@ -1,6 +1,6 @@ import abc import logging -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import gymnasium as gym import tree # pip install dm_tree @@ -15,7 +15,7 @@ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger from ray.rllib.utils.torch_utils import convert_to_torch_tensor from ray.rllib.utils.typing import StateDict, TensorType -from ray.util.annotations import PublicAPI, DeveloperAPI +from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.metrics import Counter if TYPE_CHECKING: diff --git a/rllib/env/single_agent_episode.py b/rllib/env/single_agent_episode.py index 571a37ebe462..4e99090a03b7 100644 --- a/rllib/env/single_agent_episode.py +++ b/rllib/env/single_agent_episode.py @@ -1,20 +1,20 @@ -from collections import defaultdict import copy import functools -import numpy as np import time import uuid +from collections import defaultdict +from typing import Any, Dict, List, Optional, SupportsFloat, Union import gymnasium as gym +import numpy as np import tree from gymnasium.core import ActType, ObsType -from typing import Any, Dict, List, Optional, SupportsFloat, Union +from ray._common.deprecation import Deprecated from ray.rllib.core.columns import Columns from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict -from ray._common.deprecation import Deprecated from ray.rllib.utils.typing import AgentID, ModuleID from ray.util.annotations import PublicAPI From f30f2ec943d09bca1765eacf8117bb975003f40a Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Fri, 31 Oct 2025 10:33:14 +0000 Subject: [PATCH 7/9] run pre-commit Signed-off-by: Mark Towers --- rllib/algorithms/algorithm.py | 25 +++++++++++------------ rllib/algorithms/algorithm_config.py | 2 +- rllib/algorithms/cql/cql.py | 1 + rllib/algorithms/dreamerv3/dreamerv3.py | 2 +- rllib/algorithms/impala/impala.py | 4 ++-- rllib/algorithms/impala/impala_learner.py | 2 +- rllib/algorithms/marwil/marwil.py | 1 + rllib/algorithms/sac/sac.py | 1 + 8 files changed, 20 insertions(+), 18 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 701d21bd87c7..3583df0d31e4 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -13,6 +13,7 @@ from collections import defaultdict from datetime import datetime from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -23,7 +24,6 @@ Set, Tuple, Type, - TYPE_CHECKING, Union, ) @@ -47,9 +47,9 @@ from ray.rllib.algorithms.utils import ( AggregatorActor, _get_env_runner_bundles, - _get_offline_eval_runner_bundles, _get_learner_bundles, _get_main_process_bundle, + _get_offline_eval_runner_bundles, ) from ray.rllib.callbacks.utils import make_callback from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector @@ -84,30 +84,30 @@ from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.offline import get_dataset_and_shards from ray.rllib.offline.estimators import ( - OffPolicyEstimator, - ImportanceSampling, - WeightedImportanceSampling, DirectMethod, DoublyRobust, + ImportanceSampling, + OffPolicyEstimator, + WeightedImportanceSampling, ) from ray.rllib.offline.offline_evaluator import OfflineEvaluator from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch -from ray.rllib.utils import deep_update, FilterManager, force_list +from ray.rllib.utils import FilterManager, deep_update, force_list from ray.rllib.utils.actor_manager import FaultTolerantActorManager from ray.rllib.utils.annotations import ( DeveloperAPI, ExperimentalAPI, OldAPIStack, - override, OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, PublicAPI, + override, ) from ray.rllib.utils.checkpoints import ( - Checkpointable, CHECKPOINT_VERSION, CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER, + Checkpointable, get_checkpoint_info, try_import_msgpack, ) @@ -134,9 +134,9 @@ NUM_AGENT_STEPS_TRAINED, NUM_AGENT_STEPS_TRAINED_LIFETIME, NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER, NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_SAMPLED_THIS_ITER, - NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER, NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_EPISODES, @@ -147,13 +147,13 @@ RESTORE_ENV_RUNNERS_TIMER, RESTORE_EVAL_ENV_RUNNERS_TIMER, RESTORE_OFFLINE_EVAL_RUNNERS_TIMER, + STEPS_TRAINED_THIS_ITER_COUNTER, SYNCH_ENV_CONNECTOR_STATES_TIMER, SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER, SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, TRAINING_ITERATION_TIMER, TRAINING_STEP_TIMER, - STEPS_TRAINED_THIS_ITER_COUNTER, ) from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.metrics.metrics_logger import MetricsLogger @@ -164,7 +164,7 @@ ) from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer from ray.rllib.utils.runners.runner_group import RunnerGroup -from ray.rllib.utils.serialization import deserialize_type, NOT_SERIALIZABLE +from ray.rllib.utils.serialization import NOT_SERIALIZABLE, deserialize_type from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import ( AgentConnectorDataType, @@ -191,8 +191,7 @@ from ray.tune.execution.placement_groups import PlacementGroupFactory from ray.tune.experiment.trial import ExportFormat from ray.tune.logger import Logger, UnifiedLogger -from ray.tune.registry import ENV_CREATOR, _global_registry -from ray.tune.registry import get_trainable_cls +from ray.tune.registry import ENV_CREATOR, _global_registry, get_trainable_cls from ray.tune.resources import Resources from ray.tune.result import TRAINING_ITERATION from ray.tune.trainable import Trainable diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 781763d4c59e..3953c2c4dfe7 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -16,11 +16,11 @@ Type, Union, ) -from typing_extensions import Self import gymnasium as gym import tree from packaging import version +from typing_extensions import Self import ray from ray._common.deprecation import ( diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 1b2e2abe46ad..681f5210c6dc 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -1,5 +1,6 @@ import logging from typing import Optional, Type, Union + from typing_extensions import Self from ray._common.deprecation import ( diff --git a/rllib/algorithms/dreamerv3/dreamerv3.py b/rllib/algorithms/dreamerv3/dreamerv3.py index 5c83c7b787bb..935f7a53a738 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3.py +++ b/rllib/algorithms/dreamerv3/dreamerv3.py @@ -10,9 +10,9 @@ import logging from typing import Any, Dict, Optional, Union -from typing_extensions import Self import gymnasium as gym +from typing_extensions import Self from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index dc3f14013573..ce0f3d8555ce 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -32,8 +32,8 @@ LEARNER_RESULTS, LEARNER_UPDATE_TIMER, MEAN_NUM_EPISODE_LISTS_RECEIVED, - MEAN_NUM_LEARNER_RESULTS_RECEIVED, MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED, + MEAN_NUM_LEARNER_RESULTS_RECEIVED, NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_SAMPLED, @@ -42,8 +42,8 @@ NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_SYNCH_WORKER_WEIGHTS, NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS, - SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, ) from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder diff --git a/rllib/algorithms/impala/impala_learner.py b/rllib/algorithms/impala/impala_learner.py index 9c01248203a0..ecedaf1ce1f7 100644 --- a/rllib/algorithms/impala/impala_learner.py +++ b/rllib/algorithms/impala/impala_learner.py @@ -13,8 +13,8 @@ from ray.rllib.core.learner.training_data import TrainingData from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index 33cd13fcf01d..e54843213e64 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -1,4 +1,5 @@ from typing import Callable, Optional, Type, Union + from typing_extensions import Self from ray._common.deprecation import deprecation_warning diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index 3c9002acfb42..d464e95889db 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -1,5 +1,6 @@ import logging from typing import Any, Dict, Optional, Tuple, Type, Union + from typing_extensions import Self from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning From 02f639875a7a4c747ece5590b0285723cb8c3576 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Fri, 31 Oct 2025 14:04:41 +0000 Subject: [PATCH 8/9] remove all redundant ruff isort configures and run pre-commit Signed-off-by: Mark Towers --- pyproject.toml | 16 ---------------- rllib/utils/from_config.py | 3 +-- rllib/utils/replay_buffers/__init__.py | 8 ++++---- .../tests/test_segment_tree_replay_buffer_api.py | 3 ++- rllib/utils/runners/runner.py | 1 - 5 files changed, 7 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 72226dbe2c54..c8bfb82794d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,22 +65,6 @@ afterray = ["psutil", "setproctitle"] "python/ray/__init__.py" = ["I"] "python/ray/dag/__init__.py" = ["I"] "python/ray/air/__init__.py" = ["I"] -# "rllib/__init__.py" = ["I"] -# "rllib/benchmarks/*" = ["I"] -# "rllib/connectors/*" = ["I"] -# "rllib/evaluation/*" = ["I"] -# "rllib/models/*" = ["I"] -"rllib/utils/*" = ["I"] -# "rllib/algorithms/*" = ["I"] -# "rllib/core/*" = ["I"] -# "rllib/examples/*" = ["I"] -# "rllib/offline/*" = ["I"] -# "rllib/tests/*" = ["I"] -# "rllib/callbacks/*" = ["I"] -# "rllib/env/*" = ["I"] -# "rllib/execution/*" = ["I"] -# "rllib/policy/*" = ["I"] -# "rllib/tuned_examples/*" = ["I"] "release/*" = ["I"] # TODO(matthewdeng): Remove this line diff --git a/rllib/utils/from_config.py b/rllib/utils/from_config.py index 30b7290999fc..3f80e785265b 100644 --- a/rllib/utils/from_config.py +++ b/rllib/utils/from_config.py @@ -4,14 +4,13 @@ import re from copy import deepcopy from functools import partial +from typing import TYPE_CHECKING, Optional import yaml from ray.rllib.utils import force_list, merge_dicts from ray.rllib.utils.annotations import DeveloperAPI -from typing import Optional, TYPE_CHECKING - if TYPE_CHECKING: from ray.rllib.utils.typing import FromConfigSpec diff --git a/rllib/utils/replay_buffers/__init__.py b/rllib/utils/replay_buffers/__init__.py index e929ab7d5988..bb39bd80d6cf 100644 --- a/rllib/utils/replay_buffers/__init__.py +++ b/rllib/utils/replay_buffers/__init__.py @@ -1,11 +1,12 @@ +from ray.rllib.utils.replay_buffers import utils from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.replay_buffers.fifo_replay_buffer import FifoReplayBuffer -from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import ( - MultiAgentMixInReplayBuffer, -) from ray.rllib.utils.replay_buffers.multi_agent_episode_buffer import ( MultiAgentEpisodeReplayBuffer, ) +from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import ( + MultiAgentMixInReplayBuffer, +) from ray.rllib.utils.replay_buffers.multi_agent_prioritized_episode_buffer import ( MultiAgentPrioritizedEpisodeReplayBuffer, ) @@ -24,7 +25,6 @@ ) from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit from ray.rllib.utils.replay_buffers.reservoir_replay_buffer import ReservoirReplayBuffer -from ray.rllib.utils.replay_buffers import utils __all__ = [ "EpisodeReplayBuffer", diff --git a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py index 9deb9e7f1387..1077c4300c22 100644 --- a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py +++ b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py @@ -1,8 +1,9 @@ import unittest + import numpy as np from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree +from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree from ray.rllib.utils.replay_buffers import PrioritizedEpisodeReplayBuffer diff --git a/rllib/utils/runners/runner.py b/rllib/utils/runners/runner.py index 6d40319b63d7..fb3a8b61d278 100644 --- a/rllib/utils/runners/runner.py +++ b/rllib/utils/runners/runner.py @@ -1,6 +1,5 @@ import abc import logging - from typing import TYPE_CHECKING, Any, Union from ray.rllib.utils.actor_manager import FaultAwareApply From f69992a04b81339d28fadc7a7dbf612c7c4e23f6 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Fri, 31 Oct 2025 15:44:21 +0000 Subject: [PATCH 9/9] Fix isort for replay-buffer Signed-off-by: Mark Towers --- rllib/utils/replay_buffers/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rllib/utils/replay_buffers/__init__.py b/rllib/utils/replay_buffers/__init__.py index bb39bd80d6cf..c5f53f25e3e3 100644 --- a/rllib/utils/replay_buffers/__init__.py +++ b/rllib/utils/replay_buffers/__init__.py @@ -1,4 +1,3 @@ -from ray.rllib.utils.replay_buffers import utils from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.replay_buffers.fifo_replay_buffer import FifoReplayBuffer from ray.rllib.utils.replay_buffers.multi_agent_episode_buffer import ( @@ -26,6 +25,8 @@ from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit from ray.rllib.utils.replay_buffers.reservoir_replay_buffer import ReservoirReplayBuffer +from ray.rllib.utils.replay_buffers import utils # isort: skip + __all__ = [ "EpisodeReplayBuffer", "FifoReplayBuffer",