-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2ef802
commit e343f8f
Showing
21 changed files
with
1,096 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Ignore __pycache__ directories | ||
__pycache__/ | ||
*.pyc | ||
*.pyo | ||
*.pyd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# .pre-commit-config.yaml | ||
|
||
repos: | ||
- repo: https://github.com/pycqa/isort | ||
rev: 5.13.2 | ||
hooks: | ||
- id: isort | ||
language_version: python3 | ||
files: \.py$ | ||
exclude: ^tests/ | ||
- repo: https://github.com/psf/black | ||
rev: 24.4.2 | ||
hooks: | ||
- id: black | ||
language_version: python3 | ||
additional_dependencies: [] | ||
|
||
files: \.py$ | ||
exclude: ^tests/ # Exclude the tests directory if desired |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
""" | ||
# @ Author: Meet Patel | ||
# @ Create Time: 2024-06-30 13:55:42 | ||
# @ Modified by: Meet Patel | ||
# @ Modified time: 2024-07-20 16:53:09 | ||
# @ Description: | ||
""" | ||
|
||
import random | ||
from typing import Generator, List, Tuple | ||
|
||
import datasets | ||
import torch | ||
from datasets import load_dataset | ||
from torch.nn.utils.rnn import pad_sequence | ||
from torch.utils.data import DataLoader, Sampler | ||
|
||
|
||
class BatchSamplerSimilarLength(Sampler): | ||
def __init__( | ||
self, | ||
dataset_iterator: datasets.Dataset, | ||
batch_size: int, | ||
seq_len: int, | ||
shuffle: bool = True, | ||
): | ||
"""Initializer to load the dataset and sort as per the source sequences | ||
(german sequence) and prepare the indices in a bucketed manner where the | ||
sequences with similar lengths are grouped together. | ||
Args: | ||
dataset_iterator (datasets.Dataset): Dataset iterator. | ||
batch_size (int): Batch size to be used to compute upper limit of | ||
tokens. | ||
seq_len (int): Sequence length to be used to compute upper limit of | ||
tokens. | ||
shuffle (bool, optional): Shuffle the dataset before sorting and | ||
after getting the buckets. Defaults to True. | ||
""" | ||
self.batch_size = batch_size | ||
self.seq_len = seq_len | ||
self.total_tokens_in_batch = self.batch_size * self.seq_len | ||
self.shuffle = shuffle | ||
|
||
self.indices = [ | ||
(i, len(data["text"].split(" "))) | ||
for i, data in enumerate(dataset_iterator) | ||
] | ||
|
||
if self.shuffle: | ||
random.shuffle(self.indices) | ||
|
||
sorted_indices = sorted(self.indices, key=lambda x: x[1]) | ||
|
||
self.all_batch_idx = [] | ||
single_batch_idx = [] | ||
cummulative_token_len = 0 | ||
|
||
for idx, token_len in sorted_indices: | ||
cummulative_token_len += token_len | ||
|
||
single_batch_idx.append(idx) | ||
|
||
if cummulative_token_len > self.total_tokens_in_batch: | ||
self.all_batch_idx.append(single_batch_idx.copy()) | ||
single_batch_idx.clear() | ||
cummulative_token_len = 0 | ||
|
||
if self.shuffle: | ||
random.shuffle(self.all_batch_idx) | ||
|
||
def __iter__(self) -> Generator[int, int, int]: | ||
""" | ||
Function will fetch list of indices to be used to generate a batch. | ||
Yields: | ||
List[int]: Yields list of indices for batch generation. | ||
""" | ||
for batch_idx in self.all_batch_idx: | ||
random.shuffle(batch_idx) | ||
yield batch_idx | ||
|
||
def __len__(self) -> int: | ||
"""Function to get the total number of batches which can be generated. | ||
Returns: | ||
int: Number of batches from the given dataset. | ||
""" | ||
return len(self.all_batch_idx) | ||
|
||
|
||
class DatasetHelper: | ||
def __init__( | ||
self, | ||
tokenizer, | ||
batch_size, | ||
seq_len, | ||
num_workers, | ||
persistent_workers, | ||
split="train", | ||
): | ||
if split not in ["train", "validation"]: | ||
raise RuntimeError( | ||
f"Split for dataloader shall be from 'train' or 'validation' only." | ||
) | ||
data_loader = load_dataset("roneneldan/TinyStories", split=split) | ||
|
||
self.tokenizer = tokenizer | ||
self.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) | ||
batch_sampler = BatchSamplerSimilarLength( | ||
data_loader, batch_size, seq_len, shuffle=True | ||
) | ||
self.dataloader = DataLoader( | ||
data_loader, | ||
batch_sampler=batch_sampler, | ||
collate_fn=self.collate_batch, | ||
num_workers=num_workers, | ||
persistent_workers=persistent_workers, | ||
) | ||
|
||
def collate_batch(self, batch_data: List) -> Tuple[torch.Tensor]: | ||
"""Function to tokenize sequences and prepare the mask for training. | ||
Args: | ||
batch_data (List): List of Tuple of source and target sequences. | ||
Returns: | ||
Tuple[torch.Tensor]: Tuple of source token, target token, source | ||
mask, target mask and target labels. | ||
""" | ||
input_ids, attention_mask, labels = [], [], [] | ||
|
||
for data in batch_data: | ||
text = data["text"] | ||
tokenized_data = self.tokenizer(text) | ||
input_ids.append( | ||
torch.tensor(tokenized_data["input_ids"], dtype=torch.int32) | ||
) | ||
labels.append( | ||
torch.tensor( | ||
tokenized_data["input_ids"][1:] + [self.pad_token_id], | ||
dtype=torch.int32, | ||
) | ||
) | ||
attention_mask.append( | ||
torch.tensor( | ||
tokenized_data["attention_mask"], dtype=torch.int32 | ||
) | ||
) | ||
input_ids = pad_sequence( | ||
input_ids, batch_first=True, padding_value=self.pad_token_id | ||
) | ||
attention_mask = pad_sequence( | ||
attention_mask, batch_first=True, padding_value=0 | ||
) | ||
labels = pad_sequence( | ||
labels, batch_first=True, padding_value=self.pad_token_id | ||
) | ||
labels = labels.to(torch.long) | ||
return input_ids, attention_mask, labels | ||
|
||
def get_loader(self): | ||
return self.dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
""" | ||
# @ Author: Meet Patel | ||
# @ Create Time: 2024-07-12 22:44:09 | ||
# @ Modified by: Meet Patel | ||
# @ Modified time: 2024-07-20 16:42:39 | ||
# @ Description: | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
# @ Author: Daniel Lin | ||
# @ Create Time: 2024-07-06 22:43:40 | ||
# @ Modified by: Daniel Lin | ||
# @ Modified time: 2024-07-08 21:56:31 | ||
# @ Description: | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from models.helper import Config | ||
from models.layers.decoder_block import TransformerDecoderBlock | ||
from models.layers.learnable_pos_emb import LearnablePositionalEmbeddings | ||
|
||
|
||
class GPTModel(nn.Module): | ||
def __init__(self, config: Config): | ||
""" | ||
Creates a GPT model from given config. | ||
Args: | ||
config (Config): Instance of Config which contains model | ||
architecture related parameters. | ||
""" | ||
super().__init__() | ||
self.emb_layer = nn.Embedding(config.vocab_size, config.emb_dim) | ||
self.pos_emb_layer = LearnablePositionalEmbeddings( | ||
config.max_seq_len, config.emb_dim | ||
) | ||
self.transformer_blocks = nn.ModuleList( | ||
[ | ||
TransformerDecoderBlock( | ||
config.emb_dim, | ||
config.num_heads, | ||
config.ff_multiplier, | ||
config.drop_prob, | ||
) | ||
for _ in range(config.num_blocks) | ||
] | ||
) | ||
self.MAX_NEG = torch.tensor(float("-inf")) | ||
self.lm_head = nn.Linear(config.emb_dim, config.vocab_size) | ||
self.lm_head.weight = nn.Parameter( | ||
self.emb_layer.weight | ||
) # sharing of weights between embedding layer and language model head | ||
|
||
def update_mask(self, mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Update the attention mask from [batch, seq] into [batch, 1, seq, seq]. | ||
During this the mask will undergo transformations to make if usable in | ||
attention block. | ||
Args: | ||
mask (torch.Tensor): attention mask input tensor of shape [batch, seq]. | ||
Returns: | ||
torch.Tensor: Updated attention mask of shape [batch, 1, seq, seq]. | ||
""" | ||
mask = mask.to(torch.float32) | ||
mask = mask.view( | ||
mask.shape[0], 1, 1, mask.shape[1] | ||
) # [batch, 1, 1, seq] | ||
mask = 1 - mask # invert the mask | ||
mask = torch.where(mask == 0, 0, self.MAX_NEG) | ||
return mask # [batch, 1, 1, seq] | ||
|
||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward function for GPT model. | ||
Args: | ||
x (torch.Tensor): Input tensor of shape [batch, seq]. | ||
mask (torch.Tensor): Attention mask of shape [batch, seq]. | ||
Returns: | ||
torch.Tensor: Output of GPT model of shape [batch, seq, vocab_size]. | ||
""" | ||
x = self.emb_layer(x) # [batch, seq, emb_dim] | ||
x = self.pos_emb_layer(x) # [batch, seq, emb_dim] | ||
mask = self.update_mask(mask) # [batch, 1, 1, seq] | ||
for transformer_block in self.transformer_blocks: | ||
x = transformer_block(x, mask) # [batch, seq, emb_dim] | ||
|
||
x = self.lm_head(x) # [batch, seq, vocab_size] | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
# @ Author: Meet Patel | ||
# @ Create Time: 2024-07-06 22:45:44 | ||
# @ Modified by: Meet Patel | ||
# @ Modified time: 2024-07-12 22:10:36 | ||
# @ Description: | ||
""" | ||
|
||
from models.helper import Config | ||
|
||
# class GPTConfig: | ||
# vocab_size = 40000 | ||
# emb_dim = 512 | ||
# max_seq_len = 512 | ||
# num_heads = 8 | ||
# drop_prob = 0.1 | ||
# ff_multiplier = 4 | ||
# num_blocks = 12 | ||
|
||
|
||
class GPTConfig(Config): | ||
""" | ||
Model Config for GPT model. | ||
""" | ||
|
||
vocab_size = 50257 | ||
emb_dim = 128 | ||
max_seq_len = 2048 | ||
num_heads = 4 | ||
drop_prob = 0.1 | ||
ff_multiplier = 1 | ||
num_blocks = 2 | ||
|
||
|
||
g = GPTConfig() | ||
g.print_config() |
Oops, something went wrong.