Skip to content

Commit

Permalink
Merge pull request #409 from allenai/epwalsh/tulu-fine-tune
Browse files Browse the repository at this point in the history
Configs and changes for instructions fine-tuning on Tulu
  • Loading branch information
epwalsh authored Jan 10, 2024
2 parents a2e1d13 + 6514607 commit 3e3df71
Show file tree
Hide file tree
Showing 9 changed files with 519 additions and 162 deletions.
102 changes: 102 additions & 0 deletions configs/mcli/mitchish-instruct.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
run_name: olmo-7b-instruct
image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
gpu_num: 64
#gpu_num: 8
#cluster: r12z3
cluster: r7z2
gpu_type: a100_40gb
integrations:
- integration_type: git_repo
git_repo: allenai/LLM
git_branch: epwalsh/tulu-fine-tune
pip_install: -e .
ssh_clone: true
command: |-
# NOTE: For some reason getting S3 and R2 authentication working both from the command line and
# from Python proved to be challenging, maybe because Mosaic's server are in Australia.
# In the end I had to use separate methods to get everything working:
# 1. AWS config files for CLI access.
# 2. Environment variables for boto3 access (to S3 only).
# Since we only need CLI access prior to training, we remove the AWS config files before launching
# the training job. Otherwise the environment variables won't work.
# Install aws cli
apt-get update
apt-get install zip unzip
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
unzip awscliv2.zip
sudo ./aws/install
cd LLM
pip freeze
# Prepare environment including AWS config files for both S3 and R2 access.
mkdir -p /root/.cache/torch
mkdir /root/checkpoint-unsharded
mkdir /root/data
mkdir /root/.aws
touch /root/.aws/credentials /root/.aws/config
echo '[s3]' >> /root/.aws/credentials
echo "aws_access_key_id = ${AWS_ACCESS_KEY_ID}" >> /root/.aws/credentials
echo "aws_secret_access_key = ${AWS_SECRET_ACCESS_KEY}" >> /root/.aws/credentials
echo '' >> /root/.aws/credentials
echo '[r2]' >> /root/.aws/credentials
echo "aws_access_key_id = ${R2_ACCESS_KEY_ID}" >> /root/.aws/credentials
echo "aws_secret_access_key = ${R2_SECRET_ACCESS_KEY}" >> /root/.aws/credentials
echo "[default]" >> /root/.aws/config
echo "region = auto" >> /root/.aws/config
echo "output = json" >> /root/.aws/config
#export S3_PROFILE=s3
#export R2_PROFILE=r2
export OMP_NUM_THREADS=8
export LOG_FILTER_TYPE=local_rank0_only
# Download checkpoint (everything except optimizer state).
checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
echo "Downloading checkpoint '${checkpoint}'..."
# Download config.
aws s3 cp --profile=r2 --region=auto \
--endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
"${checkpoint}/config.yaml" /root/checkpoint-unsharded/
# Download trainer state.
aws s3 cp --profile=r2 --region=auto \
--endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
"${checkpoint}/train.pt" /root/checkpoint-unsharded/
# Download model weights.
aws s3 cp --profile=r2 --region=auto \
--endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
"${checkpoint}/model.pt" /root/checkpoint-unsharded/
# Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3.
rm -rf /root/.aws
# Download data (it's small enough so might as well).
echo "Downloading data..."
aws s3 cp \
s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy \
/root/data/data.npy
torchrun \
--master_addr "$MASTER_ADDR" \
--master_port "$MASTER_PORT" \
--nnodes "$NUM_NODES" \
--node_rank "$NODE_RANK" \
--nproc_per_node 8 \
scripts/train.py configs/mitchish-instruct.yaml \
--run_name=mitchish-mcli-2.5T-instruct-2e-6 \
--optimizer.learning_rate=2e-6 \
--save_overwrite \
--time_limit=169200 \
--data.paths=[/root/data/data.npy] \
--save_interval_unsharded=10000 \
--load_path=/root/checkpoint-unsharded \
--reset_optimizer_state \
--reset_trainer_state \
--compile=null \
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=size_based
182 changes: 182 additions & 0 deletions configs/mitchish-instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
run_name: v1_5-mix-medium-mitch-ish
seed: 6198
dry_run: false

wandb:
name: ${run_name}
project: olmo-medium
group: v1_5-mix

model:
d_model: 4096
n_heads: 32
n_layers: 32
# mlp_ratio: 6
mlp_hidden_size: 22016
weight_tying: false
alibi: false
rope: true
flash_attention: true
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
include_bias: false
block_type: sequential
layer_norm_type: default
layer_norm_with_affine: false
bias_for_layer_norm: false
attention_layer_norm_with_affine: false
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 2048
vocab_size: 50280
embedding_size: 50304
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_fn: mitchell

compile:
fullgraph: false

optimizer:
name: adamw
learning_rate: 2e-5
weight_decay: 0.0
betas:
- 0.9
- 0.999
metrics_log_interval: 10

scheduler:
name: linear_with_warmup
t_warmup: 100
alpha_f: 0.001

tokenizer:
identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
truncate_direction: right

save_folder: runs/${run_name}
remote_save_folder: s3://ai2-llm/checkpoints/7b/${run_name}
save_overwrite: false
# Sharded checkpoints (best for restarts)
save_interval: 1000
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: null # getting errors on LUMI right now
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 2ep
global_train_batch_size: 128
device_train_microbatch_size: 2
time_limit: null

precision: amp_bf16

fsdp:
wrapping_strategy: by_block
precision: mixed

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
window_size: 20

eval_interval: ${save_interval}
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
- label: all-small-ppl-validation
data:
num_workers: 0
drop_last: true
# pin_memory: true
# prefetch_factor: 1
# persistent_workers: false
# timeout: 0
datasets:
4chan-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy
c4_100_domains-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy
c4_en-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy
gab-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy
ice-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy
m2d2_s2orc-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy
m2d2_wiki-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy
manosphere-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy
mc4_en-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy
pile-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy
ptb-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy
twitterAEE-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy
wikitext_103-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy

##########################
# Downstream evaluations #
##########################
- label: piqa
type: downstream

- label: hellaswag
type: downstream

- label: winogrande
type: downstream

- label: openbook_qa
type: downstream

# - label: boolq # requires implemention of the pmi_dc matrix
# type: downstream

- label: sciq
type: downstream

- label: arc_easy
type: downstream

# - label: arc_challenge # requires implemention of the pmi_dc matrix
# type: downstream

- label: copa
type: downstream

- label: rte
type: downstream

- label: commitment_bank
type: downstream

- label: mrpc
type: downstream

- label: sst2
type: downstream

data:
pad_direction: right
num_workers: 0
drop_last: true
pin_memory: true
prefetch_factor: 1
persistent_workers: true
timeout: 0
generate_attention_mask: true
paths:
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy
6 changes: 4 additions & 2 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,10 @@ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
barrier()

# Finally if all went well replace the temporary directory with the actual
# checkpoint directory.
if get_fs_local_rank() == 0:
# checkpoint directory. Note that for some checkpointers the local rank 0 might
# not use this folder, so it may not exist; FullCheckpointer, for example, only creates
# this for global rank 0.
if get_fs_local_rank() == 0 and checkpoint_dir_tmp.exists():
# Replace temp directory with target checkpoint directory.
try:
checkpoint_dir_tmp.replace(checkpoint_dir)
Expand Down
8 changes: 7 additions & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ class DataConfig(BaseConfig):
paths: Optional[List[str]] = None
datasets: Optional[Dict[str, List[str]]] = None
pad_direction: PaddingDirection = PaddingDirection.right
generate_attention_mask: bool = False
num_workers: int = 0
drop_last: bool = False
pin_memory: bool = False
Expand Down Expand Up @@ -683,7 +684,7 @@ class TrainConfig(BaseConfig):
Used to seed all initial RNG states.
"""

epoch: int = 0
epoch: Optional[int] = None
"""
Increment this when starting a new epoch.
"""
Expand Down Expand Up @@ -832,6 +833,11 @@ class TrainConfig(BaseConfig):
curve (according to the current learning rate schedule settings), and continues from there.
"""

reset_trainer_state: bool = False
"""
When this is set we don't restore the trainer state from a checkpoint.
"""

sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
"""
The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
Expand Down
4 changes: 3 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def build_memmap_dataset(
chunk_size=train_config.model.max_sequence_length,
metadata=metadata,
include_instance_metadata=include_instance_metadata,
pad_token_id=train_config.model.pad_token_id,
generate_attention_mask=data_config.generate_attention_mask,
)


Expand Down Expand Up @@ -93,7 +95,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
IterableDataset(
dataset, # type: ignore
train_config.global_train_batch_size,
seed=train_config.seed + train_config.epoch,
seed=train_config.seed + (train_config.epoch or 0),
shuffle=True,
drop_last=train_config.data.drop_last,
work_dir=work_dir,
Expand Down
36 changes: 23 additions & 13 deletions olmo/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,26 @@ def __init__(
assert global_batch_size % self.world_size == 0
self.device_batch_size = global_batch_size // self.world_size
self.global_indices_file: Optional[Path] = None
self.work_dir = work_dir

if work_dir is not None:
self.global_indices_file = Path(work_dir) / "global_indices.npy"
if self.fs_local_rank == 0:
log.info("Saving global data order indices...")
self.global_indices_file.parent.mkdir(parents=True, exist_ok=True)
global_indices = self._build_global_indices()
global_indices_mmap = np.memmap(
self.global_indices_file, dtype=np.uint32, mode="w+", shape=(len(global_indices),)
)
global_indices_mmap[:] = global_indices
global_indices_mmap.flush()
del global_indices_mmap
log.info("Global data order indices saved to '%s'", self.global_indices_file)
barrier()
self._build_and_save_global_indices()

def _build_and_save_global_indices(self):
assert self.work_dir is not None
self.global_indices_file = Path(self.work_dir) / "global_indices.npy"
if self.fs_local_rank == 0:
log.info("Saving global data order indices...")
self.global_indices_file.parent.mkdir(parents=True, exist_ok=True)
global_indices = self._build_global_indices()
global_indices_mmap = np.memmap(
self.global_indices_file, dtype=np.uint32, mode="w+", shape=(len(global_indices),)
)
global_indices_mmap[:] = global_indices
global_indices_mmap.flush()
del global_indices_mmap
log.info("Global data order indices saved to '%s'", self.global_indices_file)
barrier()

def _build_global_indices(self) -> np.ndarray:
assert len(self.dataset) < np.iinfo(np.uint32).max
Expand Down Expand Up @@ -111,6 +116,11 @@ def get_global_indices(self) -> np.ndarray:
else:
return self._build_global_indices()

def reshuffle(self):
self.seed += 1
if self.work_dir is not None:
self._build_and_save_global_indices()

def __iter__(self) -> Iterator[Dict[str, Any]]:
indices = self.get_global_indices()

Expand Down
Loading

0 comments on commit 3e3df71

Please sign in to comment.