diff --git a/.gitignore b/.gitignore index bdc5b27..75035a1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ __pycache__/ *.pyc *.pyo -*.pyd \ No newline at end of file +*.pyd +exp/ +wandb/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 10a8326..0a5caef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,13 @@ # .pre-commit-config.yaml repos: - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - language_version: python3 - files: \.py$ - exclude: ^tests/ + # - repo: https://github.com/pycqa/isort + # rev: 5.13.2 + # hooks: + # - id: isort + # language_version: python3 + # files: \.py$ + # exclude: ^tests/ - repo: https://github.com/psf/black rev: 24.4.2 hooks: diff --git a/dataset_helper.py b/dataset_helper.py index d99a8e9..cf15c0d 100644 --- a/dataset_helper.py +++ b/dataset_helper.py @@ -14,6 +14,7 @@ from datasets import load_dataset from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, Sampler +from transformers import PreTrainedTokenizer class BatchSamplerSimilarLength(Sampler): @@ -33,7 +34,8 @@ def __init__( batch_size (int): Batch size to be used to compute upper limit of tokens. seq_len (int): Sequence length to be used to compute upper limit of - tokens. + tokens in a given batch. This will be directly correlated to + available GPU memory. shuffle (bool, optional): Shuffle the dataset before sorting and after getting the buckets. Defaults to True. """ @@ -92,19 +94,43 @@ def __len__(self) -> int: class DatasetHelper: def __init__( self, - tokenizer, - batch_size, - seq_len, - num_workers, - persistent_workers, - split="train", + tokenizer: PreTrainedTokenizer, + batch_size: int, + seq_len: int, + max_seq_len: int, + num_workers: int, + persistent_workers: bool, + split: str = "train", ): + """ + Construct a dataset loader within this class. + + Args: + tokenizer (PreTrainedTokenizer): Tokenizer instance. + batch_size (int): Batch size to be used to compute upper limit of + tokens. + seq_len (int): Sequence length to be used to compute upper limit of + tokens in a given batch. This will be directly correlated to + available GPU memory. + max_seq_len (int): Maximum sequence length to preserve. If the + sequence is larger than max_seq_len then it will be clipped to + max_seq_len. If it is shorter than this than it will be padded + to maximum sequence length of the batch. + num_workers (int): Number of workers for dataset loader. + persistent_workers (bool): Use persistent_worker in torch dataloader. + split (str, optional): Dataset split. Either "train" or "validation". + Defaults to "train". + + Raises: + RuntimeError: If unsupported split is provided. + """ if split not in ["train", "validation"]: raise RuntimeError( f"Split for dataloader shall be from 'train' or 'validation' only." ) data_loader = load_dataset("roneneldan/TinyStories", split=split) + self.max_seq_len = max_seq_len self.tokenizer = tokenizer self.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) batch_sampler = BatchSamplerSimilarLength( @@ -132,20 +158,23 @@ def collate_batch(self, batch_data: List) -> Tuple[torch.Tensor]: for data in batch_data: text = data["text"] - tokenized_data = self.tokenizer(text) + tokenized_data = self.tokenizer( + text, max_length=self.max_seq_len, truncation=True + ) + tokenized_input_ids = tokenized_data["input_ids"] + tokenized_attn_mask = tokenized_data["attention_mask"] + input_ids.append( - torch.tensor(tokenized_data["input_ids"], dtype=torch.int32) + torch.tensor(tokenized_input_ids, dtype=torch.int32) ) labels.append( torch.tensor( - tokenized_data["input_ids"][1:] + [self.pad_token_id], + tokenized_input_ids[1:] + [self.pad_token_id], dtype=torch.int32, ) ) attention_mask.append( - torch.tensor( - tokenized_data["attention_mask"], dtype=torch.int32 - ) + torch.tensor(tokenized_attn_mask, dtype=torch.int32) ) input_ids = pad_sequence( input_ids, batch_first=True, padding_value=self.pad_token_id @@ -159,5 +188,40 @@ def collate_batch(self, batch_data: List) -> Tuple[torch.Tensor]: labels = labels.to(torch.long) return input_ids, attention_mask, labels - def get_loader(self): + def get_loader(self) -> DataLoader: + """ + Get instance of dataloader. + + Returns: + DataLoader: Dataloader based on split. + """ return self.dataloader + + +if __name__ == "__main__": + from models.helper import train_config_factory + from utils.misc import get_tokenizer + + model_type = "gpt" + train_config = train_config_factory(model_type) + tokenizer = get_tokenizer(model_type) + + valid_helper = DatasetHelper( + tokenizer, + train_config.batch_size, + train_config.avg_seq_len_in_batch, + train_config.max_seq_len, + train_config.num_workers, + train_config.persistent_workers, + "validation", + ) + valid_loader = valid_helper.get_loader() + + for batch_idx, (input_ids, attn_mask, labels) in enumerate(valid_loader): + print(f"Input ids: {input_ids}") + print(f"Attention mask: {attn_mask}") + print(f"Labels: {labels}") + print(f"Input ids shape: {input_ids.shape}") + print(f"Attention mask shape: {attn_mask.shape}") + print(f"Labels shape: {labels.shape}") + break diff --git a/export.py b/export.py new file mode 100644 index 0000000..e69de29 diff --git a/models/gpt_config.py b/models/gpt_model_config.py similarity index 89% rename from models/gpt_config.py rename to models/gpt_model_config.py index ace12f3..ee917b2 100644 --- a/models/gpt_config.py +++ b/models/gpt_model_config.py @@ -25,12 +25,12 @@ class GPTConfig(Config): vocab_size = 50257 emb_dim = 128 - max_seq_len = 2048 num_heads = 4 drop_prob = 0.1 ff_multiplier = 1 num_blocks = 2 -g = GPTConfig() -g.print_config() +if __name__ == "__main__": + g = GPTConfig() + g.print_config() diff --git a/train_config.py b/models/gpt_train_config.py similarity index 62% rename from train_config.py rename to models/gpt_train_config.py index db63da9..496cacf 100644 --- a/train_config.py +++ b/models/gpt_train_config.py @@ -17,12 +17,23 @@ class GPTTrainConfig(Config): model_type = "gpt" num_epochs = 10 - batch_size = 16 - avg_seq_len_in_batch = 128 + batch_size = 4 + avg_seq_len_in_batch = 1024 + max_seq_len = 1024 num_workers = 4 persistent_workers = True lr_scheduler_type = "cosine" init_lr = 1e-3 - warmup_epochs = 0 + warmup_epochs = 2 label_smoothing = 0.1 + + use_wandb = True + resume_wandb_id = None + track_gradients = False + fp16_training = True + + +if __name__ == "__main__": + g = GPTTrainConfig() + g.print_config() diff --git a/models/helper.py b/models/helper.py index 64cde8c..a41bc69 100644 --- a/models/helper.py +++ b/models/helper.py @@ -8,12 +8,21 @@ from __future__ import annotations +import os from abc import ABC +from datetime import datetime -from utils.misc import logger +from tabulate import tabulate + +from utils.logger_utils import logger class Config(ABC): + exp_time = datetime.now().strftime("%Y-%m-%d-%H-%M") + base_exp_path = f"./exp/{exp_time}" + os.makedirs(base_exp_path, exist_ok=True) + log_file = os.path.join(base_exp_path, "log.txt") + @classmethod def print_config(cls: Config): """ @@ -31,11 +40,33 @@ def print_config(cls: Config): continue class_vars[name] = value - logger.info("#" * 49) + logger.info("#" * 50) logger.info(f"#{cls.__name__.center(47)}#") - for name, value in class_vars.items(): - logger.info(f"#\t{name} \t\t: {value}\t\t#") - logger.info("#" * 49) + data = ([name, value] for (name, value) in class_vars.items()) + table_data = tabulate(data, tablefmt="grid") + table_data_lines = table_data.split("\n") + for line in table_data_lines: + logger.info(line) + logger.info("#" * 50) + + @classmethod + def to_dict(cls: Config): + """ + Print the static variables (class variables) of the class. + """ + class_vars = {} + for name, value in cls.__dict__.items(): + if callable(value): + continue + if isinstance(value, classmethod) or isinstance( + value, staticmethod + ): + continue + if name.startswith("__") or name == "_abc_impl": + continue + class_vars[name] = value + + return class_vars def model_factory(model_type: str, config: Config): @@ -63,7 +94,7 @@ def model_factory(model_type: str, config: Config): ) -def config_factory(model_type: str) -> Config: +def model_config_factory(model_type: str) -> Config: """ Get the model config based on the model_type. @@ -78,10 +109,34 @@ def config_factory(model_type: str) -> Config: Config: Config class return for given model_type. """ if model_type == "gpt": - from models.gpt_config import GPTConfig + from models.gpt_model_config import GPTConfig return GPTConfig else: raise NotImplementedError( f"No model config implemented for model type: {model_type}" ) + + +def train_config_factory(model_type: str) -> Config: + """ + Get the model training config based on the model_type. + + Args: + model_type (str): Name of the model type. + + Raises: + NotImplementedError: If unsupported model_type is provided then this + exception will be thrown. + + Returns: + Config: Config class return for given model_type. + """ + if model_type == "gpt": + from models.gpt_train_config import GPTTrainConfig + + return GPTTrainConfig + else: + raise NotImplementedError( + f"No model config implemented for model type: {model_type}" + ) diff --git a/models/layers/learnable_pos_emb.py b/models/layers/learnable_pos_emb.py index be618a6..302fa04 100644 --- a/models/layers/learnable_pos_emb.py +++ b/models/layers/learnable_pos_emb.py @@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ batch_seq_len = x.shape[1] assert ( - batch_seq_len < self.max_seq_len + batch_seq_len <= self.max_seq_len ), "Sequence length of the batch is more than max sequence length." pos_emb_value = self.pos_emb[:, :batch_seq_len, :] x = x + pos_emb_value # [batch, seq, emb_dim] diff --git a/test.py b/test.py new file mode 100644 index 0000000..debc7c3 --- /dev/null +++ b/test.py @@ -0,0 +1,7 @@ +""" + # @ Author: Meet Patel + # @ Create Time: 2024-08-05 19:20:12 + # @ Modified by: Meet Patel + # @ Modified time: 2024-08-05 19:20:17 + # @ Description: + """ diff --git a/train.py b/train.py index 4f6698c..5bf552c 100644 --- a/train.py +++ b/train.py @@ -10,42 +10,60 @@ os.environ["CUDA_LAUNCH_BLOCKING"] = "1" import torch +from torch.cuda.amp import GradScaler, autocast from torch.nn import CrossEntropyLoss +from tqdm import tqdm -import train_config +import wandb from dataset_helper import DatasetHelper -from models.helper import config_factory, model_factory -from utils.misc import get_tokenizer, lr_scheduler_factory +from models.helper import ( + model_config_factory, + model_factory, + train_config_factory, +) +from utils.checkpoint_handler import CheckpointHandler +from utils.logger_utils import configure_logging, logger +from utils.misc import get_tokenizer, init_wandb, lr_scheduler_factory -def run(): +def run(model_type): cuda = torch.device("cuda") - model_config = config_factory(train_config.model_type) - model = model_factory(train_config.model_type, model_config) + train_config = train_config_factory(model_type) + exp_path = train_config.base_exp_path + configure_logging(train_config.log_file) + + model_config = model_config_factory(model_type) + model_config.max_seq_len = ( + train_config.max_seq_len + ) # For Positional Embeddings. + model = model_factory(model_type, model_config) model.to(cuda) - tokenizer = get_tokenizer(train_config.model_type) + tokenizer = get_tokenizer(model_type) train_helper = DatasetHelper( tokenizer, train_config.batch_size, train_config.avg_seq_len_in_batch, + train_config.max_seq_len, train_config.num_workers, train_config.persistent_workers, - "validation", + "train", ) train_loader = train_helper.get_loader() valid_helper = DatasetHelper( tokenizer, train_config.batch_size, train_config.avg_seq_len_in_batch, + train_config.max_seq_len, train_config.num_workers, train_config.persistent_workers, "validation", ) valid_loader = valid_helper.get_loader() + ckpt_handler = CheckpointHandler(exp_path, "model", max_to_keep=3) lr_scheduler = lr_scheduler_factory( train_config.lr_scheduler_type, init_lr=train_config.init_lr, @@ -60,52 +78,107 @@ def run(): ignore_index=tokenizer.convert_tokens_to_ids(tokenizer.pad_token), ) + if train_config.fp16_training: + scaler = GradScaler() + + if train_config.use_wandb: + init_wandb(train_config, model_config, train_config.resume_wandb_id) g_step = 0 + if train_config.use_wandb and train_config.track_gradients: + wandb.watch(model) for eps_num in range(train_config.num_epochs): model.train() for batch_idx, (input_ids, attn_mask, labels) in enumerate( train_loader ): optimizer.zero_grad() - input_ids = input_ids.to(cuda) - attn_mask = attn_mask.to(cuda) - labels = labels.to(cuda) + batch_size = input_ids.shape[0] + input_ids = input_ids.to(cuda, non_blocking=True) + attn_mask = attn_mask.to(cuda, non_blocking=True) + labels = labels.to(cuda, non_blocking=True) - logits = model(input_ids, attn_mask) + if train_config.fp16_training: + with autocast(): + logits = model(input_ids, attn_mask) - batch_size = logits.shape[0] - logits = logits.view(-1, logits.shape[2]) - labels = labels.view(-1).to(torch.long) + logits = logits.view(-1, logits.shape[2]) + labels = labels.view(-1).to(torch.long) - # We would take mean across all sequence length and all batches. - loss = loss_fn(logits, labels) * batch_size - loss.backward() + # We would take mean across all sequence length and all batches. + loss = loss_fn(logits, labels) * batch_size + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + logits = model(input_ids, attn_mask) - lr = lr_scheduler.step(g_step, optimizer) - optimizer.step() + logits = logits.view(-1, logits.shape[2]) + labels = labels.view(-1).to(torch.long) - print( - f"Epoch: {eps_num+1}/{train_config.num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss:.4f}, LR: {lr:.4f}" - ) + # We would take mean across all sequence length and all batches. + loss = loss_fn(logits, labels) * batch_size + loss.backward() - # model.eval() - # total_eval_loss = 0 - # for batch in eval_dataloader: - # with torch.no_grad(): - # input_ids = batch['input_ids'] - # attention_mask = batch['attention_mask'] - # labels = batch['labels'] - # outputs = model(input_ids, attention_mask=attention_mask, labels=labels) - # loss = outputs.loss - # total_eval_loss += loss.item() + lr = lr_scheduler.step(g_step, optimizer) - # avg_eval_loss = total_eval_loss / len(eval_dataloader) - # print(f"Epoch {epoch+1}, Evaluation Loss: {avg_eval_loss}") + if not train_config.fp16_training: + optimizer.step() - # Save the model - # model.save_pretrained("path/to/save/model") - # tokenizer.save_pretrained("path/to/save/tokenizer") + logger.info( + f"Epoch: {eps_num+1}/{train_config.num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Batch Size: {batch_size}, Loss: {loss:.4f}, LR: {lr:.4f}" + ) + metrics = { + "Epoch": eps_num + 1, + "Batch": batch_idx + 1, + "Loss": loss, + "LR": lr, + } + if train_config.use_wandb: + wandb.log(metrics, step=g_step) + g_step += 1 + + model.eval() + total_eval_loss = 0 + with torch.no_grad(): + for input_ids, attn_mask, labels in tqdm(valid_loader): + input_ids = input_ids.to(cuda, non_blocking=True) + attn_mask = attn_mask.to(cuda, non_blocking=True) + labels = labels.to(cuda, non_blocking=True) + + logits = model(input_ids, attn_mask) + + batch_size = logits.shape[0] + logits = logits.view(-1, logits.shape[2]) + labels = labels.view(-1).to(torch.long) + + # We would take mean across all sequence length and all batches. + loss = loss_fn(logits, labels) * batch_size + total_eval_loss += loss.item() + + avg_eval_loss = total_eval_loss / len(valid_loader) + logger.info(f"Epoch {eps_num+1}, Evaluation Loss: {avg_eval_loss:.4f}") + if train_config.use_wandb: + metrics = {"Test Loss": loss} + wandb.log(metrics, step=g_step) + + # Save the model + torch.save(model.state_dict(), "model.pth") + + checkpoint = { + "epoch": eps_num, + "global_step": g_step, + "test_loss": avg_eval_loss, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scaler": ( + scaler.state_dict() if train_config.fp16_training else None + ), + } + ckpt_handler.save(checkpoint) + + if train_config.use_wandb: + wandb.finish() if __name__ == "__main__": - run() + run(model_type="gpt") diff --git a/utils/checkpoint_handler.py b/utils/checkpoint_handler.py new file mode 100644 index 0000000..dbee620 --- /dev/null +++ b/utils/checkpoint_handler.py @@ -0,0 +1,68 @@ +""" + # @ Author: Meet Patel + # @ Create Time: 2024-08-05 20:35:25 + # @ Modified by: Meet Patel + # @ Modified time: 2024-08-05 20:35:28 + # @ Description: + """ + +import os +from typing import Dict + +import torch + + +class CheckpointHandler: + def __init__( + self, ckpt_dir: str, model_name: str = "model", max_to_keep: int = 3 + ): + """Initializer for CheckpointHandler. + This will save model whenever called and it will only keep track of + last n number of epochs data only. + + Args: + ckpt_dir (str): Directory where all the checkpoints are saved. + model_name (str, optional): Model name to use while saving the + checkpoint. Defaults to "model". + max_to_keep (int, optional): Number of last checkpoints to retain. + Defaults to 3. + """ + self.ckpt_dir = ckpt_dir + self.model_name = model_name + self.max_to_keep = max_to_keep + self.ckpt_path_history = [] + + def __get_ckpt_path(self, eps: int, loss: float) -> str: + """Function to get the checkpoint path based on given epoch and loss value. + + Args: + eps (int): Epoch number. + loss (float): Loss value. + + Returns: + str: Checkpoint path. + """ + ckpt_name = f"{self.model_name}_eps_{eps}_test_loss_{loss:.4f}.pt" + cur_ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) + return cur_ckpt_path + + def save(self, checkpoint_state: Dict) -> None: + """Function to save the current checkpoint based on provided checkpoint + dict. + + Args: + checkpoint_state (Dict): Checkpoint dict which contains epoch_num, + test_loss value, checkpoint statedict. + """ + eps = checkpoint_state["epoch"] + test_loss = checkpoint_state["test_loss"] + + cur_ckpt_path = self.__get_ckpt_path(eps, test_loss) + + torch.save(checkpoint_state, cur_ckpt_path) + + self.ckpt_path_history.append(cur_ckpt_path) + + if len(self.ckpt_path_history) > self.max_to_keep: + remove_ckpt_path = self.ckpt_path_history.pop(0) + os.remove(remove_ckpt_path) diff --git a/utils/logger_utils.py b/utils/logger_utils.py new file mode 100644 index 0000000..d503c20 --- /dev/null +++ b/utils/logger_utils.py @@ -0,0 +1,40 @@ +""" + # @ Author: Meet Patel + # @ Create Time: 2024-07-21 08:24:18 + # @ Modified by: Meet Patel + # @ Modified time: 2024-07-21 08:24:20 + # @ Description: + """ + +import logging + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - Line %(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +console_handler = logging.StreamHandler() +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + + +def configure_logging(log_file: str) -> None: + """ + Configure the logger to dump the logs in given file. + + Args: + log_file (str): Log file to store the logs. + """ + for handler in logger.handlers: + logger.removeHandler(handler) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) diff --git a/utils/lr_utils/cosine_annealing_lr.py b/utils/lr_utils/cosine_annealing_lr.py index 68a5773..9ded68b 100644 --- a/utils/lr_utils/cosine_annealing_lr.py +++ b/utils/lr_utils/cosine_annealing_lr.py @@ -9,8 +9,8 @@ import numpy as np +from utils.logger_utils import logger from utils.lr_utils.lr_scheduler import LearningRateScheduler -from utils.misc import logger class CosineAnnealing(LearningRateScheduler): diff --git a/utils/lr_utils/exp_decay_lr.py b/utils/lr_utils/exp_decay_lr.py index d6cce81..cae0ae5 100644 --- a/utils/lr_utils/exp_decay_lr.py +++ b/utils/lr_utils/exp_decay_lr.py @@ -9,8 +9,8 @@ import numpy as np +from utils.logger_utils import logger from utils.lr_utils.lr_scheduler import LearningRateScheduler -from utils.misc import logger class ExpDecay(LearningRateScheduler): diff --git a/utils/misc.py b/utils/misc.py index 43f160b..e294e49 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -6,24 +6,17 @@ # @ Description: """ -import logging +import datetime +import os from transformers import PreTrainedTokenizer +import wandb +from models.helper import Config from utils.lr_utils.cosine_annealing_lr import CosineAnnealing from utils.lr_utils.exp_decay_lr import ExpDecay from utils.lr_utils.lr_scheduler import LearningRateScheduler -logger = logging.getLogger() -logger.setLevel(logging.DEBUG) -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) -logger.addHandler(console_handler) - def get_tokenizer(model_type: str) -> PreTrainedTokenizer: """ @@ -77,21 +70,38 @@ def lr_scheduler_factory( return scheduler_class(*args, **kwargs) -def configure_logging(log_file: str) -> None: - """ - Configure the logger to dump the logs in given file. +def init_wandb( + train_config: Config, model_config: Config, resume_wandb_id: int +) -> None: + """Initiate the weights and bias tracking. To be called at the start of experiment. Args: - log_file (str): Log file to store the logs. + train_config (Config): Config instance representing training parameters. + model_config (Config): Config instance representing model architecture + parameters. + resume_wandb_id (int): Weights and Bias tracking id to be reused in + case of resuming training. Defaults to None. """ - for handler in logger.handlers: - logger.removeHandler(handler) + config_dict = {**train_config.to_dict(), **model_config.to_dict()} + wandb.init( + project="TinyLLM", + config=config_dict, + resume="allow", + id=resume_wandb_id, + ) + - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) +def get_exp_path(base_dir: str) -> str: + """Function to get the directory to same the experiment related data. - if log_file: - file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + Args: + base_dir (str): Directory to store all experiments. + + Returns: + str: Path for current experiment. + """ + start_time = datetime.datetime.now() + exp_name = start_time.strftime("%Y_%m_%d_%H_%M_%S") + cur_exp_path = os.path.join(base_dir, exp_name) + os.makedirs(cur_exp_path, exist_ok=True) + return cur_exp_path