-
t = Trajectories.CircularArraySARTTraces(;capacity=10)
+t = Trajectories.CircularArraySARTSTraces(;capacity=10)
diff --git a/docs/src/How_to_implement_a_new_algorithm.md b/docs/src/How_to_implement_a_new_algorithm.md
index 9c05be8ac..ddfe492e4 100644
--- a/docs/src/How_to_implement_a_new_algorithm.md
+++ b/docs/src/How_to_implement_a_new_algorithm.md
@@ -94,7 +94,7 @@ A `Trajectory` is composed of three elements: a `container`, a `controller`, and
The container is typically an `AbstractTraces`, an object that store a set of `Trace` in a structured manner. You can either define your own (and contribute it to the package if it is likely to be usable for other algorithms), or use a predefined one if it exists.
-The most common `AbstractTraces` object is the `CircularArraySARTTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which toghether are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace.
+The most common `AbstractTraces` object is the `CircularArraySARTSTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which together are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace.
```julia
function (capacity, state_size, state_eltype, action_size, action_eltype, reward_eltype)
diff --git a/docs/src/Zoo_Algorithms/MPO.md b/docs/src/Zoo_Algorithms/MPO.md
index 52327b27f..2e22e18bb 100644
--- a/docs/src/Zoo_Algorithms/MPO.md
+++ b/docs/src/Zoo_Algorithms/MPO.md
@@ -56,7 +56,7 @@ The next step is to wrap this policy into an `Agent`. An agent is a combination
```julia
trajectory = Trajectory(
- CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,),action = Float32 => (1,)),
+ CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)),
MetaSampler(
actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10),
critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 1000)
diff --git a/src/ReinforcementLearningCore/Project.toml b/src/ReinforcementLearningCore/Project.toml
index 36d8a5ec4..b815399d9 100644
--- a/src/ReinforcementLearningCore/Project.toml
+++ b/src/ReinforcementLearningCore/Project.toml
@@ -1,6 +1,6 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
-version = "0.11.3"
+version = "0.12.0"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -40,7 +40,7 @@ Parsers = "2"
ProgressMeter = "1"
Reexport = "1"
ReinforcementLearningBase = "0.12"
-ReinforcementLearningTrajectories = "^0.1.9"
+ReinforcementLearningTrajectories = "^0.3.2"
StatsBase = "0.32, 0.33, 0.34"
TimerOutputs = "0.5"
UnicodePlots = "1.3, 2, 3"
diff --git a/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl b/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl
index df71cd89d..fdf9107a9 100644
--- a/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl
+++ b/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl
@@ -3,7 +3,6 @@ module ReinforcementLearningCore
using TimerOutputs
using ReinforcementLearningBase
using Reexport
-
const RLCore = ReinforcementLearningCore
export RLCore
diff --git a/src/ReinforcementLearningCore/src/core/run.jl b/src/ReinforcementLearningCore/src/core/run.jl
index 67900775b..e325fb0ef 100644
--- a/src/ReinforcementLearningCore/src/core/run.jl
+++ b/src/ReinforcementLearningCore/src/core/run.jl
@@ -102,21 +102,17 @@ function _run(policy::AbstractPolicy,
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
@timeit_debug timer "act!" act!(env, action)
- @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
+ @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action)
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)
if check_stop(stop_condition, policy, env)
is_stop = true
- @timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
- @timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
- @timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
- @timeit_debug timer "plan!" RLBase.plan!(policy, env) # let the policy see the last observation
break
end
end # end of an episode
- @timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
+ @timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env)
@timeit_debug timer "optimise! PostEpisodeStage" optimise!(policy, PostEpisodeStage())
@timeit_debug timer "push!(hook) PostEpisodeStage" push!(hook, PostEpisodeStage(), policy, env)
diff --git a/src/ReinforcementLearningCore/src/core/stages.jl b/src/ReinforcementLearningCore/src/core/stages.jl
index 52f2a30a9..61e48f57d 100644
--- a/src/ReinforcementLearningCore/src/core/stages.jl
+++ b/src/ReinforcementLearningCore/src/core/stages.jl
@@ -17,7 +17,9 @@ struct PreActStage <: AbstractStage end
struct PostActStage <: AbstractStage end
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing
+Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action) = nothing
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing
+Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action, ::Symbol) = nothing
RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing
diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent.jl b/src/ReinforcementLearningCore/src/policies/agent/agent.jl
index a4c7ceb8a..55f11198d 100644
--- a/src/ReinforcementLearningCore/src/policies/agent/agent.jl
+++ b/src/ReinforcementLearningCore/src/policies/agent/agent.jl
@@ -1,3 +1,3 @@
-include("base.jl")
+include("agent_base.jl")
include("agent_srt_cache.jl")
include("multi_agent.jl")
diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl
new file mode 100644
index 000000000..79eeead9a
--- /dev/null
+++ b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl
@@ -0,0 +1,64 @@
+export Agent
+
+using Base.Threads: @spawn
+
+using Functors: @functor
+import Base.push!
+"""
+ Agent(;policy, trajectory) <: AbstractPolicy
+
+A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to
+update the trajectory and policy appropriately in different stages. Agent
+is a Callable and its call method accepts varargs and keyword arguments to be
+passed to the policy.
+
+"""
+mutable struct Agent{P,T} <: AbstractPolicy
+ policy::P
+ trajectory::T
+
+ function Agent(policy::P, trajectory::T) where {P<:AbstractPolicy, T<:Trajectory}
+ agent = new{P,T}(policy, trajectory)
+
+ if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
+ bind(trajectory, @spawn(optimise!(policy, trajectory)))
+ end
+ agent
+ end
+end
+
+Agent(;policy, trajectory) = Agent(policy, trajectory)
+
+RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
+RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory)
+
+# already spawn a task to optimise inner policy when initializing the agent
+RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing
+
+#by default, optimise does nothing at all stage
+function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end
+
+@functor Agent (policy,)
+
+function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv)
+ push!(agent.trajectory, (state = state(env),))
+end
+
+# !!! TODO: In async scenarios, parameters of the policy may still be updating
+# (partially), which will result to incorrect action. This should be addressed
+# in Oolong.jl with a wrapper
+function RLBase.plan!(agent::Agent, env::AbstractEnv)
+ RLBase.plan!(agent.policy, env)
+end
+
+function Base.push!(agent::Agent, ::PostActStage, env::AbstractEnv, action)
+ next_state = state(env)
+ push!(agent.trajectory, (state = next_state, action = action, reward = reward(env), terminal = is_terminated(env)))
+end
+
+function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv)
+ if haskey(agent.trajectory, :next_action)
+ action = RLBase.plan!(agent.policy, env)
+ push!(agent.trajectory, PartialNamedTuple((action = action, )))
+ end
+end
diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl b/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl
index 321a4a6c5..7bd0bed80 100644
--- a/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl
+++ b/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl
@@ -27,12 +27,12 @@ struct SART{S,A,R,T}
end
# This method is used to push a state and action to a trace
-function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SA)
+function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SA)
push!(ts.traces[1].trace, xs.state)
push!(ts.traces[2].trace, xs.action)
end
-function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SART)
+function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SART)
push!(ts.traces[1].trace, xs.state)
push!(ts.traces[2].trace, xs.action)
push!(ts.traces[3], xs.reward)
diff --git a/src/ReinforcementLearningCore/src/policies/agent/base.jl b/src/ReinforcementLearningCore/src/policies/agent/base.jl
deleted file mode 100644
index 7654de4d6..000000000
--- a/src/ReinforcementLearningCore/src/policies/agent/base.jl
+++ /dev/null
@@ -1,89 +0,0 @@
-export Agent
-
-using Base.Threads: @spawn
-
-using Functors: @functor
-import Base.push!
-"""
- Agent(;policy, trajectory) <: AbstractPolicy
-
-A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to
-update the trajectory and policy appropriately in different stages. Agent
-is a Callable and its call method accepts varargs and keyword arguments to be
-passed to the policy.
-
-"""
-mutable struct Agent{P,T,C} <: AbstractPolicy
- policy::P
- trajectory::T
- cache::C # need cache to collect elements as trajectory does not support partial inserting
-
- function Agent(policy::P, trajectory::T) where {P,T}
- agent = new{P,T, SRT}(policy, trajectory, SRT())
-
- if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
- bind(trajectory, @spawn(optimise!(policy, trajectory)))
- end
- agent
- end
-
- function Agent(policy::P, trajectory::T, cache::C) where {P,T,C}
- agent = new{P,T,C}(policy, trajectory, cache)
-
- if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
- bind(trajectory, @spawn(optimise!(policy, trajectory)))
- end
- agent
- end
-end
-
-Agent(;policy, trajectory, cache = SRT()) = Agent(policy, trajectory, cache)
-
-RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} =RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
-RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} =
- RLBase.optimise!(agent.policy, stage, agent.trajectory)
-
-# already spawn a task to optimise inner policy when initializing the agent
-RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing
-
-#by default, optimise does nothing at all stage
-function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end
-
-@functor Agent (policy,)
-
-function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv)
- push!(agent, state(env))
-end
-
-# !!! TODO: In async scenarios, parameters of the policy may still be updating
-# (partially), which will result to incorrect action. This should be addressed
-# in Oolong.jl with a wrapper
-function RLBase.plan!(agent::Agent{P,T,C}, env::AbstractEnv) where {P,T,C}
- action = RLBase.plan!(agent.policy, env)
- push!(agent.trajectory, agent.cache, action)
- action
-end
-
-# Multiagent Version
-function RLBase.plan!(agent::Agent{P,T,C}, env::E, p::Symbol) where {P,T,C,E<:AbstractEnv}
- action = RLBase.plan!(agent.policy, env, p)
- push!(agent.trajectory, agent.cache, action)
- action
-end
-
-function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E) where {P,T,C,E<:AbstractEnv}
- push!(agent.cache, reward(env), is_terminated(env))
-end
-
-function Base.push!(agent::Agent, ::PostExperimentStage, env::E) where {E<:AbstractEnv}
- RLBase.reset!(agent.cache)
-end
-
-function Base.push!(agent::Agent, ::PostExperimentStage, env::E, player::Symbol) where {E<:AbstractEnv}
- RLBase.reset!(agent.cache)
-end
-
-function Base.push!(agent::Agent{P,T,C}, state::S) where {P,T,C,S}
- push!(agent.cache, state)
-end
-
diff --git a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl
index 1db43a18c..98ebd9685 100644
--- a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl
+++ b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl
@@ -125,18 +125,12 @@ function Base.run(
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
@timeit_debug timer "act!" act!(env, action)
-
-
- @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
+ @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action)
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)
if check_stop(stop_condition, policy, env)
is_stop = true
- @timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env)
- @timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage())
- @timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env)
- @timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
break
end
@@ -191,21 +185,43 @@ function Base.push!(multiagent::MultiAgentPolicy, stage::S, env::E) where {S<:Ab
end
end
-# Like in the single-agent case, push! at the PreActStage() calls push! on each player with the state of the environment
-function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PreActStage, env::E) where {E<:AbstractEnv, names, T <: Agent}
+# Like in the single-agent case, push! at the PostActStage() calls push! on each player.
+function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv, player::Symbol)
+ push!(agent.trajectory, (state = state(env, player),))
+end
+
+function Base.push!(multiagent::MultiAgentPolicy, s::PreEpisodeStage, env::E) where {E<:AbstractEnv}
for player in players(env)
- push!(multiagent[player], state(env, player))
+ push!(multiagent[player], s, env, player)
end
end
-# Like in the single-agent case, push! at the PostActStage() calls push! on each player with the reward and termination status of the environment
-function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PostActStage, env::E) where {E<:AbstractEnv, names, T <: Agent}
- for player in players(env)
- push!(multiagent[player].cache, reward(env, player), is_terminated(env))
+function RLBase.plan!(agent::Agent, env::AbstractEnv, player::Symbol)
+ RLBase.plan!(agent.policy, env, player)
+end
+
+# Like in the single-agent case, push! at the PostActStage() calls push! on each player to store the action, reward, next_state, and terminal signal.
+function Base.push!(multiagent::MultiAgentPolicy, ::PostActStage, env::E, actions) where {E<:AbstractEnv}
+ for (player, action) in zip(players(env), actions)
+ next_state = state(env, player)
+ observation = (
+ state = next_state,
+ action = action,
+ reward = reward(env, player),
+ terminal = is_terminated(env)
+ )
+ push!(multiagent[player].trajectory, observation)
+ end
+end
+
+function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv, p::Symbol)
+ if haskey(agent.trajectory, :next_action)
+ action = RLBase.plan!(agent.policy, env, p)
+ push!(agent.trajectory, PartialNamedTuple((action = action, )))
end
end
-function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv,S<:AbstractStage}
+function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv, S<:AbstractStage}
for player in players(env)
push!(hook[player], stage, multiagent[player], env, player)
end
@@ -227,8 +243,9 @@ function Base.push!(composed_hook::ComposedHook{T},
_push!(stage, policy, env, player, composed_hook.hooks...)
end
+#For simultaneous players, plan! returns a Tuple of actions.
function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
- return (RLBase.plan!(multiagent[player], env, player) for player in players(env))
+ return Tuple(RLBase.plan!(multiagent[player], env, player) for player in players(env))
end
function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage}
diff --git a/src/ReinforcementLearningCore/test/core/base.jl b/src/ReinforcementLearningCore/test/core/base.jl
index 31af5b994..300f3fcd7 100644
--- a/src/ReinforcementLearningCore/test/core/base.jl
+++ b/src/ReinforcementLearningCore/test/core/base.jl
@@ -1,4 +1,4 @@
-using ReinforcementLearningCore: SRT
+using ReinforcementLearningCore
using ReinforcementLearningBase
using TimerOutputs
@@ -8,7 +8,7 @@ using TimerOutputs
agent = Agent(
RandomPolicy(),
Trajectory(
- CircularArraySARTTraces(; capacity = 1_000),
+ CircularArraySARTSTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
@@ -18,14 +18,14 @@ using TimerOutputs
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)
- @test sum(hook[]) == length(agent.trajectory.container)
+ @test sum(hook[]) + length(hook[]) - 1 == length(agent.trajectory.container)
end
@testset "StopAfterEpisode" begin
agent = Agent(
RandomPolicy(),
Trajectory(
- CircularArraySARTTraces(; capacity = 1_000),
+ CircularArraySARTSTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
),
@@ -35,25 +35,8 @@ using TimerOutputs
hook = StepsPerEpisode()
run(agent, env, stop_condition, hook)
- @test sum(hook[]) == length(agent.trajectory.container)
- end
-
- @testset "StopAfterStep, use type stable Agent" begin
- env = RandomWalk1D()
- agent = Agent(
- RandomPolicy(legal_action_space(env)),
- Trajectory(
- CircularArraySARTTraces(; capacity = 1_000),
- BatchSampler(1),
- InsertSampleRatioController(n_inserted = -1),
- ),
- SRT{Any, Any, Any}(),
- )
- stop_condition = StopAfterStep(123; is_show_progress=false)
- hook = StepsPerEpisode()
- run(agent, env, stop_condition, hook)
- @test sum(hook[]) == length(agent.trajectory.container)
- end
+ @test length(hook[]) == 10
+ end
end
@testset "Debug Timer" begin
@@ -63,11 +46,10 @@ using TimerOutputs
agent = Agent(
RandomPolicy(legal_action_space(env)),
Trajectory(
- CircularArraySARTTraces(; capacity = 1_000),
+ CircularArraySARTSTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
- ),
- SRT{Any, Any, Any}(),
+ )
)
stop_condition = StopAfterStep(123; is_show_progress=false)
hook = StepsPerEpisode()
diff --git a/src/ReinforcementLearningCore/test/policies/agent.jl b/src/ReinforcementLearningCore/test/policies/agent.jl
index a35a79724..81297ce77 100644
--- a/src/ReinforcementLearningCore/test/policies/agent.jl
+++ b/src/ReinforcementLearningCore/test/policies/agent.jl
@@ -3,51 +3,18 @@ using ReinforcementLearningCore: SRT
using ReinforcementLearningCore
@testset "agent.jl" begin
- @testset "Agent Cache struct" begin
- srt = SRT{Int64, Float64, Bool}()
- push!(srt, 2)
- @test srt.state == 2
- push!(srt, 1.0, true)
- @test srt.reward == 1.0
- @test srt.terminal == true
- end
-
- @testset "Trajectory SART struct compatibility" begin
- srt_1 = SRT()
- srt_2 = SRT{Any, Nothing, Nothing}()
- srt_2.state = 1
- srt_3 = SRT{Any, Any, Bool}()
- srt_3.state = 1
- srt_3.reward = 1.0
- srt_3.terminal = true
- trajectory = Trajectory(
- CircularArraySARTTraces(; capacity = 1_000, reward=Float64=>()),
- DummySampler(),
- )
-
- @test_throws ArgumentError push!(trajectory, srt_1)
- push!(trajectory, srt_2, 1)
- @test length(trajectory.container) == 0
- push!(trajectory, srt_3, 2)
- @test length(trajectory.container) == 1
- @test trajectory.container[:action] == [1]
- push!(trajectory, srt_3, 3)
- @test trajectory.container[:action] == [1, 2]
- @test trajectory.container[:state] == [1, 1]
- end
-
@testset "Agent Tests" begin
a_1 = Agent(
RandomPolicy(),
Trajectory(
- CircularArraySARTTraces(; capacity = 1_000),
+ CircularArraySARTSTraces(; capacity = 1_000),
DummySampler(),
),
)
a_2 = Agent(
RandomPolicy(),
Trajectory(
- CircularArraySARTTraces(; capacity = 1_000),
+ CircularArraySARTSTraces(; capacity = 1_000),
BatchSampler(1),
InsertSampleRatioController(),
),
@@ -58,17 +25,12 @@ using ReinforcementLearningCore
@testset "Test Agent $i" begin
agent = agent_list[i]
env = RandomWalk1D()
- push!(agent, PreActStage(), env)
- @test agent.cache.state != nothing
- @test agent.cache.reward == nothing
- @test agent.cache.terminal == nothing
- @test state(env) == agent.cache.state
- @test RLBase.plan!(agent, env) in (1,2)
+ push!(agent, PreEpisodeStage(), env)
+ action = RLBase.plan!(agent, env)
+ @test action in (1,2)
@test length(agent.trajectory.container) == 0
- push!(agent, PostActStage(), env)
- @test agent.cache.reward == 0. && agent.cache.terminal == false
+ push!(agent, PostActStage(), env, action)
push!(agent, PreActStage(), env)
- @test state(env) == agent.cache.state
@test RLBase.plan!(agent, env) in (1,2)
@test length(agent.trajectory.container) == 1
@@ -76,18 +38,6 @@ using ReinforcementLearningCore
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, 1)
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, fake_kwarg = 1)
end
-
- @testset "Test push! method" begin
- env = RandomWalk1D()
- agent = agent_list[i]
- push!(agent, PostActStage(), env)
- push!(agent, 7)
- @test agent.cache.state == 7
- RLBase.reset!(agent.cache)
- @test agent.cache.state == nothing
- end
end
end
end
-
-
diff --git a/src/ReinforcementLearningCore/test/policies/multi_agent.jl b/src/ReinforcementLearningCore/test/policies/multi_agent.jl
index 7c6bd05a4..3b7392f3d 100644
--- a/src/ReinforcementLearningCore/test/policies/multi_agent.jl
+++ b/src/ReinforcementLearningCore/test/policies/multi_agent.jl
@@ -7,13 +7,13 @@ using DomainSets
@testset "MultiAgentPolicy" begin
trajectory_1 = Trajectory(
- CircularArraySARTTraces(; capacity = 1),
+ CircularArraySARTSTraces(; capacity = 1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
trajectory_2 = Trajectory(
- CircularArraySARTTraces(; capacity = 1),
+ CircularArraySARTSTraces(; capacity = 1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
@@ -56,13 +56,13 @@ end
@testset "Basic TicTacToeEnv (Sequential) env checks" begin
trajectory_1 = Trajectory(
- CircularArraySARTTraces(; capacity = 1),
+ CircularArraySARTSTraces(; capacity = 1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
trajectory_2 = Trajectory(
- CircularArraySARTTraces(; capacity = 1),
+ CircularArraySARTSTraces(; capacity = 1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
@@ -106,13 +106,13 @@ end
@testset "Basic RockPaperScissors (simultaneous) env checks" begin
trajectory_1 = Trajectory(
- CircularArraySARTTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)),
+ CircularArraySARTSTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
trajectory_2 = Trajectory(
- CircularArraySARTTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)),
+ CircularArraySARTSTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
diff --git a/src/ReinforcementLearningEnvironments/Project.toml b/src/ReinforcementLearningEnvironments/Project.toml
index e579c18b5..9618fb4f0 100644
--- a/src/ReinforcementLearningEnvironments/Project.toml
+++ b/src/ReinforcementLearningEnvironments/Project.toml
@@ -23,7 +23,7 @@ DelimitedFiles = "1"
IntervalSets = "0.7"
MacroTools = "0.5"
ReinforcementLearningBase = "0.12"
-ReinforcementLearningCore = "0.10, 0.11"
+ReinforcementLearningCore = "0.12"
Requires = "1.0"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.3"
diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl
index d700deace..661d3c63b 100644
--- a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl
+++ b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl
@@ -3,13 +3,13 @@
using ReinforcementLearningEnvironments, ReinforcementLearningBase, ReinforcementLearningCore
trajectory_1 = Trajectory(
- CircularArraySARTTraces(; capacity = 1),
+ CircularArraySARTSTraces(; capacity = 1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
trajectory_2 = Trajectory(
- CircularArraySARTTraces(; capacity = 1),
+ CircularArraySARTSTraces(; capacity = 1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
)
diff --git a/src/ReinforcementLearningExperiments/Project.toml b/src/ReinforcementLearningExperiments/Project.toml
index 49e30691b..9cd023e17 100644
--- a/src/ReinforcementLearningExperiments/Project.toml
+++ b/src/ReinforcementLearningExperiments/Project.toml
@@ -20,7 +20,7 @@ Distributions = "0.25"
Flux = "0.13, 0.14"
Reexport = "1"
ReinforcementLearningBase = "0.12"
-ReinforcementLearningCore = "0.10, 0.11"
+ReinforcementLearningCore = "0.12"
ReinforcementLearningEnvironments = "0.8"
ReinforcementLearningZoo = "0.7, 0.8"
StableRNGs = "1"
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl
index 22663b843..1e8cde26c 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl
@@ -47,7 +47,7 @@ function RLCore.Experiment(
),
),
Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=cap,
state=Float32 => (ns),
),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
index 5a43c8c01..4b0938968 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
@@ -46,7 +46,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
index 3f481eff4..a6aa5d0c2 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
@@ -56,7 +56,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
index 0bc2832ff..ea5d7a20d 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
@@ -67,7 +67,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl
index 07a48ecef..affd6ab02 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl
@@ -50,7 +50,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=10_000,
state=Float32 => (ns,),
action=Float32 => (na,),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl
index 1072722b7..66b75e9e3 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl
@@ -57,7 +57,7 @@ function RLCore.Experiment(
),
trajectory=Trajectory(
container=CircularPrioritizedTraces(
- CircularArraySARTTraces(
+ CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
);
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl
index 53ad1c35e..81173762d 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl
@@ -53,7 +53,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
index 5b7f0d595..82223b093 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
@@ -58,7 +58,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
index 0e4f7d1c9..eb7492837 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
@@ -54,7 +54,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
),
diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl
index 81d1d9c16..f071b7c5b 100644
--- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl
+++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl
@@ -38,7 +38,7 @@ function RLCore.Experiment(
agent = Agent(
policy = policy,
trajectory = Trajectory(
- CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)),
+ CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)),
MetaSampler(
actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10),
critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 2000)
@@ -78,7 +78,7 @@ function RLCore.Experiment(
agent = Agent(
policy = policy,
trajectory = Trajectory(
- CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (2,)),
+ CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (2,)),
MetaSampler(
actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10),
critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 2000)
@@ -122,7 +122,7 @@ function RLCore.Experiment(
agent = Agent(
policy = policy,
trajectory = Trajectory(
- CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)),
+ CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)),
MetaSampler(
actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10),
critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 2000)
diff --git a/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl b/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl
index 2c3363df9..8623bd50e 100644
--- a/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl
+++ b/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl
@@ -55,7 +55,7 @@ function RLCore.Experiment(
),
),
trajectory=Trajectory(
- container=CircularArraySARTTraces(
+ container=CircularArraySARTSTraces(
capacity=1000,
state=Float32 => (ns,),
),
diff --git a/src/ReinforcementLearningExperiments/test/runtests.jl b/src/ReinforcementLearningExperiments/test/runtests.jl
index 8f95d1e3f..1176e11c1 100644
--- a/src/ReinforcementLearningExperiments/test/runtests.jl
+++ b/src/ReinforcementLearningExperiments/test/runtests.jl
@@ -11,7 +11,7 @@ run(E`JuliaRL_QRDQN_CartPole`)
run(E`JuliaRL_REMDQN_CartPole`)
run(E`JuliaRL_IQN_CartPole`)
run(E`JuliaRL_Rainbow_CartPole`)
-run(E`JuliaRL_VPG_CartPole`)
+# run(E`JuliaRL_VPG_CartPole`)
run(E`JuliaRL_MPODiscrete_CartPole`)
run(E`JuliaRL_MPOContinuous_CartPole`)
run(E`JuliaRL_MPOCovariance_CartPole`)
diff --git a/src/ReinforcementLearningZoo/Project.toml b/src/ReinforcementLearningZoo/Project.toml
index 166e30fca..4a6067224 100644
--- a/src/ReinforcementLearningZoo/Project.toml
+++ b/src/ReinforcementLearningZoo/Project.toml
@@ -28,7 +28,7 @@ LogExpFunctions = "0.3"
NNlib = "0.8, 0.9"
Optim = "1"
ReinforcementLearningBase = "0.12"
-ReinforcementLearningCore = "0.11"
+ReinforcementLearningCore = "0.12"
StatsBase = "0.33, 0.34"
Zygote = "0.6"
julia = "1.9"
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl
index a14d36c5b..4270a3882 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl
@@ -45,7 +45,7 @@ function RLBase.optimise!(learner::NFQ, ::PostEpisodeStage, trajectory::Trajecto
loss_func = learner.loss_function
as = learner.action_space
las = length(as)
- batch = ReinforcementLearningTrajectories.sample(trajectory)
+ batch = ReinforcementLearningTrajectories.StatsBase.sample(trajectory)
(s, a, r, ss) = batch[[:state, :action, :reward, :next_state]]
a = Float32.(a)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl
index e6eb4ed60..00d393959 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl
@@ -1,6 +1,6 @@
# include("run.jl")
include("util.jl")
-include("vpg.jl")
+# include("vpg.jl")
# include("A2C.jl")
# include("ppo.jl")
# include("A2CGAE.jl")
@@ -10,5 +10,5 @@ include("vpg.jl")
# include("sac.jl")
# include("maddpg.jl")
# include("vmpo.jl")
-include("trpo.jl")
+# include("trpo.jl")
include("mpo.jl")