Skip to content

Commit

Permalink
Merge pull request #629 from allenai/epwalsh/amberish
Browse files Browse the repository at this point in the history
Amberish runs
  • Loading branch information
epwalsh authored Jul 25, 2024
2 parents 4e00460 + 8b7afd0 commit 3e30710
Show file tree
Hide file tree
Showing 38 changed files with 9,454 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for document masking via flash-attn during training with `--data.generate_doc_lengths`.
- Added config options for `model.norm_after`, `model.scale_emb_init`, and `auxiliary_loss_multiplier` (used with zloss).
- Added scripts for running experiments on qk_norm, norm reordering, and zloss.
- Added `model.rope_theta` configuration option.
- Added `model.embedding_layer_norm` configuration option for adding a LN to the embeddings.
- Added `model.emb_init_std` configuration option to override the standard deviation used to initialize the embeddings.

### Changed

Expand Down
1,297 changes: 1,297 additions & 0 deletions configs/amberish1-weka.yaml

Large diffs are not rendered by default.

1,293 changes: 1,293 additions & 0 deletions configs/amberish13-weka.yaml

Large diffs are not rendered by default.

1,293 changes: 1,293 additions & 0 deletions configs/amberish7-weka.yaml

Large diffs are not rendered by default.

1,294 changes: 1,294 additions & 0 deletions configs/amberish70-weka.yaml

Large diffs are not rendered by default.

1,383 changes: 1,383 additions & 0 deletions configs/peteish1-weka.yaml

Large diffs are not rendered by default.

1,382 changes: 1,382 additions & 0 deletions configs/peteish7-weka.yaml

Large diffs are not rendered by default.

26 changes: 22 additions & 4 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ class ModelConfig(BaseConfig):
apply RoPE at the precision of the input.
"""

rope_theta: int = 10_000
"""
The theta setting for RoPE.
"""

flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
Expand Down Expand Up @@ -346,6 +351,11 @@ class ModelConfig(BaseConfig):
The dropout probability for embeddings.
"""

embedding_layer_norm: bool = False
"""
Apply layer norm directly to the embeddings.
"""

layer_norm_type: LayerNormType = LayerNormType.default
"""
The layernorm implementation to use.
Expand Down Expand Up @@ -449,7 +459,13 @@ class ModelConfig(BaseConfig):

scale_emb_init: bool = False
"""
If ``True``, embeddings are scaled up by ``sqrt(d_model)`` during initialization. To be used with `full_megatron` init.
If ``True``, embeddings are scaled up by ``sqrt(d_model)`` during initialization.
Currently this is only used with `full_megatron` init when ``emb_init_std`` is unset.
"""

emb_init_std: Optional[float] = None
"""
Override the standard deviation to use when initializing the embedding weights.
"""

norm_after: bool = False
Expand Down Expand Up @@ -791,7 +807,7 @@ class FSDPConfig(BaseConfig):
FSDP instance.
"""

precision: FSDPPrecision = FSDPPrecision.pure
precision: Optional[FSDPPrecision] = FSDPPrecision.pure

hybrid_sharding_num_model_replicas: Optional[int] = None
"""
Expand Down Expand Up @@ -1213,9 +1229,11 @@ def autocast_precision(self) -> torch.dtype:
raise ValueError(f"Unexpected precision type '{self.precision}'")

@property
def fsdp_precision(self) -> MixedPrecision:
def fsdp_precision(self) -> Optional[MixedPrecision]:
if self.fsdp is not None:
if self.fsdp.precision == FSDPPrecision.pure:
if self.fsdp.precision is None:
return None
elif self.fsdp.precision == FSDPPrecision.pure:
return MixedPrecision(
param_dtype=self.autocast_precision,
reduce_dtype=self.autocast_precision,
Expand Down
22 changes: 16 additions & 6 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torc

with torch.autocast(device.type, enabled=False):
dim = self.config.d_model // self.config.n_heads
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
inv_freq = 1.0 / (
self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
)
seq = torch.arange(seq_len, device=device, dtype=torch.float)
freqs = einsum("i , j -> i j", seq, inv_freq)
positions = torch.cat((freqs, freqs), dim=-1)
Expand Down Expand Up @@ -535,7 +537,6 @@ def _scaled_dot_product_attention(
if max_doc_len is not None and cu_doc_lens is not None:
assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking"
assert attn_mask is None, "attn-mask is currently not supported with document masking"
assert self.training, "document masking is only supported for training, not inference"
B, T, D = q.size(0), q.size(2), q.size(3)
r = self.flash_attn_varlen_func(
q.transpose(1, 2).view(B * T, -1, D),
Expand Down Expand Up @@ -1121,6 +1122,9 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
)
}
)
if config.embedding_layer_norm:
self.transformer.update({"emb_norm": LayerNorm.build(config)})

# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
if init_params and self.config.init_device != "meta":
self.reset_parameters()
Expand Down Expand Up @@ -1157,14 +1161,16 @@ def reset_parameters(self):
# Note: We may potentially want to multiply the std by a factor of sqrt(d) in case of `scale_logits`
# and `weight_tying`. However, we are currently not using either, and may need to rethink the init logic
# if/when we do want it.
wte_std = self.config.init_std
wte_std = self.config.emb_init_std or self.config.init_std
wte_cutoff_factor = self.config.init_cutoff_factor
elif self.config.init_fn == InitFnType.mitchell:
wte_std = 1.0 / math.sqrt(self.config.d_model)
wte_std = self.config.emb_init_std or 1.0 / math.sqrt(self.config.d_model)
wte_cutoff_factor = self.config.init_cutoff_factor or 3.0
elif self.config.init_fn == InitFnType.full_megatron:
wte_std = self.config.init_std
if self.config.scale_emb_init:
if self.config.emb_init_std is not None:
wte_std = self.config.emb_init_std
elif self.config.scale_emb_init:
wte_std *= math.sqrt(self.config.d_model)
wte_cutoff_factor = self.config.init_cutoff_factor or 3.0
else:
Expand Down Expand Up @@ -1294,6 +1300,10 @@ def forward(
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore

# Apply embedding layer norm.
if self.config.embedding_layer_norm:
x = self.transformer.emb_norm(x)

if not (self.config.alibi or self.config.rope):
# Get positional embeddings.
# shape: (1, seq_len)
Expand All @@ -1302,7 +1312,7 @@ def forward(
pos_emb = self.transformer.wpe(pos) # type: ignore
x = pos_emb + x

# Add input + positional embeddings and apply dropout.
# Apply dropout.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.emb_drop(x) # type: ignore

Expand Down
40 changes: 40 additions & 0 deletions scripts/beaker/amberish/amberish1-8k-cham-launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash

set -ex

NUM_NODES=16

gantry run \
--workspace ai2/OLMo-pretraining-stability \
--task-name amberish1-8k-cham \
--description "Amberish 1B with 8k context length and chameleon fixes" \
--priority urgent \
--preemptible \
--beaker-image petew/olmo-torch23-gantry \
--cluster ai2/jupiter-cirrascale-2 \
--gpus 8 \
--replicas "${NUM_NODES}" \
--leader-selection \
--host-networking \
--budget ai2/oe-training \
--no-nfs \
--weka oe-training-default:/weka/oe-training-default \
--propagate-failure \
--propagate-preemption \
--synchronized-start-timeout 90m \
--no-python \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env OLMO_TASK=model \
--env R2_PROFILE=R2 \
--env S3_PROFILE=S3 \
--env WEKA_PROFILE=WEKA \
--env-secret AWS_CONFIG=PETEW_AWS_CONFIG \
--env-secret AWS_CREDENTIALS=PETEW_AWS_CREDENTIALS \
--env-secret R2_ENDPOINT_URL=R2_ENDPOINT_URL \
--env-secret WEKA_ENDPOINT_URL=WEKA_ENDPOINT_URL \
--env-secret WANDB_API_KEY=PETEW_WANDB_API_KEY \
--shared-memory 10GiB \
--yes \
--timeout=-1 \
-- /bin/bash -c "scripts/beaker/amberish/amberish1-8k-cham.sh \$BEAKER_LEADER_REPLICA_HOSTNAME ${NUM_NODES} \$BEAKER_REPLICA_RANK"
64 changes: 64 additions & 0 deletions scripts/beaker/amberish/amberish1-8k-cham.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env bash

set -exuo pipefail
IFS=$'\n\t'

BEAKER_LEADER_REPLICA_HOSTNAME=$1
shift

NUM_NODES=$1
shift

BEAKER_REPLICA_RANK=$1
shift

# Setup Python environment.
conda shell.bash activate base

# Install flash-attn
#conda install -y -c nvidia cuda-python
pip install packaging ninja
export FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE
pip install flash-attn==2.5.9.post1 --no-build-isolation
# pip install awscli
pip install '.[train]'
pip freeze

# Move AWS credentials from env to relevant files
mkdir -p ~/.aws
printenv AWS_CONFIG > ~/.aws/config
printenv AWS_CREDENTIALS > ~/.aws/credentials

# Force processes to synchronize at init_process_group
export TORCH_DIST_INIT_BARRIER=1

# Tell OLMo all ranks share the same filesystem for checkpoints.
export OLMO_SHARED_FS=1

export NCCL_DEBUG=INFO
export NCCL_IB_HCA="^=mlx5_bond_0"
export NCCL_SOCKET_IFNAME=ib
# export NCCL_IB_GID_INDEX=0

torchrun \
--nnodes "${NUM_NODES}:${NUM_NODES}" \
--nproc-per-node 8 \
--rdzv_id 12347 \
--rdzv_backend static \
--rdzv_endpoint "${BEAKER_LEADER_REPLICA_HOSTNAME}:29400" \
--node_rank "${BEAKER_REPLICA_RANK}" \
--rdzv_conf 'read_timeout=420' \
scripts/train.py \
configs/amberish1-weka.yaml \
--run_name="${GANTRY_TASK_NAME}" \
--model.max_sequence_length=8192 \
--device_train_microbatch_size=2 \
--global_train_batch_size=512 \
--fused_loss=true \
--softmax_auxiliary_loss=true \
--auxiliary_loss_multiplier=1e-5 \
--model.attention_layer_norm=true \
--model.norm_after=true \
--save_overwrite

# '--load_path=${path.last_checkpoint:${save_folder}}' \
40 changes: 40 additions & 0 deletions scripts/beaker/amberish/amberish1-8k-doc-mask-cham-launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash

set -ex

NUM_NODES=16

gantry run \
--workspace ai2/OLMo-pretraining-stability \
--task-name amberish1-8k-doc-mask-cham \
--description "Amberish 1B with 8k context length, doc masking, and chameleon fixes" \
--priority urgent \
--preemptible \
--beaker-image petew/olmo-torch23-gantry \
--cluster ai2/jupiter-cirrascale-2 \
--gpus 8 \
--replicas "${NUM_NODES}" \
--leader-selection \
--host-networking \
--budget ai2/oe-training \
--no-nfs \
--weka oe-training-default:/weka/oe-training-default \
--propagate-failure \
--propagate-preemption \
--synchronized-start-timeout 90m \
--no-python \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env OLMO_TASK=model \
--env R2_PROFILE=R2 \
--env S3_PROFILE=S3 \
--env WEKA_PROFILE=WEKA \
--env-secret AWS_CONFIG=PETEW_AWS_CONFIG \
--env-secret AWS_CREDENTIALS=PETEW_AWS_CREDENTIALS \
--env-secret R2_ENDPOINT_URL=R2_ENDPOINT_URL \
--env-secret WEKA_ENDPOINT_URL=WEKA_ENDPOINT_URL \
--env-secret WANDB_API_KEY=PETEW_WANDB_API_KEY \
--shared-memory 10GiB \
--yes \
--timeout=-1 \
-- /bin/bash -c "scripts/beaker/amberish/amberish1-8k-doc-mask-cham.sh \$BEAKER_LEADER_REPLICA_HOSTNAME ${NUM_NODES} \$BEAKER_REPLICA_RANK"
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash

set -ex

NUM_NODES=16

gantry run \
--workspace ai2/OLMo-pretraining-stability \
--task-name amberish1-8k-doc-mask-cham-rtheta \
--description "Amberish 1B with 8k context length, doc masking, and chameleon fixes" \
--priority urgent \
--preemptible \
--beaker-image petew/olmo-torch23-gantry \
--cluster ai2/jupiter-cirrascale-2 \
--gpus 8 \
--replicas "${NUM_NODES}" \
--leader-selection \
--host-networking \
--budget ai2/oe-training \
--no-nfs \
--weka oe-training-default:/weka/oe-training-default \
--propagate-failure \
--propagate-preemption \
--synchronized-start-timeout 90m \
--no-python \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env OLMO_TASK=model \
--env R2_PROFILE=R2 \
--env S3_PROFILE=S3 \
--env WEKA_PROFILE=WEKA \
--env-secret AWS_CONFIG=PETEW_AWS_CONFIG \
--env-secret AWS_CREDENTIALS=PETEW_AWS_CREDENTIALS \
--env-secret R2_ENDPOINT_URL=R2_ENDPOINT_URL \
--env-secret WEKA_ENDPOINT_URL=WEKA_ENDPOINT_URL \
--env-secret WANDB_API_KEY=PETEW_WANDB_API_KEY \
--shared-memory 10GiB \
--yes \
--timeout=-1 \
-- /bin/bash -c "scripts/beaker/amberish/amberish1-8k-doc-mask-cham-rtheta.sh \$BEAKER_LEADER_REPLICA_HOSTNAME ${NUM_NODES} \$BEAKER_REPLICA_RANK"
66 changes: 66 additions & 0 deletions scripts/beaker/amberish/amberish1-8k-doc-mask-cham-rtheta.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env bash

set -exuo pipefail
IFS=$'\n\t'

BEAKER_LEADER_REPLICA_HOSTNAME=$1
shift

NUM_NODES=$1
shift

BEAKER_REPLICA_RANK=$1
shift

# Setup Python environment.
conda shell.bash activate base

# Install flash-attn
#conda install -y -c nvidia cuda-python
pip install packaging ninja
export FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE
pip install flash-attn==2.5.9.post1 --no-build-isolation
# pip install awscli
pip install '.[train]'
pip freeze

# Move AWS credentials from env to relevant files
mkdir -p ~/.aws
printenv AWS_CONFIG > ~/.aws/config
printenv AWS_CREDENTIALS > ~/.aws/credentials

# Force processes to synchronize at init_process_group
export TORCH_DIST_INIT_BARRIER=1

# Tell OLMo all ranks share the same filesystem for checkpoints.
export OLMO_SHARED_FS=1

export NCCL_DEBUG=INFO
export NCCL_IB_HCA="^=mlx5_bond_0"
export NCCL_SOCKET_IFNAME=ib
# export NCCL_IB_GID_INDEX=0

torchrun \
--nnodes "${NUM_NODES}:${NUM_NODES}" \
--nproc-per-node 8 \
--rdzv_id 12347 \
--rdzv_backend static \
--rdzv_endpoint "${BEAKER_LEADER_REPLICA_HOSTNAME}:29400" \
--node_rank "${BEAKER_REPLICA_RANK}" \
--rdzv_conf 'read_timeout=420' \
scripts/train.py \
configs/amberish1-weka.yaml \
--run_name="${GANTRY_TASK_NAME}" \
--model.max_sequence_length=8192 \
--device_train_microbatch_size=2 \
--global_train_batch_size=512 \
--fused_loss=true \
--data.generate_doc_lengths=true \
--softmax_auxiliary_loss=true \
--auxiliary_loss_multiplier=1e-5 \
--model.attention_layer_norm=true \
--model.norm_after=true \
--model.rope_theta=500000 \
--save_overwrite

# '--load_path=${path.last_checkpoint:${save_folder}}' \
Loading

0 comments on commit 3e30710

Please sign in to comment.