Skip to content

Commit ddbad75

Browse files
authored
Implement simple character level GPT and the trainer for it (#122)
* Add simple GPU Signed-off-by: Aisuko <[email protected]> * implement simple gpt Signed-off-by: Aisuko <[email protected]> * ruff format Signed-off-by: Aisuko <[email protected]> --------- Signed-off-by: Aisuko <[email protected]>
1 parent add6b31 commit ddbad75

File tree

7 files changed

+320
-14
lines changed

7 files changed

+320
-14
lines changed

Diff for: .gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,6 @@ params/
162162

163163
model/
164164

165-
runs/
165+
runs/
166+
167+
*.txt

Diff for: Makefile

+5-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ install:
3535

3636
.PHONY: install-dev
3737
install-dev:
38-
@poetry install -vvv --no-root
38+
@poetry install --only dev -vvv --no-root
39+
40+
.PHONY: format
41+
format:
42+
@ruff format
3943

4044
.PHONY: lint
4145
lint:

Diff for: src/models/simple_gpt.py

+231
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# coding=utf-8
2+
3+
# Copyright [2024] [SkywardAI]
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import torch
18+
import torch.nn as nn
19+
from torch.nn import functional as F
20+
from torch.utils.tensorboard import SummaryWriter
21+
22+
23+
class SimpleGPT(nn.Module):
24+
def __init__(self, vocab_size):
25+
super().__init__()
26+
# each token directly reads off the logits for the next token from the lookup table
27+
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
28+
29+
def forward(self, idx, targets=None):
30+
# idx and targets are both(B,T) tensors of integers
31+
logits = self.token_embedding_table(idx) # (B,T,C)
32+
33+
if targets is None:
34+
loss = None
35+
else:
36+
B, T, C = logits.shape
37+
logits = logits.view(B * T, C)
38+
targets = targets.view(B * T) # B*T it also ok here
39+
loss = F.cross_entropy(logits, targets)
40+
return logits, loss
41+
42+
def generate(self, idx, max_new_tokens):
43+
# idx is (B,T) array of indices in the current context
44+
for _ in range(max_new_tokens):
45+
# get the predictions
46+
logits, loss = self(idx) # call forward automatically
47+
# focus only on the last time step
48+
logits = logits[:, -1, :] # becomes (B,C)
49+
# apply softmax to get probabilities
50+
probs = F.softmax(logits, dim=-1) # (B,C)
51+
# sample from the distribution
52+
idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
53+
# append sampled index to the running sequence
54+
idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
55+
return idx
56+
57+
58+
class SimpleGPTTrainer:
59+
ds_url = "https://www.kaggleusercontent.com/kf/189948176/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..1Jb9UiqqB4H5KVymvfAOrw.cX1aqGGJcBvzxM56ysKgmNbqhQdDIr5UuLBnx2OOHQJlOAAZMwG4n27TKm2K-KN7cxSiUxsLV-Ua53hQa7Y-Eup4QhqYRs47y_IFRVHxqUYILGfbzcZHaTtdvZM2UlGcMjO3-htDg3huWl_bT6vD0wEIpWWjw_vFA8MBiFndQUgBQcjnwMI4W-KKfOpeKcaonl-3HLaIBoDau-fGAFq1KPY7h6M1Oy20c4goF86AGyVYC1E3rbipDcIuF5jLjiUXLh6B5TgpybwmygfdsKrz8qOoK0W2UFEwH0pNQ1a3le222k1s7iwnLofU7P0cznFKa4glCa6U7UQ4JMcB371Pcz9YQXA5f8dvfOymgpFQ7Jwjx6FJZ211bD3zHYq2RYM1pE5N_0U-iPOnAHlNKVSgnOWbGkaJtckDUa7MHgfbJEEcPMjPdEZRf1AofQJKoFK3QTH87wpjboUxo8F-KfKr-40K5HbNisTOuJbSeZrBE1y1EDvbBQ1rFQxei9bjyz71eZdV9pjwdYEso1C1M8I669mAGmJ4X9TDkl2eO3wItIZzE5Jy5CIug8j6-kghz-jBDr9wkiMiwWoZ3rcM8JM1dbPDV-8HDTBfiAZFDl5w4tLH8o7bKXbd004X3l4H-O5uIj0inEv07OsU-80CSzkuuQ._myoCNJ7mrE61Hp6wJtDbw/input.txt"
60+
61+
# hyperaparameters
62+
batch_size = 32 # how many independent sequences will we process in parallel
63+
block_size = 8 # what is the maximum context length got predictions?
64+
max_iters = 100
65+
eval_interval = 300
66+
learning_rate = 1e-2
67+
device = "cuda" if torch.cuda.is_available() else "cpu"
68+
eval_iters = 200
69+
70+
def __init__(self) -> None:
71+
raise Exception("This class is not meant to be instantiated")
72+
73+
@classmethod
74+
def set_hyperparameters(cls, **kwargs) -> None:
75+
"""
76+
Set hyperparameters
77+
"""
78+
cls.batch_size = kwargs.get("batch_size", 32)
79+
cls.block_size = kwargs.get("block_size", 8)
80+
cls.max_iters = kwargs.get("max_iters", 3000)
81+
cls.eval_interval = kwargs.get("eval_interval", 300)
82+
cls.learning_rate = kwargs.get("learning_rate", 1e-2)
83+
cls.device = kwargs.get(
84+
"device", "cuda" if torch.cuda.is_available() else "cpu"
85+
)
86+
cls.eval_iters = kwargs.get("eval_iters", 200)
87+
88+
@classmethod
89+
def load_data(cls, ds_file) -> str:
90+
"""
91+
Load the dataset
92+
"""
93+
with open(ds_file, "r", encoding="utf-8") as f:
94+
text = f.read()
95+
return text
96+
97+
@classmethod
98+
def unique_chars(cls, text: str) -> list:
99+
"""
100+
Get all the unique characters in the text
101+
"""
102+
return sorted(list(set(text)))
103+
104+
@classmethod
105+
def build_vocab(cls, chars: list) -> int:
106+
"""
107+
Build a vocabulary from the unique characters
108+
"""
109+
return len(chars)
110+
111+
@classmethod
112+
def stoi(cls, chars: list) -> dict:
113+
"""
114+
Convert characters to indices
115+
"""
116+
return {char: i for i, char in enumerate(chars)}
117+
118+
@classmethod
119+
def itos(cls, chars: list) -> dict:
120+
"""
121+
Convert indices to characters
122+
"""
123+
return {i: char for i, char in enumerate(chars)}
124+
125+
@classmethod
126+
def encoder(cls, stoi: dict) -> torch.Tensor:
127+
"""
128+
Convert string to list of indices
129+
"""
130+
return lambda s: [stoi[c] for c in s]
131+
132+
@classmethod
133+
def decoder(cls, itos: dict) -> torch.Tensor:
134+
"""
135+
Convert list of indices to string
136+
"""
137+
return lambda x: "".join([itos[i] for i in x])
138+
139+
@classmethod
140+
def split_to_train_validate(
141+
cls, text: torch.Tensor, train_frac: float
142+
) -> tuple[torch.Tensor, torch.Tensor]:
143+
"""
144+
Split the text into training and validation sets
145+
"""
146+
data = torch.tensor(
147+
text, dtype=torch.long
148+
) # construct a tensor with no autograd history
149+
n = int(train_frac * len(data))
150+
train_data = data[:n]
151+
val_data = data[n:]
152+
return train_data, val_data
153+
154+
@classmethod
155+
def get_batch(
156+
cls, split: str, train_data: torch.Tensor, val_data: torch.Tensor
157+
) -> tuple[torch.Tensor, torch.Tensor]:
158+
"""
159+
Generate a small batch of data of inputs x and targets y
160+
"""
161+
# generate a small batch of data of inputs x and targets y
162+
data = train_data if split == "train" else val_data
163+
ix = torch.randint(len(data) - cls.block_size, (cls.batch_size,))
164+
x = torch.stack([data[i : i + cls.block_size] for i in ix])
165+
y = torch.stack([data[i + 1 : i + cls.block_size + 1] for i in ix])
166+
x, y = x.to(cls.device), y.to(cls.device)
167+
return x, y
168+
169+
@classmethod
170+
def adam_optimizer(cls, model: SimpleGPT) -> torch.optim.Adam:
171+
"""
172+
Create an optimizer
173+
"""
174+
return torch.optim.Adam(model.parameters(), lr=cls.learning_rate)
175+
176+
@classmethod
177+
def train(
178+
cls,
179+
model: SimpleGPT,
180+
optimizer: torch.optim.Adam,
181+
train_data: torch.Tensor,
182+
val_data: torch.Tensor,
183+
) -> float:
184+
"""
185+
Train the model
186+
"""
187+
188+
@torch.no_grad()
189+
def estimate_loss():
190+
out = {}
191+
model.eval()
192+
for split in ["train", "val"]:
193+
losses = torch.zeros(cls.eval_iters)
194+
for k in range(cls.eval_iters):
195+
X, Y = cls.get_batch(split, train_data, val_data)
196+
logits, loss = model(X, Y)
197+
losses[k] = loss.item()
198+
out[split] = losses.mean()
199+
model.train()
200+
return out
201+
202+
writer = SummaryWriter()
203+
torch.manual_seed(1337)
204+
205+
for i in range(cls.max_iters):
206+
# every once in a while evaluate the loss on train and val sets
207+
if i % cls.eval_interval == 0:
208+
losses = estimate_loss()
209+
print(
210+
f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
211+
)
212+
writer.add_scalar("Loss/train", losses["train"], i)
213+
214+
# sample a batch of data
215+
xb, yb = cls.get_batch("train", train_data, val_data)
216+
217+
# evalute the loss
218+
logits, loss = model(xb, yb)
219+
optimizer.zero_grad(set_to_none=True)
220+
loss.backward()
221+
optimizer.step()
222+
writer.flush()
223+
writer.close()
224+
225+
@classmethod
226+
def sample(cls, model: SimpleGPT, max_new_tokens: int) -> list:
227+
"""
228+
Getting the sample from the model
229+
"""
230+
context = torch.zeros((1, 1), dtype=torch.long, device=cls.device)
231+
return model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()

Diff for: src/models/simple_nn.py renamed to src/pkg/dataset_helper.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@
1414
# limitations under the License.
1515

1616

17-
class SimpleNN:
18-
def __init__(self):
19-
pass
17+
import requests
18+
from pathlib import Path
2019

20+
21+
class DatasetHelper:
22+
def __init__(self) -> None:
23+
raise Exception("This class is not meant to be instantiated")
24+
25+
@classmethod
26+
def download_remote_file(cls, url: str, filename: Path) -> str:
27+
response = requests.get(url)
28+
response.raise_for_status()
29+
with open(filename, "wb") as f:
30+
f.write(response.content)
31+
32+
return filename

Diff for: src/tests/test_simple_gpt.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# coding=utf-8
2+
3+
# Copyright [2024] [SkywardAI]
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import unittest
18+
from pathlib import Path
19+
20+
from models.simple_gpt import SimpleGPT, SimpleGPTTrainer
21+
from pkg.dataset_helper import DatasetHelper
22+
23+
24+
class TestSimpleGPT(unittest.TestCase):
25+
@classmethod
26+
def setUpClass(cls) -> None:
27+
src_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent
28+
abs_file_path = os.path.join(src_dir, "input.txt")
29+
_ = DatasetHelper.download_remote_file(SimpleGPTTrainer.ds_url, abs_file_path)
30+
31+
cls.dataset = SimpleGPTTrainer.load_data(abs_file_path)
32+
cls.chars = SimpleGPTTrainer.unique_chars(cls.dataset)
33+
cls.vocabsize = SimpleGPTTrainer.build_vocab(cls.chars)
34+
cls.stoi = SimpleGPTTrainer.stoi(cls.chars)
35+
cls.itos = SimpleGPTTrainer.itos(cls.chars)
36+
37+
def test_simple_gpt_trainer(self):
38+
encoder = SimpleGPTTrainer.encoder(self.stoi)
39+
# decoder=SimpleGPTTrainer.decoder(self.itos)
40+
41+
ds = encoder(self.dataset)
42+
43+
train_data, val_data = SimpleGPTTrainer.split_to_train_validate(ds, 0.9)
44+
45+
model = SimpleGPT(self.vocabsize)
46+
optimizer = SimpleGPTTrainer.adam_optimizer(model)
47+
48+
SimpleGPTTrainer.train(model, optimizer, train_data, val_data)
49+
50+
output = SimpleGPTTrainer.sample(model, 100)
51+
52+
decoder = SimpleGPTTrainer.decoder(self.itos)
53+
output = decoder(output)
54+
55+
self.assertTrue(output)

Diff for: src/tests/test_simple_nn.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,17 @@
1919

2020

2121
class TestSimpleNN(unittest.TestCase):
22-
2322
@classmethod
2423
def setUpClass(cls) -> None:
2524
cls.model = torch.nn.Linear(1, 1)
2625
cls.criterion = torch.nn.MSELoss()
2726
cls.optimizer = torch.optim.SGD(cls.model.parameters(), lr=0.1)
2827

29-
3028
def test_simple_trainer(self):
3129
x = torch.arange(-5, 5, 0.1).view(-1, 1)
3230
y = -5 * x + 0.1 * torch.randn(x.size())
33-
trainer=SimpleTrainer(model=self.model, loss_func=self.criterion, optimizer=self.optimizer)
31+
trainer = SimpleTrainer(
32+
model=self.model, loss_func=self.criterion, optimizer=self.optimizer
33+
)
3434
self.assertTrue(trainer)
3535
trainer.train(x, y)
36-

Diff for: src/trainers/simple_trainer.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@
1717
import torch
1818
from torch.utils.tensorboard import SummaryWriter
1919

20+
2021
class SimpleTrainer:
2122
def __init__(self, **kwargs):
22-
self.model: Any=kwargs.get('model') or None
23-
self.loss_func: Any=kwargs.get('loss_func') or None
24-
self.optimizer: Any=kwargs.get('optimizer') or None
23+
self.model: Any = kwargs.get("model") or None
24+
self.loss_func: Any = kwargs.get("loss_func") or None
25+
self.optimizer: Any = kwargs.get("optimizer") or None
2526
self.writer: Any = SummaryWriter()
26-
assert (self.model and self.loss_func and self.optimizer), "Model, Loss Function and Optimizer are required"
27+
assert (
28+
self.model and self.loss_func and self.optimizer
29+
), "Model, Loss Function and Optimizer are required"
2730

28-
def train(self, x:torch.Tensor, y:torch.Tensor)-> None:
31+
def train(self, x: torch.Tensor, y: torch.Tensor) -> None:
2932
for epoch in range(10):
3033
y1 = self.model(x)
3134
loss = self.loss_func(y1, y)

0 commit comments

Comments
 (0)