-
Notifications
You must be signed in to change notification settings - Fork 0
/
ounoise.py
32 lines (27 loc) · 1.09 KB
/
ounoise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
from collections import deque
import random
import torch
class OUNoise(object):
def __init__(self, action_dim, mu=0.5, theta=0.5, max_sigma=1, min_sigma=0.2, decay_period=10000):
self.mu = mu
self.theta = theta
self.sigma = max_sigma
self.max_sigma = max_sigma
self.min_sigma = min_sigma
self.decay_period = decay_period
self.action_dim = action_dim
self.low = 0
self.high = 1
self.reset()
def reset(self):
self.state = np.ones(self.action_dim) * self.mu
def evolve_state(self):
x = self.state
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
self.state = x + dx
return self.state
def get_action(self, action, t=0):
ou_state = self.evolve_state()
self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)
return torch.clamp(action + torch.tensor(ou_state), min=self.low, max=self.high)