Skip to content

Commit

Permalink
Add truncated llama style model init via reset parameters() (#54)
Browse files Browse the repository at this point in the history
This PR adds the following:
1 - via reset parameters, a full layerwise init for the llama models
under /llama. This uses the total model depth as part of the init via:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

2 - The final output ffn (head) is init with sqrt of the dim of the
model itself and a slightly wider cutoff factor of 3.

3 - tangential change - updates run_llama_train.sh with updated MODEL
and MODEL_CONF params to allow for direct model control via the sh
script. (there was a MODEL already but it was incorrectly using that in
place of MODEL_CONF...though we should update this as it's not
intuitive).

4 - made the debugmodel default to 2 layers as an improved debug check.

5 - added a 1B and 40B for additional testing configs. I can't currently
run 70B on my H100 due to OOM, but can run 40B.

Testing:
Verified proper init and training with 7B, 13B and ~40B:

<img width="1085" alt="Screenshot 2024-02-11 at 10 39 12 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/049037ed-63a4-4ab0-bebc-f297857aab72">

[ghstack-poisoned]
  • Loading branch information
H-Huang committed Aug 20, 2024
1 parent 4119639 commit 9e975f3
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 9 deletions.
9 changes: 6 additions & 3 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
# e.g.
# LOG_RANK=0,1 NGPU=4 SP=2 ./run_llama_train.sh

MODEL=${MODEL:-"debugmodel"}
MODEL=${MODEL:-"llama"}
MODEL_CONF=${MODEL_CONF:-"debugmodel"}
NGPU=${NGPU:-"8"}
PP=${PP:-"1"}
SP=${SP:-"1"}
Expand All @@ -24,6 +25,8 @@ CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

torchrun --nproc_per_node=${NGPU} \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 --compile \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP}
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--compile
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
4 changes: 3 additions & 1 deletion torchtrain/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
__all__ = ["Transformer"]

llama_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16),
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
"1B": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"40B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
"70B": ModelArgs(
dim=8192,
n_layers=80,
Expand Down
66 changes: 62 additions & 4 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch.nn.functional as F
from torch import nn

from torchtrain.logging_utils import rank0_log


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -183,7 +185,6 @@ class Attention(nn.Module):
"""

def __init__(self, model_args: ModelArgs):

super().__init__()
self.n_heads = model_args.n_heads
self.n_kv_heads = (
Expand All @@ -203,6 +204,20 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def reset_parameters(self, init_std):
for item in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=0.02,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,
)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -277,7 +292,6 @@ def __init__(
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):

super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
Expand All @@ -292,6 +306,20 @@ def __init__(
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def reset_parameters(self, init_std):
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=0.02,
)

for item in (self.w2, self.w3):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=init_std,
)


class RotaryEmbedding(nn.Module):
"""
Expand Down Expand Up @@ -350,7 +378,6 @@ class TransformerBlock(nn.Module):
"""

def __init__(self, layer_id: int, model_args: ModelArgs):

super().__init__()
self.n_heads = model_args.n_heads
self.dim = model_args.dim
Expand All @@ -362,8 +389,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

def forward(
self,
Expand All @@ -385,6 +414,14 @@ def forward(
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
"""reset params and norms for entire block"""
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()

self.attention.reset_parameters(self.weight_init_std)
self.feed_forward.reset_parameters(self.weight_init_std)


class Transformer(nn.Module):
"""
Expand All @@ -406,11 +443,11 @@ class Transformer(nn.Module):
"""

def __init__(self, model_args: ModelArgs):

super().__init__()
self.model_args = model_args
self.vocab_size = model_args.vocab_size
self.n_layers = model_args.n_layers
self.model_dim = model_args.dim

self.embeddings = RotaryEmbedding(model_args)

Expand All @@ -421,6 +458,27 @@ def __init__(self, model_args: ModelArgs):
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

# init model weights
self.reset_parameters()
rank0_log(f"Model built with: {self.model_args}")

def reset_parameters(
self,
):
for layer in self.layers:
layer.reset_parameters()
self.norm.reset_parameters()
final_out_std = self.model_dim**-0.5
cutoff_factor = 3
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
rank0_log("Model fully initialized via reset_params")

def forward(self, tokens: torch.Tensor):
"""
Perform a forward pass through the Transformer model.
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def main(args):
world_mesh = parallel_dims.build_mesh(device_type="cuda")

model_name = args.model
rank0_log(f"Building {model_name}")
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, args.tokenizer_path)
Expand Down Expand Up @@ -222,7 +223,7 @@ def main(args):
parser.add_argument(
"--warmup_pct",
type=float,
default=0.10,
default=0.20,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
Expand Down

0 comments on commit 9e975f3

Please sign in to comment.