-
Notifications
You must be signed in to change notification settings - Fork 0
/
discrete_env.py
59 lines (46 loc) · 1.48 KB
/
discrete_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from gym import Env, spaces
from gym.utils import seeding
def categorical_sample(prob_n, np_random):
"""
Sample from categorical distribution
Each row specifies class probabilities
"""
prob_n = np.asarray(prob_n)
csprob_n = np.cumsum(prob_n)
return (csprob_n > np_random.rand()).argmax()
class DiscreteEnv(Env):
"""
Has the following members
- nS: number of states
- nA: number of actions
- P: transitions (*)
- isd: initial state distribution (**)
(*) dictionary dict of dicts of lists, where
P[s][a] == [(probability, nextstate, reward, done), ...]
(**) list or array of length nS
"""
def __init__(self, nS, nA, P, isd):
self.P = P
self.isd = isd
self.lastaction=None # for rendering
self.nS = nS
self.nA = nA
self.action_space = spaces.Discrete(self.nA)
self.observation_space = spaces.Discrete(self.nS)
self._seed()
self._reset()
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _reset(self):
self.s = categorical_sample(self.isd, self.np_random)
self.lastaction=None
return self.s
def _step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d= transitions[i]
self.s = s
self.lastaction=a
return (s, r, d, {"prob" : p})