Skip to content

Commit

Permalink
Merge pull request #443 from allenai/mitchish65-2-gqa
Browse files Browse the repository at this point in the history
GQA into Mitchich65
  • Loading branch information
epwalsh authored Mar 18, 2024
2 parents cfc362c + b1e3855 commit fa8ec33
Show file tree
Hide file tree
Showing 46 changed files with 1,963 additions and 206 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added support for Grouped Query Attention.

### Changed

- Rename `Olmo` to `OLMo` everywhere in the codebase
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ base-image :
docker build -f docker/Dockerfile.base -t $(IMAGE_NAME_BASE)-base .

.PHONY : gantry-image
gantry-image : base-image
gantry-image :
docker build -f docker/Dockerfile.gantry -t $(IMAGE_NAME_BASE)-gantry .
beaker image create $(IMAGE_NAME_BASE)-gantry --name $(IMAGE_NAME_BASE)-gantry-tmp --workspace $(BEAKER_WORKSPACE)
beaker image delete $(GANTRY_IMAGE) || true
Expand Down
2 changes: 1 addition & 1 deletion configs/llama7-s3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model:
n_layers: 32
mlp_hidden_size: 22016
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
2 changes: 1 addition & 1 deletion configs/llama7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model:
n_layers: 32
mlp_hidden_size: 22016
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
1 change: 1 addition & 0 deletions configs/mcli/ananya-1b-ib.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ model:
pad_token_id: 1
init_device: meta
init_fn: mitchell
flash_attention: true

compile: null # causes instability on AMD GPUs

Expand Down
1 change: 1 addition & 0 deletions configs/mcli/ananya-1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ model:
pad_token_id: 1
init_device: meta
init_fn: normal
flash_attention: true

compile: null # causes instability on AMD GPUs

Expand Down
1 change: 1 addition & 0 deletions configs/mcli/mitchish-final.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ command: |-
--save_interval_unsharded=10000 \
--load_path=/root/checkpoint-unsharded \
--compile=null \
--model.flash_attention=true \
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=size_based \
--remote_save_folder=s3://ai2-llm/checkpoints/7b/${run_name} \
Expand Down
1 change: 1 addition & 0 deletions configs/mcli/mitchish-instruct.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ command: |-
--reset_trainer_state \
--reset_optimizer_state \
--compile=null \
--model.flash_attention=true \
--activation_checkpointing=whole_layer \
--fsdp.wrapping_strategy=size_based \
--max_duration=5ep
1 change: 1 addition & 0 deletions configs/mcli/mitchish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ command: |-
--save_interval_unsharded=10000 \
--load_path=${checkpoint} \
--compile=null \
--model.flash_attention=true \
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=size_based \
--remote_save_folder=s3://ai2-llm/checkpoints/7b/${run_name} \
Expand Down
40 changes: 40 additions & 0 deletions configs/mcli/mitchish70.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: olmo-70b
image: mosaicml/pytorch:2.2.1_cu121-python3.11-ubuntu20.04
compute:
cluster: r14z3p2
gpus: 256
gpu_type: h100_80gb
integrations:
- integration_type: git_repo
git_repo: allenai/OLMo
git_branch: mitchish65-2
pip_install: -e .[train]
ssh_clone: true
env_variables:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
OMP_NUM_THREADS: "8"
LOG_FILTER_TYPE: local_rank0_only
command: |-
# Make sure we have a recent flash-attn.
# NOTE: only pinning flash-attn here to future proof it.
pip install flash-attn==2.5.3 --no-build-isolation
# Show packages for debugging.
pip freeze
# Prepare environment.
cd OLMo
mkdir -p /root/.cache/torch
torchrun \
--master_addr "$MASTER_ADDR" \
--master_port "$MASTER_PORT" \
--nnodes "$NUM_NODES" \
--node_rank "$NODE_RANK" \
--nproc_per_node 8 \
scripts/train.py configs/mitchish70-s3.yaml \
--run_name=mitchish70-001 \
--wandb.group=mitchish70 \
--global_train_batch_size=768 \
--device_train_microbatch_size=3 \
--save_overwrite
1 change: 1 addition & 0 deletions configs/mcli/v1-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ command: |-
--nproc_per_node 8 \
scripts/train.py configs/v1-mix-medium-mitch-ish-s3.yaml \
--run_name=v1-mix-mitch-ish \
--model.flash_attention=true \
--global_train_batch_size=2160
1 change: 1 addition & 0 deletions configs/mcli/v1-mix-medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ command: |-
--nproc_per_node 8 \
scripts/train.py configs/v1-mix-medium-s3.yaml \
--run_name=v1-mix-medium \
--model.flash_attention=true \
--scheduler.name=linear_with_warmup \
--global_train_batch_size=2160
1 change: 1 addition & 0 deletions configs/mcli/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ command: |-
--run_name=v1_5-mix-mitch-ish \
--wandb.name=v1_5-mix-mitch-ish-mcli-final \
--global_train_batch_size=2160 \
--model.flash_attention=true \
--time_limit=169200
# We added these flags in order to get a final checkpoint where we decayed the LR down to 0.
Expand Down
1 change: 1 addition & 0 deletions configs/mcli/v1_5-mix-medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ command: |-
scripts/train.py configs/v1_5-mix-medium-s3.yaml \
--run_name=v1_5-mix-mcli \
--scheduler.name=linear_with_warmup \
--model.flash_attention=true \
--global_train_batch_size=2160
2 changes: 1 addition & 1 deletion configs/mitchish35.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model:
weight_tying: false
alibi: false
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
2 changes: 1 addition & 1 deletion configs/mitchish50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model:
weight_tying: false
alibi: false
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
Loading

0 comments on commit fa8ec33

Please sign in to comment.