-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.js
123 lines (111 loc) · 2.73 KB
/
engine.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class Value {
constructor(data, _label = "", _children = [], _op = "") {
this.data = data;
this.label = _label;
this.op = _op;
this.childs = _children;
this.prev = new Set(_children);
this.grad = 0.0;
this._backward = () => {};
}
valueOf() {
return this.data;
}
add(other) {
var out;
if (other instanceof Value) {
out = new Value(this.data + other.data, "", [this, other], "+");
} else {
out = new Value(this.data + other, "", [this, new Value(other)], "+");
}
out._backward = () => {
this.grad += out.grad;
other.grad += out.grad;
};
return out;
}
mult(other) {
var out;
if (other instanceof Value) {
out = new Value(this.data * other.data, "", [this, other], "*");
} else {
out = new Value(this.data * other, "", [this, new Value(other)], "*");
}
out._backward = () => {
this.grad += other.data * out.grad;
other.grad += this.data * out.grad;
};
return out;
}
div(other) {
if (other instanceof Value) {
return new Value(this.data / other.data, "", [this, other], "/");
} else {
return new Value(this.data / other, "", [this, new Value(other)], "/");
}
}
pow(other) {
var out;
if (other instanceof Value) {
out = new Value(Math.pow(this.data, other.data), "", [this, other], "^");
} else {
out = new Value(
Math.pow(this.data, other),
"",
[this, new Value(other)],
"^",
);
}
out._backward = () => {
this.grad += other * Math.pow(this.data, other - 1) * out.grad;
};
return out;
}
neg() {
return this.mult(-1);
}
sub(other) {
return this.add(other.neg());
}
relu() {
var out = this.data < 0 ? 0 : this.data;
out = new Value(out, "relu", [this], "relu");
out._backward = () => {
this.grad += (out.data > 0) * out.grad;
};
return out;
}
tanh() {
var out = Math.tanh(this.data);
out = new Value(out, "tanh", [this], "tanh");
out._backward = () => {
this.grad += (1 - Math.tanh(this.data) ** 2) * out.grad;
};
return out;
}
sigmoid() {
var sig = 1 / (1 + Math.exp(-this.data));
var out = new Value(sig, "sigmoid", [this], "sigmoid");
out._backward = () => {
const grad = sig * (1 - sig) * out.grad;
this.grad += grad;
};
return out;
}
backward() {
const topo = [];
const visited = new Set();
const buildTopo = (v) => {
if (!visited.has(v)) {
visited.add(v);
v.prev.forEach((child) => buildTopo(child));
topo.push(v);
}
};
buildTopo(this);
this.grad = 1;
for (let i = topo.length - 1; i >= 0; i--) {
topo[i]._backward();
}
}
}