From 20f1312597d819bdec8c986804e84d9672b7abf5 Mon Sep 17 00:00:00 2001 From: Aleksei Petrenko Date: Wed, 19 Jun 2024 20:34:18 -0700 Subject: [PATCH] Fix pre-commit --- Makefile | 2 +- sample_factory/algo/learning/batcher.py | 18 ++++++---------- .../algo/learning/learner_worker.py | 21 +++++++------------ sample_factory/algo/runners/runner.py | 15 +++++-------- .../algo/sampling/evaluation_sampling_api.py | 9 +++----- .../algo/sampling/inference_worker.py | 6 ++---- .../algo/sampling/rollout_worker.py | 3 +-- sample_factory/algo/sampling/sampler.py | 12 ++++------- sample_factory/algo/utils/heartbeat.py | 6 ++---- setup.py | 10 ++++----- 10 files changed, 36 insertions(+), 66 deletions(-) diff --git a/Makefile b/Makefile index 918430fa0..5fe3d1996 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ check-codestyle: black --check $(line_len_arg) -t py38 $(code_folders) isort --check-only $(line_len_arg) --py 38 --profile black $(code_folders) # ignore some formatting issues already covered by black - flake8 --max-line-length $(line_len) --ignore=E501,F401,E203,W503,E126,E722 $(code_folders) + flake8 --max-line-length $(line_len) --ignore=E501,F401,E203,W503,E126,E722,E704 $(code_folders) # Run tests for the library diff --git a/sample_factory/algo/learning/batcher.py b/sample_factory/algo/learning/batcher.py index 14ba6b0a5..d94e621a0 100644 --- a/sample_factory/algo/learning/batcher.py +++ b/sample_factory/algo/learning/batcher.py @@ -122,28 +122,22 @@ def __init__( ] @signal - def initialized(self): - ... + def initialized(self): ... @signal - def trajectory_buffers_available(self): - ... + def trajectory_buffers_available(self): ... @signal - def training_batches_available(self): - ... + def training_batches_available(self): ... @signal - def stop_experience_collection(self): - ... + def stop_experience_collection(self): ... @signal - def resume_experience_collection(self): - ... + def resume_experience_collection(self): ... @signal - def stop(self): - ... + def stop(self): ... def init(self): device = policy_device(self.cfg, self.policy_id) diff --git a/sample_factory/algo/learning/learner_worker.py b/sample_factory/algo/learning/learner_worker.py index 3855002eb..cb2ded364 100644 --- a/sample_factory/algo/learning/learner_worker.py +++ b/sample_factory/algo/learning/learner_worker.py @@ -76,32 +76,25 @@ def __init__( self.cache_cleanup_timer.timeout.connect(self._cleanup_cache) @signal - def initialized(self): - ... + def initialized(self): ... @signal - def model_initialized(self): - ... + def model_initialized(self): ... @signal - def report_msg(self): - ... + def report_msg(self): ... @signal - def training_batch_released(self): - ... + def training_batch_released(self): ... @signal - def finished_training_iteration(self): - ... + def finished_training_iteration(self): ... @signal - def saved_model(self): - ... + def saved_model(self): ... @signal - def stop(self): - ... + def stop(self): ... def save(self) -> bool: if self.learner.save(): diff --git a/sample_factory/algo/runners/runner.py b/sample_factory/algo/runners/runner.py index b91cad8b4..a14fab83c 100644 --- a/sample_factory/algo/runners/runner.py +++ b/sample_factory/algo/runners/runner.py @@ -185,20 +185,16 @@ def periodic(period, cb): # signals emitted by the runner @signal - def save_periodic(self): - ... + def save_periodic(self): ... @signal - def save_best(self): - ... + def save_best(self): ... @signal - def update_training_info(self): - ... + def update_training_info(self): ... @signal - def save_milestone(self): - ... + def save_milestone(self): ... @signal def stop(self): @@ -206,8 +202,7 @@ def stop(self): ... @signal - def all_components_stopped(self): - ... + def all_components_stopped(self): ... def _handle_restart(self): exp_dir = experiment_dir(self.cfg, mkdir=False) diff --git a/sample_factory/algo/sampling/evaluation_sampling_api.py b/sample_factory/algo/sampling/evaluation_sampling_api.py index 4d7e014bc..7524b1b25 100644 --- a/sample_factory/algo/sampling/evaluation_sampling_api.py +++ b/sample_factory/algo/sampling/evaluation_sampling_api.py @@ -80,16 +80,13 @@ def __init__(self, cfg: Config, env_info: EnvInfo, print_episode_info: bool = Tr self.print_episode_info = print_episode_info @signal - def model_initialized(self): - ... + def model_initialized(self): ... @signal - def trajectory_buffers_available(self): - ... + def trajectory_buffers_available(self): ... @signal - def stop(self): - ... + def stop(self): ... def init( self, buffer_mgr: Optional[BufferMgr] = None, param_servers: Optional[Dict[PolicyID, ParameterServer]] = None diff --git a/sample_factory/algo/sampling/inference_worker.py b/sample_factory/algo/sampling/inference_worker.py index a1a57297a..1fede875b 100644 --- a/sample_factory/algo/sampling/inference_worker.py +++ b/sample_factory/algo/sampling/inference_worker.py @@ -134,12 +134,10 @@ def __init__( self.is_initialized = False @signal - def initialized(self): - ... + def initialized(self): ... @signal - def report_msg(self): - ... + def report_msg(self): ... def init(self, init_model_data: Optional[InitModelData]): if self.is_initialized: diff --git a/sample_factory/algo/sampling/rollout_worker.py b/sample_factory/algo/sampling/rollout_worker.py index 635d34322..f32dc1bf1 100644 --- a/sample_factory/algo/sampling/rollout_worker.py +++ b/sample_factory/algo/sampling/rollout_worker.py @@ -129,8 +129,7 @@ def __init__( self.is_initialized: bool = False @signal - def report_msg(self): - ... + def report_msg(self): ... def init(self): for split_idx in range(self.num_splits): diff --git a/sample_factory/algo/sampling/sampler.py b/sample_factory/algo/sampling/sampler.py index 6404f8700..23dbd3a2a 100644 --- a/sample_factory/algo/sampling/sampler.py +++ b/sample_factory/algo/sampling/sampler.py @@ -36,12 +36,10 @@ def __init__( self.env_info: EnvInfo = env_info @signal - def started(self): - ... + def started(self): ... @signal - def initialized(self): - ... + def initialized(self): ... def init(self) -> None: raise NotImplementedError() @@ -98,12 +96,10 @@ def __init__( # internal signals used for communication with the workers, these are not a part of the interface @signal - def _init_inference_workers(self): - ... + def _init_inference_workers(self): ... @signal - def _inference_workers_initialized(self): - ... + def _inference_workers_initialized(self): ... def _make_inference_worker(self, event_loop, policy_id: PolicyID, worker_idx: int, param_server: ParameterServer): return InferenceWorker( diff --git a/sample_factory/algo/utils/heartbeat.py b/sample_factory/algo/utils/heartbeat.py index c78ce5157..64b66ce01 100644 --- a/sample_factory/algo/utils/heartbeat.py +++ b/sample_factory/algo/utils/heartbeat.py @@ -10,12 +10,10 @@ def __init__(self, evt_loop: EventLoop, unique_name: str, interval_sec: int = 10 self.heartbeat_timer.timeout.connect(self._report_heartbeat) @signal - def heartbeat(self): - ... + def heartbeat(self): ... @signal - def stop(self): - ... + def stop(self): ... def _report_heartbeat(self): p_name = process_name(self.event_loop.process) diff --git a/setup.py b/setup.py index e6a7f2c87..6097520ee 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import platform + import setuptools from setuptools import setup @@ -36,12 +37,10 @@ "mkdocs-git-authors-plugin", ] + def is_macos(): - return platform.system() == 'Darwin' + return platform.system() == "Darwin" -platform_specific = [] -if not is_macos(): - platform_specific += ["opencv-python"] # on arm64 the CI build takes forever otherwise setup( # Information @@ -74,7 +73,8 @@ def is_macos(): "wandb>=0.12.9", "huggingface-hub>=0.10.0,<1.0", "pandas", - ] + platform_specific, + "opencv-python", + ], extras_require={ # some tests require Atari and Mujoco so let's make sure dev environment has that "dev": ["black", "isort>=5.12", "pytest<8.0", "flake8", "pre-commit", "twine"]