Skip to content

Commit

Permalink
Fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-petrenko committed Jun 20, 2024
1 parent e0c8ad8 commit 20f1312
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 66 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ check-codestyle:
black --check $(line_len_arg) -t py38 $(code_folders)
isort --check-only $(line_len_arg) --py 38 --profile black $(code_folders)
# ignore some formatting issues already covered by black
flake8 --max-line-length $(line_len) --ignore=E501,F401,E203,W503,E126,E722 $(code_folders)
flake8 --max-line-length $(line_len) --ignore=E501,F401,E203,W503,E126,E722,E704 $(code_folders)


# Run tests for the library
Expand Down
18 changes: 6 additions & 12 deletions sample_factory/algo/learning/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,22 @@ def __init__(
]

@signal
def initialized(self):
...
def initialized(self): ...

@signal
def trajectory_buffers_available(self):
...
def trajectory_buffers_available(self): ...

@signal
def training_batches_available(self):
...
def training_batches_available(self): ...

@signal
def stop_experience_collection(self):
...
def stop_experience_collection(self): ...

@signal
def resume_experience_collection(self):
...
def resume_experience_collection(self): ...

@signal
def stop(self):
...
def stop(self): ...

def init(self):
device = policy_device(self.cfg, self.policy_id)
Expand Down
21 changes: 7 additions & 14 deletions sample_factory/algo/learning/learner_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,32 +76,25 @@ def __init__(
self.cache_cleanup_timer.timeout.connect(self._cleanup_cache)

@signal
def initialized(self):
...
def initialized(self): ...

@signal
def model_initialized(self):
...
def model_initialized(self): ...

@signal
def report_msg(self):
...
def report_msg(self): ...

@signal
def training_batch_released(self):
...
def training_batch_released(self): ...

@signal
def finished_training_iteration(self):
...
def finished_training_iteration(self): ...

@signal
def saved_model(self):
...
def saved_model(self): ...

@signal
def stop(self):
...
def stop(self): ...

def save(self) -> bool:
if self.learner.save():
Expand Down
15 changes: 5 additions & 10 deletions sample_factory/algo/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,29 +185,24 @@ def periodic(period, cb):

# signals emitted by the runner
@signal
def save_periodic(self):
...
def save_periodic(self): ...

@signal
def save_best(self):
...
def save_best(self): ...

@signal
def update_training_info(self):
...
def update_training_info(self): ...

@signal
def save_milestone(self):
...
def save_milestone(self): ...

@signal
def stop(self):
"""Emitted when we're about to stop the experiment."""
...

@signal
def all_components_stopped(self):
...
def all_components_stopped(self): ...

def _handle_restart(self):
exp_dir = experiment_dir(self.cfg, mkdir=False)
Expand Down
9 changes: 3 additions & 6 deletions sample_factory/algo/sampling/evaluation_sampling_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,13 @@ def __init__(self, cfg: Config, env_info: EnvInfo, print_episode_info: bool = Tr
self.print_episode_info = print_episode_info

@signal
def model_initialized(self):
...
def model_initialized(self): ...

@signal
def trajectory_buffers_available(self):
...
def trajectory_buffers_available(self): ...

@signal
def stop(self):
...
def stop(self): ...

def init(
self, buffer_mgr: Optional[BufferMgr] = None, param_servers: Optional[Dict[PolicyID, ParameterServer]] = None
Expand Down
6 changes: 2 additions & 4 deletions sample_factory/algo/sampling/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,10 @@ def __init__(
self.is_initialized = False

@signal
def initialized(self):
...
def initialized(self): ...

@signal
def report_msg(self):
...
def report_msg(self): ...

def init(self, init_model_data: Optional[InitModelData]):
if self.is_initialized:
Expand Down
3 changes: 1 addition & 2 deletions sample_factory/algo/sampling/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def __init__(
self.is_initialized: bool = False

@signal
def report_msg(self):
...
def report_msg(self): ...

def init(self):
for split_idx in range(self.num_splits):
Expand Down
12 changes: 4 additions & 8 deletions sample_factory/algo/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ def __init__(
self.env_info: EnvInfo = env_info

@signal
def started(self):
...
def started(self): ...

@signal
def initialized(self):
...
def initialized(self): ...

def init(self) -> None:
raise NotImplementedError()
Expand Down Expand Up @@ -98,12 +96,10 @@ def __init__(

# internal signals used for communication with the workers, these are not a part of the interface
@signal
def _init_inference_workers(self):
...
def _init_inference_workers(self): ...

@signal
def _inference_workers_initialized(self):
...
def _inference_workers_initialized(self): ...

def _make_inference_worker(self, event_loop, policy_id: PolicyID, worker_idx: int, param_server: ParameterServer):
return InferenceWorker(
Expand Down
6 changes: 2 additions & 4 deletions sample_factory/algo/utils/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ def __init__(self, evt_loop: EventLoop, unique_name: str, interval_sec: int = 10
self.heartbeat_timer.timeout.connect(self._report_heartbeat)

@signal
def heartbeat(self):
...
def heartbeat(self): ...

@signal
def stop(self):
...
def stop(self): ...

def _report_heartbeat(self):
p_name = process_name(self.event_loop.process)
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import platform

import setuptools
from setuptools import setup

Expand Down Expand Up @@ -36,12 +37,10 @@
"mkdocs-git-authors-plugin",
]


def is_macos():
return platform.system() == 'Darwin'
return platform.system() == "Darwin"

platform_specific = []
if not is_macos():
platform_specific += ["opencv-python"] # on arm64 the CI build takes forever otherwise

setup(
# Information
Expand Down Expand Up @@ -74,7 +73,8 @@ def is_macos():
"wandb>=0.12.9",
"huggingface-hub>=0.10.0,<1.0",
"pandas",
] + platform_specific,
"opencv-python",
],
extras_require={
# some tests require Atari and Mujoco so let's make sure dev environment has that
"dev": ["black", "isort>=5.12", "pytest<8.0", "flake8", "pre-commit", "twine"]
Expand Down

0 comments on commit 20f1312

Please sign in to comment.