Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify envs to be compatible with twrl #8

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
70eb229
Modified specs to follow gym API
Oct 1, 2016
a56188c
Update function names
Oct 2, 2016
1fe33fa
Experiment uses new specs
Oct 2, 2016
522b8a4
Changed function calls to getters
Oct 2, 2016
9e92921
table with all rlenv envs
Oct 19, 2016
1c276f2
Update function calls, added new steps method, updated README
Oct 21, 2016
63cf6db
Added timestep limits, exposed start function added _start function t…
Oct 24, 2016
c20c8b8
Added render
Oct 24, 2016
a9ea21e
Added zoom option
Oct 26, 2016
430a2b4
Added zoom variable to experiment
Oct 26, 2016
358a8c7
fixed variable name
Oct 26, 2016
135bf69
Modified specs to follow gym API
Oct 1, 2016
2afdffe
Update function names
Oct 2, 2016
0e3b561
Experiment uses new specs
Oct 2, 2016
3a39f0c
Changed function calls to getters
Oct 2, 2016
834c837
table with all rlenv envs
Oct 19, 2016
b6f086b
Update function calls, added new steps method, updated README
Oct 21, 2016
2319531
Added timestep limits, exposed start function added _start function t…
Oct 24, 2016
b02b760
Added render
Oct 24, 2016
68f773e
Added zoom option
Oct 26, 2016
0f86433
Added zoom variable to experiment
Oct 26, 2016
904f747
fixed variable name
Oct 26, 2016
922f10d
Merge remote-tracking branch 'origin/twrl' into twrl
Nov 9, 2016
c1f3b16
Modified XOWorld to new api standards
Nov 9, 2016
9114305
1 channel for XOWorld
Nov 9, 2016
796bc58
Added super call
Nov 9, 2016
eb8ea9b
Merge branch 'master' into twrl
Nov 9, 2016
52ad76d
Added base tests
Nov 14, 2016
1096f31
Merge remote-tracking branch 'origin/master' into twrl
Nov 14, 2016
59da7cd
Modified Minecraft functions to support api
Nov 17, 2016
2d23ed7
Exclude minecraft from tests
Nov 17, 2016
55f90eb
Added assertions to tests
Nov 25, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ local observation = env:start()

**Note that the API is under development and may be subject to change**

### rlenvs.envs

A table of all possible environments implemented in rlenvs.

### observation = env:start([opts])

Starts a new episode in the environment and returns the first `observation`. May take `opts`.
Expand All @@ -55,7 +59,7 @@ Starts a new episode in the environment and returns the first `observation`. May

Performs a step in the environment using `action` (which may be a list - see below), and returns the `reward`, the `observation` of the state transitioned to, and a `terminal` flag. Optionally provides `actionTaken`, if the environment provides supervision in the form of the actual action taken by the agent in spite of the provided action.

### stateSpec = env:getStateSpec()
### stateSpec = env:getStateSpace()

Returns a state specification as a list with 3 elements:

Expand All @@ -67,11 +71,11 @@ Returns a state specification as a list with 3 elements:

If several states are returned, `stateSpec` is itself a list of state specifications. Ranges may use `nil` if unknown.

### actionSpec = env:getActionSpec()
### actionSpec = env:getActionSpace()

Returns an action specification, with the same structure as used for state specifications.

### minReward, maxReward = env:getRewardSpec()
### minReward, maxReward = env:getRewardSpace()

Returns the minimum and maximum rewards produced by the environment. Values may be `nil` if unknown.

Expand Down
42 changes: 18 additions & 24 deletions experiment.lua
Original file line number Diff line number Diff line change
@@ -1,38 +1,32 @@
local image = require 'image'
local Catch = require 'rlenvs/Catch'

-- Detect QT for image display
local qt = pcall(require, 'qt')
require 'rlenvs'
local Catch = require('rlenvs.Catch')

-- Initialise and start environment
local env = Catch({level = 2})
local stateSpec = env:getStateSpec()
local actionSpec = env:getActionSpec()
local env = Catch({level = 2, render = true})
local getActionSpace = env:getActionSpace()
local observation = env:start()

local reward, terminal
local reward, terminal = 0, false
local episodes, totalReward = 0, 0
local nSteps = 1000 * (stateSpec[2][2] - 1) -- Run for 1000 episodes
local nEpisodes = 1000

-- Display
local window = qt and image.display({image=observation, zoom=10})
env:render()

for i = 1, nSteps do
-- Pick random action and execute it
local action = torch.random(actionSpec[3][1], actionSpec[3][2])
reward, observation, terminal = env:step(action)
totalReward = totalReward + reward
for i = 1, nEpisodes do
while not terminal do
-- Pick random action and execute it
local action = torch.random(0, getActionSpace['n'] - 1)
reward, observation, terminal = env:step(action)
totalReward = totalReward + reward

-- Display
if qt then
image.display({image=observation, zoom=10, win=window})
-- Display
env:render()
end

-- If game finished, start again
if terminal then
episodes = episodes + 1
observation = env:start()
end
episodes = episodes + 1
observation = env:start()
terminal = false
end
print('Episodes: ' .. episodes)
print('Total Reward: ' .. totalReward)
41 changes: 29 additions & 12 deletions rlenvs/Acrobot.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
local classic = require 'classic'

local Acrobot, super = classic.class('Acrobot', Env)
Acrobot.timeStepLimit = 500

-- Constructor
function Acrobot:_init(opts)
opts = opts or {}

opts.timeStepLimit = Acrobot.timeStepLimit
super._init(self, opts)

-- Constants
self.g = opts.g or 9.8
self.m1 = opts.m1 or 1 -- Mass of link 1
Expand All @@ -21,27 +24,40 @@ function Acrobot:_init(opts)
end

-- 4 states returned, of type 'real', of dimensionality 1, with differing ranges
function Acrobot:getStateSpec()
return {
{'real', 1, {-math.pi, math.pi}}, -- Joint 1 angle
{'real', 1, {-math.pi, math.pi}}, -- Joint 2 angle
{'real', 1, {-4*math.pi, 4*math.pi}}, -- Joint 1 angular velocity
{'real', 1, {-9*math.pi, 9*math.pi}} -- Joint 2 angular velocity
function Acrobot:getStateSpace()
local state = {}
state['name'] = 'Box'
state['shape'] = {4}
state['low'] = {
-math.pi, -- Joint 1 angle
-math.pi, -- Joint 2 angle
-4*math.pi, -- Joint 1 angular velocity
-9*math.pi -- Joint 2 angular velocity
}
state['high'] = {
math.pi, -- Joint 1 angle
math.pi, -- Joint 2 angle
4*math.pi, -- Joint 1 angular velocity
9*math.pi -- Joint 2 angular velocity
}
return state
end

-- 1 action required, of type 'int', of dimensionality 1, with second torque joint in {-1, 0, 1}
function Acrobot:getActionSpec()
return {'int', 1, {-1, 1}}
function Acrobot:getActionSpace()
local action = {}
action['name'] = 'Discrete'
action['n'] = 3
return action
end

-- Min and max reward
function Acrobot:getRewardSpec()
function Acrobot:getRewardSpace()
return -1, 0
end

-- Resets the cart
function Acrobot:start()
function Acrobot:_start()
-- Reset angles and velocities
self.q1 = 0 -- Joint 1 angle
self.q2 = 0 -- Joint 2 angle
Expand All @@ -52,7 +68,8 @@ function Acrobot:start()
end

-- Swings the pole via torque on second joint
function Acrobot:step(action)
function Acrobot:_step(action)
action = action - 1 -- rescale the action
local reward = -1
local terminal = false

Expand Down
30 changes: 23 additions & 7 deletions rlenvs/Atari.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ if not hasALEWrap then
end

local Atari, super = classic.class('Atari', Env)
Atari.timeStepLimit = 100000

-- Constructor
function Atari:_init(opts)
-- Create ALEWrap options from opts
opts = opts or {}
opts.timeStepLimit = Atari.timeStepLimit
super._init(self, opts)

if opts.lifeLossTerminal == nil then
opts.lifeLossTerminal = true
end
Expand Down Expand Up @@ -44,13 +48,25 @@ function Atari:_init(opts)
end

-- 1 state returned, of type 'real', of dimensionality 3 x 210 x 160, between 0 and 1
function Atari:getStateSpec()
return {'real', {3, 210, 160}, {0, 1}}
function Atari:getStateSpace()
local state = {}
state['name'] = 'Box'
state['shape'] = {3, 210, 160}
state['low'] = {
0
}
state['high'] = {
1
}
return state
end

-- 1 action required, of type 'int', of dimensionality 1, between 1 and 18 (max)
function Atari:getActionSpec()
return {'int', 1, {1, #self.actions}}
function Atari:getActionSpace()
local action = {}
action['name'] = 'Discrete'
action['n'] = #self.actions
return action
end

-- RGB screen of height 210 and width 160
Expand All @@ -59,12 +75,12 @@ function Atari:getDisplaySpec()
end

-- Min and max reward (unknown)
function Atari:getRewardSpec()
function Atari:getRewardSpace()
return nil, nil
end

-- Starts a new game, possibly with a random number of no-ops
function Atari:start()
function Atari:_start()
local screen, reward, terminal

if self.gameEnv._random_starts > 0 then
Expand All @@ -77,7 +93,7 @@ function Atari:start()
end

-- Steps in a game
function Atari:step(action)
function Atari:_step(action)
-- Map action index to action for game
action = self.actions[action]

Expand Down
34 changes: 24 additions & 10 deletions rlenvs/Blackjack.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,45 @@ local Blackjack, super = classic.class('Blackjack', Env)
function Blackjack:_init(opts)
opts = opts or {}

super._init(self, opts)

-- Create number-only suit
self.suit = torch.Tensor({2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 11})
end

-- 2 states returned, of type 'int', of dimensionality 1, for the player sum, dealer's showing card, and player-usable ace
function Blackjack:getStateSpec()
return {
{'int', 1, {2, 20}},
{'int', 1, {1, 10}},
{'int', 1, {0, 1}}
function Blackjack:getStateSpace()
local state = {}
state['name'] = 'Box'
state['shape'] = {3}
state['low'] = {
2,
1,
0
}
state['high'] = {
20,
10,
1
}
return state
end

-- 1 action required, of type 'int', of dimensionality 1, either stand or hit
function Blackjack:getActionSpec()
return {'int', 1, {0, 1}}
function Blackjack:getActionSpace()
local action = {}
action['name'] = 'Discrete'
action['n'] = 2
return action
end

-- Min and max reward
function Blackjack:getRewardSpec()
function Blackjack:getRewardSpace()
return -1, 1
end

-- Draw 2 cards for player and dealer
function Blackjack:start()
function Blackjack:_start()
-- Shuffle deck
self.deck = torch.cat({self.suit, self.suit, self.suit, self.suit}, 1):index(1, torch.randperm(52):long())

Expand All @@ -51,7 +65,7 @@ function Blackjack:start()
end

-- Player stands or hits
function Blackjack:step(action)
function Blackjack:_step(action)
local reward = 0
local terminal = false

Expand Down
44 changes: 30 additions & 14 deletions rlenvs/CartPole.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
local classic = require 'classic'

local CartPole, super = classic.class('CartPole', Env)
CartPole.timeStepLimit = 200

-- Constructor
function CartPole:_init(opts)
opts = opts or {}

opts.timeStepLimit = CartPole.timeStepLimit
super._init(self, opts)

-- Constants
self.gravity = opts.gravity or 9.8
self.cartMass = opts.cartMass or 1.0
Expand All @@ -19,27 +22,40 @@ function CartPole:_init(opts)
end

-- 4 states returned, of type 'real', of dimensionality 1, with differing ranges
function CartPole:getStateSpec()
return {
{'real', 1, {-2.4, 2.4}}, -- Cart position
{'real', 1, {nil, nil}}, -- Cart velocity
{'real', 1, {math.rad(-12), math.rad(12)}}, -- Pole angle
{'real', 1, {nil, nil}} -- Pole angular velocity
function CartPole:getStateSpace()
local state = {}
state['name'] = 'Box'
state['shape'] = {4}
state['low'] = {
-2.4, -- Cart position
math.huge, -- Cart velocity
math.rad(-12), -- Pole angle
math.huge -- Pole angular velocity
}
state['high'] = {
2.4, -- Cart position
math.huge, -- Cart velocity
math.rad(12), -- Pole angle
math.huge -- Pole angular velocity
}
return state
end

-- 1 action required, of type 'int', of dimensionality 1, between 0 and 1 (left, right)
function CartPole:getActionSpec()
return {'int', 1, {0, 1}}
function CartPole:getActionSpace()
local action = {}
action['name'] = 'Discrete'
action['n'] = 2
return action
end

-- Min and max reward
function CartPole:getRewardSpec()
function CartPole:getRewardSpace()
return -1, 0
end

-- Resets the cart
function CartPole:start()
function CartPole:_start()
-- Reset position, angle and velocities
self.x = 0 -- Cart position (m)
self.xDot = 0 -- Cart velocity
Expand All @@ -50,7 +66,7 @@ function CartPole:start()
end

-- Drives the cart
function CartPole:step(action)
function CartPole:_step(action)
-- Calculate acceleration
local force = action == 1 and self.forceMagnitude or -self.forceMagnitude
local cosTheta = math.cos(self.theta)
Expand All @@ -66,10 +82,10 @@ function CartPole:step(action)
self.thetaDot = self.thetaDot + self.tau * thetaDotDot

-- Check failure (if cart reaches sides of track/pole tips too much)
local reward = 0
local reward = 1
local terminal = false
if self.x < -2.4 or self.x > 2.4 or self.theta < math.rad(-12) or self.theta > math.rad(12) then
reward = -1
reward = 0
terminal = true
end

Expand Down
Loading