Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA into Mitchich65 #443

Merged
merged 92 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 90 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
1acac5c
Switch to flash attention proper.
dirkgr Feb 8, 2024
9c6a144
LUMI Dockerfile that builds Flash Attention
dirkgr Feb 9, 2024
53f9ba6
Adds another mitchish65 script that lets us specify a random seed
dirkgr Feb 9, 2024
d4e5ef7
Merge branch 'mitchish65-2' of https://github.com/allenai/LLM into mi…
dirkgr Feb 9, 2024
d4b14f9
Merge branch 'mitchish65-2' into mitchish65-2-gqa
dirkgr Feb 9, 2024
bc41306
Makes the config actually work
dirkgr Feb 10, 2024
fbc8cba
Run script
dirkgr Feb 10, 2024
95dd681
Remove obsolete docstring
dirkgr Feb 10, 2024
b9ed52c
Merge branch 'mitchish65-2-gqa' of https://github.com/allenai/LLM int…
dirkgr Feb 10, 2024
5bf7211
We do batch size warmup.
dirkgr Feb 10, 2024
559ab28
Offline HF Datasets
dirkgr Feb 10, 2024
96abd69
Adds a config that runs against data in S3
dirkgr Feb 15, 2024
6980f07
Merge branch 'mitchish65-2-gqa' of https://github.com/allenai/LLM int…
dirkgr Feb 15, 2024
e04ee47
Gantry script
dirkgr Feb 15, 2024
c91ecb4
Nail down torch dependency
dirkgr Feb 15, 2024
d8c1ef9
Add flash to requirements
dirkgr Feb 15, 2024
013bd44
New Docker image for gantry
dirkgr Feb 15, 2024
c9c928d
Trying for a standalone Docker image
dirkgr Feb 15, 2024
e286076
Why is everything so abstracted away? What's wrong with just writing …
dirkgr Feb 15, 2024
275463a
Tighter pin for torch
dirkgr Feb 15, 2024
cb1343b
Lazy import of flash to make this work on Mac
dirkgr Feb 15, 2024
ee49920
Optional dependency on flash
dirkgr Feb 15, 2024
010ef38
Script for re-grouping wandb runs
dirkgr Feb 15, 2024
4ffa5ea
This script is obsolete.
dirkgr Feb 15, 2024
e6fd8f3
Mitchish70b in S3
dirkgr Feb 16, 2024
5369c4c
Hack to force a certain wrapping
dirkgr Feb 18, 2024
e4894bb
Fix wrap policy without hack
epwalsh Feb 19, 2024
859eb9c
use the right wrapping strategy
epwalsh Feb 19, 2024
ea847d0
Fix
epwalsh Feb 19, 2024
4f586ba
revert and try debug
epwalsh Feb 19, 2024
f66c248
I'm sick of PyTorch
epwalsh Feb 19, 2024
527f3ed
Fix dockerfile
dirkgr Feb 19, 2024
deb353b
This is how you set base images?
dirkgr Feb 20, 2024
241761f
Fix ninja install
dirkgr Feb 20, 2024
9d4c9fe
Forgot to install torch
dirkgr Feb 20, 2024
05003ef
Match torch version on LUMI
dirkgr Feb 20, 2024
f84d7e4
add mcli config
epwalsh Feb 22, 2024
95fd624
Fall back to torch when flash-attn not available
epwalsh Feb 22, 2024
2f70619
Fix the way we call flash attention
dirkgr Feb 23, 2024
7925438
Test for the fix for how we call flash attention
dirkgr Feb 23, 2024
b82a51f
Uses flash-attn's "fused" CE loss function, when available (#465)
epwalsh Feb 23, 2024
20d6a4c
Remove tests for parallel block
dirkgr Feb 23, 2024
523b23b
Fix test for flash attention
dirkgr Feb 23, 2024
4ab7aa4
Flash Attention does not support Alibi yet
dirkgr Feb 23, 2024
9e6f251
beaker configs
epwalsh Feb 23, 2024
666ffbb
Workaround so we can use GQA on LUMI
dirkgr Feb 24, 2024
dd59f98
Merge branch 'mitchish65-2-gqa' of https://github.com/allenai/LLM int…
dirkgr Feb 24, 2024
37eaf9c
The flash attention flag is now about Tri Dao's flash attention.
dirkgr Feb 24, 2024
4cb93c4
Flash attention is unsafe in general, so it's off by default, and on …
dirkgr Feb 24, 2024
c3084c5
Turn on FA for MosaicML hardware
dirkgr Feb 24, 2024
4fccc21
update gantry image build
epwalsh Feb 26, 2024
54919e0
Make fused CE loss opt-in
epwalsh Feb 26, 2024
b66a599
Allow setting different data order seed
epwalsh Feb 26, 2024
f2331a8
Skip saving final checkpoint if already exists
epwalsh Feb 27, 2024
2292afd
update config
epwalsh Feb 27, 2024
21e288a
update config
epwalsh Feb 29, 2024
394bb22
fix merge conflicts
epwalsh Mar 1, 2024
f537966
Fix?
epwalsh Mar 1, 2024
951b7ca
Add type annotation
epwalsh Mar 1, 2024
ddf8467
Update install in mcli config
epwalsh Mar 1, 2024
05d29b8
Fix
epwalsh Mar 2, 2024
e80587a
try GQA directly with SDPA
epwalsh Mar 5, 2024
e9c6f20
revert
epwalsh Mar 5, 2024
66a99ff
Prepare for sequence length ablation
epwalsh Mar 7, 2024
f84bb62
Pick and choose eval sets
epwalsh Mar 7, 2024
1ae86c5
fix merge conflicts
epwalsh Mar 8, 2024
9556dd9
clean up
epwalsh Mar 8, 2024
1e2729c
Merge branch 'mitchish65-2-gqa' into epwalsh/mitchish65-2-gqa
epwalsh Mar 8, 2024
f60ae3a
Fix config
epwalsh Mar 8, 2024
40cf61c
don't set start method, this causes a crash sometimes
epwalsh Mar 9, 2024
5da1c7e
try setting start method again
epwalsh Mar 9, 2024
1e60b73
update config
epwalsh Mar 9, 2024
fb15986
Ensure we raise pickle-able errors with new-style checkpointing
epwalsh Mar 11, 2024
f2fdf99
Fix canceling run via W&B tags
epwalsh Mar 12, 2024
1394236
Force setting start method
epwalsh Mar 13, 2024
3897360
Always print the start method
epwalsh Mar 13, 2024
d625ccb
Create S3 clients before launching threads
epwalsh Mar 13, 2024
bc73854
Default to 1 thread with RemoteFileSystemWriter
epwalsh Mar 13, 2024
17b3316
Log start method
epwalsh Mar 13, 2024
058e20e
Update configs for longer seq length
epwalsh Mar 14, 2024
7ddb124
update default microbatch size
epwalsh Mar 14, 2024
671c852
clean up how we apply activation checkpointing to blocks
epwalsh Mar 14, 2024
439e18f
Add "two_in_three" and "three_in_four" strategies
epwalsh Mar 14, 2024
fba3dee
update configs with improved settings for throughput
epwalsh Mar 14, 2024
b99b93e
update how we set env vars
epwalsh Mar 15, 2024
6103643
welp, back to our own checkpointing
epwalsh Mar 15, 2024
9ed5965
Make `find_latest_checkpoint` more robust
epwalsh Mar 15, 2024
8133f58
clean up config
epwalsh Mar 15, 2024
026c26c
Merge pull request #492 from allenai/epwalsh/mitchish65-2-gqa
epwalsh Mar 15, 2024
ca0fe2b
update evals
epwalsh Mar 18, 2024
856860d
Revert metric name change
epwalsh Mar 18, 2024
b1e3855
update name and branch
epwalsh Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added support for Grouped Query Attention.

### Changed

- Rename `Olmo` to `OLMo` everywhere in the codebase
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ base-image :
docker build -f docker/Dockerfile.base -t $(IMAGE_NAME_BASE)-base .

.PHONY : gantry-image
gantry-image : base-image
gantry-image :
docker build -f docker/Dockerfile.gantry -t $(IMAGE_NAME_BASE)-gantry .
beaker image create $(IMAGE_NAME_BASE)-gantry --name $(IMAGE_NAME_BASE)-gantry-tmp --workspace $(BEAKER_WORKSPACE)
beaker image delete $(GANTRY_IMAGE) || true
Expand Down
2 changes: 1 addition & 1 deletion configs/llama7-s3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model:
n_layers: 32
mlp_hidden_size: 22016
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
2 changes: 1 addition & 1 deletion configs/llama7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model:
n_layers: 32
mlp_hidden_size: 22016
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
1 change: 1 addition & 0 deletions configs/mcli/ananya-1b-ib.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ model:
pad_token_id: 1
init_device: meta
init_fn: mitchell
flash_attention: true

compile: null # causes instability on AMD GPUs

Expand Down
1 change: 1 addition & 0 deletions configs/mcli/ananya-1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ model:
pad_token_id: 1
init_device: meta
init_fn: normal
flash_attention: true

compile: null # causes instability on AMD GPUs

Expand Down
1 change: 1 addition & 0 deletions configs/mcli/mitchish-final.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ command: |-
--save_interval_unsharded=10000 \
--load_path=/root/checkpoint-unsharded \
--compile=null \
--model.flash_attention=true \
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=size_based \
--remote_save_folder=s3://ai2-llm/checkpoints/7b/${run_name} \
Expand Down
1 change: 1 addition & 0 deletions configs/mcli/mitchish-instruct.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ command: |-
--reset_trainer_state \
--reset_optimizer_state \
--compile=null \
--model.flash_attention=true \
--activation_checkpointing=whole_layer \
--fsdp.wrapping_strategy=size_based \
--max_duration=5ep
1 change: 1 addition & 0 deletions configs/mcli/mitchish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ command: |-
--save_interval_unsharded=10000 \
--load_path=${checkpoint} \
--compile=null \
--model.flash_attention=true \
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=size_based \
--remote_save_folder=s3://ai2-llm/checkpoints/7b/${run_name} \
Expand Down
40 changes: 40 additions & 0 deletions configs/mcli/mitchish70.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: olmo-70b
image: mosaicml/pytorch:2.2.1_cu121-python3.11-ubuntu20.04
compute:
cluster: r14z3p2
gpus: 256
gpu_type: h100_80gb
integrations:
- integration_type: git_repo
git_repo: allenai/OLMo
git_branch: mitchish65-2-gqa
pip_install: -e .[train]
ssh_clone: true
env_variables:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
OMP_NUM_THREADS: "8"
LOG_FILTER_TYPE: local_rank0_only
command: |-
# Make sure we have a recent flash-attn.
# NOTE: only pinning flash-attn here to future proof it.
pip install flash-attn==2.5.3 --no-build-isolation

# Show packages for debugging.
pip freeze

# Prepare environment.
cd OLMo
mkdir -p /root/.cache/torch

torchrun \
--master_addr "$MASTER_ADDR" \
--master_port "$MASTER_PORT" \
--nnodes "$NUM_NODES" \
--node_rank "$NODE_RANK" \
--nproc_per_node 8 \
scripts/train.py configs/mitchish70-s3.yaml \
--run_name=OLMo-70B-001 \
--wandb.group=mitchish70 \
--global_train_batch_size=768 \
--device_train_microbatch_size=3 \
--save_overwrite
1 change: 1 addition & 0 deletions configs/mcli/v1-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ command: |-
--nproc_per_node 8 \
scripts/train.py configs/v1-mix-medium-mitch-ish-s3.yaml \
--run_name=v1-mix-mitch-ish \
--model.flash_attention=true \
--global_train_batch_size=2160
1 change: 1 addition & 0 deletions configs/mcli/v1-mix-medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ command: |-
--nproc_per_node 8 \
scripts/train.py configs/v1-mix-medium-s3.yaml \
--run_name=v1-mix-medium \
--model.flash_attention=true \
--scheduler.name=linear_with_warmup \
--global_train_batch_size=2160
1 change: 1 addition & 0 deletions configs/mcli/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ command: |-
--run_name=v1_5-mix-mitch-ish \
--wandb.name=v1_5-mix-mitch-ish-mcli-final \
--global_train_batch_size=2160 \
--model.flash_attention=true \
--time_limit=169200

# We added these flags in order to get a final checkpoint where we decayed the LR down to 0.
Expand Down
1 change: 1 addition & 0 deletions configs/mcli/v1_5-mix-medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ command: |-
scripts/train.py configs/v1_5-mix-medium-s3.yaml \
--run_name=v1_5-mix-mcli \
--scheduler.name=linear_with_warmup \
--model.flash_attention=true \
--global_train_batch_size=2160
2 changes: 1 addition & 1 deletion configs/mitchish35.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model:
weight_tying: false
alibi: false
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
2 changes: 1 addition & 1 deletion configs/mitchish50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model:
weight_tying: false
alibi: false
rope: true
flash_attention: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
Expand Down
Loading