Skip to content

Commit

Permalink
Updated training code. Training is working fine now.
Browse files Browse the repository at this point in the history
  • Loading branch information
meet-minimalist committed Aug 5, 2024
1 parent e343f8f commit 3dc53aa
Show file tree
Hide file tree
Showing 15 changed files with 431 additions and 101 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
__pycache__/
*.pyc
*.pyo
*.pyd
*.pyd
exp/
wandb/
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# .pre-commit-config.yaml

repos:
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
language_version: python3
files: \.py$
exclude: ^tests/
# - repo: https://github.com/pycqa/isort
# rev: 5.13.2
# hooks:
# - id: isort
# language_version: python3
# files: \.py$
# exclude: ^tests/
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
Expand Down
92 changes: 78 additions & 14 deletions dataset_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Sampler
from transformers import PreTrainedTokenizer


class BatchSamplerSimilarLength(Sampler):
Expand All @@ -33,7 +34,8 @@ def __init__(
batch_size (int): Batch size to be used to compute upper limit of
tokens.
seq_len (int): Sequence length to be used to compute upper limit of
tokens.
tokens in a given batch. This will be directly correlated to
available GPU memory.
shuffle (bool, optional): Shuffle the dataset before sorting and
after getting the buckets. Defaults to True.
"""
Expand Down Expand Up @@ -92,19 +94,43 @@ def __len__(self) -> int:
class DatasetHelper:
def __init__(
self,
tokenizer,
batch_size,
seq_len,
num_workers,
persistent_workers,
split="train",
tokenizer: PreTrainedTokenizer,
batch_size: int,
seq_len: int,
max_seq_len: int,
num_workers: int,
persistent_workers: bool,
split: str = "train",
):
"""
Construct a dataset loader within this class.
Args:
tokenizer (PreTrainedTokenizer): Tokenizer instance.
batch_size (int): Batch size to be used to compute upper limit of
tokens.
seq_len (int): Sequence length to be used to compute upper limit of
tokens in a given batch. This will be directly correlated to
available GPU memory.
max_seq_len (int): Maximum sequence length to preserve. If the
sequence is larger than max_seq_len then it will be clipped to
max_seq_len. If it is shorter than this than it will be padded
to maximum sequence length of the batch.
num_workers (int): Number of workers for dataset loader.
persistent_workers (bool): Use persistent_worker in torch dataloader.
split (str, optional): Dataset split. Either "train" or "validation".
Defaults to "train".
Raises:
RuntimeError: If unsupported split is provided.
"""
if split not in ["train", "validation"]:
raise RuntimeError(
f"Split for dataloader shall be from 'train' or 'validation' only."
)
data_loader = load_dataset("roneneldan/TinyStories", split=split)

self.max_seq_len = max_seq_len
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
batch_sampler = BatchSamplerSimilarLength(
Expand Down Expand Up @@ -132,20 +158,23 @@ def collate_batch(self, batch_data: List) -> Tuple[torch.Tensor]:

for data in batch_data:
text = data["text"]
tokenized_data = self.tokenizer(text)
tokenized_data = self.tokenizer(
text, max_length=self.max_seq_len, truncation=True
)
tokenized_input_ids = tokenized_data["input_ids"]
tokenized_attn_mask = tokenized_data["attention_mask"]

input_ids.append(
torch.tensor(tokenized_data["input_ids"], dtype=torch.int32)
torch.tensor(tokenized_input_ids, dtype=torch.int32)
)
labels.append(
torch.tensor(
tokenized_data["input_ids"][1:] + [self.pad_token_id],
tokenized_input_ids[1:] + [self.pad_token_id],
dtype=torch.int32,
)
)
attention_mask.append(
torch.tensor(
tokenized_data["attention_mask"], dtype=torch.int32
)
torch.tensor(tokenized_attn_mask, dtype=torch.int32)
)
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=self.pad_token_id
Expand All @@ -159,5 +188,40 @@ def collate_batch(self, batch_data: List) -> Tuple[torch.Tensor]:
labels = labels.to(torch.long)
return input_ids, attention_mask, labels

def get_loader(self):
def get_loader(self) -> DataLoader:
"""
Get instance of dataloader.
Returns:
DataLoader: Dataloader based on split.
"""
return self.dataloader


if __name__ == "__main__":
from models.helper import train_config_factory
from utils.misc import get_tokenizer

model_type = "gpt"
train_config = train_config_factory(model_type)
tokenizer = get_tokenizer(model_type)

valid_helper = DatasetHelper(
tokenizer,
train_config.batch_size,
train_config.avg_seq_len_in_batch,
train_config.max_seq_len,
train_config.num_workers,
train_config.persistent_workers,
"validation",
)
valid_loader = valid_helper.get_loader()

for batch_idx, (input_ids, attn_mask, labels) in enumerate(valid_loader):
print(f"Input ids: {input_ids}")
print(f"Attention mask: {attn_mask}")
print(f"Labels: {labels}")
print(f"Input ids shape: {input_ids.shape}")
print(f"Attention mask shape: {attn_mask.shape}")
print(f"Labels shape: {labels.shape}")
break
Empty file added export.py
Empty file.
6 changes: 3 additions & 3 deletions models/gpt_config.py → models/gpt_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class GPTConfig(Config):

vocab_size = 50257
emb_dim = 128
max_seq_len = 2048
num_heads = 4
drop_prob = 0.1
ff_multiplier = 1
num_blocks = 2


g = GPTConfig()
g.print_config()
if __name__ == "__main__":
g = GPTConfig()
g.print_config()
17 changes: 14 additions & 3 deletions train_config.py → models/gpt_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,23 @@ class GPTTrainConfig(Config):
model_type = "gpt"

num_epochs = 10
batch_size = 16
avg_seq_len_in_batch = 128
batch_size = 4
avg_seq_len_in_batch = 1024
max_seq_len = 1024
num_workers = 4
persistent_workers = True

lr_scheduler_type = "cosine"
init_lr = 1e-3
warmup_epochs = 0
warmup_epochs = 2
label_smoothing = 0.1

use_wandb = True
resume_wandb_id = None
track_gradients = False
fp16_training = True


if __name__ == "__main__":
g = GPTTrainConfig()
g.print_config()
69 changes: 62 additions & 7 deletions models/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@

from __future__ import annotations

import os
from abc import ABC
from datetime import datetime

from utils.misc import logger
from tabulate import tabulate

from utils.logger_utils import logger


class Config(ABC):
exp_time = datetime.now().strftime("%Y-%m-%d-%H-%M")
base_exp_path = f"./exp/{exp_time}"
os.makedirs(base_exp_path, exist_ok=True)
log_file = os.path.join(base_exp_path, "log.txt")

@classmethod
def print_config(cls: Config):
"""
Expand All @@ -31,11 +40,33 @@ def print_config(cls: Config):
continue
class_vars[name] = value

logger.info("#" * 49)
logger.info("#" * 50)
logger.info(f"#{cls.__name__.center(47)}#")
for name, value in class_vars.items():
logger.info(f"#\t{name} \t\t: {value}\t\t#")
logger.info("#" * 49)
data = ([name, value] for (name, value) in class_vars.items())
table_data = tabulate(data, tablefmt="grid")
table_data_lines = table_data.split("\n")
for line in table_data_lines:
logger.info(line)
logger.info("#" * 50)

@classmethod
def to_dict(cls: Config):
"""
Print the static variables (class variables) of the class.
"""
class_vars = {}
for name, value in cls.__dict__.items():
if callable(value):
continue
if isinstance(value, classmethod) or isinstance(
value, staticmethod
):
continue
if name.startswith("__") or name == "_abc_impl":
continue
class_vars[name] = value

return class_vars


def model_factory(model_type: str, config: Config):
Expand Down Expand Up @@ -63,7 +94,7 @@ def model_factory(model_type: str, config: Config):
)


def config_factory(model_type: str) -> Config:
def model_config_factory(model_type: str) -> Config:
"""
Get the model config based on the model_type.
Expand All @@ -78,10 +109,34 @@ def config_factory(model_type: str) -> Config:
Config: Config class return for given model_type.
"""
if model_type == "gpt":
from models.gpt_config import GPTConfig
from models.gpt_model_config import GPTConfig

return GPTConfig
else:
raise NotImplementedError(
f"No model config implemented for model type: {model_type}"
)


def train_config_factory(model_type: str) -> Config:
"""
Get the model training config based on the model_type.
Args:
model_type (str): Name of the model type.
Raises:
NotImplementedError: If unsupported model_type is provided then this
exception will be thrown.
Returns:
Config: Config class return for given model_type.
"""
if model_type == "gpt":
from models.gpt_train_config import GPTTrainConfig

return GPTTrainConfig
else:
raise NotImplementedError(
f"No model config implemented for model type: {model_type}"
)
2 changes: 1 addition & 1 deletion models/layers/learnable_pos_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
batch_seq_len = x.shape[1]
assert (
batch_seq_len < self.max_seq_len
batch_seq_len <= self.max_seq_len
), "Sequence length of the batch is more than max sequence length."
pos_emb_value = self.pos_emb[:, :batch_seq_len, :]
x = x + pos_emb_value # [batch, seq, emb_dim]
Expand Down
7 changes: 7 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
# @ Author: Meet Patel
# @ Create Time: 2024-08-05 19:20:12
# @ Modified by: Meet Patel
# @ Modified time: 2024-08-05 19:20:17
# @ Description:
"""
Loading

0 comments on commit 3dc53aa

Please sign in to comment.