-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsum_tree.py
76 lines (65 loc) · 3.33 KB
/
sum_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch
class SumTree(object):
"""
Story data with its priority in the tree.
Tree structure and array storage:
Tree index:
0 -> storing priority sum
/ \
1 2
/ \ / \
3 4 5 6 -> storing priority for transitions
Array type for storing:
[0,1,2,3,4,5,6]
"""
def __init__(self, buffer_capacity):
self.buffer_capacity = buffer_capacity # buffer的容量
self.tree_capacity = 2 * buffer_capacity - 1 # sum_tree的容量
self.tree = np.zeros(self.tree_capacity)
def update(self, data_index, priority):
# data_index表示当前数据在buffer中的index
# tree_index表示当前数据在sum_tree中的index
tree_index = data_index + self.buffer_capacity - 1 # 把当前数据在buffer中的index转换为在sum_tree中的index
change = priority - self.tree[tree_index] # 当前数据的priority的改变量
self.tree[tree_index] = priority # 更新树的最后一层叶子节点的优先级
# then propagate the change through the tree
while tree_index != 0: # 更新上层节点的优先级,一直传播到最顶端
tree_index = (tree_index - 1) // 2
self.tree[tree_index] += change
def get_index(self, v):
parent_idx = 0 # 从树的顶端开始
while True:
child_left_idx = 2 * parent_idx + 1 # 父节点下方的左右两个子节点的index
child_right_idx = child_left_idx + 1
if child_left_idx >= self.tree_capacity: # reach bottom, end search
tree_index = parent_idx # tree_index表示采样到的数据在sum_tree中的index
break
else: # downward search, always search for a higher priority node
if v <= self.tree[child_left_idx]:
parent_idx = child_left_idx
else:
v -= self.tree[child_left_idx]
parent_idx = child_right_idx
data_index = tree_index - self.buffer_capacity + 1 # tree_index->data_index
return data_index, self.tree[tree_index] # 返回采样到的data在buffer中的index,以及相对应的priority
def get_batch_index(self, current_size, batch_size, beta):
batch_index = np.zeros(batch_size, dtype=np.long)
IS_weight = torch.zeros(batch_size, dtype=torch.float32)
segment = self.priority_sum / batch_size # 把[0,priority_sum]等分成batch_size个区间,在每个区间均匀采样一个数
for i in range(batch_size):
a = segment * i
b = segment * (i + 1)
v = np.random.uniform(a, b)
index, priority = self.get_index(v)
batch_index[i] = index
prob = priority / self.priority_sum # 当前数据被采样的概率
IS_weight[i] = (current_size * prob) ** (-beta)
IS_weight /= IS_weight.max() # normalization
return batch_index, IS_weight
@property
def priority_sum(self):
return self.tree[0] # 树的顶端保存了所有priority之和
@property
def priority_max(self):
return self.tree[self.buffer_capacity - 1:].max() # 树的最后一层叶节点,保存的才是每个数据对应的priority