Skip to content

Commit

Permalink
feat: intermediate work
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 20, 2024
1 parent 56714f1 commit bc4cc50
Show file tree
Hide file tree
Showing 14 changed files with 910 additions and 84 deletions.
6 changes: 6 additions & 0 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ class CoreLearnerState(NamedTuple):
timestep: TimeStep


class CoreOffPolicyLearnerState(NamedTuple):
params: Parameters
opt_states: OptStates
key: chex.PRNGKey


class OnPolicyLearnerState(NamedTuple):
"""State of the learner. Used for on-policy learners."""

Expand Down
4 changes: 2 additions & 2 deletions stoix/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ architecture_name: anakin
# --- Training ---
seed: 42 # RNG seed.
update_batch_size: 1 # Number of vectorised gradient updates per device.
total_num_envs: 1024 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size.
total_timesteps: 1e7 # Set the total environment steps.
total_num_envs: 1 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size.
total_timesteps: 1e4 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates

Expand Down
11 changes: 11 additions & 0 deletions stoix/configs/default/sebulba/default_ff_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- logger: base_logger
- arch: sebulba
- system: ff_dqn
- network: mlp_dqn
- env: envpool/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../configs/default/anakin",
config_path="../../../configs/default/anakin",
config_name="default_ff_c51.yaml",
version_base="1.2",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../configs/default/anakin",
config_path="../../../configs/default/anakin",
config_name="default_ff_ddqn.yaml",
version_base="1.2",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../configs/default/anakin",
config_path="../../../configs/default/anakin",
config_name="default_ff_dqn.yaml",
version_base="1.2",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../configs/default/anakin",
config_path="../../../configs/default/anakin",
config_name="default_ff_dqn_reg.yaml",
version_base="1.2",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../configs/default/anakin",
config_path="../../../configs/default/anakin",
config_name="default_ff_mdqn.yaml",
version_base="1.2",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../configs/default/anakin",
config_path="../../../configs/default/anakin",
config_name="default_ff_qr_dqn.yaml",
version_base="1.2",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../configs/default/anakin",
config_path="../../../configs/default/anakin",
config_name="default_ff_rainbow.yaml",
version_base="1.2",
)
Expand Down
Loading

0 comments on commit bc4cc50

Please sign in to comment.