-
Notifications
You must be signed in to change notification settings - Fork 0
/
adam.py
67 lines (51 loc) · 1.8 KB
/
adam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import cupy
import numpy
import util
adam = cupy.ElementwiseKernel(
'T grad, T lr, T one_minus_beta1, T one_minus_beta2, '
'T eps, T eta, T weight_decay_rate',
'T param, T m, T v',
'''m += one_minus_beta1 * (grad - m);
v += one_minus_beta2 * (grad * grad - v);
param -= eta * (lr * m / (sqrt(v) + eps) +
weight_decay_rate * param);''',
'adam')
@cupy.fuse()
def adam_fuse(
grad, lr, one_minus_beta1, one_minus_beta2, eps, eta,
weight_decay_rate, param, m, v):
xp = cupy.get_array_module(grad)
m += one_minus_beta1 * (grad - m)
v += one_minus_beta2 * (grad * grad - v)
param -= eta * (lr * m / (xp.sqrt(v) + eps) + weight_decay_rate * param)
class hp:
lr = 0.1
beta1 = 0.9
beta2 = 0.999
eps = 1e-8
eta = 1.0
weight_decay_rate = 0.9
def call_adam(grad, data, state_m, state_v):
adam(grad, hp.lr, 1 - hp.beta1, 1 - hp.beta2, hp.eps, hp.eta,
hp.weight_decay_rate, data, state_m, state_v)
def call_adam_fuse(grad, data, state_m, state_v):
adam_fuse(grad, numpy.float32(hp.lr), numpy.float32(1 - hp.beta1),
numpy.float32(1 - hp.beta2), numpy.float32(hp.eps),
numpy.float32(hp.eta), numpy.float32(hp.weight_decay_rate),
data, state_m, state_v)
sizes = [1, 10, 100, 1000, 2000, 5000]
for size in sizes:
zero = cupy.zeros((size, size))
def f():
call_adam(zero, zero, zero, zero)
util.measure(f, "adam , %4d" % (size), 100)
for size in sizes:
zero = cupy.zeros((size, size))
def f():
call_adam_fuse(zero, zero, zero, zero)
util.measure(f, "adam_fuse, %4d" % (size), 100)
for size in sizes[:4]:
zero = numpy.zeros((size, size))
def f():
call_adam_fuse(zero, zero, zero, zero)
util.measure(f, "adam_np , %4d" % (size), 10)