Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 23 additions & 29 deletions environments/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mutable struct Environment{X,T,M,A,O,I}
dynamics_jacobian_state::Matrix{T}
dynamics_jacobian_input::Matrix{T}
input_previous::Vector{T}
control_map::Matrix{T}
control_map::Matrix{T}
num_states::Int
num_inputs::Int
num_observations::Int
Expand Down Expand Up @@ -66,33 +66,33 @@ end
attitude_decompress: flag for pre- and post-concatenating Jacobians with attitude Jacobians
"""
function Base.step(env::Environment, x, u;
gradients=false,
attitude_decompress=false)
gradients = false,
attitude_decompress = false)

mechanism = env.mechanism
timestep= mechanism.timestep
timestep = mechanism.timestep

x0 = x
# u = clip(env.input_space, u) # control limits
env.input_previous .= u # for rendering in Gym
u_scaled = env.control_map * u
u_scaled = env.control_map * u

z0 = env.representation == :minimal ? minimal_to_maximal(mechanism, x0) : x0
z1 = step!(mechanism, z0, u_scaled; opts=env.opts_step)
z1 = step!(mechanism, z0, u_scaled; opts = env.opts_step)
env.state .= env.representation == :minimal ? maximal_to_minimal(mechanism, z1) : z1

# Compute cost
costs = cost(env, x, u)

# Check termination
done = is_done(env, x)
# Check termination
done = is_done(env, x)

# Gradients
if gradients
if env.representation == :minimal
fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad)
fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad)
elseif env.representation == :maximal
fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad)
fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad)
if attitude_decompress
A0 = attitude_jacobian(z0, length(env.mechanism.bodies))
A1 = attitude_jacobian(z1, length(env.mechanism.bodies))
Expand All @@ -109,11 +109,11 @@ function Base.step(env::Environment, x, u;
end

function Base.step(env::Environment, u;
gradients=false,
attitude_decompress=false)
step(env, env.state, u;
gradients=gradients,
attitude_decompress=attitude_decompress)
gradients = false,
attitude_decompress = false)
step(env, env.state, u;
gradients = gradients,
attitude_decompress = attitude_decompress)
end

"""
Expand Down Expand Up @@ -156,7 +156,7 @@ is_done(env::Environment, x) = false
x: state
"""
function Base.reset(env::Environment{X};
x=nothing) where X
x = nothing) where {X}

initialize!(env.mechanism, type2symbol(X))
if x != nothing
Expand All @@ -172,14 +172,14 @@ function Base.reset(env::Environment{X};
return get_observation(env)
end

function MeshCat.render(env::Environment,
mode="human")
function MeshCat.render(env::Environment,
mode = "human")
z = env.representation == :minimal ? minimal_to_maximal(env.mechanism, env.state) : env.state
set_robot(env.vis, env.mechanism, z, name=:robot)
set_robot(env.vis, env.mechanism, z, name = :robot)
return nothing
end

function seed(env::Environment, s=0)
function seed(env::Environment, s = 0)
env.rng[1] = MersenneTwister(s)
return nothing
end
Expand Down Expand Up @@ -214,26 +214,20 @@ mutable struct BoxSpace{T,N} <: Space{T,N}
dtype::DataType # this is always T, it's needed to interface with Stable-Baselines
end

function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where T
function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where {T}
return BoxSpace{T,n}(n, low, high, (n,), T)
end

function sample(s::BoxSpace{T,N}) where {T,N}
return rand(T,N) .* (s.high .- s.low) .+ s.low
return rand(T, N) .* (s.high .- s.low) .+ s.low
end

function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N}
all(v .>= s.low) && all(v .<= s.high)
end

# For compat with RLBase
Base.length(s::BoxSpace) = s.n
Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high)
Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low

function clip(s::BoxSpace, u)
clamp.(u, s.low, s.high)
end



Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T, N) .* (s.high .- s.low) .+ s.low
10 changes: 7 additions & 3 deletions environments/rlenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ function DojoRLEnv(name::String; kwargs...)
DojoRLEnv(Dojo.get_environment(name; kwargs...))
end

RLBase.action_space(env::DojoRLEnv) = env.dojoenv.input_space
RLBase.state_space(env::DojoRLEnv) = env.dojoenv.observation_space
function Base.convert(::Type{RLBase.Space}, s::BoxSpace)
RLBase.Space([BoxSpace(1; low = s.low[i:i], high = s.high[i:i]) for i in 1:s.n])
end

RLBase.action_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.input_space)
RLBase.state_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.observation_space)
RLBase.is_terminated(env::DojoRLEnv) = env.done

RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv)
Expand All @@ -39,4 +43,4 @@ function (env::DojoRLEnv)(a)
env.info = i
return nothing
end
(env::DojoRLEnv)(a::AbstractFloat) = env([a])
(env::DojoRLEnv)(a::Number) = env([a])