Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2D2 #34

Merged
merged 77 commits into from
Mar 3, 2025
Merged

R2D2 #34

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
0543ae4
R2D2 intial commit, not tested
garymm Feb 19, 2025
1055fbe
more WIP. Trying to get Atari frames in
garymm Feb 20, 2025
f9f60f3
atari input at least doesn't crash
garymm Feb 20, 2025
8ea67e1
add test for atari input and fix q value size
garymm Feb 20, 2025
a236ce0
fix static fields in r2d2 networks
garymm Feb 20, 2025
21327c0
gymnasium_loop: fix mixed up argument order
garymm Feb 20, 2025
260cbf9
r2d2 WIP
garymm Feb 20, 2025
ab84e4a
r2d2: some fixes. still not done
garymm Feb 20, 2025
844b796
r2d2: add value rescaling
garymm Feb 20, 2025
940d216
port code over from acme/agents/jax/r2d2
garymm Feb 21, 2025
66507b3
update experience state in loss
garymm Feb 21, 2025
afb5a99
gymnasium_loop: fix bug. only copy one net replica
garymm Feb 21, 2025
1ba65bc
sharding: fix docstring
garymm Feb 21, 2025
9a892ac
r2d2: runs for 2 cycles!
garymm Feb 21, 2025
5647447
change grad means metric names to make Mlflow happy
garymm Feb 22, 2025
ad743ba
scripts to train r2d2 on atari
garymm Feb 22, 2025
db4d7c5
gymnasium_loop: fix buffer donation bug
garymm Feb 22, 2025
40749b2
r2d2: experiment runner files
garymm Feb 22, 2025
7a82127
fix bug in _sample_from_experience
garymm Feb 24, 2025
6489689
use better import path
garymm Feb 24, 2025
18bd558
fix buffer update
garymm Feb 24, 2025
d444d6b
minor cleanup
garymm Feb 24, 2025
7ab4825
r2d2: epslion-greedy and remove incremental updates
garymm Feb 26, 2025
961cb01
suppress warnings
garymm Feb 26, 2025
9e2525b
epsilon greedy schedule and more debugging
garymm Feb 26, 2025
a48c3cb
make lstm optional
garymm Feb 26, 2025
d6d001f
add TODO about stop grad after burn in
garymm Feb 26, 2025
76e6f2f
use optax.incremental_update
garymm Feb 27, 2025
ac6bcde
r2d2: tests passing!
garymm Feb 27, 2025
68f52d8
set adam eps value to what it was in r2d2 paper
garymm Feb 27, 2025
49e90af
test_learns_cartpole passing with LSTM
garymm Feb 27, 2025
cfaec6b
use jax-loop-utils from PyPi
garymm Feb 27, 2025
059475a
remove unused filter_incremental_update
garymm Feb 27, 2025
5a1fcee
fix test_sample_from_experince
garymm Feb 27, 2025
cee2658
make cartpole test easier for now
garymm Feb 27, 2025
25c1621
implement prioritized replay
garymm Feb 27, 2025
1ea3bd1
support sticky actions for exploration
garymm Feb 28, 2025
773a36f
use tensorboard instead of mlflow
garymm Feb 28, 2025
3540374
some updates to run_gymnax_cartpole
garymm Feb 28, 2025
e840390
ignore logs dir
garymm Feb 28, 2025
50d8c55
upgrade basedpyright
garymm Feb 28, 2025
3bcbe79
new notebook for asterix
garymm Feb 28, 2025
c3bb26d
remove commented out code
garymm Feb 28, 2025
073584e
remove unneeded warning suppression
garymm Feb 28, 2025
86a3aa5
better hyperparemeters for test_learns_cartpole
garymm Feb 28, 2025
206084d
require python 3.11, not 3.12, for Colab
garymm Feb 28, 2025
c4401c9
WIP: asterix notebook
garymm Feb 28, 2025
293f3c5
fix sharding.pytree_get_index_0
garymm Mar 1, 2025
5106531
env_info_from_gymnasium: support vecenv
garymm Mar 1, 2025
750ab7b
WIP: support envpool
garymm Mar 1, 2025
c4d7f7a
delete unused runners
garymm Mar 1, 2025
219916e
fix shard agent state
garymm Mar 1, 2025
80381fe
add render atari observe cycle
garymm Mar 1, 2025
f750cb0
WIP: asterix atari
garymm Mar 1, 2025
0f837eb
run_atari: assert num envs per learner even
garymm Mar 2, 2025
1981640
r2d2: support replaying larger batches
garymm Mar 2, 2025
8336410
ignore warning triggered by envpool
garymm Mar 2, 2025
5e6f4f2
test_r2d2: use env_factory
garymm Mar 2, 2025
b881812
fix metrickey import
garymm Mar 2, 2025
a8d2c08
restore test_learns_cartpole
garymm Mar 2, 2025
db2ac3c
set priority to 1 for new experience
garymm Mar 2, 2025
2faf2b0
double replay batch size
garymm Mar 2, 2025
84b81ef
improve error message
garymm Mar 2, 2025
0e12c9c
fix dtype support in resnet
garymm Mar 2, 2025
cb5ceca
gymnasium_loop: fix bug when len(learner_devices) > 1
garymm Mar 2, 2025
f0258a7
run minasterix longer
garymm Mar 3, 2025
5b519b4
vs code setting: ignore git limit warning
garymm Mar 3, 2025
14ce38d
start to fix bazel test
garymm Mar 3, 2025
727603f
fix gymnasium tests
garymm Mar 3, 2025
152145c
fix run_experiment for env_factory
garymm Mar 3, 2025
a8fb7d7
suppress false pyright error
garymm Mar 3, 2025
f3a52c3
rename and delete runners
garymm Mar 3, 2025
4ad97b3
fix some broken stuff
garymm Mar 3, 2025
587a9df
set long timeout for slow github runner
garymm Mar 3, 2025
9781362
shard test_run_experiment
garymm Mar 3, 2025
36891f1
split learns cartpole to separate test
garymm Mar 3, 2025
a18c16f
shorten test
garymm Mar 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
bazel-*
__pycache__
uv.lock
# mlflow
mlruns
logs
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
"bazel.buildifierExecutable": "@buildifier_prebuilt//:buildifier",
"python.testing.pytestArgs": [
"earl"
]
],
"git.ignoreLimitWarning": true
}
4 changes: 3 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ pip_compile(
"--emit-index-url",
"--no-strip-extras",
"--extra=test",
"--extra=agent-r2d2",
"--index=https://download.pytorch.org/whl/cpu",
],
python_platform = "x86_64-unknown-linux-gnu",
python_platform = "x86_64-manylinux_2_28", # envpool needs at least 2_24
requirements_in = "//:pyproject.toml",
requirements_txt = "requirements_linux_x86_64.txt",
)
Expand Down
4 changes: 2 additions & 2 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ module(
)

# BEGIN python toolchain
_PYTHON_VERSION = "3.12"
_PYTHON_VERSION = "3.11" # latest that supports envpool

bazel_dep(name = "rules_python", version = "1.1.0", dev_dependency = True)

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.toolchain(python_version = _PYTHON_VERSION)
# END python toolchain

# BEGIN python dependencies
# BEGIN python dependenciesP
bazel_dep(name = "rules_uv", version = "0.53.0", dev_dependency = True)

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
Expand Down
1,219 changes: 871 additions & 348 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions earl/agents/r2d2/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
load("@aspect_rules_py//py:defs.bzl", "py_library")
load("//tools/py_test:py_test.bzl", "py_test")

py_library(
name = "r2d2",
srcs = [
"networks.py",
"r2d2.py",
"utils.py",
],
visibility = ["//visibility:public"],
deps = [
"//earl:core",
"@pypi//chex",
"@pypi//distrax",
"@pypi//equinox",
"@pypi//jax",
"@pypi//jaxtyping",
"@pypi//optax",
],
)

py_test(
name = "test_r2d2",
timeout = "long",
srcs = ["test_r2d2.py"],
filterwarnings = [
"ignore:jax.interpreters.xla.pytype_aval_mappings is deprecated.:DeprecationWarning",
"ignore:Shape is deprecated; use StableHLO instead.:DeprecationWarning",
],
shard_count = 2,
deps = [
":r2d2",
"//earl:core",
"//earl/environment_loop:gymnasium_loop",
"//earl/environment_loop:gymnax_loop",
"@pypi//envpool",
"@pypi//gymnasium",
"@pypi//gymnax",
"@pypi//jax",
"@pypi//jax_loop_utils",
"@pypi//numpy",
"@pypi//optax",
"@pypi//pytest",
],
)

py_test(
name = "test_r2d2_learns",
timeout = "long",
srcs = ["test_r2d2_learns.py"],
filterwarnings = [
"ignore:jax.interpreters.xla.pytype_aval_mappings is deprecated.:DeprecationWarning",
"ignore:Shape is deprecated; use StableHLO instead.:DeprecationWarning",
],
tags = ["manual"],
deps = [
":r2d2",
"//earl:core",
"//earl/environment_loop:gymnax_loop",
"@pypi//gymnax",
"@pypi//jax",
"@pypi//jax_loop_utils",
"@pypi//numpy",
"@pypi//pytest",
],
)
Loading
Loading