diff --git a/docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html b/docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html index 2e5a7bcad..46c54f9ea 100644 --- a/docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html +++ b/docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html @@ -14804,7 +14804,7 @@

Understand the TrajectoriesIn [28]:
-
t = Trajectories.CircularArraySARTTraces(;capacity=10)
+
t = Trajectories.CircularArraySARTSTraces(;capacity=10)
 
diff --git a/docs/src/How_to_implement_a_new_algorithm.md b/docs/src/How_to_implement_a_new_algorithm.md index 9c05be8ac..ddfe492e4 100644 --- a/docs/src/How_to_implement_a_new_algorithm.md +++ b/docs/src/How_to_implement_a_new_algorithm.md @@ -94,7 +94,7 @@ A `Trajectory` is composed of three elements: a `container`, a `controller`, and The container is typically an `AbstractTraces`, an object that store a set of `Trace` in a structured manner. You can either define your own (and contribute it to the package if it is likely to be usable for other algorithms), or use a predefined one if it exists. -The most common `AbstractTraces` object is the `CircularArraySARTTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which toghether are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace. +The most common `AbstractTraces` object is the `CircularArraySARTSTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which together are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace. ```julia function (capacity, state_size, state_eltype, action_size, action_eltype, reward_eltype) diff --git a/docs/src/Zoo_Algorithms/MPO.md b/docs/src/Zoo_Algorithms/MPO.md index 52327b27f..2e22e18bb 100644 --- a/docs/src/Zoo_Algorithms/MPO.md +++ b/docs/src/Zoo_Algorithms/MPO.md @@ -56,7 +56,7 @@ The next step is to wrap this policy into an `Agent`. An agent is a combination ```julia trajectory = Trajectory( - CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,),action = Float32 => (1,)), + CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)), MetaSampler( actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10), critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 1000) diff --git a/src/ReinforcementLearningCore/Project.toml b/src/ReinforcementLearningCore/Project.toml index 36d8a5ec4..b815399d9 100644 --- a/src/ReinforcementLearningCore/Project.toml +++ b/src/ReinforcementLearningCore/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningCore" uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6" -version = "0.11.3" +version = "0.12.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -40,7 +40,7 @@ Parsers = "2" ProgressMeter = "1" Reexport = "1" ReinforcementLearningBase = "0.12" -ReinforcementLearningTrajectories = "^0.1.9" +ReinforcementLearningTrajectories = "^0.3.2" StatsBase = "0.32, 0.33, 0.34" TimerOutputs = "0.5" UnicodePlots = "1.3, 2, 3" diff --git a/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl b/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl index df71cd89d..fdf9107a9 100644 --- a/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl +++ b/src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl @@ -3,7 +3,6 @@ module ReinforcementLearningCore using TimerOutputs using ReinforcementLearningBase using Reexport - const RLCore = ReinforcementLearningCore export RLCore diff --git a/src/ReinforcementLearningCore/src/core/run.jl b/src/ReinforcementLearningCore/src/core/run.jl index 67900775b..e325fb0ef 100644 --- a/src/ReinforcementLearningCore/src/core/run.jl +++ b/src/ReinforcementLearningCore/src/core/run.jl @@ -102,21 +102,17 @@ function _run(policy::AbstractPolicy, action = @timeit_debug timer "plan!" RLBase.plan!(policy, env) @timeit_debug timer "act!" act!(env, action) - @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env) + @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action) @timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage()) @timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env) if check_stop(stop_condition, policy, env) is_stop = true - @timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env) - @timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage()) - @timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env) - @timeit_debug timer "plan!" RLBase.plan!(policy, env) # let the policy see the last observation break end end # end of an episode - @timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env) # let the policy see the last observation + @timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env) @timeit_debug timer "optimise! PostEpisodeStage" optimise!(policy, PostEpisodeStage()) @timeit_debug timer "push!(hook) PostEpisodeStage" push!(hook, PostEpisodeStage(), policy, env) diff --git a/src/ReinforcementLearningCore/src/core/stages.jl b/src/ReinforcementLearningCore/src/core/stages.jl index 52f2a30a9..61e48f57d 100644 --- a/src/ReinforcementLearningCore/src/core/stages.jl +++ b/src/ReinforcementLearningCore/src/core/stages.jl @@ -17,7 +17,9 @@ struct PreActStage <: AbstractStage end struct PostActStage <: AbstractStage end Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing +Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action) = nothing Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing +Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action, ::Symbol) = nothing RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent.jl b/src/ReinforcementLearningCore/src/policies/agent/agent.jl index a4c7ceb8a..55f11198d 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/agent.jl @@ -1,3 +1,3 @@ -include("base.jl") +include("agent_base.jl") include("agent_srt_cache.jl") include("multi_agent.jl") diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl new file mode 100644 index 000000000..79eeead9a --- /dev/null +++ b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl @@ -0,0 +1,64 @@ +export Agent + +using Base.Threads: @spawn + +using Functors: @functor +import Base.push! +""" + Agent(;policy, trajectory) <: AbstractPolicy + +A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to +update the trajectory and policy appropriately in different stages. Agent +is a Callable and its call method accepts varargs and keyword arguments to be +passed to the policy. + +""" +mutable struct Agent{P,T} <: AbstractPolicy + policy::P + trajectory::T + + function Agent(policy::P, trajectory::T) where {P<:AbstractPolicy, T<:Trajectory} + agent = new{P,T}(policy, trajectory) + + if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle() + bind(trajectory, @spawn(optimise!(policy, trajectory))) + end + agent + end +end + +Agent(;policy, trajectory) = Agent(policy, trajectory) + +RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage) +RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory) + +# already spawn a task to optimise inner policy when initializing the agent +RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing + +#by default, optimise does nothing at all stage +function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end + +@functor Agent (policy,) + +function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv) + push!(agent.trajectory, (state = state(env),)) +end + +# !!! TODO: In async scenarios, parameters of the policy may still be updating +# (partially), which will result to incorrect action. This should be addressed +# in Oolong.jl with a wrapper +function RLBase.plan!(agent::Agent, env::AbstractEnv) + RLBase.plan!(agent.policy, env) +end + +function Base.push!(agent::Agent, ::PostActStage, env::AbstractEnv, action) + next_state = state(env) + push!(agent.trajectory, (state = next_state, action = action, reward = reward(env), terminal = is_terminated(env))) +end + +function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv) + if haskey(agent.trajectory, :next_action) + action = RLBase.plan!(agent.policy, env) + push!(agent.trajectory, PartialNamedTuple((action = action, ))) + end +end diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl b/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl index 321a4a6c5..7bd0bed80 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl @@ -27,12 +27,12 @@ struct SART{S,A,R,T} end # This method is used to push a state and action to a trace -function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SA) +function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SA) push!(ts.traces[1].trace, xs.state) push!(ts.traces[2].trace, xs.action) end -function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SART) +function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SART) push!(ts.traces[1].trace, xs.state) push!(ts.traces[2].trace, xs.action) push!(ts.traces[3], xs.reward) diff --git a/src/ReinforcementLearningCore/src/policies/agent/base.jl b/src/ReinforcementLearningCore/src/policies/agent/base.jl deleted file mode 100644 index 7654de4d6..000000000 --- a/src/ReinforcementLearningCore/src/policies/agent/base.jl +++ /dev/null @@ -1,89 +0,0 @@ -export Agent - -using Base.Threads: @spawn - -using Functors: @functor -import Base.push! -""" - Agent(;policy, trajectory) <: AbstractPolicy - -A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to -update the trajectory and policy appropriately in different stages. Agent -is a Callable and its call method accepts varargs and keyword arguments to be -passed to the policy. - -""" -mutable struct Agent{P,T,C} <: AbstractPolicy - policy::P - trajectory::T - cache::C # need cache to collect elements as trajectory does not support partial inserting - - function Agent(policy::P, trajectory::T) where {P,T} - agent = new{P,T, SRT}(policy, trajectory, SRT()) - - if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle() - bind(trajectory, @spawn(optimise!(policy, trajectory))) - end - agent - end - - function Agent(policy::P, trajectory::T, cache::C) where {P,T,C} - agent = new{P,T,C}(policy, trajectory, cache) - - if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle() - bind(trajectory, @spawn(optimise!(policy, trajectory))) - end - agent - end -end - -Agent(;policy, trajectory, cache = SRT()) = Agent(policy, trajectory, cache) - -RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} =RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage) -RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = - RLBase.optimise!(agent.policy, stage, agent.trajectory) - -# already spawn a task to optimise inner policy when initializing the agent -RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing - -#by default, optimise does nothing at all stage -function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end - -@functor Agent (policy,) - -function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv) - push!(agent, state(env)) -end - -# !!! TODO: In async scenarios, parameters of the policy may still be updating -# (partially), which will result to incorrect action. This should be addressed -# in Oolong.jl with a wrapper -function RLBase.plan!(agent::Agent{P,T,C}, env::AbstractEnv) where {P,T,C} - action = RLBase.plan!(agent.policy, env) - push!(agent.trajectory, agent.cache, action) - action -end - -# Multiagent Version -function RLBase.plan!(agent::Agent{P,T,C}, env::E, p::Symbol) where {P,T,C,E<:AbstractEnv} - action = RLBase.plan!(agent.policy, env, p) - push!(agent.trajectory, agent.cache, action) - action -end - -function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E) where {P,T,C,E<:AbstractEnv} - push!(agent.cache, reward(env), is_terminated(env)) -end - -function Base.push!(agent::Agent, ::PostExperimentStage, env::E) where {E<:AbstractEnv} - RLBase.reset!(agent.cache) -end - -function Base.push!(agent::Agent, ::PostExperimentStage, env::E, player::Symbol) where {E<:AbstractEnv} - RLBase.reset!(agent.cache) -end - -function Base.push!(agent::Agent{P,T,C}, state::S) where {P,T,C,S} - push!(agent.cache, state) -end - diff --git a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl index 1db43a18c..98ebd9685 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl @@ -125,18 +125,12 @@ function Base.run( action = @timeit_debug timer "plan!" RLBase.plan!(policy, env) @timeit_debug timer "act!" act!(env, action) - - - @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env) + @timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action) @timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage()) @timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env) if check_stop(stop_condition, policy, env) is_stop = true - @timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env) - @timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage()) - @timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env) - @timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation break end @@ -191,21 +185,43 @@ function Base.push!(multiagent::MultiAgentPolicy, stage::S, env::E) where {S<:Ab end end -# Like in the single-agent case, push! at the PreActStage() calls push! on each player with the state of the environment -function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PreActStage, env::E) where {E<:AbstractEnv, names, T <: Agent} +# Like in the single-agent case, push! at the PostActStage() calls push! on each player. +function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv, player::Symbol) + push!(agent.trajectory, (state = state(env, player),)) +end + +function Base.push!(multiagent::MultiAgentPolicy, s::PreEpisodeStage, env::E) where {E<:AbstractEnv} for player in players(env) - push!(multiagent[player], state(env, player)) + push!(multiagent[player], s, env, player) end end -# Like in the single-agent case, push! at the PostActStage() calls push! on each player with the reward and termination status of the environment -function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PostActStage, env::E) where {E<:AbstractEnv, names, T <: Agent} - for player in players(env) - push!(multiagent[player].cache, reward(env, player), is_terminated(env)) +function RLBase.plan!(agent::Agent, env::AbstractEnv, player::Symbol) + RLBase.plan!(agent.policy, env, player) +end + +# Like in the single-agent case, push! at the PostActStage() calls push! on each player to store the action, reward, next_state, and terminal signal. +function Base.push!(multiagent::MultiAgentPolicy, ::PostActStage, env::E, actions) where {E<:AbstractEnv} + for (player, action) in zip(players(env), actions) + next_state = state(env, player) + observation = ( + state = next_state, + action = action, + reward = reward(env, player), + terminal = is_terminated(env) + ) + push!(multiagent[player].trajectory, observation) + end +end + +function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv, p::Symbol) + if haskey(agent.trajectory, :next_action) + action = RLBase.plan!(agent.policy, env, p) + push!(agent.trajectory, PartialNamedTuple((action = action, ))) end end -function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv,S<:AbstractStage} +function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv, S<:AbstractStage} for player in players(env) push!(hook[player], stage, multiagent[player], env, player) end @@ -227,8 +243,9 @@ function Base.push!(composed_hook::ComposedHook{T}, _push!(stage, policy, env, player, composed_hook.hooks...) end +#For simultaneous players, plan! returns a Tuple of actions. function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv} - return (RLBase.plan!(multiagent[player], env, player) for player in players(env)) + return Tuple(RLBase.plan!(multiagent[player], env, player) for player in players(env)) end function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage} diff --git a/src/ReinforcementLearningCore/test/core/base.jl b/src/ReinforcementLearningCore/test/core/base.jl index 31af5b994..300f3fcd7 100644 --- a/src/ReinforcementLearningCore/test/core/base.jl +++ b/src/ReinforcementLearningCore/test/core/base.jl @@ -1,4 +1,4 @@ -using ReinforcementLearningCore: SRT +using ReinforcementLearningCore using ReinforcementLearningBase using TimerOutputs @@ -8,7 +8,7 @@ using TimerOutputs agent = Agent( RandomPolicy(), Trajectory( - CircularArraySARTTraces(; capacity = 1_000), + CircularArraySARTSTraces(; capacity = 1_000), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ), @@ -18,14 +18,14 @@ using TimerOutputs hook = StepsPerEpisode() run(agent, env, stop_condition, hook) - @test sum(hook[]) == length(agent.trajectory.container) + @test sum(hook[]) + length(hook[]) - 1 == length(agent.trajectory.container) end @testset "StopAfterEpisode" begin agent = Agent( RandomPolicy(), Trajectory( - CircularArraySARTTraces(; capacity = 1_000), + CircularArraySARTSTraces(; capacity = 1_000), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ), @@ -35,25 +35,8 @@ using TimerOutputs hook = StepsPerEpisode() run(agent, env, stop_condition, hook) - @test sum(hook[]) == length(agent.trajectory.container) - end - - @testset "StopAfterStep, use type stable Agent" begin - env = RandomWalk1D() - agent = Agent( - RandomPolicy(legal_action_space(env)), - Trajectory( - CircularArraySARTTraces(; capacity = 1_000), - BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), - ), - SRT{Any, Any, Any}(), - ) - stop_condition = StopAfterStep(123; is_show_progress=false) - hook = StepsPerEpisode() - run(agent, env, stop_condition, hook) - @test sum(hook[]) == length(agent.trajectory.container) - end + @test length(hook[]) == 10 + end end @testset "Debug Timer" begin @@ -63,11 +46,10 @@ using TimerOutputs agent = Agent( RandomPolicy(legal_action_space(env)), Trajectory( - CircularArraySARTTraces(; capacity = 1_000), + CircularArraySARTSTraces(; capacity = 1_000), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), - ), - SRT{Any, Any, Any}(), + ) ) stop_condition = StopAfterStep(123; is_show_progress=false) hook = StepsPerEpisode() diff --git a/src/ReinforcementLearningCore/test/policies/agent.jl b/src/ReinforcementLearningCore/test/policies/agent.jl index a35a79724..81297ce77 100644 --- a/src/ReinforcementLearningCore/test/policies/agent.jl +++ b/src/ReinforcementLearningCore/test/policies/agent.jl @@ -3,51 +3,18 @@ using ReinforcementLearningCore: SRT using ReinforcementLearningCore @testset "agent.jl" begin - @testset "Agent Cache struct" begin - srt = SRT{Int64, Float64, Bool}() - push!(srt, 2) - @test srt.state == 2 - push!(srt, 1.0, true) - @test srt.reward == 1.0 - @test srt.terminal == true - end - - @testset "Trajectory SART struct compatibility" begin - srt_1 = SRT() - srt_2 = SRT{Any, Nothing, Nothing}() - srt_2.state = 1 - srt_3 = SRT{Any, Any, Bool}() - srt_3.state = 1 - srt_3.reward = 1.0 - srt_3.terminal = true - trajectory = Trajectory( - CircularArraySARTTraces(; capacity = 1_000, reward=Float64=>()), - DummySampler(), - ) - - @test_throws ArgumentError push!(trajectory, srt_1) - push!(trajectory, srt_2, 1) - @test length(trajectory.container) == 0 - push!(trajectory, srt_3, 2) - @test length(trajectory.container) == 1 - @test trajectory.container[:action] == [1] - push!(trajectory, srt_3, 3) - @test trajectory.container[:action] == [1, 2] - @test trajectory.container[:state] == [1, 1] - end - @testset "Agent Tests" begin a_1 = Agent( RandomPolicy(), Trajectory( - CircularArraySARTTraces(; capacity = 1_000), + CircularArraySARTSTraces(; capacity = 1_000), DummySampler(), ), ) a_2 = Agent( RandomPolicy(), Trajectory( - CircularArraySARTTraces(; capacity = 1_000), + CircularArraySARTSTraces(; capacity = 1_000), BatchSampler(1), InsertSampleRatioController(), ), @@ -58,17 +25,12 @@ using ReinforcementLearningCore @testset "Test Agent $i" begin agent = agent_list[i] env = RandomWalk1D() - push!(agent, PreActStage(), env) - @test agent.cache.state != nothing - @test agent.cache.reward == nothing - @test agent.cache.terminal == nothing - @test state(env) == agent.cache.state - @test RLBase.plan!(agent, env) in (1,2) + push!(agent, PreEpisodeStage(), env) + action = RLBase.plan!(agent, env) + @test action in (1,2) @test length(agent.trajectory.container) == 0 - push!(agent, PostActStage(), env) - @test agent.cache.reward == 0. && agent.cache.terminal == false + push!(agent, PostActStage(), env, action) push!(agent, PreActStage(), env) - @test state(env) == agent.cache.state @test RLBase.plan!(agent, env) in (1,2) @test length(agent.trajectory.container) == 1 @@ -76,18 +38,6 @@ using ReinforcementLearningCore @test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, 1) @test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, fake_kwarg = 1) end - - @testset "Test push! method" begin - env = RandomWalk1D() - agent = agent_list[i] - push!(agent, PostActStage(), env) - push!(agent, 7) - @test agent.cache.state == 7 - RLBase.reset!(agent.cache) - @test agent.cache.state == nothing - end end end end - - diff --git a/src/ReinforcementLearningCore/test/policies/multi_agent.jl b/src/ReinforcementLearningCore/test/policies/multi_agent.jl index 7c6bd05a4..3b7392f3d 100644 --- a/src/ReinforcementLearningCore/test/policies/multi_agent.jl +++ b/src/ReinforcementLearningCore/test/policies/multi_agent.jl @@ -7,13 +7,13 @@ using DomainSets @testset "MultiAgentPolicy" begin trajectory_1 = Trajectory( - CircularArraySARTTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity = 1), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) trajectory_2 = Trajectory( - CircularArraySARTTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity = 1), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) @@ -56,13 +56,13 @@ end @testset "Basic TicTacToeEnv (Sequential) env checks" begin trajectory_1 = Trajectory( - CircularArraySARTTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity = 1), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) trajectory_2 = Trajectory( - CircularArraySARTTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity = 1), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) @@ -106,13 +106,13 @@ end @testset "Basic RockPaperScissors (simultaneous) env checks" begin trajectory_1 = Trajectory( - CircularArraySARTTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)), + CircularArraySARTSTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) trajectory_2 = Trajectory( - CircularArraySARTTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)), + CircularArraySARTSTraces(; capacity = 1, action = Any => (1,), state = Any => (1,), reward = Any => (2,)), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) diff --git a/src/ReinforcementLearningEnvironments/Project.toml b/src/ReinforcementLearningEnvironments/Project.toml index e579c18b5..9618fb4f0 100644 --- a/src/ReinforcementLearningEnvironments/Project.toml +++ b/src/ReinforcementLearningEnvironments/Project.toml @@ -23,7 +23,7 @@ DelimitedFiles = "1" IntervalSets = "0.7" MacroTools = "0.5" ReinforcementLearningBase = "0.12" -ReinforcementLearningCore = "0.10, 0.11" +ReinforcementLearningCore = "0.12" Requires = "1.0" StatsBase = "0.32, 0.33, 0.34" julia = "1.3" diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl index d700deace..661d3c63b 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl @@ -3,13 +3,13 @@ using ReinforcementLearningEnvironments, ReinforcementLearningBase, ReinforcementLearningCore trajectory_1 = Trajectory( - CircularArraySARTTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity = 1), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) trajectory_2 = Trajectory( - CircularArraySARTTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity = 1), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) diff --git a/src/ReinforcementLearningExperiments/Project.toml b/src/ReinforcementLearningExperiments/Project.toml index 49e30691b..9cd023e17 100644 --- a/src/ReinforcementLearningExperiments/Project.toml +++ b/src/ReinforcementLearningExperiments/Project.toml @@ -20,7 +20,7 @@ Distributions = "0.25" Flux = "0.13, 0.14" Reexport = "1" ReinforcementLearningBase = "0.12" -ReinforcementLearningCore = "0.10, 0.11" +ReinforcementLearningCore = "0.12" ReinforcementLearningEnvironments = "0.8" ReinforcementLearningZoo = "0.7, 0.8" StableRNGs = "1" diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl index 22663b843..1e8cde26c 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/DQN_CartPoleGPU.jl @@ -47,7 +47,7 @@ function RLCore.Experiment( ), ), Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=cap, state=Float32 => (ns), ), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl index 5a43c8c01..4b0938968 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl @@ -46,7 +46,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl index 3f481eff4..a6aa5d0c2 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl @@ -56,7 +56,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl index 0bc2832ff..ea5d7a20d 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl @@ -67,7 +67,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl index 07a48ecef..affd6ab02 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl @@ -50,7 +50,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=10_000, state=Float32 => (ns,), action=Float32 => (na,), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl index 1072722b7..66b75e9e3 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_PrioritizedDQN_CartPole.jl @@ -57,7 +57,7 @@ function RLCore.Experiment( ), trajectory=Trajectory( container=CircularPrioritizedTraces( - CircularArraySARTTraces( + CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ); diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl index 53ad1c35e..81173762d 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_QRDQN_CartPole.jl @@ -53,7 +53,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl index 5b7f0d595..82223b093 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl @@ -58,7 +58,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl index 0e4f7d1c9..eb7492837 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl @@ -54,7 +54,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ), diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl index 81d1d9c16..f071b7c5b 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_MPO_CartPole.jl @@ -38,7 +38,7 @@ function RLCore.Experiment( agent = Agent( policy = policy, trajectory = Trajectory( - CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)), + CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)), MetaSampler( actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10), critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 2000) @@ -78,7 +78,7 @@ function RLCore.Experiment( agent = Agent( policy = policy, trajectory = Trajectory( - CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (2,)), + CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (2,)), MetaSampler( actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10), critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 2000) @@ -122,7 +122,7 @@ function RLCore.Experiment( agent = Agent( policy = policy, trajectory = Trajectory( - CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)), + CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)), MetaSampler( actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10), critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 2000) diff --git a/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl b/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl index 2c3363df9..8623bd50e 100644 --- a/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl +++ b/src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl @@ -55,7 +55,7 @@ function RLCore.Experiment( ), ), trajectory=Trajectory( - container=CircularArraySARTTraces( + container=CircularArraySARTSTraces( capacity=1000, state=Float32 => (ns,), ), diff --git a/src/ReinforcementLearningExperiments/test/runtests.jl b/src/ReinforcementLearningExperiments/test/runtests.jl index 8f95d1e3f..1176e11c1 100644 --- a/src/ReinforcementLearningExperiments/test/runtests.jl +++ b/src/ReinforcementLearningExperiments/test/runtests.jl @@ -11,7 +11,7 @@ run(E`JuliaRL_QRDQN_CartPole`) run(E`JuliaRL_REMDQN_CartPole`) run(E`JuliaRL_IQN_CartPole`) run(E`JuliaRL_Rainbow_CartPole`) -run(E`JuliaRL_VPG_CartPole`) +# run(E`JuliaRL_VPG_CartPole`) run(E`JuliaRL_MPODiscrete_CartPole`) run(E`JuliaRL_MPOContinuous_CartPole`) run(E`JuliaRL_MPOCovariance_CartPole`) diff --git a/src/ReinforcementLearningZoo/Project.toml b/src/ReinforcementLearningZoo/Project.toml index 166e30fca..4a6067224 100644 --- a/src/ReinforcementLearningZoo/Project.toml +++ b/src/ReinforcementLearningZoo/Project.toml @@ -28,7 +28,7 @@ LogExpFunctions = "0.3" NNlib = "0.8, 0.9" Optim = "1" ReinforcementLearningBase = "0.12" -ReinforcementLearningCore = "0.11" +ReinforcementLearningCore = "0.12" StatsBase = "0.33, 0.34" Zygote = "0.6" julia = "1.9" diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl index a14d36c5b..4270a3882 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/NFQ.jl @@ -45,7 +45,7 @@ function RLBase.optimise!(learner::NFQ, ::PostEpisodeStage, trajectory::Trajecto loss_func = learner.loss_function as = learner.action_space las = length(as) - batch = ReinforcementLearningTrajectories.sample(trajectory) + batch = ReinforcementLearningTrajectories.StatsBase.sample(trajectory) (s, a, r, ss) = batch[[:state, :action, :reward, :next_state]] a = Float32.(a) diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl index e6eb4ed60..00d393959 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl @@ -1,6 +1,6 @@ # include("run.jl") include("util.jl") -include("vpg.jl") +# include("vpg.jl") # include("A2C.jl") # include("ppo.jl") # include("A2CGAE.jl") @@ -10,5 +10,5 @@ include("vpg.jl") # include("sac.jl") # include("maddpg.jl") # include("vmpo.jl") -include("trpo.jl") +# include("trpo.jl") include("mpo.jl")