Skip to content

Commit

Permalink
Merge branch 'main' into epwalsh/tulu-fine-tune
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jan 10, 2024
2 parents 2d16b0e + a2e1d13 commit 299aafd
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 9 deletions.
4 changes: 3 additions & 1 deletion configs/v1_5-mix-medium-mitch-ish-s3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ scheduler:
name: linear_with_warmup
t_warmup: 5000
alpha_f: 0.1
grad_clip_warmup_steps: 1000
grad_clip_warmup_factor: 10.0

tokenizer:
identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
Expand All @@ -70,7 +72,7 @@ save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 476837 # 2T tokens
max_duration: 2e12T # 2T tokens
global_train_batch_size: 2048
device_train_microbatch_size: 2
time_limit: null
Expand Down
4 changes: 3 additions & 1 deletion configs/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ scheduler:
name: linear_with_warmup
t_warmup: 5000
alpha_f: 0.1
grad_clip_warmup_steps: 1000
grad_clip_warmup_factor: 10.0

tokenizer:
identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
Expand All @@ -68,7 +70,7 @@ save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 476837 # 2T tokens
max_duration: 2e12T # 2T tokens
global_train_batch_size: 2048
device_train_microbatch_size: 2

Expand Down
12 changes: 9 additions & 3 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def _make_optim_state_dict_compatible(
# This state dict comes in two forms: one where the state keys are integers and one where the
# keys are fully qualified parameter names. The latter case is easier to deal with here so we
# first transform the integer key form into the FQN key form.
if isinstance(next(iter(optim_state_dict["state"].keys())), int):
if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
id_to_fqn: Dict[int, str] = {}
for group in optim_state_dict["param_groups"]:
new_param_names = []
Expand All @@ -712,7 +712,9 @@ def _make_optim_state_dict_compatible(
# Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
# First fix param names in the state.
for og_key, new_keys in og_keys_to_new.items():
og_state = optim_state_dict["state"].pop(og_key)
og_state = optim_state_dict["state"].pop(og_key, None)
if og_state is None:
continue
for i, new_key in enumerate(new_keys):
if i == len(new_keys) - 1:
optim_state_dict["state"][new_key] = og_state
Expand Down Expand Up @@ -1117,7 +1119,11 @@ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
if version.parse(torch.__version__) < version.parse("2.1.0"):
return fsdp_model._handles # type: ignore
elif version.parse(torch.__version__) < version.parse("2.2.0"):
return [fsdp_model._handle] # type: ignore
# Handle could be None if the FSDP wrapper doesn't manage any parameters.
if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
return [fsdp_model._handle] # type: ignore
else:
return []
else:
# Need to verify FSDP internals with newer versions.
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion scripts/lumi/lumi-interactive.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ singularity shell \
-B /opt/cray:/opt/cray \
-B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \
-B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \
$PROJECT_DIR/containers/llm-lumi_latest.sif
$PROJECT_DIR/containers/llm-lumi-torch21_latest.sif
16 changes: 13 additions & 3 deletions scripts/lumi/mitch-ish-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH --job-name=v1.5-mix-medium-mitch-ish
#SBATCH --account=project_462000229
#SBATCH --output=/pfs/lustref1/flash/project_462000229/logs/%j.log
#SBATCH --nodes=256 # Total number of nodes
#SBATCH --nodes=128 # Total number of nodes
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank
#SBATCH --cpus-per-task=6
Expand All @@ -25,6 +25,7 @@ export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}
export CXI_FORK_SAFE=1
export CXI_FORK_SAFE_HP=1
export FI_CXI_DISABLE_CQ_HUGETLB=1
export GPU_MAX_HW_QUEUES=8

# We need to set this to avoid "Cassini Event Queue overflow detected." errors.
export FI_CXI_DEFAULT_CQ_SIZE=131072
Expand All @@ -37,6 +38,8 @@ export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2.
# Try playing with max_split_size_mb if you run into OOM errors.
#export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:128

export HF_DATASETS_OFFLINE=1

export DATA_PATH=$FLASH_DIR/preprocessed/olmo-mix
export CHECKPOINTS_PATH=$FLASH_DIR/checkpoints
export EVAL_DATA_PATH=$SCRATCH_DIR/eval-data
Expand All @@ -56,5 +59,12 @@ srun \
$PROJECT_DIR/containers/$OLMO_CONTAINER \
python scripts/train.py configs/v1_5-mix-medium-mitch-ish.yaml ${@} \
--run_name=${SLURM_JOB_ID} \
--global_train_batch_size=4096 \
--max_duration=238418
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=one_in_four \
--fsdp.sharding_strategy=FULL_SHARD \
--sharded_checkpointer=local \
--wandb.name=v1_5-mix-mitch-ish-lumi \
--save_interval=10000 \
--save_interval_ephemeral=1000 \
--remote_save_folder=s3://ai2-llm/checkpoints/7b/mitchish-lumi \
--save_folder=${FLASH_DIR}/checkpoints/mitchish-lumi

0 comments on commit 299aafd

Please sign in to comment.