1
+ # coding=utf-8
2
+
3
+ # Copyright [2024] [SkywardAI]
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import random
18
+ import unittest
19
+
20
+ from pathlib import Path
21
+ import torch
22
+ from torch .nn import functional as F
23
+
24
+ from models .mlp_batchnorm import MlpBatchNormTrainer ,Linear , BatchNorm1d , Tanh
25
+ from pkg .dataset_helper import DatasetHelper
26
+
27
+
28
+ class TestMLPBatchNorm (unittest .TestCase ):
29
+
30
+ @classmethod
31
+ def setUpClass (cls ):
32
+ cls .n_embd = MlpBatchNormTrainer .n_embed
33
+ cls .n_hidden = MlpBatchNormTrainer .n_hidden
34
+ cls .n_block_size = MlpBatchNormTrainer .n_block_size
35
+
36
+ src_dir = Path (os .path .dirname (os .path .abspath (__file__ ))).parent
37
+ abs_file_path = os .path .join (src_dir , "input.txt" )
38
+ _ = DatasetHelper .download_remote_file (MlpBatchNormTrainer .ds_url , abs_file_path )
39
+ cls .data = MlpBatchNormTrainer .load_dataset (abs_file_path )
40
+ cls .unique_chars = MlpBatchNormTrainer .unique_chars (cls .data .splitlines ())
41
+ cls .stoi = MlpBatchNormTrainer .stoi (cls .unique_chars )
42
+ cls .itos = MlpBatchNormTrainer .itos (cls .unique_chars )
43
+ cls .vocab_size = MlpBatchNormTrainer .build_vocab (cls .itos )
44
+
45
+ def test_mlp_batchnorm_trainer (self ):
46
+
47
+ self .assertEqual (self .vocab_size ,27 )
48
+ random .seed (42 )
49
+ words = self .data .splitlines ()
50
+ random .shuffle (words )
51
+
52
+ n1 = int (0.8 * len (words ))
53
+ n2 = int (0.9 * len (words ))
54
+
55
+ Xtr , Ytr = MlpBatchNormTrainer .build_dataset (words [:n1 ], self .stoi ) # 80%
56
+ Xdev , Ydev = MlpBatchNormTrainer .build_dataset (words [n1 :n2 ],self .stoi ) # 10%
57
+ Xte , Yte = MlpBatchNormTrainer .build_dataset (words [n2 :],self .stoi ) # 10%
58
+ g = torch .Generator ().manual_seed (2147483647 )
59
+ self .assertEqual (self .n_embd , 10 )
60
+ C = torch .randn ((self .vocab_size , self .n_embd ), generator = g )
61
+
62
+ self .assertEqual (C .shape , torch .Size ([27 , 10 ]))
63
+
64
+ # sequential 6 MLP layers
65
+ layers = [
66
+ Linear (self .n_embd * self .n_block_size , self .n_hidden , bias = False ), BatchNorm1d (self .n_hidden ), Tanh (),
67
+ Linear (self .n_hidden , self .n_hidden , bias = False ), BatchNorm1d (self .n_hidden ), Tanh (),
68
+ Linear (self .n_hidden , self .n_hidden , bias = False ), BatchNorm1d (self .n_hidden ), Tanh (),
69
+ Linear (self .n_hidden , self .n_hidden , bias = False ), BatchNorm1d (self .n_hidden ), Tanh (),
70
+ Linear (self .n_hidden , self .n_hidden , bias = False ), BatchNorm1d (self .n_hidden ), Tanh (),
71
+ Linear (self .n_hidden , self .vocab_size , bias = False ), BatchNorm1d (self .vocab_size )
72
+ ]
73
+
74
+ with torch .no_grad ():
75
+ # here, out latest layer is a batch norm layer and we wouldn't change the weights to make the softmax less confident
76
+ # we would like to changing the gamma(from the batch norm paper algorithm1)
77
+ # because gamma remember int he batchnorm is the variable that multiplicatively interacts with the output of thah normalization
78
+ layers [- 1 ].gamma *= 0.1
79
+
80
+ # all pther layers: apply again
81
+ for layer in layers [:- 1 ]:
82
+ if isinstance (layer , Linear ):
83
+ layer .weight *= 5 / 3 # booster the linear layer by the gain, the number from torch document
84
+ # [C] the embedding matrix and all the parameters of all the layers
85
+ parameters = [C ]+ [p for layer in layers for p in layer .parameters ()]
86
+ print (sum (p .nelement () for p in parameters )) # number of parameters in total
87
+ for p in parameters :
88
+ p .requires_grad = True
89
+
90
+
91
+ # training loop
92
+ lossi = []
93
+ ud = []
94
+
95
+ for i in range (MlpBatchNormTrainer .max_steps ):
96
+ # minibatch construct
97
+ ix = torch .randint (0 , Xtr .shape [0 ], (MlpBatchNormTrainer .batch_size ,), generator = g )
98
+ Xb , Yb = Xtr [ix ], Ytr [ix ] # batch X,Y
99
+
100
+ # forward pass
101
+ emb = C [Xb ] # embed the characters into vectors
102
+ x = emb .view (emb .shape [0 ], - 1 ) # flatten/concatenate the vectors
103
+ for layer in layers :
104
+ x = layer (x )
105
+ loss = F .cross_entropy (x , Yb ) # loss function
106
+
107
+ # backward pass
108
+ for layer in layers :
109
+ layer .out .retain_grad ()
110
+
111
+ for p in parameters :
112
+ p .grad = None
113
+
114
+ loss .backward ()
115
+
116
+ # update
117
+ lr = 0.1 if i < 100000 else 0.01 # step learning rate decay
118
+ for p in parameters :
119
+ p .data += - lr * p .grad
120
+
121
+ # track stats
122
+ if i % 10000 == 0 : # print every once in a while
123
+ print (f'{ i :7d} /{ MlpBatchNormTrainer .max_steps :7d} : { loss .item ():.4f} ' )
124
+ lossi .append (loss .log10 ().item ())
125
+
126
+ with torch .no_grad ():
127
+ ud .append ([(lr * p .grad .std ()/ p .data .std ()).log10 ().item () for p in parameters ])
128
+
129
+ if i >= 1000 :
130
+ break
131
+
132
+ for layer in layers :
133
+ layer .training = False
134
+
135
+ g = torch .Generator ().manual_seed (2147483647 + 10 )
136
+
137
+ for _ in range (20 ):
138
+ out = []
139
+ context = [0 ]* self .n_block_size
140
+ while True :
141
+ #forward pass the neural net
142
+ emb = C [torch .tensor ([context ])] # (1, block_size, n_embd)
143
+ x = emb .view (emb .shape [0 ], - 1 ) # concatenate the vectors
144
+ for layer in layers :
145
+ x = layer (x )
146
+ logits = x
147
+ probs = F .softmax (logits , dim = 1 )
148
+ self .assertEqual (probs .shape , torch .Size ([1 , 27 ]))
149
+ # sample from the distribution
150
+ ix = torch .multinomial (probs , num_samples = 1 , generator = g ).item ()
151
+ # shift the contetx window and track the samples
152
+ context = context [1 :]+ [ix ]
153
+ out .append (ix )
154
+ if ix == 0 :
155
+ break
156
+ self .assertIsNotNone ('' .join (self .itos [i ] for i in out [:- 1 ]))
0 commit comments