Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/08c0770f-17fc-44cd-971d-734a7a28a3e3.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/11dd9171-d060-4279-a6e5-5ba91fb7758e.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/1316b79f-d02b-4cd6-b98a-43b48023aedf.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/13de21d5-e0e9-4dab-b42d-ad13e73bc402.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/1711f2cf-76af-46e0-b2df-47bd0d6bec0c.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/27974127-6559-494e-9941-2d88325c2e52.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/351728bd-3438-40d2-a006-41ed492e139f.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/550ba6aa-d6a7-4a20-8303-f2b8d93c5f52.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/559a562e-aaa1-46ef-aa6a-06f46a3b019d.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/5c44ff06-998a-4310-af6e-f0f5441452f4.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/68d6605a-f386-4e4f-84c0-2582dc6989d8.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/7fa6fb13-cac4-46c4-bf34-83c8290e17f0.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/8237bd60-bbc4-4ad6-8f7f-9a2a654a1c5a.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/82d579bc-45e2-4600-8436-7d425016e9b3.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/b21c8cc7-c09c-401f-b654-23e947ad3e38.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/b54670db-06ce-4aa6-b50f-869bfe329c8b.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/cb4e8b78-b9ab-4c83-9ff9-3fdfb8bb1b9b.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/d89c0dc1-c0ce-4346-a405-af9e88ed79bc.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/e6622691-5ab5-4066-995d-41dada989dab.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/efa9ba5e-7c95-4d47-8873-ad23d1f28e80.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/f26e4a90-074c-4ed4-b3e3-ce69223863c4.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/f713f5c8-a6e3-446a-9ec4-5014917cb254.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/f790419e-3027-441e-a5ab-11549e63fc1c.txt

Large diffs are not rendered by default.

3,206 changes: 3,206 additions & 0 deletions records/092725_BF16CE/f7c90ea9-95b0-4652-b933-a73edab09583.txt

Large diffs are not rendered by default.

39 changes: 22 additions & 17 deletions train_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,10 @@ class Muon(torch.optim.Optimizer):
This hyper-optimized class has faster execution time than the current impl of Adam for small params

Custom distributed sizing:
The model stores all attn and mlp weights in the same shape, and then updates the view as
needed on the forward pass. This enables attn and mlp weights to be contained within the same
dist.reduce_scatter_tensor() call. The model architecture has been customized to enable
(n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn.
The model stores all attn and mlp weights in the same shape, and then updates the view as
needed on the forward pass. This enables attn and mlp weights to be contained within the same
dist.reduce_scatter_tensor() call. The model architecture has been customized to enable
(n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn.
The scheduling is:
1. reduce scatter smear_gate (1 param 7 padding params)
2. reduce scatter attn_gate (10 params 6 padding params)
Expand Down Expand Up @@ -456,10 +456,10 @@ def generate_standard_param_groups(self, params):
group_params = [p for p in non_attn_subset if p.shape == size]
param_groups.append(dict(params=group_params))
return param_groups

def generate_custom_param_groups(self, params):
"""
Implementation requires that a single GPU does not receive both attn
Implementation requires that a single GPU does not receive both attn
and mlp params when a param group is split across GPUs.
"""
module_ranks = {
Expand Down Expand Up @@ -614,7 +614,7 @@ def step(self):
for p in params[module_idx:module_idx+chunk_size]:
assert getattr(params[module_idx],'module','none')=='attn'
batch = 4 * original_shape[0]
d1 = original_shape[1]
d1 = original_shape[1]
d2 = original_shape[2] // 4
batched = batched_update_grads.view(batch, d1, d2)
v_chunk = newton_schulz_triton(batched)
Expand Down Expand Up @@ -777,7 +777,7 @@ def __init__(self, head_dim, max_seq_len):
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.reset()

def reset(self):
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device)
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
Expand Down Expand Up @@ -1020,10 +1020,15 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_sho
skip_connections.append(x)

x = norm(x)
logits = self.lm_head(x).float()
logits = self.lm_head(x)
# @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
logits = 30 * torch.sigmoid(logits / 7.5)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean")
logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0)
logits_for_loss = logits.float() if not self.training else logits
loss = F.cross_entropy(
logits_for_loss.view(-1, logits_for_loss.size(-1)),
target_seq,
reduction="sum" if self.training else "mean",
)
return loss

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -1065,12 +1070,12 @@ def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False)
def _load(self):
self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy()
self.ready.set()

def start(self):
self.ready.clear()
self.thread = threading.Thread(target=self._load)
self.thread.start()

def get(self):
if self.thread:
self.ready.wait()
Expand Down Expand Up @@ -1113,17 +1118,17 @@ def __init__(self, file_iter, world_size: int = 1):
self.thread = None
self.data = None
self.ready = threading.Event()

def _load(self):
tokens = _load_data_shard(next(self.file_iter))
self.data = (tokens, BOSFinder(tokens, self.world_size))
self.ready.set()

def start(self):
self.ready.clear()
self.thread = threading.Thread(target=self._load)
self.thread.start()

def get(self):
if self.thread:
self.ready.wait()
Expand Down Expand Up @@ -1390,7 +1395,7 @@ def get_ws(step: int):
assert args.val_tokens % args.val_batch_size == 0
val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size
val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False)
val_loss = 0
val_loss = torch.zeros((), device=device, dtype=torch.float32)
with torch.no_grad():
for _ in range(val_steps):
inputs, targets, cum_seqlens = next(val_loader)
Expand Down