diff --git a/environments/environment.jl b/environments/environment.jl index ae813ea79..94d1ddb8e 100644 --- a/environments/environment.jl +++ b/environments/environment.jl @@ -31,7 +31,7 @@ mutable struct Environment{X,T,M,A,O,I} dynamics_jacobian_state::Matrix{T} dynamics_jacobian_input::Matrix{T} input_previous::Vector{T} - control_map::Matrix{T} + control_map::Matrix{T} num_states::Int num_inputs::Int num_observations::Int @@ -66,33 +66,33 @@ end attitude_decompress: flag for pre- and post-concatenating Jacobians with attitude Jacobians """ function Base.step(env::Environment, x, u; - gradients=false, - attitude_decompress=false) + gradients = false, + attitude_decompress = false) mechanism = env.mechanism - timestep= mechanism.timestep + timestep = mechanism.timestep x0 = x # u = clip(env.input_space, u) # control limits env.input_previous .= u # for rendering in Gym - u_scaled = env.control_map * u + u_scaled = env.control_map * u z0 = env.representation == :minimal ? minimal_to_maximal(mechanism, x0) : x0 - z1 = step!(mechanism, z0, u_scaled; opts=env.opts_step) + z1 = step!(mechanism, z0, u_scaled; opts = env.opts_step) env.state .= env.representation == :minimal ? maximal_to_minimal(mechanism, z1) : z1 # Compute cost costs = cost(env, x, u) - # Check termination - done = is_done(env, x) + # Check termination + done = is_done(env, x) # Gradients if gradients if env.representation == :minimal - fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad) + fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad) elseif env.representation == :maximal - fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad) + fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad) if attitude_decompress A0 = attitude_jacobian(z0, length(env.mechanism.bodies)) A1 = attitude_jacobian(z1, length(env.mechanism.bodies)) @@ -109,11 +109,11 @@ function Base.step(env::Environment, x, u; end function Base.step(env::Environment, u; - gradients=false, - attitude_decompress=false) - step(env, env.state, u; - gradients=gradients, - attitude_decompress=attitude_decompress) + gradients = false, + attitude_decompress = false) + step(env, env.state, u; + gradients = gradients, + attitude_decompress = attitude_decompress) end """ @@ -156,7 +156,7 @@ is_done(env::Environment, x) = false x: state """ function Base.reset(env::Environment{X}; - x=nothing) where X + x = nothing) where {X} initialize!(env.mechanism, type2symbol(X)) if x != nothing @@ -172,14 +172,14 @@ function Base.reset(env::Environment{X}; return get_observation(env) end -function MeshCat.render(env::Environment, - mode="human") +function MeshCat.render(env::Environment, + mode = "human") z = env.representation == :minimal ? minimal_to_maximal(env.mechanism, env.state) : env.state - set_robot(env.vis, env.mechanism, z, name=:robot) + set_robot(env.vis, env.mechanism, z, name = :robot) return nothing end -function seed(env::Environment, s=0) +function seed(env::Environment, s = 0) env.rng[1] = MersenneTwister(s) return nothing end @@ -214,26 +214,20 @@ mutable struct BoxSpace{T,N} <: Space{T,N} dtype::DataType # this is always T, it's needed to interface with Stable-Baselines end -function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where T +function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where {T} return BoxSpace{T,n}(n, low, high, (n,), T) end function sample(s::BoxSpace{T,N}) where {T,N} - return rand(T,N) .* (s.high .- s.low) .+ s.low + return rand(T, N) .* (s.high .- s.low) .+ s.low end function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N} all(v .>= s.low) && all(v .<= s.high) end -# For compat with RLBase -Base.length(s::BoxSpace) = s.n -Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high) -Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low - function clip(s::BoxSpace, u) clamp.(u, s.low, s.high) end - - +Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T, N) .* (s.high .- s.low) .+ s.low diff --git a/environments/rlenv.jl b/environments/rlenv.jl index 947026973..3baf47d12 100644 --- a/environments/rlenv.jl +++ b/environments/rlenv.jl @@ -17,8 +17,12 @@ function DojoRLEnv(name::String; kwargs...) DojoRLEnv(Dojo.get_environment(name; kwargs...)) end -RLBase.action_space(env::DojoRLEnv) = env.dojoenv.input_space -RLBase.state_space(env::DojoRLEnv) = env.dojoenv.observation_space +function Base.convert(::Type{RLBase.Space}, s::BoxSpace) + RLBase.Space([BoxSpace(1; low = s.low[i:i], high = s.high[i:i]) for i in 1:s.n]) +end + +RLBase.action_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.input_space) +RLBase.state_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.observation_space) RLBase.is_terminated(env::DojoRLEnv) = env.done RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv) @@ -39,4 +43,4 @@ function (env::DojoRLEnv)(a) env.info = i return nothing end -(env::DojoRLEnv)(a::AbstractFloat) = env([a]) +(env::DojoRLEnv)(a::Number) = env([a])