-
Notifications
You must be signed in to change notification settings - Fork 11
/
N-step Off-policy n-step Q(σ).kt
104 lines (97 loc) · 2.95 KB
/
N-step Off-policy n-step Q(σ).kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
@file:Suppress("NAME_SHADOWING")
package lab.mars.rl.algo.ntd
import lab.mars.rl.algo.V_from_Q
import lab.mars.rl.algo.`ε-greedy`
import lab.mars.rl.model.impl.mdp.IndexedAction
import lab.mars.rl.model.impl.mdp.IndexedMDP
import lab.mars.rl.model.impl.mdp.IndexedState
import lab.mars.rl.model.impl.mdp.OptimalSolution
import lab.mars.rl.model.isTerminal
import lab.mars.rl.model.log
import lab.mars.rl.util.buf.newBuf
import lab.mars.rl.util.log.debug
import lab.mars.rl.util.math.Σ
import lab.mars.rl.util.tuples.tuple3
import org.apache.commons.math3.util.FastMath.min
fun IndexedMDP.`N-step off-policy n-step Q(σ)`(
n: Int,
σ: (Int) -> Int,
ε: Double,
α: (IndexedState, IndexedAction) -> Double,
episodes: Int): OptimalSolution {
val b = equiprobablePolicy()
val π = equiprobablePolicy()
val Q = QFunc { 0.0 }
val _Q = newBuf<Double>(min(n, MAX_N))
val _π = newBuf<Double>(min(n, MAX_N))
val ρ = newBuf<Double>(min(n, MAX_N))
val _σ = newBuf<Int>(min(n, MAX_N))
val δ = newBuf<Double>(min(n, MAX_N))
val _S = newBuf<IndexedState>(min(n, MAX_N))
val _A = newBuf<IndexedAction>(min(n, MAX_N))
for (episode in 1..episodes) {
log.debug { "$episode/$episodes" }
var n = n
var T = Int.MAX_VALUE
var t = 0
var s = started()
var a = b(s)
_Q.clear(); _Q.append(0.0)
_π.clear(); _π.append(π[s, a])
ρ.clear();ρ.append(π[s, a] / b[s, a])
_σ.clear(); _σ.append(σ(0))
δ.clear()
_S.clear();_S.append(s)
_A.clear();_A.append(a)
do {
if (t >= n) {//at most n
_Q.removeFirst()
_π.removeFirst()
ρ.removeFirst()
_σ.removeFirst()
δ.removeFirst()
_S.removeFirst()
_A.removeFirst()
}
if (t < T) {
val (s_next, reward) = a.sample()
_S.append(s_next)
s = s_next
if (s.isTerminal) {
δ.append(reward - _Q.last)
T = t + 1
val τ = t - n + 1
if (τ < 0) n = T //n is too large
} else {
a = b(s);_A.append(a)
val tmp_σ = σ(t + 1)
_σ.append(tmp_σ)
δ.append(reward + γ * tmp_σ * Q[s, a] + γ * (1 - tmp_σ) * Σ(s.actions) { π[s, it] * Q[s, it] } - _Q.last)
_Q.append(Q[s, a])
_π.append(π[s, a])
ρ.append(π[s, a] / b[s, a])
}
}
val τ = t - n + 1
if (τ >= 0) {
var _ρ = 1.0
var Z = 1.0
var G = _Q[0]
val end = min(n - 1, T - 1 - τ)
for (k in 0..end) {
G += Z * δ[k]
if (k < end) Z *= γ * ((1 - _σ[k + 1]) * _π[k + 1] + _σ[k + 1])
_ρ *= 1 - _σ[k] + _σ[k] * ρ[k]
}
Q[_S[0], _A[0]] += α(_S[0], _A[0]) * _ρ * (G - Q[_S[0], _A[0]])
`ε-greedy`(_S[0], Q, π, ε)
}
t++
} while (τ < T - 1)
log.debug { "n=$n,T=$T" }
}
val V = VFunc { 0.0 }
val result = tuple3(π, V, Q)
V_from_Q(states, result)
return result
}