Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions examples/research_projects/codeparrot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ python scripts/preprocessing.py \
```
During preprocessing the dataset is downloaded and stored locally as well as caches of the computations. Make sure you have more than 500GB free disk space to execute it.

### Pretokenization
The tokenization of the data might be slow during the training especially for small models. We provide code to pretokenize the data beforehand in `scripts/pretokenizing.py`, but this step is optional. The dataset is downloaded and stored locally and the tokenized data is pushed to the hub. The tokenized clean [train](https://huggingface.co/datasets/loubnabnl/tokenized-codeparrot-train) and [validation](https://huggingface.co/datasets/loubnabnl/tokenized-codeparrot-valid) datasets are available if you want to use them directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a note to the training section how to leverage the pretokenized data?

To execute the pretokenization, for the clean train data for instance, run the following command:
```bash
python scripts/pretokenizing.py \
--dataset_name lvwerra/codeparrot-clean-train \
--tokenized_data_repo tokenized-codeparrot-train
```

## Tokenizer
Before training a new model for code we create a new tokenizer that is efficient at code tokenization. To train the tokenizer you can run the following command:
```bash
Expand All @@ -77,7 +87,8 @@ python scripts/initialize_model.py \
```
This will initialize a new model with the architecture and configuration of `gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the the hub.

Now that the dataset, tokenizer, and model are ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
We can either pass the name of a text dataset or a pretokenized dataset which speeds up training a bit.
Now that the tokenizer and model are also ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.

First you need to configure `accelerate` and login to Weights & Biases:

Expand All @@ -89,7 +100,7 @@ wandb login
Note that during the `accelerate` configuration we enabled FP16. Then to train the large model you can run

```bash
python scripts/codeparrot_training.py
accelerate launch scripts/codeparrot_training.py
```

If you want to train the small model you need to make some modifications:
Expand Down
44 changes: 28 additions & 16 deletions examples/research_projects/codeparrot/scripts/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ class TrainingArguments:
"""

model_ckpt: Optional[str] = field(
default="lvwerra/codeparrot",
metadata={"help": "Model name or path of model to be trained."},
default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be trained."}
)
save_dir: Optional[str] = field(
default="./",
metadata={"help": "Save dir where model repo is cloned and models updates are saved to."},
default="./", metadata={"help": "Save dir where model repo is cloned and models updates are saved to."}
)
dataset_name_train: Optional[str] = field(
default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
Expand All @@ -39,7 +37,7 @@ class TrainingArguments:
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "Use gradient checkpointing to reduce memory footprint."}
)
max_train_steps: Optional[int] = field(default=50_000, metadata={"help": "Maximum number of training steps."})
max_train_steps: Optional[int] = field(default=50000, metadata={"help": "Maximum number of training steps."})
max_eval_steps: Optional[int] = field(
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
)
Expand All @@ -50,9 +48,9 @@ class TrainingArguments:
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
)
resume_from_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "States path if the training should continue from a checkpoint folder."},
default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."}
)
tokenized: Optional[bool] = field(default=False, metadata={"help": "If True the data is pretokenized."})


@dataclass
Expand All @@ -62,8 +60,7 @@ class EvaluationArguments:
"""

model_ckpt: Optional[str] = field(
default="lvwerra/codeparrot",
metadata={"help": "Model name or path of model to be evaluated."},
default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
)
dataset_name: Optional[str] = field(
default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
Expand All @@ -83,8 +80,7 @@ class HumanEvalArguments:
"""

model_ckpt: Optional[str] = field(
default="lvwerra/codeparrot",
metadata={"help": "Model name or path of model to be evaluated."},
default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
)
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
num_tasks: Optional[int] = field(
Expand Down Expand Up @@ -157,30 +153,46 @@ class TokenizerTrainingArguments:
"""

base_tokenizer: Optional[str] = field(
default="gpt2",
metadata={"help": "Base tokenizer to build new tokenizer from."},
default="gpt2", metadata={"help": "Base tokenizer to build new tokenizer from."}
)
dataset_name: Optional[str] = field(
default="transformersbook/codeparrot-train", metadata={"help": "Dataset to train tokenizer on."}
)
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
vocab_size: Optional[int] = field(default=200000, metadata={"help": "Number of examples to train tokenizer on."})
vocab_size: Optional[int] = field(default=200_000, metadata={"help": "Number of examples to train tokenizer on."})
n_examples: Optional[int] = field(
default=32768, metadata={"help": "Number of examples to train the tokenizer on."}
)
tokenizer_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of new tokenizer."})
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})


@dataclass
class PretokenizationArguments:
"""
Configuration for data pretokenization.
"""

tokenizer_dir: Optional[str] = field(
default="lvwerra/codeparrot", metadata={"help": "Name or path to the tokenizer."}
)
dataset_name: Optional[str] = field(
default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path to the dataset to pretokenize."}
)
tokenized_data_repo: Optional[str] = field(
default="tokenized-codeparrot-train", metadata={"help": "Repo name of the pretokenized data."}
)
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})


@dataclass
class InitializationArguments:
"""
Configuration for initializing new model.
"""

config_name: Optional[str] = field(
default="gpt2-large",
metadata={"help": "Configuration to use for model initialization."},
default="gpt2-large", metadata={"help": "Configuration to use for model initialization."}
)
tokenizer_name: Optional[str] = field(
default="lvwerra/codeparrot", metadata={"help": "Tokenizer attached to model."}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,45 @@ class ConstantLengthDataset(IterableDataset):
seq_length (int): Length of token sequences to return.
num_of_sequences: Number of token sequences to keep in buffer.
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
tokenized: If true we use a pretokenized dataset.
"""

def __init__(
self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
self,
tokenizer,
dataset,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
tokenized=False,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.bos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.input_characters = seq_length * chars_per_token * num_of_sequences
self.epoch = 0
self.infinite = infinite
self.current_size = 0
self.tokenized = tokenized

if self.tokenized:
self.max_buffer_size = seq_length * num_of_sequences
self.content_field = "input_ids"
else:
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.content_field = "content"

def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.input_characters:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator)["content"])
buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
Expand All @@ -60,7 +75,10 @@ def __iter__(self):
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
if self.tokenized:
tokenized_inputs = buffer
else:
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
Expand Down Expand Up @@ -102,8 +120,12 @@ def create_dataloaders(args):
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
train_dataset = ConstantLengthDataset(
tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized
)
valid_dataset = ConstantLengthDataset(
tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
return train_dataloader, eval_dataloader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)

# Initialize new model with config
model = AutoModelForCausalLM(config)
model = AutoModelForCausalLM.from_config(config)

# Save model to the hub
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
49 changes: 49 additions & 0 deletions examples/research_projects/codeparrot/scripts/pretokenizing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import multiprocessing
import time

from datasets import load_dataset

from arguments import PretokenizationArguments
from transformers import AutoTokenizer, HfArgumentParser


def tokenize(example):
output = dict()
output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"]
output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"])
return output


parser = HfArgumentParser(PretokenizationArguments)
args = parser.parse_args()
if args.num_workers is None:
args.num_workers = multiprocessing.cpu_count()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)

t_start = time.time()
ds = load_dataset(args.dataset_name, split="train")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to stream the dataset here, too, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it should but when I activate streaming in the data loading the map() function doesn't have the argument remove_columns

print(f"Dataset loaded in {time.time()-t_start:.2f}s")

t_start = time.time()
ds = ds.map(
tokenize,
num_proc=args.num_workers,
remove_columns=[
"repo_name",
"path",
"copies",
"size",
"content",
"license",
"hash",
"line_mean",
"line_max",
"alpha_frac",
"autogenerated",
],
)
print(f"Dataset tokenized in {time.time()-t_start:.2f}s")

t_start = time.time()
ds.push_to_hub(args.tokenized_data_repo)
print(f"Data pushed to the hub in {time.time()-t_start:.2f}s")