diff --git a/records/092925_PolarExpress/0e3f0af5-ad08-47a6-813d-0c709b50d422.txt b/records/092925_PolarExpress/0e3f0af5-ad08-47a6-813d-0c709b50d422.txt new file mode 100644 index 000000000..45f8d6a8e --- /dev/null +++ b/records/092925_PolarExpress/0e3f0af5-ad08-47a6-813d-0c709b50d422.txt @@ -0,0 +1,3252 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from kernels import get_kernel +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[ + Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size() == 8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1, 10, 16, 16] + assert len(params) == sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx + size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx < len(params) else 0 + if getattr(params[module_idx], 'module', 'none') == 'attn': + for p in params[module_idx:module_idx + chunk_size]: + assert getattr(params[module_idx], 'module', 'none') == 'attn' + batch = 4 * original_shape[0] + d1 = original_shape[1] + d2 = original_shape[2] // 4 + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = polar_express(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = polar_express(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim // 4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim // 4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim * 4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module = 'attn' + with torch.no_grad(): + self.qkvo_w.view(4, self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4, self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, + 3 * self.num_heads, + self.head_dim).chunk( + 3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module = 'mlp' + self.c_proj.module = 'mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, + max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim ** 0.5) / 448, w_s=2 ** -9, + grad_s=1 / 448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), # extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, + long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1, len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i < 11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens = tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter == 5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter += 1 + return starts, ends + + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, + align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1630 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if + p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999, step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations + args.iteration_extension: + return args.ws_validate // 2, args.ws_validate + x = min(step / (1 + args.num_iterations), 0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[ + step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long < ws_long: + model.yarn.reset() + ws_long = new_ws_long + model(inputs, targets, cum_seqlens, ws_long // 2, ws_long).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.yarn.reset() +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations + args.iteration_extension +ws_short, ws_long = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws_short, new_ws_long = get_ws(step) + if new_ws_long != ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, + grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.5.0 +Mon Sep 29 06:26:49 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 40C P0 130W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:143ms step_avg:143.06ms +step:2/1670 train_time:164ms step_avg:81.87ms +step:3/1670 train_time:226ms step_avg:75.49ms +step:4/1670 train_time:312ms step_avg:77.97ms +step:5/1670 train_time:398ms step_avg:79.61ms +step:6/1670 train_time:485ms step_avg:80.83ms +step:7/1670 train_time:572ms step_avg:81.66ms +step:8/1670 train_time:658ms step_avg:82.27ms +step:9/1670 train_time:745ms step_avg:82.82ms +step:10/1670 train_time:832ms step_avg:83.20ms +step:11/1670 train_time:919ms step_avg:83.55ms +step:12/1670 train_time:1008ms step_avg:83.98ms +step:13/1670 train_time:1101ms step_avg:84.67ms +step:14/1670 train_time:1194ms step_avg:85.27ms +step:15/1670 train_time:1282ms step_avg:85.48ms +step:16/1670 train_time:1369ms step_avg:85.57ms +step:17/1670 train_time:1456ms step_avg:85.67ms +step:18/1670 train_time:1544ms step_avg:85.78ms +step:19/1670 train_time:1631ms step_avg:85.83ms +step:20/1670 train_time:1718ms step_avg:85.88ms +step:21/1670 train_time:1805ms step_avg:85.94ms +step:22/1670 train_time:1892ms step_avg:85.98ms +step:23/1670 train_time:1980ms step_avg:86.10ms +step:24/1670 train_time:2069ms step_avg:86.23ms +step:25/1670 train_time:2160ms step_avg:86.38ms +step:26/1670 train_time:2248ms step_avg:86.46ms +step:27/1670 train_time:2336ms step_avg:86.53ms +step:28/1670 train_time:2424ms step_avg:86.58ms +step:29/1670 train_time:2512ms step_avg:86.60ms +step:30/1670 train_time:2599ms step_avg:86.63ms +step:31/1670 train_time:2686ms step_avg:86.64ms +step:32/1670 train_time:2774ms step_avg:86.68ms +step:33/1670 train_time:2862ms step_avg:86.72ms +step:34/1670 train_time:2950ms step_avg:86.76ms +step:35/1670 train_time:3040ms step_avg:86.84ms +step:36/1670 train_time:3129ms step_avg:86.92ms +step:37/1670 train_time:3218ms step_avg:86.98ms +step:38/1670 train_time:3306ms step_avg:87.01ms +step:39/1670 train_time:3395ms step_avg:87.04ms +step:40/1670 train_time:3483ms step_avg:87.08ms +step:41/1670 train_time:3571ms step_avg:87.09ms +step:42/1670 train_time:3658ms step_avg:87.09ms +step:43/1670 train_time:3745ms step_avg:87.10ms +step:44/1670 train_time:3833ms step_avg:87.11ms +step:45/1670 train_time:3921ms step_avg:87.12ms +step:46/1670 train_time:4008ms step_avg:87.14ms +step:47/1670 train_time:4098ms step_avg:87.18ms +step:48/1670 train_time:4186ms step_avg:87.21ms +step:49/1670 train_time:4276ms step_avg:87.26ms +step:50/1670 train_time:4364ms step_avg:87.28ms +step:51/1670 train_time:4452ms step_avg:87.29ms +step:52/1670 train_time:4539ms step_avg:87.29ms +step:53/1670 train_time:4627ms step_avg:87.30ms +step:54/1670 train_time:4714ms step_avg:87.30ms +step:55/1670 train_time:4802ms step_avg:87.31ms +step:56/1670 train_time:4889ms step_avg:87.31ms +step:57/1670 train_time:4977ms step_avg:87.31ms +step:58/1670 train_time:5065ms step_avg:87.33ms +step:59/1670 train_time:5153ms step_avg:87.34ms +step:60/1670 train_time:5242ms step_avg:87.37ms +step:61/1670 train_time:5330ms step_avg:87.38ms +step:62/1670 train_time:5419ms step_avg:87.40ms +step:63/1670 train_time:5506ms step_avg:87.40ms +step:64/1670 train_time:5593ms step_avg:87.39ms +step:65/1670 train_time:5681ms step_avg:87.41ms +step:66/1670 train_time:5769ms step_avg:87.41ms +step:67/1670 train_time:5857ms step_avg:87.42ms +step:68/1670 train_time:5944ms step_avg:87.41ms +step:69/1670 train_time:6031ms step_avg:87.41ms +step:70/1670 train_time:6119ms step_avg:87.41ms +step:71/1670 train_time:6207ms step_avg:87.42ms +step:72/1670 train_time:6295ms step_avg:87.43ms +step:73/1670 train_time:6384ms step_avg:87.45ms +step:74/1670 train_time:6471ms step_avg:87.45ms +step:75/1670 train_time:6559ms step_avg:87.45ms +step:76/1670 train_time:6647ms step_avg:87.45ms +step:77/1670 train_time:6734ms step_avg:87.46ms +step:78/1670 train_time:6823ms step_avg:87.47ms +step:79/1670 train_time:6911ms step_avg:87.48ms +step:80/1670 train_time:7000ms step_avg:87.50ms +step:81/1670 train_time:7088ms step_avg:87.50ms +step:82/1670 train_time:7176ms step_avg:87.51ms +step:83/1670 train_time:7264ms step_avg:87.52ms +step:84/1670 train_time:7352ms step_avg:87.52ms +step:85/1670 train_time:7440ms step_avg:87.52ms +step:86/1670 train_time:7527ms step_avg:87.52ms +step:87/1670 train_time:7615ms step_avg:87.53ms +step:88/1670 train_time:7703ms step_avg:87.53ms +step:89/1670 train_time:7790ms step_avg:87.53ms +step:90/1670 train_time:7877ms step_avg:87.53ms +step:91/1670 train_time:7965ms step_avg:87.53ms +step:92/1670 train_time:8053ms step_avg:87.53ms +step:93/1670 train_time:8142ms step_avg:87.54ms +step:94/1670 train_time:8229ms step_avg:87.55ms +step:95/1670 train_time:8318ms step_avg:87.56ms +step:96/1670 train_time:8406ms step_avg:87.56ms +step:97/1670 train_time:8493ms step_avg:87.56ms +step:98/1670 train_time:8581ms step_avg:87.56ms +step:99/1670 train_time:8669ms step_avg:87.56ms +step:100/1670 train_time:8757ms step_avg:87.57ms +step:101/1670 train_time:8844ms step_avg:87.57ms +step:102/1670 train_time:8932ms step_avg:87.57ms +step:103/1670 train_time:9021ms step_avg:87.58ms +step:104/1670 train_time:9108ms step_avg:87.58ms +step:105/1670 train_time:9196ms step_avg:87.58ms +step:106/1670 train_time:9284ms step_avg:87.59ms +step:107/1670 train_time:9371ms step_avg:87.58ms +step:108/1670 train_time:9459ms step_avg:87.58ms +step:109/1670 train_time:9546ms step_avg:87.58ms +step:110/1670 train_time:9634ms step_avg:87.58ms +step:111/1670 train_time:9723ms step_avg:87.59ms +step:112/1670 train_time:9810ms step_avg:87.59ms +step:113/1670 train_time:9897ms step_avg:87.59ms +step:114/1670 train_time:9986ms step_avg:87.59ms +step:115/1670 train_time:10073ms step_avg:87.59ms +step:116/1670 train_time:10161ms step_avg:87.59ms +step:117/1670 train_time:10249ms step_avg:87.60ms +step:118/1670 train_time:10337ms step_avg:87.60ms +step:119/1670 train_time:10425ms step_avg:87.60ms +step:120/1670 train_time:10512ms step_avg:87.60ms +step:121/1670 train_time:10599ms step_avg:87.59ms +step:122/1670 train_time:10687ms step_avg:87.59ms +step:123/1670 train_time:10774ms step_avg:87.59ms +step:124/1670 train_time:10862ms step_avg:87.60ms +step:125/1670 train_time:10949ms step_avg:87.59ms +step:125/1670 val_loss:4.3341 train_time:11038ms step_avg:88.30ms +step:126/1670 train_time:11057ms step_avg:87.76ms +step:127/1670 train_time:11128ms step_avg:87.62ms +step:128/1670 train_time:11223ms step_avg:87.68ms +step:129/1670 train_time:11313ms step_avg:87.70ms +step:130/1670 train_time:11400ms step_avg:87.69ms +step:131/1670 train_time:11487ms step_avg:87.68ms +step:132/1670 train_time:11574ms step_avg:87.68ms +step:133/1670 train_time:11660ms step_avg:87.67ms +step:134/1670 train_time:11747ms step_avg:87.67ms +step:135/1670 train_time:11834ms step_avg:87.66ms +step:136/1670 train_time:11922ms step_avg:87.66ms +step:137/1670 train_time:12010ms step_avg:87.66ms +step:138/1670 train_time:12099ms step_avg:87.67ms +step:139/1670 train_time:12190ms step_avg:87.69ms +step:140/1670 train_time:12278ms step_avg:87.70ms +step:141/1670 train_time:12367ms step_avg:87.71ms +step:142/1670 train_time:12454ms step_avg:87.71ms +step:143/1670 train_time:12541ms step_avg:87.70ms +step:144/1670 train_time:12628ms step_avg:87.70ms +step:145/1670 train_time:12715ms step_avg:87.69ms +step:146/1670 train_time:12803ms step_avg:87.69ms +step:147/1670 train_time:12890ms step_avg:87.69ms +step:148/1670 train_time:12978ms step_avg:87.69ms +step:149/1670 train_time:13067ms step_avg:87.70ms +step:150/1670 train_time:13156ms step_avg:87.70ms +step:151/1670 train_time:13245ms step_avg:87.72ms +step:152/1670 train_time:13333ms step_avg:87.72ms +step:153/1670 train_time:13421ms step_avg:87.72ms +step:154/1670 train_time:13510ms step_avg:87.72ms +step:155/1670 train_time:13596ms step_avg:87.72ms +step:156/1670 train_time:13683ms step_avg:87.71ms +step:157/1670 train_time:13771ms step_avg:87.71ms +step:158/1670 train_time:13857ms step_avg:87.71ms +step:159/1670 train_time:13945ms step_avg:87.70ms +step:160/1670 train_time:14033ms step_avg:87.70ms +step:161/1670 train_time:14122ms step_avg:87.71ms +step:162/1670 train_time:14211ms step_avg:87.72ms +step:163/1670 train_time:14300ms step_avg:87.73ms +step:164/1670 train_time:14388ms step_avg:87.73ms +step:165/1670 train_time:14476ms step_avg:87.73ms +step:166/1670 train_time:14563ms step_avg:87.73ms +step:167/1670 train_time:14651ms step_avg:87.73ms +step:168/1670 train_time:14738ms step_avg:87.73ms +step:169/1670 train_time:14826ms step_avg:87.73ms +step:170/1670 train_time:14913ms step_avg:87.72ms +step:171/1670 train_time:15001ms step_avg:87.72ms +step:172/1670 train_time:15090ms step_avg:87.73ms +step:173/1670 train_time:15177ms step_avg:87.73ms +step:174/1670 train_time:15266ms step_avg:87.73ms +step:175/1670 train_time:15354ms step_avg:87.73ms +step:176/1670 train_time:15443ms step_avg:87.74ms +step:177/1670 train_time:15531ms step_avg:87.75ms +step:178/1670 train_time:15618ms step_avg:87.74ms +step:179/1670 train_time:15706ms step_avg:87.74ms +step:180/1670 train_time:15792ms step_avg:87.74ms +step:181/1670 train_time:15879ms step_avg:87.73ms +step:182/1670 train_time:15968ms step_avg:87.74ms +step:183/1670 train_time:16056ms step_avg:87.74ms +step:184/1670 train_time:16144ms step_avg:87.74ms +step:185/1670 train_time:16233ms step_avg:87.74ms +step:186/1670 train_time:16320ms step_avg:87.74ms +step:187/1670 train_time:16409ms step_avg:87.75ms +step:188/1670 train_time:16497ms step_avg:87.75ms +step:189/1670 train_time:16585ms step_avg:87.75ms +step:190/1670 train_time:16672ms step_avg:87.75ms +step:191/1670 train_time:16761ms step_avg:87.75ms +step:192/1670 train_time:16849ms step_avg:87.75ms +step:193/1670 train_time:16936ms step_avg:87.75ms +step:194/1670 train_time:17023ms step_avg:87.75ms +step:195/1670 train_time:17112ms step_avg:87.75ms +step:196/1670 train_time:17200ms step_avg:87.75ms +step:197/1670 train_time:17288ms step_avg:87.76ms +step:198/1670 train_time:17376ms step_avg:87.76ms +step:199/1670 train_time:17464ms step_avg:87.76ms +step:200/1670 train_time:17552ms step_avg:87.76ms +step:201/1670 train_time:17640ms step_avg:87.76ms +step:202/1670 train_time:17729ms step_avg:87.77ms +step:203/1670 train_time:17816ms step_avg:87.76ms +step:204/1670 train_time:17904ms step_avg:87.76ms +step:205/1670 train_time:17991ms step_avg:87.76ms +step:206/1670 train_time:18078ms step_avg:87.76ms +step:207/1670 train_time:18166ms step_avg:87.76ms +step:208/1670 train_time:18254ms step_avg:87.76ms +step:209/1670 train_time:18342ms step_avg:87.76ms +step:210/1670 train_time:18429ms step_avg:87.76ms +step:211/1670 train_time:18516ms step_avg:87.75ms +step:212/1670 train_time:18604ms step_avg:87.75ms +step:213/1670 train_time:18692ms step_avg:87.75ms +step:214/1670 train_time:18779ms step_avg:87.75ms +step:215/1670 train_time:18867ms step_avg:87.75ms +step:216/1670 train_time:18953ms step_avg:87.75ms +step:217/1670 train_time:19041ms step_avg:87.75ms +step:218/1670 train_time:19128ms step_avg:87.74ms +step:219/1670 train_time:19216ms step_avg:87.74ms +step:220/1670 train_time:19303ms step_avg:87.74ms +step:221/1670 train_time:19391ms step_avg:87.74ms +step:222/1670 train_time:19479ms step_avg:87.74ms +step:223/1670 train_time:19567ms step_avg:87.74ms +step:224/1670 train_time:19654ms step_avg:87.74ms +step:225/1670 train_time:19742ms step_avg:87.74ms +step:226/1670 train_time:19830ms step_avg:87.74ms +step:227/1670 train_time:19917ms step_avg:87.74ms +step:228/1670 train_time:20005ms step_avg:87.74ms +step:229/1670 train_time:20092ms step_avg:87.74ms +step:230/1670 train_time:20180ms step_avg:87.74ms +step:231/1670 train_time:20268ms step_avg:87.74ms +step:232/1670 train_time:20355ms step_avg:87.74ms +step:233/1670 train_time:20443ms step_avg:87.74ms +step:234/1670 train_time:20530ms step_avg:87.74ms +step:235/1670 train_time:20618ms step_avg:87.74ms +step:236/1670 train_time:20706ms step_avg:87.74ms +step:237/1670 train_time:20793ms step_avg:87.73ms +step:238/1670 train_time:20880ms step_avg:87.73ms +step:239/1670 train_time:20969ms step_avg:87.74ms +step:240/1670 train_time:21056ms step_avg:87.73ms +step:241/1670 train_time:21144ms step_avg:87.73ms +step:242/1670 train_time:21232ms step_avg:87.74ms +step:243/1670 train_time:21320ms step_avg:87.74ms +step:244/1670 train_time:21407ms step_avg:87.74ms +step:245/1670 train_time:21495ms step_avg:87.73ms +step:246/1670 train_time:21582ms step_avg:87.73ms +step:247/1670 train_time:21670ms step_avg:87.73ms +step:248/1670 train_time:21758ms step_avg:87.73ms +step:249/1670 train_time:21846ms step_avg:87.74ms +step:250/1670 train_time:21933ms step_avg:87.73ms +step:250/1670 val_loss:3.9771 train_time:22022ms step_avg:88.09ms +step:251/1670 train_time:22042ms step_avg:87.82ms +step:252/1670 train_time:22112ms step_avg:87.75ms +step:253/1670 train_time:22204ms step_avg:87.76ms +step:254/1670 train_time:22291ms step_avg:87.76ms +step:255/1670 train_time:22379ms step_avg:87.76ms +step:256/1670 train_time:22465ms step_avg:87.76ms +step:257/1670 train_time:22552ms step_avg:87.75ms +step:258/1670 train_time:22639ms step_avg:87.75ms +step:259/1670 train_time:22725ms step_avg:87.74ms +step:260/1670 train_time:22813ms step_avg:87.74ms +step:261/1670 train_time:22899ms step_avg:87.74ms +step:262/1670 train_time:22988ms step_avg:87.74ms +step:263/1670 train_time:23079ms step_avg:87.75ms +step:264/1670 train_time:23169ms step_avg:87.76ms +step:265/1670 train_time:23257ms step_avg:87.76ms +step:266/1670 train_time:23345ms step_avg:87.76ms +step:267/1670 train_time:23432ms step_avg:87.76ms +step:268/1670 train_time:23518ms step_avg:87.76ms +step:269/1670 train_time:23605ms step_avg:87.75ms +step:270/1670 train_time:23692ms step_avg:87.75ms +step:271/1670 train_time:23779ms step_avg:87.75ms +step:272/1670 train_time:23866ms step_avg:87.74ms +step:273/1670 train_time:23954ms step_avg:87.74ms +step:274/1670 train_time:24042ms step_avg:87.75ms +step:275/1670 train_time:24130ms step_avg:87.75ms +step:276/1670 train_time:24219ms step_avg:87.75ms +step:277/1670 train_time:24307ms step_avg:87.75ms +step:278/1670 train_time:24395ms step_avg:87.75ms +step:279/1670 train_time:24482ms step_avg:87.75ms +step:280/1670 train_time:24570ms step_avg:87.75ms +step:281/1670 train_time:24657ms step_avg:87.75ms +step:282/1670 train_time:24744ms step_avg:87.75ms +step:283/1670 train_time:24832ms step_avg:87.74ms +step:284/1670 train_time:24919ms step_avg:87.74ms +step:285/1670 train_time:25007ms step_avg:87.74ms +step:286/1670 train_time:25096ms step_avg:87.75ms +step:287/1670 train_time:25186ms step_avg:87.75ms +step:288/1670 train_time:25273ms step_avg:87.75ms +step:289/1670 train_time:25361ms step_avg:87.75ms +step:290/1670 train_time:25448ms step_avg:87.75ms +step:291/1670 train_time:25536ms step_avg:87.75ms +step:292/1670 train_time:25623ms step_avg:87.75ms +step:293/1670 train_time:25710ms step_avg:87.75ms +step:294/1670 train_time:25798ms step_avg:87.75ms +step:295/1670 train_time:25885ms step_avg:87.75ms +step:296/1670 train_time:25974ms step_avg:87.75ms +step:297/1670 train_time:26061ms step_avg:87.75ms +step:298/1670 train_time:26149ms step_avg:87.75ms +step:299/1670 train_time:26238ms step_avg:87.75ms +step:300/1670 train_time:26325ms step_avg:87.75ms +step:301/1670 train_time:26413ms step_avg:87.75ms +step:302/1670 train_time:26500ms step_avg:87.75ms +step:303/1670 train_time:26588ms step_avg:87.75ms +step:304/1670 train_time:26676ms step_avg:87.75ms +step:305/1670 train_time:26762ms step_avg:87.75ms +step:306/1670 train_time:26850ms step_avg:87.74ms +step:307/1670 train_time:26937ms step_avg:87.74ms +step:308/1670 train_time:27025ms step_avg:87.74ms +step:309/1670 train_time:27113ms step_avg:87.74ms +step:310/1670 train_time:27201ms step_avg:87.74ms +step:311/1670 train_time:27289ms step_avg:87.75ms +step:312/1670 train_time:27377ms step_avg:87.75ms +step:313/1670 train_time:27463ms step_avg:87.74ms +step:314/1670 train_time:27551ms step_avg:87.74ms +step:315/1670 train_time:27638ms step_avg:87.74ms +step:316/1670 train_time:27725ms step_avg:87.74ms +step:317/1670 train_time:27813ms step_avg:87.74ms +step:318/1670 train_time:27900ms step_avg:87.74ms +step:319/1670 train_time:27988ms step_avg:87.74ms +step:320/1670 train_time:28077ms step_avg:87.74ms +step:321/1670 train_time:28164ms step_avg:87.74ms +step:322/1670 train_time:28252ms step_avg:87.74ms +step:323/1670 train_time:28340ms step_avg:87.74ms +step:324/1670 train_time:28428ms step_avg:87.74ms +step:325/1670 train_time:28516ms step_avg:87.74ms +step:326/1670 train_time:28603ms step_avg:87.74ms +step:327/1670 train_time:28690ms step_avg:87.74ms +step:328/1670 train_time:28778ms step_avg:87.74ms +step:329/1670 train_time:28865ms step_avg:87.74ms +step:330/1670 train_time:28953ms step_avg:87.74ms +step:331/1670 train_time:29040ms step_avg:87.74ms +step:332/1670 train_time:29128ms step_avg:87.74ms +step:333/1670 train_time:29216ms step_avg:87.74ms +step:334/1670 train_time:29304ms step_avg:87.74ms +step:335/1670 train_time:29391ms step_avg:87.73ms +step:336/1670 train_time:29479ms step_avg:87.73ms +step:337/1670 train_time:29567ms step_avg:87.73ms +step:338/1670 train_time:29655ms step_avg:87.74ms +step:339/1670 train_time:29743ms step_avg:87.74ms +step:340/1670 train_time:29830ms step_avg:87.74ms +step:341/1670 train_time:29918ms step_avg:87.74ms +step:342/1670 train_time:30006ms step_avg:87.74ms +step:343/1670 train_time:30094ms step_avg:87.74ms +step:344/1670 train_time:30181ms step_avg:87.74ms +step:345/1670 train_time:30270ms step_avg:87.74ms +step:346/1670 train_time:30358ms step_avg:87.74ms +step:347/1670 train_time:30445ms step_avg:87.74ms +step:348/1670 train_time:30533ms step_avg:87.74ms +step:349/1670 train_time:30621ms step_avg:87.74ms +step:350/1670 train_time:30708ms step_avg:87.74ms +step:351/1670 train_time:30797ms step_avg:87.74ms +step:352/1670 train_time:30884ms step_avg:87.74ms +step:353/1670 train_time:30972ms step_avg:87.74ms +step:354/1670 train_time:31059ms step_avg:87.74ms +step:355/1670 train_time:31147ms step_avg:87.74ms +step:356/1670 train_time:31235ms step_avg:87.74ms +step:357/1670 train_time:31322ms step_avg:87.74ms +step:358/1670 train_time:31410ms step_avg:87.74ms +step:359/1670 train_time:31498ms step_avg:87.74ms +step:360/1670 train_time:31585ms step_avg:87.74ms +step:361/1670 train_time:31673ms step_avg:87.74ms +step:362/1670 train_time:31760ms step_avg:87.74ms +step:363/1670 train_time:31848ms step_avg:87.73ms +step:364/1670 train_time:31936ms step_avg:87.74ms +step:365/1670 train_time:32023ms step_avg:87.73ms +step:366/1670 train_time:32113ms step_avg:87.74ms +step:367/1670 train_time:32200ms step_avg:87.74ms +step:368/1670 train_time:32287ms step_avg:87.74ms +step:369/1670 train_time:32375ms step_avg:87.74ms +step:370/1670 train_time:32462ms step_avg:87.74ms +step:371/1670 train_time:32550ms step_avg:87.74ms +step:372/1670 train_time:32638ms step_avg:87.74ms +step:373/1670 train_time:32726ms step_avg:87.74ms +step:374/1670 train_time:32814ms step_avg:87.74ms +step:375/1670 train_time:32901ms step_avg:87.74ms +step:375/1670 val_loss:3.8204 train_time:32989ms step_avg:87.97ms +step:376/1670 train_time:33010ms step_avg:87.79ms +step:377/1670 train_time:33082ms step_avg:87.75ms +step:378/1670 train_time:33173ms step_avg:87.76ms +step:379/1670 train_time:33262ms step_avg:87.76ms +step:380/1670 train_time:33349ms step_avg:87.76ms +step:381/1670 train_time:33436ms step_avg:87.76ms +step:382/1670 train_time:33524ms step_avg:87.76ms +step:383/1670 train_time:33609ms step_avg:87.75ms +step:384/1670 train_time:33696ms step_avg:87.75ms +step:385/1670 train_time:33784ms step_avg:87.75ms +step:386/1670 train_time:33870ms step_avg:87.75ms +step:387/1670 train_time:33959ms step_avg:87.75ms +step:388/1670 train_time:34050ms step_avg:87.76ms +step:389/1670 train_time:34140ms step_avg:87.76ms +step:390/1670 train_time:34230ms step_avg:87.77ms +step:391/1670 train_time:34317ms step_avg:87.77ms +step:392/1670 train_time:34405ms step_avg:87.77ms +step:393/1670 train_time:34492ms step_avg:87.77ms +step:394/1670 train_time:34579ms step_avg:87.76ms +step:395/1670 train_time:34665ms step_avg:87.76ms +step:396/1670 train_time:34752ms step_avg:87.76ms +step:397/1670 train_time:34839ms step_avg:87.75ms +step:398/1670 train_time:34926ms step_avg:87.75ms +step:399/1670 train_time:35014ms step_avg:87.75ms +step:400/1670 train_time:35104ms step_avg:87.76ms +step:401/1670 train_time:35192ms step_avg:87.76ms +step:402/1670 train_time:35280ms step_avg:87.76ms +step:403/1670 train_time:35368ms step_avg:87.76ms +step:404/1670 train_time:35456ms step_avg:87.76ms +step:405/1670 train_time:35543ms step_avg:87.76ms +step:406/1670 train_time:35630ms step_avg:87.76ms +step:407/1670 train_time:35717ms step_avg:87.76ms +step:408/1670 train_time:35804ms step_avg:87.76ms +step:409/1670 train_time:35891ms step_avg:87.75ms +step:410/1670 train_time:35979ms step_avg:87.75ms +step:411/1670 train_time:36068ms step_avg:87.76ms +step:412/1670 train_time:36156ms step_avg:87.76ms +step:413/1670 train_time:36244ms step_avg:87.76ms +step:414/1670 train_time:36332ms step_avg:87.76ms +step:415/1670 train_time:36419ms step_avg:87.76ms +step:416/1670 train_time:36507ms step_avg:87.76ms +step:417/1670 train_time:36594ms step_avg:87.76ms +step:418/1670 train_time:36681ms step_avg:87.75ms +step:419/1670 train_time:36768ms step_avg:87.75ms +step:420/1670 train_time:36855ms step_avg:87.75ms +step:421/1670 train_time:36943ms step_avg:87.75ms +step:422/1670 train_time:37030ms step_avg:87.75ms +step:423/1670 train_time:37119ms step_avg:87.75ms +step:424/1670 train_time:37207ms step_avg:87.75ms +step:425/1670 train_time:37295ms step_avg:87.75ms +step:426/1670 train_time:37383ms step_avg:87.75ms +step:427/1670 train_time:37470ms step_avg:87.75ms +step:428/1670 train_time:37558ms step_avg:87.75ms +step:429/1670 train_time:37645ms step_avg:87.75ms +step:430/1670 train_time:37733ms step_avg:87.75ms +step:431/1670 train_time:37820ms step_avg:87.75ms +step:432/1670 train_time:37907ms step_avg:87.75ms +step:433/1670 train_time:37995ms step_avg:87.75ms +step:434/1670 train_time:38083ms step_avg:87.75ms +step:435/1670 train_time:38170ms step_avg:87.75ms +step:436/1670 train_time:38258ms step_avg:87.75ms +step:437/1670 train_time:38346ms step_avg:87.75ms +step:438/1670 train_time:38434ms step_avg:87.75ms +step:439/1670 train_time:38523ms step_avg:87.75ms +step:440/1670 train_time:38610ms step_avg:87.75ms +step:441/1670 train_time:38697ms step_avg:87.75ms +step:442/1670 train_time:38785ms step_avg:87.75ms +step:443/1670 train_time:38872ms step_avg:87.75ms +step:444/1670 train_time:38960ms step_avg:87.75ms +step:445/1670 train_time:39047ms step_avg:87.75ms +step:446/1670 train_time:39135ms step_avg:87.75ms +step:447/1670 train_time:39224ms step_avg:87.75ms +step:448/1670 train_time:39311ms step_avg:87.75ms +step:449/1670 train_time:39399ms step_avg:87.75ms +step:450/1670 train_time:39487ms step_avg:87.75ms +step:451/1670 train_time:39574ms step_avg:87.75ms +step:452/1670 train_time:39662ms step_avg:87.75ms +step:453/1670 train_time:39749ms step_avg:87.75ms +step:454/1670 train_time:39836ms step_avg:87.74ms +step:455/1670 train_time:39924ms step_avg:87.75ms +step:456/1670 train_time:40011ms step_avg:87.74ms +step:457/1670 train_time:40099ms step_avg:87.74ms +step:458/1670 train_time:40187ms step_avg:87.74ms +step:459/1670 train_time:40274ms step_avg:87.74ms +step:460/1670 train_time:40363ms step_avg:87.75ms +step:461/1670 train_time:40450ms step_avg:87.74ms +step:462/1670 train_time:40537ms step_avg:87.74ms +step:463/1670 train_time:40626ms step_avg:87.75ms +step:464/1670 train_time:40714ms step_avg:87.75ms +step:465/1670 train_time:40801ms step_avg:87.74ms +step:466/1670 train_time:40889ms step_avg:87.74ms +step:467/1670 train_time:40976ms step_avg:87.74ms +step:468/1670 train_time:41064ms step_avg:87.74ms +step:469/1670 train_time:41151ms step_avg:87.74ms +step:470/1670 train_time:41240ms step_avg:87.74ms +step:471/1670 train_time:41328ms step_avg:87.75ms +step:472/1670 train_time:41416ms step_avg:87.75ms +step:473/1670 train_time:41504ms step_avg:87.75ms +step:474/1670 train_time:41591ms step_avg:87.75ms +step:475/1670 train_time:41679ms step_avg:87.74ms +step:476/1670 train_time:41767ms step_avg:87.75ms +step:477/1670 train_time:41855ms step_avg:87.75ms +step:478/1670 train_time:41942ms step_avg:87.74ms +step:479/1670 train_time:42029ms step_avg:87.74ms +step:480/1670 train_time:42117ms step_avg:87.74ms +step:481/1670 train_time:42205ms step_avg:87.74ms +step:482/1670 train_time:42293ms step_avg:87.74ms +step:483/1670 train_time:42380ms step_avg:87.74ms +step:484/1670 train_time:42468ms step_avg:87.74ms +step:485/1670 train_time:42555ms step_avg:87.74ms +step:486/1670 train_time:42642ms step_avg:87.74ms +step:487/1670 train_time:42730ms step_avg:87.74ms +step:488/1670 train_time:42817ms step_avg:87.74ms +step:489/1670 train_time:42905ms step_avg:87.74ms +step:490/1670 train_time:42993ms step_avg:87.74ms +step:491/1670 train_time:43081ms step_avg:87.74ms +step:492/1670 train_time:43168ms step_avg:87.74ms +step:493/1670 train_time:43256ms step_avg:87.74ms +step:494/1670 train_time:43344ms step_avg:87.74ms +step:495/1670 train_time:43431ms step_avg:87.74ms +step:496/1670 train_time:43519ms step_avg:87.74ms +step:497/1670 train_time:43607ms step_avg:87.74ms +step:498/1670 train_time:43694ms step_avg:87.74ms +step:499/1670 train_time:43783ms step_avg:87.74ms +step:500/1670 train_time:43870ms step_avg:87.74ms +step:500/1670 val_loss:3.7193 train_time:43960ms step_avg:87.92ms +step:501/1670 train_time:43981ms step_avg:87.79ms +step:502/1670 train_time:44049ms step_avg:87.75ms +step:503/1670 train_time:44140ms step_avg:87.75ms +step:504/1670 train_time:44227ms step_avg:87.75ms +step:505/1670 train_time:44314ms step_avg:87.75ms +step:506/1670 train_time:44402ms step_avg:87.75ms +step:507/1670 train_time:44489ms step_avg:87.75ms +step:508/1670 train_time:44576ms step_avg:87.75ms +step:509/1670 train_time:44663ms step_avg:87.75ms +step:510/1670 train_time:44751ms step_avg:87.75ms +step:511/1670 train_time:44837ms step_avg:87.74ms +step:512/1670 train_time:44926ms step_avg:87.75ms +step:513/1670 train_time:45016ms step_avg:87.75ms +step:514/1670 train_time:45104ms step_avg:87.75ms +step:515/1670 train_time:45192ms step_avg:87.75ms +step:516/1670 train_time:45280ms step_avg:87.75ms +step:517/1670 train_time:45367ms step_avg:87.75ms +step:518/1670 train_time:45455ms step_avg:87.75ms +step:519/1670 train_time:45542ms step_avg:87.75ms +step:520/1670 train_time:45630ms step_avg:87.75ms +step:521/1670 train_time:45717ms step_avg:87.75ms +step:522/1670 train_time:45803ms step_avg:87.75ms +step:523/1670 train_time:45891ms step_avg:87.75ms +step:524/1670 train_time:45979ms step_avg:87.75ms +step:525/1670 train_time:46068ms step_avg:87.75ms +step:526/1670 train_time:46158ms step_avg:87.75ms +step:527/1670 train_time:46245ms step_avg:87.75ms +step:528/1670 train_time:46332ms step_avg:87.75ms +step:529/1670 train_time:46421ms step_avg:87.75ms +step:530/1670 train_time:46507ms step_avg:87.75ms +step:531/1670 train_time:46595ms step_avg:87.75ms +step:532/1670 train_time:46683ms step_avg:87.75ms +step:533/1670 train_time:46770ms step_avg:87.75ms +step:534/1670 train_time:46859ms step_avg:87.75ms +step:535/1670 train_time:46946ms step_avg:87.75ms +step:536/1670 train_time:47034ms step_avg:87.75ms +step:537/1670 train_time:47123ms step_avg:87.75ms +step:538/1670 train_time:47211ms step_avg:87.75ms +step:539/1670 train_time:47300ms step_avg:87.76ms +step:540/1670 train_time:47388ms step_avg:87.75ms +step:541/1670 train_time:47476ms step_avg:87.76ms +step:542/1670 train_time:47563ms step_avg:87.75ms +step:543/1670 train_time:47650ms step_avg:87.75ms +step:544/1670 train_time:47738ms step_avg:87.75ms +step:545/1670 train_time:47826ms step_avg:87.75ms +step:546/1670 train_time:47916ms step_avg:87.76ms +step:547/1670 train_time:48004ms step_avg:87.76ms +step:548/1670 train_time:48094ms step_avg:87.76ms +step:549/1670 train_time:48183ms step_avg:87.77ms +step:550/1670 train_time:48273ms step_avg:87.77ms +step:551/1670 train_time:48361ms step_avg:87.77ms +step:552/1670 train_time:48450ms step_avg:87.77ms +step:553/1670 train_time:48539ms step_avg:87.77ms +step:554/1670 train_time:48628ms step_avg:87.78ms +step:555/1670 train_time:48717ms step_avg:87.78ms +step:556/1670 train_time:48806ms step_avg:87.78ms +step:557/1670 train_time:48895ms step_avg:87.78ms +step:558/1670 train_time:48984ms step_avg:87.78ms +step:559/1670 train_time:49073ms step_avg:87.79ms +step:560/1670 train_time:49163ms step_avg:87.79ms +step:561/1670 train_time:49252ms step_avg:87.79ms +step:562/1670 train_time:49341ms step_avg:87.80ms +step:563/1670 train_time:49430ms step_avg:87.80ms +step:564/1670 train_time:49520ms step_avg:87.80ms +step:565/1670 train_time:49608ms step_avg:87.80ms +step:566/1670 train_time:49697ms step_avg:87.80ms +step:567/1670 train_time:49786ms step_avg:87.81ms +step:568/1670 train_time:49875ms step_avg:87.81ms +step:569/1670 train_time:49964ms step_avg:87.81ms +step:570/1670 train_time:50052ms step_avg:87.81ms +step:571/1670 train_time:50141ms step_avg:87.81ms +step:572/1670 train_time:50231ms step_avg:87.82ms +step:573/1670 train_time:50320ms step_avg:87.82ms +step:574/1670 train_time:50409ms step_avg:87.82ms +step:575/1670 train_time:50498ms step_avg:87.82ms +step:576/1670 train_time:50586ms step_avg:87.82ms +step:577/1670 train_time:50675ms step_avg:87.83ms +step:578/1670 train_time:50764ms step_avg:87.83ms +step:579/1670 train_time:50853ms step_avg:87.83ms +step:580/1670 train_time:50941ms step_avg:87.83ms +step:581/1670 train_time:51030ms step_avg:87.83ms +step:582/1670 train_time:51119ms step_avg:87.83ms +step:583/1670 train_time:51208ms step_avg:87.84ms +step:584/1670 train_time:51298ms step_avg:87.84ms +step:585/1670 train_time:51387ms step_avg:87.84ms +step:586/1670 train_time:51477ms step_avg:87.84ms +step:587/1670 train_time:51565ms step_avg:87.84ms +step:588/1670 train_time:51654ms step_avg:87.85ms +step:589/1670 train_time:51742ms step_avg:87.85ms +step:590/1670 train_time:51831ms step_avg:87.85ms +step:591/1670 train_time:51920ms step_avg:87.85ms +step:592/1670 train_time:52009ms step_avg:87.85ms +step:593/1670 train_time:52099ms step_avg:87.86ms +step:594/1670 train_time:52188ms step_avg:87.86ms +step:595/1670 train_time:52278ms step_avg:87.86ms +step:596/1670 train_time:52366ms step_avg:87.86ms +step:597/1670 train_time:52455ms step_avg:87.86ms +step:598/1670 train_time:52544ms step_avg:87.87ms +step:599/1670 train_time:52633ms step_avg:87.87ms +step:600/1670 train_time:52722ms step_avg:87.87ms +step:601/1670 train_time:52811ms step_avg:87.87ms +step:602/1670 train_time:52899ms step_avg:87.87ms +step:603/1670 train_time:52988ms step_avg:87.87ms +step:604/1670 train_time:53076ms step_avg:87.87ms +step:605/1670 train_time:53165ms step_avg:87.88ms +step:606/1670 train_time:53254ms step_avg:87.88ms +step:607/1670 train_time:53343ms step_avg:87.88ms +step:608/1670 train_time:53432ms step_avg:87.88ms +step:609/1670 train_time:53522ms step_avg:87.88ms +step:610/1670 train_time:53612ms step_avg:87.89ms +step:611/1670 train_time:53702ms step_avg:87.89ms +step:612/1670 train_time:53791ms step_avg:87.89ms +step:613/1670 train_time:53882ms step_avg:87.90ms +step:614/1670 train_time:53971ms step_avg:87.90ms +step:615/1670 train_time:54060ms step_avg:87.90ms +step:616/1670 train_time:54148ms step_avg:87.90ms +step:617/1670 train_time:54237ms step_avg:87.90ms +step:618/1670 train_time:54325ms step_avg:87.91ms +step:619/1670 train_time:54415ms step_avg:87.91ms +step:620/1670 train_time:54504ms step_avg:87.91ms +step:621/1670 train_time:54593ms step_avg:87.91ms +step:622/1670 train_time:54682ms step_avg:87.91ms +step:623/1670 train_time:54771ms step_avg:87.91ms +step:624/1670 train_time:54860ms step_avg:87.92ms +step:625/1670 train_time:54950ms step_avg:87.92ms +step:625/1670 val_loss:3.6189 train_time:55041ms step_avg:88.07ms +step:626/1670 train_time:55060ms step_avg:87.96ms +step:627/1670 train_time:55131ms step_avg:87.93ms +step:628/1670 train_time:55220ms step_avg:87.93ms +step:629/1670 train_time:55310ms step_avg:87.93ms +step:630/1670 train_time:55398ms step_avg:87.93ms +step:631/1670 train_time:55486ms step_avg:87.93ms +step:632/1670 train_time:55574ms step_avg:87.93ms +step:633/1670 train_time:55661ms step_avg:87.93ms +step:634/1670 train_time:55749ms step_avg:87.93ms +step:635/1670 train_time:55838ms step_avg:87.93ms +step:636/1670 train_time:55926ms step_avg:87.93ms +step:637/1670 train_time:56019ms step_avg:87.94ms +step:638/1670 train_time:56108ms step_avg:87.94ms +step:639/1670 train_time:56199ms step_avg:87.95ms +step:640/1670 train_time:56288ms step_avg:87.95ms +step:641/1670 train_time:56377ms step_avg:87.95ms +step:642/1670 train_time:56465ms step_avg:87.95ms +step:643/1670 train_time:56552ms step_avg:87.95ms +step:644/1670 train_time:56640ms step_avg:87.95ms +step:645/1670 train_time:56728ms step_avg:87.95ms +step:646/1670 train_time:56817ms step_avg:87.95ms +step:647/1670 train_time:56906ms step_avg:87.95ms +step:648/1670 train_time:56996ms step_avg:87.96ms +step:649/1670 train_time:57086ms step_avg:87.96ms +step:650/1670 train_time:57176ms step_avg:87.96ms +step:651/1670 train_time:57265ms step_avg:87.96ms +step:652/1670 train_time:57353ms step_avg:87.96ms +step:653/1670 train_time:57442ms step_avg:87.97ms +step:654/1670 train_time:57530ms step_avg:87.97ms +step:655/1670 train_time:57618ms step_avg:87.97ms +step:656/1670 train_time:57706ms step_avg:87.97ms +step:657/1670 train_time:57795ms step_avg:87.97ms +step:658/1670 train_time:57883ms step_avg:87.97ms +step:659/1670 train_time:57973ms step_avg:87.97ms +step:660/1670 train_time:58063ms step_avg:87.97ms +step:661/1670 train_time:58152ms step_avg:87.98ms +step:662/1670 train_time:58241ms step_avg:87.98ms +step:663/1670 train_time:58330ms step_avg:87.98ms +step:664/1670 train_time:58419ms step_avg:87.98ms +step:665/1670 train_time:58507ms step_avg:87.98ms +step:666/1670 train_time:58596ms step_avg:87.98ms +step:667/1670 train_time:58685ms step_avg:87.98ms +step:668/1670 train_time:58773ms step_avg:87.98ms +step:669/1670 train_time:58863ms step_avg:87.99ms +step:670/1670 train_time:58952ms step_avg:87.99ms +step:671/1670 train_time:59042ms step_avg:87.99ms +step:672/1670 train_time:59130ms step_avg:87.99ms +step:673/1670 train_time:59220ms step_avg:87.99ms +step:674/1670 train_time:59309ms step_avg:87.99ms +step:675/1670 train_time:59398ms step_avg:88.00ms +step:676/1670 train_time:59486ms step_avg:88.00ms +step:677/1670 train_time:59575ms step_avg:88.00ms +step:678/1670 train_time:59664ms step_avg:88.00ms +step:679/1670 train_time:59753ms step_avg:88.00ms +step:680/1670 train_time:59842ms step_avg:88.00ms +step:681/1670 train_time:59930ms step_avg:88.00ms +step:682/1670 train_time:60020ms step_avg:88.01ms +step:683/1670 train_time:60109ms step_avg:88.01ms +step:684/1670 train_time:60199ms step_avg:88.01ms +step:685/1670 train_time:60287ms step_avg:88.01ms +step:686/1670 train_time:60376ms step_avg:88.01ms +step:687/1670 train_time:60465ms step_avg:88.01ms +step:688/1670 train_time:60554ms step_avg:88.01ms +step:689/1670 train_time:60643ms step_avg:88.02ms +step:690/1670 train_time:60731ms step_avg:88.02ms +step:691/1670 train_time:60820ms step_avg:88.02ms +step:692/1670 train_time:60908ms step_avg:88.02ms +step:693/1670 train_time:60999ms step_avg:88.02ms +step:694/1670 train_time:61087ms step_avg:88.02ms +step:695/1670 train_time:61176ms step_avg:88.02ms +step:696/1670 train_time:61266ms step_avg:88.03ms +step:697/1670 train_time:61355ms step_avg:88.03ms +step:698/1670 train_time:61445ms step_avg:88.03ms +step:699/1670 train_time:61534ms step_avg:88.03ms +step:700/1670 train_time:61624ms step_avg:88.03ms +step:701/1670 train_time:61713ms step_avg:88.04ms +step:702/1670 train_time:61802ms step_avg:88.04ms +step:703/1670 train_time:61890ms step_avg:88.04ms +step:704/1670 train_time:61979ms step_avg:88.04ms +step:705/1670 train_time:62068ms step_avg:88.04ms +step:706/1670 train_time:62157ms step_avg:88.04ms +step:707/1670 train_time:62246ms step_avg:88.04ms +step:708/1670 train_time:62334ms step_avg:88.04ms +step:709/1670 train_time:62423ms step_avg:88.04ms +step:710/1670 train_time:62512ms step_avg:88.04ms +step:711/1670 train_time:62602ms step_avg:88.05ms +step:712/1670 train_time:62690ms step_avg:88.05ms +step:713/1670 train_time:62779ms step_avg:88.05ms +step:714/1670 train_time:62867ms step_avg:88.05ms +step:715/1670 train_time:62956ms step_avg:88.05ms +step:716/1670 train_time:63045ms step_avg:88.05ms +step:717/1670 train_time:63134ms step_avg:88.05ms +step:718/1670 train_time:63224ms step_avg:88.06ms +step:719/1670 train_time:63313ms step_avg:88.06ms +step:720/1670 train_time:63402ms step_avg:88.06ms +step:721/1670 train_time:63490ms step_avg:88.06ms +step:722/1670 train_time:63579ms step_avg:88.06ms +step:723/1670 train_time:63668ms step_avg:88.06ms +step:724/1670 train_time:63757ms step_avg:88.06ms +step:725/1670 train_time:63845ms step_avg:88.06ms +step:726/1670 train_time:63934ms step_avg:88.06ms +step:727/1670 train_time:64023ms step_avg:88.06ms +step:728/1670 train_time:64111ms step_avg:88.06ms +step:729/1670 train_time:64200ms step_avg:88.07ms +step:730/1670 train_time:64289ms step_avg:88.07ms +step:731/1670 train_time:64377ms step_avg:88.07ms +step:732/1670 train_time:64466ms step_avg:88.07ms +step:733/1670 train_time:64556ms step_avg:88.07ms +step:734/1670 train_time:64645ms step_avg:88.07ms +step:735/1670 train_time:64734ms step_avg:88.07ms +step:736/1670 train_time:64823ms step_avg:88.08ms +step:737/1670 train_time:64912ms step_avg:88.08ms +step:738/1670 train_time:65001ms step_avg:88.08ms +step:739/1670 train_time:65089ms step_avg:88.08ms +step:740/1670 train_time:65178ms step_avg:88.08ms +step:741/1670 train_time:65267ms step_avg:88.08ms +step:742/1670 train_time:65357ms step_avg:88.08ms +step:743/1670 train_time:65445ms step_avg:88.08ms +step:744/1670 train_time:65534ms step_avg:88.08ms +step:745/1670 train_time:65624ms step_avg:88.09ms +step:746/1670 train_time:65713ms step_avg:88.09ms +step:747/1670 train_time:65802ms step_avg:88.09ms +step:748/1670 train_time:65890ms step_avg:88.09ms +step:749/1670 train_time:65979ms step_avg:88.09ms +step:750/1670 train_time:66068ms step_avg:88.09ms +step:750/1670 val_loss:3.5662 train_time:66159ms step_avg:88.21ms +step:751/1670 train_time:66178ms step_avg:88.12ms +step:752/1670 train_time:66250ms step_avg:88.10ms +step:753/1670 train_time:66346ms step_avg:88.11ms +step:754/1670 train_time:66436ms step_avg:88.11ms +step:755/1670 train_time:66524ms step_avg:88.11ms +step:756/1670 train_time:66612ms step_avg:88.11ms +step:757/1670 train_time:66700ms step_avg:88.11ms +step:758/1670 train_time:66788ms step_avg:88.11ms +step:759/1670 train_time:66876ms step_avg:88.11ms +step:760/1670 train_time:66965ms step_avg:88.11ms +step:761/1670 train_time:67054ms step_avg:88.11ms +step:762/1670 train_time:67143ms step_avg:88.11ms +step:763/1670 train_time:67234ms step_avg:88.12ms +step:764/1670 train_time:67324ms step_avg:88.12ms +step:765/1670 train_time:67415ms step_avg:88.12ms +step:766/1670 train_time:67504ms step_avg:88.13ms +step:767/1670 train_time:67594ms step_avg:88.13ms +step:768/1670 train_time:67682ms step_avg:88.13ms +step:769/1670 train_time:67771ms step_avg:88.13ms +step:770/1670 train_time:67860ms step_avg:88.13ms +step:771/1670 train_time:67948ms step_avg:88.13ms +step:772/1670 train_time:68037ms step_avg:88.13ms +step:773/1670 train_time:68126ms step_avg:88.13ms +step:774/1670 train_time:68215ms step_avg:88.13ms +step:775/1670 train_time:68305ms step_avg:88.14ms +step:776/1670 train_time:68395ms step_avg:88.14ms +step:777/1670 train_time:68485ms step_avg:88.14ms +step:778/1670 train_time:68575ms step_avg:88.14ms +step:779/1670 train_time:68664ms step_avg:88.14ms +step:780/1670 train_time:68753ms step_avg:88.14ms +step:781/1670 train_time:68842ms step_avg:88.15ms +step:782/1670 train_time:68930ms step_avg:88.15ms +step:783/1670 train_time:69019ms step_avg:88.15ms +step:784/1670 train_time:69108ms step_avg:88.15ms +step:785/1670 train_time:69197ms step_avg:88.15ms +step:786/1670 train_time:69286ms step_avg:88.15ms +step:787/1670 train_time:69376ms step_avg:88.15ms +step:788/1670 train_time:69465ms step_avg:88.15ms +step:789/1670 train_time:69555ms step_avg:88.16ms +step:790/1670 train_time:69644ms step_avg:88.16ms +step:791/1670 train_time:69733ms step_avg:88.16ms +step:792/1670 train_time:69822ms step_avg:88.16ms +step:793/1670 train_time:69910ms step_avg:88.16ms +step:794/1670 train_time:70000ms step_avg:88.16ms +step:795/1670 train_time:70088ms step_avg:88.16ms +step:796/1670 train_time:70178ms step_avg:88.16ms +step:797/1670 train_time:70267ms step_avg:88.16ms +step:798/1670 train_time:70357ms step_avg:88.17ms +step:799/1670 train_time:70446ms step_avg:88.17ms +step:800/1670 train_time:70535ms step_avg:88.17ms +step:801/1670 train_time:70624ms step_avg:88.17ms +step:802/1670 train_time:70712ms step_avg:88.17ms +step:803/1670 train_time:70802ms step_avg:88.17ms +step:804/1670 train_time:70890ms step_avg:88.17ms +step:805/1670 train_time:70979ms step_avg:88.17ms +step:806/1670 train_time:71068ms step_avg:88.17ms +step:807/1670 train_time:71158ms step_avg:88.18ms +step:808/1670 train_time:71247ms step_avg:88.18ms +step:809/1670 train_time:71337ms step_avg:88.18ms +step:810/1670 train_time:71425ms step_avg:88.18ms +step:811/1670 train_time:71515ms step_avg:88.18ms +step:812/1670 train_time:71603ms step_avg:88.18ms +step:813/1670 train_time:71692ms step_avg:88.18ms +step:814/1670 train_time:71781ms step_avg:88.18ms +step:815/1670 train_time:71869ms step_avg:88.18ms +step:816/1670 train_time:71959ms step_avg:88.18ms +step:817/1670 train_time:72047ms step_avg:88.18ms +step:818/1670 train_time:72136ms step_avg:88.19ms +step:819/1670 train_time:72224ms step_avg:88.19ms +step:820/1670 train_time:72314ms step_avg:88.19ms +step:821/1670 train_time:72403ms step_avg:88.19ms +step:822/1670 train_time:72493ms step_avg:88.19ms +step:823/1670 train_time:72582ms step_avg:88.19ms +step:824/1670 train_time:72670ms step_avg:88.19ms +step:825/1670 train_time:72759ms step_avg:88.19ms +step:826/1670 train_time:72847ms step_avg:88.19ms +step:827/1670 train_time:72937ms step_avg:88.19ms +step:828/1670 train_time:73025ms step_avg:88.19ms +step:829/1670 train_time:73114ms step_avg:88.20ms +step:830/1670 train_time:73203ms step_avg:88.20ms +step:831/1670 train_time:73292ms step_avg:88.20ms +step:832/1670 train_time:73381ms step_avg:88.20ms +step:833/1670 train_time:73469ms step_avg:88.20ms +step:834/1670 train_time:73558ms step_avg:88.20ms +step:835/1670 train_time:73647ms step_avg:88.20ms +step:836/1670 train_time:73736ms step_avg:88.20ms +step:837/1670 train_time:73825ms step_avg:88.20ms +step:838/1670 train_time:73914ms step_avg:88.20ms +step:839/1670 train_time:74003ms step_avg:88.20ms +step:840/1670 train_time:74093ms step_avg:88.21ms +step:841/1670 train_time:74182ms step_avg:88.21ms +step:842/1670 train_time:74271ms step_avg:88.21ms +step:843/1670 train_time:74360ms step_avg:88.21ms +step:844/1670 train_time:74448ms step_avg:88.21ms +step:845/1670 train_time:74538ms step_avg:88.21ms +step:846/1670 train_time:74626ms step_avg:88.21ms +step:847/1670 train_time:74716ms step_avg:88.21ms +step:848/1670 train_time:74804ms step_avg:88.21ms +step:849/1670 train_time:74894ms step_avg:88.21ms +step:850/1670 train_time:74984ms step_avg:88.22ms +step:851/1670 train_time:75072ms step_avg:88.22ms +step:852/1670 train_time:75162ms step_avg:88.22ms +step:853/1670 train_time:75251ms step_avg:88.22ms +step:854/1670 train_time:75339ms step_avg:88.22ms +step:855/1670 train_time:75429ms step_avg:88.22ms +step:856/1670 train_time:75518ms step_avg:88.22ms +step:857/1670 train_time:75607ms step_avg:88.22ms +step:858/1670 train_time:75696ms step_avg:88.22ms +step:859/1670 train_time:75785ms step_avg:88.22ms +step:860/1670 train_time:75874ms step_avg:88.23ms +step:861/1670 train_time:75964ms step_avg:88.23ms +step:862/1670 train_time:76053ms step_avg:88.23ms +step:863/1670 train_time:76143ms step_avg:88.23ms +step:864/1670 train_time:76232ms step_avg:88.23ms +step:865/1670 train_time:76321ms step_avg:88.23ms +step:866/1670 train_time:76410ms step_avg:88.23ms +step:867/1670 train_time:76499ms step_avg:88.23ms +step:868/1670 train_time:76588ms step_avg:88.23ms +step:869/1670 train_time:76677ms step_avg:88.24ms +step:870/1670 train_time:76766ms step_avg:88.24ms +step:871/1670 train_time:76855ms step_avg:88.24ms +step:872/1670 train_time:76945ms step_avg:88.24ms +step:873/1670 train_time:77034ms step_avg:88.24ms +step:874/1670 train_time:77123ms step_avg:88.24ms +step:875/1670 train_time:77212ms step_avg:88.24ms +step:875/1670 val_loss:3.5193 train_time:77302ms step_avg:88.34ms +step:876/1670 train_time:77321ms step_avg:88.27ms +step:877/1670 train_time:77395ms step_avg:88.25ms +step:878/1670 train_time:77486ms step_avg:88.25ms +step:879/1670 train_time:77576ms step_avg:88.25ms +step:880/1670 train_time:77664ms step_avg:88.25ms +step:881/1670 train_time:77752ms step_avg:88.25ms +step:882/1670 train_time:77839ms step_avg:88.25ms +step:883/1670 train_time:77926ms step_avg:88.25ms +step:884/1670 train_time:78015ms step_avg:88.25ms +step:885/1670 train_time:78103ms step_avg:88.25ms +step:886/1670 train_time:78192ms step_avg:88.25ms +step:887/1670 train_time:78283ms step_avg:88.26ms +step:888/1670 train_time:78376ms step_avg:88.26ms +step:889/1670 train_time:78467ms step_avg:88.26ms +step:890/1670 train_time:78556ms step_avg:88.27ms +step:891/1670 train_time:78644ms step_avg:88.27ms +step:892/1670 train_time:78734ms step_avg:88.27ms +step:893/1670 train_time:78821ms step_avg:88.27ms +step:894/1670 train_time:78908ms step_avg:88.26ms +step:895/1670 train_time:78997ms step_avg:88.26ms +step:896/1670 train_time:79085ms step_avg:88.26ms +step:897/1670 train_time:79173ms step_avg:88.26ms +step:898/1670 train_time:79263ms step_avg:88.27ms +step:899/1670 train_time:79355ms step_avg:88.27ms +step:900/1670 train_time:79446ms step_avg:88.27ms +step:901/1670 train_time:79537ms step_avg:88.28ms +step:902/1670 train_time:79626ms step_avg:88.28ms +step:903/1670 train_time:79715ms step_avg:88.28ms +step:904/1670 train_time:79803ms step_avg:88.28ms +step:905/1670 train_time:79893ms step_avg:88.28ms +step:906/1670 train_time:79981ms step_avg:88.28ms +step:907/1670 train_time:80069ms step_avg:88.28ms +step:908/1670 train_time:80157ms step_avg:88.28ms +step:909/1670 train_time:80246ms step_avg:88.28ms +step:910/1670 train_time:80336ms step_avg:88.28ms +step:911/1670 train_time:80426ms step_avg:88.28ms +step:912/1670 train_time:80517ms step_avg:88.29ms +step:913/1670 train_time:80607ms step_avg:88.29ms +step:914/1670 train_time:80697ms step_avg:88.29ms +step:915/1670 train_time:80786ms step_avg:88.29ms +step:916/1670 train_time:80874ms step_avg:88.29ms +step:917/1670 train_time:80963ms step_avg:88.29ms +step:918/1670 train_time:81051ms step_avg:88.29ms +step:919/1670 train_time:81139ms step_avg:88.29ms +step:920/1670 train_time:81228ms step_avg:88.29ms +step:921/1670 train_time:81318ms step_avg:88.29ms +step:922/1670 train_time:81408ms step_avg:88.30ms +step:923/1670 train_time:81497ms step_avg:88.30ms +step:924/1670 train_time:81587ms step_avg:88.30ms +step:925/1670 train_time:81677ms step_avg:88.30ms +step:926/1670 train_time:81766ms step_avg:88.30ms +step:927/1670 train_time:81856ms step_avg:88.30ms +step:928/1670 train_time:81945ms step_avg:88.30ms +step:929/1670 train_time:82034ms step_avg:88.30ms +step:930/1670 train_time:82121ms step_avg:88.30ms +step:931/1670 train_time:82211ms step_avg:88.30ms +step:932/1670 train_time:82300ms step_avg:88.30ms +step:933/1670 train_time:82389ms step_avg:88.31ms +step:934/1670 train_time:82479ms step_avg:88.31ms +step:935/1670 train_time:82568ms step_avg:88.31ms +step:936/1670 train_time:82657ms step_avg:88.31ms +step:937/1670 train_time:82747ms step_avg:88.31ms +step:938/1670 train_time:82836ms step_avg:88.31ms +step:939/1670 train_time:82926ms step_avg:88.31ms +step:940/1670 train_time:83015ms step_avg:88.31ms +step:941/1670 train_time:83103ms step_avg:88.31ms +step:942/1670 train_time:83192ms step_avg:88.31ms +step:943/1670 train_time:83280ms step_avg:88.31ms +step:944/1670 train_time:83370ms step_avg:88.32ms +step:945/1670 train_time:83459ms step_avg:88.32ms +step:946/1670 train_time:83547ms step_avg:88.32ms +step:947/1670 train_time:83637ms step_avg:88.32ms +step:948/1670 train_time:83726ms step_avg:88.32ms +step:949/1670 train_time:83816ms step_avg:88.32ms +step:950/1670 train_time:83905ms step_avg:88.32ms +step:951/1670 train_time:83996ms step_avg:88.32ms +step:952/1670 train_time:84085ms step_avg:88.32ms +step:953/1670 train_time:84175ms step_avg:88.33ms +step:954/1670 train_time:84263ms step_avg:88.33ms +step:955/1670 train_time:84352ms step_avg:88.33ms +step:956/1670 train_time:84441ms step_avg:88.33ms +step:957/1670 train_time:84530ms step_avg:88.33ms +step:958/1670 train_time:84619ms step_avg:88.33ms +step:959/1670 train_time:84708ms step_avg:88.33ms +step:960/1670 train_time:84797ms step_avg:88.33ms +step:961/1670 train_time:84886ms step_avg:88.33ms +step:962/1670 train_time:84975ms step_avg:88.33ms +step:963/1670 train_time:85063ms step_avg:88.33ms +step:964/1670 train_time:85153ms step_avg:88.33ms +step:965/1670 train_time:85241ms step_avg:88.33ms +step:966/1670 train_time:85330ms step_avg:88.33ms +step:967/1670 train_time:85418ms step_avg:88.33ms +step:968/1670 train_time:85507ms step_avg:88.33ms +step:969/1670 train_time:85597ms step_avg:88.34ms +step:970/1670 train_time:85686ms step_avg:88.34ms +step:971/1670 train_time:85775ms step_avg:88.34ms +step:972/1670 train_time:85864ms step_avg:88.34ms +step:973/1670 train_time:85953ms step_avg:88.34ms +step:974/1670 train_time:86041ms step_avg:88.34ms +step:975/1670 train_time:86131ms step_avg:88.34ms +step:976/1670 train_time:86220ms step_avg:88.34ms +step:977/1670 train_time:86309ms step_avg:88.34ms +step:978/1670 train_time:86397ms step_avg:88.34ms +step:979/1670 train_time:86486ms step_avg:88.34ms +step:980/1670 train_time:86575ms step_avg:88.34ms +step:981/1670 train_time:86664ms step_avg:88.34ms +step:982/1670 train_time:86754ms step_avg:88.34ms +step:983/1670 train_time:86842ms step_avg:88.34ms +step:984/1670 train_time:86932ms step_avg:88.35ms +step:985/1670 train_time:87020ms step_avg:88.35ms +step:986/1670 train_time:87109ms step_avg:88.35ms +step:987/1670 train_time:87198ms step_avg:88.35ms +step:988/1670 train_time:87287ms step_avg:88.35ms +step:989/1670 train_time:87376ms step_avg:88.35ms +step:990/1670 train_time:87466ms step_avg:88.35ms +step:991/1670 train_time:87556ms step_avg:88.35ms +step:992/1670 train_time:87644ms step_avg:88.35ms +step:993/1670 train_time:87733ms step_avg:88.35ms +step:994/1670 train_time:87821ms step_avg:88.35ms +step:995/1670 train_time:87910ms step_avg:88.35ms +step:996/1670 train_time:87999ms step_avg:88.35ms +step:997/1670 train_time:88089ms step_avg:88.35ms +step:998/1670 train_time:88178ms step_avg:88.35ms +step:999/1670 train_time:88267ms step_avg:88.36ms +step:1000/1670 train_time:88355ms step_avg:88.36ms +step:1000/1670 val_loss:3.4681 train_time:88446ms step_avg:88.45ms +step:1001/1670 train_time:88466ms step_avg:88.38ms +step:1002/1670 train_time:88539ms step_avg:88.36ms +step:1003/1670 train_time:88632ms step_avg:88.37ms +step:1004/1670 train_time:88723ms step_avg:88.37ms +step:1005/1670 train_time:88810ms step_avg:88.37ms +step:1006/1670 train_time:88899ms step_avg:88.37ms +step:1007/1670 train_time:88987ms step_avg:88.37ms +step:1008/1670 train_time:89074ms step_avg:88.37ms +step:1009/1670 train_time:89161ms step_avg:88.37ms +step:1010/1670 train_time:89249ms step_avg:88.37ms +step:1011/1670 train_time:89337ms step_avg:88.37ms +step:1012/1670 train_time:89428ms step_avg:88.37ms +step:1013/1670 train_time:89519ms step_avg:88.37ms +step:1014/1670 train_time:89609ms step_avg:88.37ms +step:1015/1670 train_time:89699ms step_avg:88.37ms +step:1016/1670 train_time:89788ms step_avg:88.37ms +step:1017/1670 train_time:89876ms step_avg:88.37ms +step:1018/1670 train_time:89965ms step_avg:88.37ms +step:1019/1670 train_time:90054ms step_avg:88.37ms +step:1020/1670 train_time:90142ms step_avg:88.37ms +step:1021/1670 train_time:90230ms step_avg:88.37ms +step:1022/1670 train_time:90317ms step_avg:88.37ms +step:1023/1670 train_time:90406ms step_avg:88.37ms +step:1024/1670 train_time:90496ms step_avg:88.38ms +step:1025/1670 train_time:90587ms step_avg:88.38ms +step:1026/1670 train_time:90676ms step_avg:88.38ms +step:1027/1670 train_time:90767ms step_avg:88.38ms +step:1028/1670 train_time:90856ms step_avg:88.38ms +step:1029/1670 train_time:90945ms step_avg:88.38ms +step:1030/1670 train_time:91033ms step_avg:88.38ms +step:1031/1670 train_time:91122ms step_avg:88.38ms +step:1032/1670 train_time:91210ms step_avg:88.38ms +step:1033/1670 train_time:91299ms step_avg:88.38ms +step:1034/1670 train_time:91388ms step_avg:88.38ms +step:1035/1670 train_time:91476ms step_avg:88.38ms +step:1036/1670 train_time:91567ms step_avg:88.39ms +step:1037/1670 train_time:91657ms step_avg:88.39ms +step:1038/1670 train_time:91748ms step_avg:88.39ms +step:1039/1670 train_time:91837ms step_avg:88.39ms +step:1040/1670 train_time:91926ms step_avg:88.39ms +step:1041/1670 train_time:92014ms step_avg:88.39ms +step:1042/1670 train_time:92103ms step_avg:88.39ms +step:1043/1670 train_time:92191ms step_avg:88.39ms +step:1044/1670 train_time:92280ms step_avg:88.39ms +step:1045/1670 train_time:92369ms step_avg:88.39ms +step:1046/1670 train_time:92457ms step_avg:88.39ms +step:1047/1670 train_time:92546ms step_avg:88.39ms +step:1048/1670 train_time:92635ms step_avg:88.39ms +step:1049/1670 train_time:92725ms step_avg:88.39ms +step:1050/1670 train_time:92814ms step_avg:88.39ms +step:1051/1670 train_time:92904ms step_avg:88.40ms +step:1052/1670 train_time:92993ms step_avg:88.40ms +step:1053/1670 train_time:93081ms step_avg:88.40ms +step:1054/1670 train_time:93170ms step_avg:88.40ms +step:1055/1670 train_time:93258ms step_avg:88.40ms +step:1056/1670 train_time:93346ms step_avg:88.40ms +step:1057/1670 train_time:93435ms step_avg:88.40ms +step:1058/1670 train_time:93524ms step_avg:88.40ms +step:1059/1670 train_time:93614ms step_avg:88.40ms +step:1060/1670 train_time:93704ms step_avg:88.40ms +step:1061/1670 train_time:93793ms step_avg:88.40ms +step:1062/1670 train_time:93883ms step_avg:88.40ms +step:1063/1670 train_time:93972ms step_avg:88.40ms +step:1064/1670 train_time:94061ms step_avg:88.40ms +step:1065/1670 train_time:94150ms step_avg:88.40ms +step:1066/1670 train_time:94238ms step_avg:88.40ms +step:1067/1670 train_time:94327ms step_avg:88.40ms +step:1068/1670 train_time:94415ms step_avg:88.40ms +step:1069/1670 train_time:94505ms step_avg:88.40ms +step:1070/1670 train_time:94594ms step_avg:88.41ms +step:1071/1670 train_time:94683ms step_avg:88.41ms +step:1072/1670 train_time:94773ms step_avg:88.41ms +step:1073/1670 train_time:94863ms step_avg:88.41ms +step:1074/1670 train_time:94952ms step_avg:88.41ms +step:1075/1670 train_time:95040ms step_avg:88.41ms +step:1076/1670 train_time:95129ms step_avg:88.41ms +step:1077/1670 train_time:95218ms step_avg:88.41ms +step:1078/1670 train_time:95306ms step_avg:88.41ms +step:1079/1670 train_time:95395ms step_avg:88.41ms +step:1080/1670 train_time:95483ms step_avg:88.41ms +step:1081/1670 train_time:95574ms step_avg:88.41ms +step:1082/1670 train_time:95663ms step_avg:88.41ms +step:1083/1670 train_time:95752ms step_avg:88.41ms +step:1084/1670 train_time:95842ms step_avg:88.41ms +step:1085/1670 train_time:95931ms step_avg:88.42ms +step:1086/1670 train_time:96019ms step_avg:88.42ms +step:1087/1670 train_time:96109ms step_avg:88.42ms +step:1088/1670 train_time:96197ms step_avg:88.42ms +step:1089/1670 train_time:96285ms step_avg:88.42ms +step:1090/1670 train_time:96375ms step_avg:88.42ms +step:1091/1670 train_time:96465ms step_avg:88.42ms +step:1092/1670 train_time:96554ms step_avg:88.42ms +step:1093/1670 train_time:96644ms step_avg:88.42ms +step:1094/1670 train_time:96734ms step_avg:88.42ms +step:1095/1670 train_time:96824ms step_avg:88.42ms +step:1096/1670 train_time:96914ms step_avg:88.42ms +step:1097/1670 train_time:97003ms step_avg:88.43ms +step:1098/1670 train_time:97093ms step_avg:88.43ms +step:1099/1670 train_time:97182ms step_avg:88.43ms +step:1100/1670 train_time:97273ms step_avg:88.43ms +step:1101/1670 train_time:97363ms step_avg:88.43ms +step:1102/1670 train_time:97452ms step_avg:88.43ms +step:1103/1670 train_time:97543ms step_avg:88.43ms +step:1104/1670 train_time:97632ms step_avg:88.43ms +step:1105/1670 train_time:97722ms step_avg:88.44ms +step:1106/1670 train_time:97813ms step_avg:88.44ms +step:1107/1670 train_time:97903ms step_avg:88.44ms +step:1108/1670 train_time:97992ms step_avg:88.44ms +step:1109/1670 train_time:98082ms step_avg:88.44ms +step:1110/1670 train_time:98172ms step_avg:88.44ms +step:1111/1670 train_time:98262ms step_avg:88.44ms +step:1112/1670 train_time:98352ms step_avg:88.45ms +step:1113/1670 train_time:98442ms step_avg:88.45ms +step:1114/1670 train_time:98531ms step_avg:88.45ms +step:1115/1670 train_time:98621ms step_avg:88.45ms +step:1116/1670 train_time:98711ms step_avg:88.45ms +step:1117/1670 train_time:98801ms step_avg:88.45ms +step:1118/1670 train_time:98890ms step_avg:88.45ms +step:1119/1670 train_time:98980ms step_avg:88.45ms +step:1120/1670 train_time:99070ms step_avg:88.46ms +step:1121/1670 train_time:99159ms step_avg:88.46ms +step:1122/1670 train_time:99249ms step_avg:88.46ms +step:1123/1670 train_time:99338ms step_avg:88.46ms +step:1124/1670 train_time:99427ms step_avg:88.46ms +step:1125/1670 train_time:99516ms step_avg:88.46ms +step:1125/1670 val_loss:3.4149 train_time:99608ms step_avg:88.54ms +step:1126/1670 train_time:99627ms step_avg:88.48ms +step:1127/1670 train_time:99698ms step_avg:88.46ms +step:1128/1670 train_time:99789ms step_avg:88.47ms +step:1129/1670 train_time:99881ms step_avg:88.47ms +step:1130/1670 train_time:99970ms step_avg:88.47ms +step:1131/1670 train_time:100058ms step_avg:88.47ms +step:1132/1670 train_time:100147ms step_avg:88.47ms +step:1133/1670 train_time:100236ms step_avg:88.47ms +step:1134/1670 train_time:100324ms step_avg:88.47ms +step:1135/1670 train_time:100416ms step_avg:88.47ms +step:1136/1670 train_time:100506ms step_avg:88.47ms +step:1137/1670 train_time:100600ms step_avg:88.48ms +step:1138/1670 train_time:100692ms step_avg:88.48ms +step:1139/1670 train_time:100782ms step_avg:88.48ms +step:1140/1670 train_time:100873ms step_avg:88.48ms +step:1141/1670 train_time:100962ms step_avg:88.49ms +step:1142/1670 train_time:101050ms step_avg:88.49ms +step:1143/1670 train_time:101139ms step_avg:88.49ms +step:1144/1670 train_time:101227ms step_avg:88.49ms +step:1145/1670 train_time:101317ms step_avg:88.49ms +step:1146/1670 train_time:101407ms step_avg:88.49ms +step:1147/1670 train_time:101497ms step_avg:88.49ms +step:1148/1670 train_time:101587ms step_avg:88.49ms +step:1149/1670 train_time:101679ms step_avg:88.49ms +step:1150/1670 train_time:101770ms step_avg:88.50ms +step:1151/1670 train_time:101859ms step_avg:88.50ms +step:1152/1670 train_time:101949ms step_avg:88.50ms +step:1153/1670 train_time:102038ms step_avg:88.50ms +step:1154/1670 train_time:102127ms step_avg:88.50ms +step:1155/1670 train_time:102216ms step_avg:88.50ms +step:1156/1670 train_time:102305ms step_avg:88.50ms +step:1157/1670 train_time:102394ms step_avg:88.50ms +step:1158/1670 train_time:102484ms step_avg:88.50ms +step:1159/1670 train_time:102575ms step_avg:88.50ms +step:1160/1670 train_time:102665ms step_avg:88.50ms +step:1161/1670 train_time:102755ms step_avg:88.51ms +step:1162/1670 train_time:102844ms step_avg:88.51ms +step:1163/1670 train_time:102934ms step_avg:88.51ms +step:1164/1670 train_time:103023ms step_avg:88.51ms +step:1165/1670 train_time:103112ms step_avg:88.51ms +step:1166/1670 train_time:103201ms step_avg:88.51ms +step:1167/1670 train_time:103290ms step_avg:88.51ms +step:1168/1670 train_time:103379ms step_avg:88.51ms +step:1169/1670 train_time:103468ms step_avg:88.51ms +step:1170/1670 train_time:103558ms step_avg:88.51ms +step:1171/1670 train_time:103648ms step_avg:88.51ms +step:1172/1670 train_time:103738ms step_avg:88.51ms +step:1173/1670 train_time:103828ms step_avg:88.52ms +step:1174/1670 train_time:103917ms step_avg:88.52ms +step:1175/1670 train_time:104007ms step_avg:88.52ms +step:1176/1670 train_time:104097ms step_avg:88.52ms +step:1177/1670 train_time:104187ms step_avg:88.52ms +step:1178/1670 train_time:104276ms step_avg:88.52ms +step:1179/1670 train_time:104366ms step_avg:88.52ms +step:1180/1670 train_time:104455ms step_avg:88.52ms +step:1181/1670 train_time:104545ms step_avg:88.52ms +step:1182/1670 train_time:104634ms step_avg:88.52ms +step:1183/1670 train_time:104724ms step_avg:88.52ms +step:1184/1670 train_time:104815ms step_avg:88.53ms +step:1185/1670 train_time:104905ms step_avg:88.53ms +step:1186/1670 train_time:104995ms step_avg:88.53ms +step:1187/1670 train_time:105085ms step_avg:88.53ms +step:1188/1670 train_time:105175ms step_avg:88.53ms +step:1189/1670 train_time:105264ms step_avg:88.53ms +step:1190/1670 train_time:105354ms step_avg:88.53ms +step:1191/1670 train_time:105443ms step_avg:88.53ms +step:1192/1670 train_time:105533ms step_avg:88.53ms +step:1193/1670 train_time:105622ms step_avg:88.54ms +step:1194/1670 train_time:105714ms step_avg:88.54ms +step:1195/1670 train_time:105803ms step_avg:88.54ms +step:1196/1670 train_time:105894ms step_avg:88.54ms +step:1197/1670 train_time:105983ms step_avg:88.54ms +step:1198/1670 train_time:106073ms step_avg:88.54ms +step:1199/1670 train_time:106162ms step_avg:88.54ms +step:1200/1670 train_time:106252ms step_avg:88.54ms +step:1201/1670 train_time:106341ms step_avg:88.54ms +step:1202/1670 train_time:106431ms step_avg:88.54ms +step:1203/1670 train_time:106520ms step_avg:88.55ms +step:1204/1670 train_time:106610ms step_avg:88.55ms +step:1205/1670 train_time:106699ms step_avg:88.55ms +step:1206/1670 train_time:106790ms step_avg:88.55ms +step:1207/1670 train_time:106879ms step_avg:88.55ms +step:1208/1670 train_time:106970ms step_avg:88.55ms +step:1209/1670 train_time:107059ms step_avg:88.55ms +step:1210/1670 train_time:107149ms step_avg:88.55ms +step:1211/1670 train_time:107238ms step_avg:88.55ms +step:1212/1670 train_time:107328ms step_avg:88.55ms +step:1213/1670 train_time:107418ms step_avg:88.56ms +step:1214/1670 train_time:107507ms step_avg:88.56ms +step:1215/1670 train_time:107597ms step_avg:88.56ms +step:1216/1670 train_time:107688ms step_avg:88.56ms +step:1217/1670 train_time:107777ms step_avg:88.56ms +step:1218/1670 train_time:107867ms step_avg:88.56ms +step:1219/1670 train_time:107957ms step_avg:88.56ms +step:1220/1670 train_time:108047ms step_avg:88.56ms +step:1221/1670 train_time:108137ms step_avg:88.56ms +step:1222/1670 train_time:108226ms step_avg:88.56ms +step:1223/1670 train_time:108316ms step_avg:88.57ms +step:1224/1670 train_time:108406ms step_avg:88.57ms +step:1225/1670 train_time:108496ms step_avg:88.57ms +step:1226/1670 train_time:108586ms step_avg:88.57ms +step:1227/1670 train_time:108676ms step_avg:88.57ms +step:1228/1670 train_time:108766ms step_avg:88.57ms +step:1229/1670 train_time:108856ms step_avg:88.57ms +step:1230/1670 train_time:108945ms step_avg:88.57ms +step:1231/1670 train_time:109035ms step_avg:88.57ms +step:1232/1670 train_time:109124ms step_avg:88.57ms +step:1233/1670 train_time:109214ms step_avg:88.58ms +step:1234/1670 train_time:109304ms step_avg:88.58ms +step:1235/1670 train_time:109393ms step_avg:88.58ms +step:1236/1670 train_time:109483ms step_avg:88.58ms +step:1237/1670 train_time:109573ms step_avg:88.58ms +step:1238/1670 train_time:109662ms step_avg:88.58ms +step:1239/1670 train_time:109752ms step_avg:88.58ms +step:1240/1670 train_time:109841ms step_avg:88.58ms +step:1241/1670 train_time:109931ms step_avg:88.58ms +step:1242/1670 train_time:110020ms step_avg:88.58ms +step:1243/1670 train_time:110110ms step_avg:88.58ms +step:1244/1670 train_time:110200ms step_avg:88.58ms +step:1245/1670 train_time:110289ms step_avg:88.59ms +step:1246/1670 train_time:110379ms step_avg:88.59ms +step:1247/1670 train_time:110468ms step_avg:88.59ms +step:1248/1670 train_time:110557ms step_avg:88.59ms +step:1249/1670 train_time:110647ms step_avg:88.59ms +step:1250/1670 train_time:110737ms step_avg:88.59ms +step:1250/1670 val_loss:3.3763 train_time:110829ms step_avg:88.66ms +step:1251/1670 train_time:110849ms step_avg:88.61ms +step:1252/1670 train_time:110923ms step_avg:88.60ms +step:1253/1670 train_time:111014ms step_avg:88.60ms +step:1254/1670 train_time:111103ms step_avg:88.60ms +step:1255/1670 train_time:111191ms step_avg:88.60ms +step:1256/1670 train_time:111281ms step_avg:88.60ms +step:1257/1670 train_time:111369ms step_avg:88.60ms +step:1258/1670 train_time:111458ms step_avg:88.60ms +step:1259/1670 train_time:111547ms step_avg:88.60ms +step:1260/1670 train_time:111635ms step_avg:88.60ms +step:1261/1670 train_time:111724ms step_avg:88.60ms +step:1262/1670 train_time:111816ms step_avg:88.60ms +step:1263/1670 train_time:111910ms step_avg:88.61ms +step:1264/1670 train_time:112003ms step_avg:88.61ms +step:1265/1670 train_time:112092ms step_avg:88.61ms +step:1266/1670 train_time:112182ms step_avg:88.61ms +step:1267/1670 train_time:112271ms step_avg:88.61ms +step:1268/1670 train_time:112360ms step_avg:88.61ms +step:1269/1670 train_time:112449ms step_avg:88.61ms +step:1270/1670 train_time:112538ms step_avg:88.61ms +step:1271/1670 train_time:112627ms step_avg:88.61ms +step:1272/1670 train_time:112717ms step_avg:88.61ms +step:1273/1670 train_time:112808ms step_avg:88.62ms +step:1274/1670 train_time:112900ms step_avg:88.62ms +step:1275/1670 train_time:112990ms step_avg:88.62ms +step:1276/1670 train_time:113080ms step_avg:88.62ms +step:1277/1670 train_time:113170ms step_avg:88.62ms +step:1278/1670 train_time:113259ms step_avg:88.62ms +step:1279/1670 train_time:113348ms step_avg:88.62ms +step:1280/1670 train_time:113437ms step_avg:88.62ms +step:1281/1670 train_time:113527ms step_avg:88.62ms +step:1282/1670 train_time:113617ms step_avg:88.63ms +step:1283/1670 train_time:113707ms step_avg:88.63ms +step:1284/1670 train_time:113797ms step_avg:88.63ms +step:1285/1670 train_time:113888ms step_avg:88.63ms +step:1286/1670 train_time:113978ms step_avg:88.63ms +step:1287/1670 train_time:114068ms step_avg:88.63ms +step:1288/1670 train_time:114158ms step_avg:88.63ms +step:1289/1670 train_time:114248ms step_avg:88.63ms +step:1290/1670 train_time:114337ms step_avg:88.63ms +step:1291/1670 train_time:114427ms step_avg:88.63ms +step:1292/1670 train_time:114516ms step_avg:88.63ms +step:1293/1670 train_time:114607ms step_avg:88.64ms +step:1294/1670 train_time:114695ms step_avg:88.64ms +step:1295/1670 train_time:114785ms step_avg:88.64ms +step:1296/1670 train_time:114875ms step_avg:88.64ms +step:1297/1670 train_time:114965ms step_avg:88.64ms +step:1298/1670 train_time:115054ms step_avg:88.64ms +step:1299/1670 train_time:115144ms step_avg:88.64ms +step:1300/1670 train_time:115234ms step_avg:88.64ms +step:1301/1670 train_time:115324ms step_avg:88.64ms +step:1302/1670 train_time:115413ms step_avg:88.64ms +step:1303/1670 train_time:115503ms step_avg:88.64ms +step:1304/1670 train_time:115593ms step_avg:88.64ms +step:1305/1670 train_time:115681ms step_avg:88.64ms +step:1306/1670 train_time:115771ms step_avg:88.65ms +step:1307/1670 train_time:115861ms step_avg:88.65ms +step:1308/1670 train_time:115951ms step_avg:88.65ms +step:1309/1670 train_time:116043ms step_avg:88.65ms +step:1310/1670 train_time:116133ms step_avg:88.65ms +step:1311/1670 train_time:116223ms step_avg:88.65ms +step:1312/1670 train_time:116313ms step_avg:88.65ms +step:1313/1670 train_time:116402ms step_avg:88.65ms +step:1314/1670 train_time:116492ms step_avg:88.65ms +step:1315/1670 train_time:116582ms step_avg:88.66ms +step:1316/1670 train_time:116671ms step_avg:88.66ms +step:1317/1670 train_time:116760ms step_avg:88.66ms +step:1318/1670 train_time:116850ms step_avg:88.66ms +step:1319/1670 train_time:116940ms step_avg:88.66ms +step:1320/1670 train_time:117032ms step_avg:88.66ms +step:1321/1670 train_time:117123ms step_avg:88.66ms +step:1322/1670 train_time:117213ms step_avg:88.66ms +step:1323/1670 train_time:117304ms step_avg:88.66ms +step:1324/1670 train_time:117392ms step_avg:88.67ms +step:1325/1670 train_time:117482ms step_avg:88.67ms +step:1326/1670 train_time:117572ms step_avg:88.67ms +step:1327/1670 train_time:117661ms step_avg:88.67ms +step:1328/1670 train_time:117751ms step_avg:88.67ms +step:1329/1670 train_time:117840ms step_avg:88.67ms +step:1330/1670 train_time:117930ms step_avg:88.67ms +step:1331/1670 train_time:118021ms step_avg:88.67ms +step:1332/1670 train_time:118112ms step_avg:88.67ms +step:1333/1670 train_time:118203ms step_avg:88.67ms +step:1334/1670 train_time:118292ms step_avg:88.67ms +step:1335/1670 train_time:118383ms step_avg:88.68ms +step:1336/1670 train_time:118472ms step_avg:88.68ms +step:1337/1670 train_time:118561ms step_avg:88.68ms +step:1338/1670 train_time:118651ms step_avg:88.68ms +step:1339/1670 train_time:118741ms step_avg:88.68ms +step:1340/1670 train_time:118830ms step_avg:88.68ms +step:1341/1670 train_time:118920ms step_avg:88.68ms +step:1342/1670 train_time:119011ms step_avg:88.68ms +step:1343/1670 train_time:119100ms step_avg:88.68ms +step:1344/1670 train_time:119191ms step_avg:88.68ms +step:1345/1670 train_time:119281ms step_avg:88.68ms +step:1346/1670 train_time:119370ms step_avg:88.68ms +step:1347/1670 train_time:119459ms step_avg:88.68ms +step:1348/1670 train_time:119548ms step_avg:88.69ms +step:1349/1670 train_time:119638ms step_avg:88.69ms +step:1350/1670 train_time:119729ms step_avg:88.69ms +step:1351/1670 train_time:119819ms step_avg:88.69ms +step:1352/1670 train_time:119910ms step_avg:88.69ms +step:1353/1670 train_time:120000ms step_avg:88.69ms +step:1354/1670 train_time:120090ms step_avg:88.69ms +step:1355/1670 train_time:120181ms step_avg:88.69ms +step:1356/1670 train_time:120271ms step_avg:88.70ms +step:1357/1670 train_time:120360ms step_avg:88.70ms +step:1358/1670 train_time:120450ms step_avg:88.70ms +step:1359/1670 train_time:120539ms step_avg:88.70ms +step:1360/1670 train_time:120629ms step_avg:88.70ms +step:1361/1670 train_time:120718ms step_avg:88.70ms +step:1362/1670 train_time:120808ms step_avg:88.70ms +step:1363/1670 train_time:120898ms step_avg:88.70ms +step:1364/1670 train_time:120987ms step_avg:88.70ms +step:1365/1670 train_time:121077ms step_avg:88.70ms +step:1366/1670 train_time:121167ms step_avg:88.70ms +step:1367/1670 train_time:121256ms step_avg:88.70ms +step:1368/1670 train_time:121347ms step_avg:88.70ms +step:1369/1670 train_time:121436ms step_avg:88.70ms +step:1370/1670 train_time:121526ms step_avg:88.71ms +step:1371/1670 train_time:121615ms step_avg:88.71ms +step:1372/1670 train_time:121705ms step_avg:88.71ms +step:1373/1670 train_time:121794ms step_avg:88.71ms +step:1374/1670 train_time:121884ms step_avg:88.71ms +step:1375/1670 train_time:121973ms step_avg:88.71ms +step:1375/1670 val_loss:3.3415 train_time:122065ms step_avg:88.77ms +step:1376/1670 train_time:122085ms step_avg:88.72ms +step:1377/1670 train_time:122158ms step_avg:88.71ms +step:1378/1670 train_time:122252ms step_avg:88.72ms +step:1379/1670 train_time:122341ms step_avg:88.72ms +step:1380/1670 train_time:122430ms step_avg:88.72ms +step:1381/1670 train_time:122518ms step_avg:88.72ms +step:1382/1670 train_time:122606ms step_avg:88.72ms +step:1383/1670 train_time:122695ms step_avg:88.72ms +step:1384/1670 train_time:122783ms step_avg:88.72ms +step:1385/1670 train_time:122872ms step_avg:88.72ms +step:1386/1670 train_time:122961ms step_avg:88.72ms +step:1387/1670 train_time:123053ms step_avg:88.72ms +step:1388/1670 train_time:123145ms step_avg:88.72ms +step:1389/1670 train_time:123237ms step_avg:88.72ms +step:1390/1670 train_time:123328ms step_avg:88.73ms +step:1391/1670 train_time:123417ms step_avg:88.73ms +step:1392/1670 train_time:123506ms step_avg:88.73ms +step:1393/1670 train_time:123595ms step_avg:88.73ms +step:1394/1670 train_time:123684ms step_avg:88.73ms +step:1395/1670 train_time:123774ms step_avg:88.73ms +step:1396/1670 train_time:123863ms step_avg:88.73ms +step:1397/1670 train_time:123954ms step_avg:88.73ms +step:1398/1670 train_time:124044ms step_avg:88.73ms +step:1399/1670 train_time:124136ms step_avg:88.73ms +step:1400/1670 train_time:124228ms step_avg:88.73ms +step:1401/1670 train_time:124318ms step_avg:88.74ms +step:1402/1670 train_time:124408ms step_avg:88.74ms +step:1403/1670 train_time:124497ms step_avg:88.74ms +step:1404/1670 train_time:124586ms step_avg:88.74ms +step:1405/1670 train_time:124675ms step_avg:88.74ms +step:1406/1670 train_time:124764ms step_avg:88.74ms +step:1407/1670 train_time:124854ms step_avg:88.74ms +step:1408/1670 train_time:124944ms step_avg:88.74ms +step:1409/1670 train_time:125035ms step_avg:88.74ms +step:1410/1670 train_time:125127ms step_avg:88.74ms +step:1411/1670 train_time:125218ms step_avg:88.74ms +step:1412/1670 train_time:125308ms step_avg:88.75ms +step:1413/1670 train_time:125399ms step_avg:88.75ms +step:1414/1670 train_time:125488ms step_avg:88.75ms +step:1415/1670 train_time:125576ms step_avg:88.75ms +step:1416/1670 train_time:125666ms step_avg:88.75ms +step:1417/1670 train_time:125757ms step_avg:88.75ms +step:1418/1670 train_time:125846ms step_avg:88.75ms +step:1419/1670 train_time:125936ms step_avg:88.75ms +step:1420/1670 train_time:126026ms step_avg:88.75ms +step:1421/1670 train_time:126117ms step_avg:88.75ms +step:1422/1670 train_time:126207ms step_avg:88.75ms +step:1423/1670 train_time:126297ms step_avg:88.75ms +step:1424/1670 train_time:126387ms step_avg:88.75ms +step:1425/1670 train_time:126477ms step_avg:88.76ms +step:1426/1670 train_time:126567ms step_avg:88.76ms +step:1427/1670 train_time:126656ms step_avg:88.76ms +step:1428/1670 train_time:126746ms step_avg:88.76ms +step:1429/1670 train_time:126836ms step_avg:88.76ms +step:1430/1670 train_time:126926ms step_avg:88.76ms +step:1431/1670 train_time:127016ms step_avg:88.76ms +step:1432/1670 train_time:127107ms step_avg:88.76ms +step:1433/1670 train_time:127197ms step_avg:88.76ms +step:1434/1670 train_time:127286ms step_avg:88.76ms +step:1435/1670 train_time:127376ms step_avg:88.76ms +step:1436/1670 train_time:127466ms step_avg:88.76ms +step:1437/1670 train_time:127556ms step_avg:88.77ms +step:1438/1670 train_time:127646ms step_avg:88.77ms +step:1439/1670 train_time:127736ms step_avg:88.77ms +step:1440/1670 train_time:127825ms step_avg:88.77ms +step:1441/1670 train_time:127915ms step_avg:88.77ms +step:1442/1670 train_time:128005ms step_avg:88.77ms +step:1443/1670 train_time:128095ms step_avg:88.77ms +step:1444/1670 train_time:128184ms step_avg:88.77ms +step:1445/1670 train_time:128275ms step_avg:88.77ms +step:1446/1670 train_time:128365ms step_avg:88.77ms +step:1447/1670 train_time:128455ms step_avg:88.77ms +step:1448/1670 train_time:128545ms step_avg:88.77ms +step:1449/1670 train_time:128636ms step_avg:88.78ms +step:1450/1670 train_time:128725ms step_avg:88.78ms +step:1451/1670 train_time:128814ms step_avg:88.78ms +step:1452/1670 train_time:128903ms step_avg:88.78ms +step:1453/1670 train_time:128993ms step_avg:88.78ms +step:1454/1670 train_time:129082ms step_avg:88.78ms +step:1455/1670 train_time:129173ms step_avg:88.78ms +step:1456/1670 train_time:129263ms step_avg:88.78ms +step:1457/1670 train_time:129354ms step_avg:88.78ms +step:1458/1670 train_time:129443ms step_avg:88.78ms +step:1459/1670 train_time:129534ms step_avg:88.78ms +step:1460/1670 train_time:129625ms step_avg:88.78ms +step:1461/1670 train_time:129714ms step_avg:88.78ms +step:1462/1670 train_time:129803ms step_avg:88.78ms +step:1463/1670 train_time:129892ms step_avg:88.79ms +step:1464/1670 train_time:129982ms step_avg:88.79ms +step:1465/1670 train_time:130072ms step_avg:88.79ms +step:1466/1670 train_time:130161ms step_avg:88.79ms +step:1467/1670 train_time:130250ms step_avg:88.79ms +step:1468/1670 train_time:130340ms step_avg:88.79ms +step:1469/1670 train_time:130431ms step_avg:88.79ms +step:1470/1670 train_time:130521ms step_avg:88.79ms +step:1471/1670 train_time:130611ms step_avg:88.79ms +step:1472/1670 train_time:130700ms step_avg:88.79ms +step:1473/1670 train_time:130789ms step_avg:88.79ms +step:1474/1670 train_time:130879ms step_avg:88.79ms +step:1475/1670 train_time:130969ms step_avg:88.79ms +step:1476/1670 train_time:131058ms step_avg:88.79ms +step:1477/1670 train_time:131148ms step_avg:88.79ms +step:1478/1670 train_time:131238ms step_avg:88.79ms +step:1479/1670 train_time:131328ms step_avg:88.80ms +step:1480/1670 train_time:131418ms step_avg:88.80ms +step:1481/1670 train_time:131508ms step_avg:88.80ms +step:1482/1670 train_time:131598ms step_avg:88.80ms +step:1483/1670 train_time:131688ms step_avg:88.80ms +step:1484/1670 train_time:131778ms step_avg:88.80ms +step:1485/1670 train_time:131868ms step_avg:88.80ms +step:1486/1670 train_time:131958ms step_avg:88.80ms +step:1487/1670 train_time:132047ms step_avg:88.80ms +step:1488/1670 train_time:132138ms step_avg:88.80ms +step:1489/1670 train_time:132228ms step_avg:88.80ms +step:1490/1670 train_time:132318ms step_avg:88.80ms +step:1491/1670 train_time:132408ms step_avg:88.80ms +step:1492/1670 train_time:132497ms step_avg:88.81ms +step:1493/1670 train_time:132587ms step_avg:88.81ms +step:1494/1670 train_time:132677ms step_avg:88.81ms +step:1495/1670 train_time:132767ms step_avg:88.81ms +step:1496/1670 train_time:132857ms step_avg:88.81ms +step:1497/1670 train_time:132947ms step_avg:88.81ms +step:1498/1670 train_time:133037ms step_avg:88.81ms +step:1499/1670 train_time:133127ms step_avg:88.81ms +step:1500/1670 train_time:133216ms step_avg:88.81ms +step:1500/1670 val_loss:3.3119 train_time:133308ms step_avg:88.87ms +step:1501/1670 train_time:133327ms step_avg:88.83ms +step:1502/1670 train_time:133402ms step_avg:88.82ms +step:1503/1670 train_time:133499ms step_avg:88.82ms +step:1504/1670 train_time:133589ms step_avg:88.82ms +step:1505/1670 train_time:133679ms step_avg:88.82ms +step:1506/1670 train_time:133768ms step_avg:88.82ms +step:1507/1670 train_time:133855ms step_avg:88.82ms +step:1508/1670 train_time:133945ms step_avg:88.82ms +step:1509/1670 train_time:134033ms step_avg:88.82ms +step:1510/1670 train_time:134123ms step_avg:88.82ms +step:1511/1670 train_time:134212ms step_avg:88.82ms +step:1512/1670 train_time:134303ms step_avg:88.82ms +step:1513/1670 train_time:134394ms step_avg:88.83ms +step:1514/1670 train_time:134486ms step_avg:88.83ms +step:1515/1670 train_time:134578ms step_avg:88.83ms +step:1516/1670 train_time:134668ms step_avg:88.83ms +step:1517/1670 train_time:134758ms step_avg:88.83ms +step:1518/1670 train_time:134847ms step_avg:88.83ms +step:1519/1670 train_time:134936ms step_avg:88.83ms +step:1520/1670 train_time:135025ms step_avg:88.83ms +step:1521/1670 train_time:135113ms step_avg:88.83ms +step:1522/1670 train_time:135204ms step_avg:88.83ms +step:1523/1670 train_time:135294ms step_avg:88.83ms +step:1524/1670 train_time:135385ms step_avg:88.84ms +step:1525/1670 train_time:135474ms step_avg:88.84ms +step:1526/1670 train_time:135565ms step_avg:88.84ms +step:1527/1670 train_time:135655ms step_avg:88.84ms +step:1528/1670 train_time:135745ms step_avg:88.84ms +step:1529/1670 train_time:135834ms step_avg:88.84ms +step:1530/1670 train_time:135924ms step_avg:88.84ms +step:1531/1670 train_time:136012ms step_avg:88.84ms +step:1532/1670 train_time:136102ms step_avg:88.84ms +step:1533/1670 train_time:136191ms step_avg:88.84ms +step:1534/1670 train_time:136282ms step_avg:88.84ms +step:1535/1670 train_time:136372ms step_avg:88.84ms +step:1536/1670 train_time:136463ms step_avg:88.84ms +step:1537/1670 train_time:136553ms step_avg:88.84ms +step:1538/1670 train_time:136643ms step_avg:88.84ms +step:1539/1670 train_time:136732ms step_avg:88.84ms +step:1540/1670 train_time:136822ms step_avg:88.85ms +step:1541/1670 train_time:136911ms step_avg:88.85ms +step:1542/1670 train_time:137002ms step_avg:88.85ms +step:1543/1670 train_time:137090ms step_avg:88.85ms +step:1544/1670 train_time:137180ms step_avg:88.85ms +step:1545/1670 train_time:137271ms step_avg:88.85ms +step:1546/1670 train_time:137362ms step_avg:88.85ms +step:1547/1670 train_time:137451ms step_avg:88.85ms +step:1548/1670 train_time:137541ms step_avg:88.85ms +step:1549/1670 train_time:137630ms step_avg:88.85ms +step:1550/1670 train_time:137720ms step_avg:88.85ms +step:1551/1670 train_time:137810ms step_avg:88.85ms +step:1552/1670 train_time:137900ms step_avg:88.85ms +step:1553/1670 train_time:137990ms step_avg:88.85ms +step:1554/1670 train_time:138080ms step_avg:88.85ms +step:1555/1670 train_time:138170ms step_avg:88.86ms +step:1556/1670 train_time:138261ms step_avg:88.86ms +step:1557/1670 train_time:138351ms step_avg:88.86ms +step:1558/1670 train_time:138441ms step_avg:88.86ms +step:1559/1670 train_time:138530ms step_avg:88.86ms +step:1560/1670 train_time:138620ms step_avg:88.86ms +step:1561/1670 train_time:138710ms step_avg:88.86ms +step:1562/1670 train_time:138800ms step_avg:88.86ms +step:1563/1670 train_time:138889ms step_avg:88.86ms +step:1564/1670 train_time:138979ms step_avg:88.86ms +step:1565/1670 train_time:139069ms step_avg:88.86ms +step:1566/1670 train_time:139160ms step_avg:88.86ms +step:1567/1670 train_time:139251ms step_avg:88.86ms +step:1568/1670 train_time:139342ms step_avg:88.87ms +step:1569/1670 train_time:139431ms step_avg:88.87ms +step:1570/1670 train_time:139521ms step_avg:88.87ms +step:1571/1670 train_time:139611ms step_avg:88.87ms +step:1572/1670 train_time:139701ms step_avg:88.87ms +step:1573/1670 train_time:139791ms step_avg:88.87ms +step:1574/1670 train_time:139881ms step_avg:88.87ms +step:1575/1670 train_time:139971ms step_avg:88.87ms +step:1576/1670 train_time:140061ms step_avg:88.87ms +step:1577/1670 train_time:140150ms step_avg:88.87ms +step:1578/1670 train_time:140240ms step_avg:88.87ms +step:1579/1670 train_time:140330ms step_avg:88.87ms +step:1580/1670 train_time:140419ms step_avg:88.87ms +step:1581/1670 train_time:140510ms step_avg:88.87ms +step:1582/1670 train_time:140600ms step_avg:88.87ms +step:1583/1670 train_time:140689ms step_avg:88.88ms +step:1584/1670 train_time:140781ms step_avg:88.88ms +step:1585/1670 train_time:140871ms step_avg:88.88ms +step:1586/1670 train_time:140961ms step_avg:88.88ms +step:1587/1670 train_time:141049ms step_avg:88.88ms +step:1588/1670 train_time:141139ms step_avg:88.88ms +step:1589/1670 train_time:141229ms step_avg:88.88ms +step:1590/1670 train_time:141320ms step_avg:88.88ms +step:1591/1670 train_time:141411ms step_avg:88.88ms +step:1592/1670 train_time:141500ms step_avg:88.88ms +step:1593/1670 train_time:141590ms step_avg:88.88ms +step:1594/1670 train_time:141680ms step_avg:88.88ms +step:1595/1670 train_time:141771ms step_avg:88.88ms +step:1596/1670 train_time:141860ms step_avg:88.88ms +step:1597/1670 train_time:141950ms step_avg:88.89ms +step:1598/1670 train_time:142039ms step_avg:88.89ms +step:1599/1670 train_time:142129ms step_avg:88.89ms +step:1600/1670 train_time:142219ms step_avg:88.89ms +step:1601/1670 train_time:142311ms step_avg:88.89ms +step:1602/1670 train_time:142401ms step_avg:88.89ms +step:1603/1670 train_time:142490ms step_avg:88.89ms +step:1604/1670 train_time:142579ms step_avg:88.89ms +step:1605/1670 train_time:142669ms step_avg:88.89ms +step:1606/1670 train_time:142759ms step_avg:88.89ms +step:1607/1670 train_time:142848ms step_avg:88.89ms +step:1608/1670 train_time:142938ms step_avg:88.89ms +step:1609/1670 train_time:143028ms step_avg:88.89ms +step:1610/1670 train_time:143117ms step_avg:88.89ms +step:1611/1670 train_time:143208ms step_avg:88.89ms +step:1612/1670 train_time:143298ms step_avg:88.89ms +step:1613/1670 train_time:143389ms step_avg:88.90ms +step:1614/1670 train_time:143479ms step_avg:88.90ms +step:1615/1670 train_time:143570ms step_avg:88.90ms +step:1616/1670 train_time:143661ms step_avg:88.90ms +step:1617/1670 train_time:143751ms step_avg:88.90ms +step:1618/1670 train_time:143841ms step_avg:88.90ms +step:1619/1670 train_time:143931ms step_avg:88.90ms +step:1620/1670 train_time:144020ms step_avg:88.90ms +step:1621/1670 train_time:144110ms step_avg:88.90ms +step:1622/1670 train_time:144200ms step_avg:88.90ms +step:1623/1670 train_time:144290ms step_avg:88.90ms +step:1624/1670 train_time:144381ms step_avg:88.90ms +step:1625/1670 train_time:144471ms step_avg:88.91ms +step:1625/1670 val_loss:3.2887 train_time:144562ms step_avg:88.96ms +step:1626/1670 train_time:144583ms step_avg:88.92ms +step:1627/1670 train_time:144655ms step_avg:88.91ms +step:1628/1670 train_time:144748ms step_avg:88.91ms +step:1629/1670 train_time:144839ms step_avg:88.91ms +step:1630/1670 train_time:144929ms step_avg:88.91ms +step:1631/1670 train_time:145018ms step_avg:88.91ms +step:1632/1670 train_time:145106ms step_avg:88.91ms +step:1633/1670 train_time:145195ms step_avg:88.91ms +step:1634/1670 train_time:145284ms step_avg:88.91ms +step:1635/1670 train_time:145373ms step_avg:88.91ms +step:1636/1670 train_time:145463ms step_avg:88.91ms +step:1637/1670 train_time:145554ms step_avg:88.91ms +step:1638/1670 train_time:145647ms step_avg:88.92ms +step:1639/1670 train_time:145738ms step_avg:88.92ms +step:1640/1670 train_time:145829ms step_avg:88.92ms +step:1641/1670 train_time:145918ms step_avg:88.92ms +step:1642/1670 train_time:146007ms step_avg:88.92ms +step:1643/1670 train_time:146097ms step_avg:88.92ms +step:1644/1670 train_time:146186ms step_avg:88.92ms +step:1645/1670 train_time:146275ms step_avg:88.92ms +step:1646/1670 train_time:146364ms step_avg:88.92ms +step:1647/1670 train_time:146453ms step_avg:88.92ms +step:1648/1670 train_time:146543ms step_avg:88.92ms +step:1649/1670 train_time:146634ms step_avg:88.92ms +step:1650/1670 train_time:146726ms step_avg:88.92ms +step:1651/1670 train_time:146817ms step_avg:88.93ms +step:1652/1670 train_time:146908ms step_avg:88.93ms +step:1653/1670 train_time:146998ms step_avg:88.93ms +step:1654/1670 train_time:147088ms step_avg:88.93ms +step:1655/1670 train_time:147176ms step_avg:88.93ms +step:1656/1670 train_time:147265ms step_avg:88.93ms +step:1657/1670 train_time:147354ms step_avg:88.93ms +step:1658/1670 train_time:147443ms step_avg:88.93ms +step:1659/1670 train_time:147533ms step_avg:88.93ms +step:1660/1670 train_time:147623ms step_avg:88.93ms +step:1661/1670 train_time:147715ms step_avg:88.93ms +step:1662/1670 train_time:147806ms step_avg:88.93ms +step:1663/1670 train_time:147898ms step_avg:88.93ms +step:1664/1670 train_time:147988ms step_avg:88.94ms +step:1665/1670 train_time:148078ms step_avg:88.94ms +step:1666/1670 train_time:148167ms step_avg:88.94ms +step:1667/1670 train_time:148256ms step_avg:88.94ms +step:1668/1670 train_time:148345ms step_avg:88.94ms +step:1669/1670 train_time:148435ms step_avg:88.94ms +step:1670/1670 train_time:148525ms step_avg:88.94ms +step:1670/1670 val_loss:3.2789 train_time:148617ms step_avg:88.99ms +peak memory allocated: 30760 MiB reserved: 45574 MiB diff --git a/records/092925_PolarExpress/16ae9716-24a6-4b5f-ad2e-ce0986903334.txt b/records/092925_PolarExpress/16ae9716-24a6-4b5f-ad2e-ce0986903334.txt new file mode 100644 index 000000000..bd8cec213 --- /dev/null +++ b/records/092925_PolarExpress/16ae9716-24a6-4b5f-ad2e-ce0986903334.txt @@ -0,0 +1,3252 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from kernels import get_kernel +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[ + Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size() == 8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1, 10, 16, 16] + assert len(params) == sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx + size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx < len(params) else 0 + if getattr(params[module_idx], 'module', 'none') == 'attn': + for p in params[module_idx:module_idx + chunk_size]: + assert getattr(params[module_idx], 'module', 'none') == 'attn' + batch = 4 * original_shape[0] + d1 = original_shape[1] + d2 = original_shape[2] // 4 + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = polar_express(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = polar_express(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim // 4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim // 4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim * 4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module = 'attn' + with torch.no_grad(): + self.qkvo_w.view(4, self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4, self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, + 3 * self.num_heads, + self.head_dim).chunk( + 3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module = 'mlp' + self.c_proj.module = 'mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, + max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim ** 0.5) / 448, w_s=2 ** -9, + grad_s=1 / 448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), # extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, + long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1, len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i < 11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens = tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter == 5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter += 1 + return starts, ends + + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, + align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1630 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if + p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999, step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations + args.iteration_extension: + return args.ws_validate // 2, args.ws_validate + x = min(step / (1 + args.num_iterations), 0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[ + step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long < ws_long: + model.yarn.reset() + ws_long = new_ws_long + model(inputs, targets, cum_seqlens, ws_long // 2, ws_long).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.yarn.reset() +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations + args.iteration_extension +ws_short, ws_long = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws_short, new_ws_long = get_ws(step) + if new_ws_long != ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, + grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.5.0 +Mon Sep 29 06:35:14 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 40C P0 130W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 129W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:142ms step_avg:141.53ms +step:2/1670 train_time:163ms step_avg:81.29ms +step:3/1670 train_time:226ms step_avg:75.35ms +step:4/1670 train_time:311ms step_avg:77.87ms +step:5/1670 train_time:398ms step_avg:79.55ms +step:6/1670 train_time:484ms step_avg:80.73ms +step:7/1670 train_time:571ms step_avg:81.52ms +step:8/1670 train_time:658ms step_avg:82.20ms +step:9/1670 train_time:744ms step_avg:82.70ms +step:10/1670 train_time:831ms step_avg:83.09ms +step:11/1670 train_time:918ms step_avg:83.41ms +step:12/1670 train_time:1008ms step_avg:83.98ms +step:13/1670 train_time:1101ms step_avg:84.71ms +step:14/1670 train_time:1191ms step_avg:85.09ms +step:15/1670 train_time:1280ms step_avg:85.36ms +step:16/1670 train_time:1368ms step_avg:85.53ms +step:17/1670 train_time:1456ms step_avg:85.64ms +step:18/1670 train_time:1543ms step_avg:85.73ms +step:19/1670 train_time:1630ms step_avg:85.81ms +step:20/1670 train_time:1717ms step_avg:85.87ms +step:21/1670 train_time:1805ms step_avg:85.95ms +step:22/1670 train_time:1892ms step_avg:86.00ms +step:23/1670 train_time:1980ms step_avg:86.07ms +step:24/1670 train_time:2070ms step_avg:86.23ms +step:25/1670 train_time:2159ms step_avg:86.37ms +step:26/1670 train_time:2248ms step_avg:86.46ms +step:27/1670 train_time:2337ms step_avg:86.54ms +step:28/1670 train_time:2424ms step_avg:86.58ms +step:29/1670 train_time:2512ms step_avg:86.63ms +step:30/1670 train_time:2600ms step_avg:86.65ms +step:31/1670 train_time:2687ms step_avg:86.68ms +step:32/1670 train_time:2774ms step_avg:86.69ms +step:33/1670 train_time:2861ms step_avg:86.70ms +step:34/1670 train_time:2949ms step_avg:86.72ms +step:35/1670 train_time:3037ms step_avg:86.78ms +step:36/1670 train_time:3126ms step_avg:86.84ms +step:37/1670 train_time:3215ms step_avg:86.90ms +step:38/1670 train_time:3304ms step_avg:86.94ms +step:39/1670 train_time:3391ms step_avg:86.95ms +step:40/1670 train_time:3479ms step_avg:86.97ms +step:41/1670 train_time:3567ms step_avg:86.99ms +step:42/1670 train_time:3655ms step_avg:87.01ms +step:43/1670 train_time:3742ms step_avg:87.03ms +step:44/1670 train_time:3831ms step_avg:87.06ms +step:45/1670 train_time:3918ms step_avg:87.07ms +step:46/1670 train_time:4006ms step_avg:87.09ms +step:47/1670 train_time:4095ms step_avg:87.13ms +step:48/1670 train_time:4184ms step_avg:87.16ms +step:49/1670 train_time:4273ms step_avg:87.20ms +step:50/1670 train_time:4360ms step_avg:87.21ms +step:51/1670 train_time:4448ms step_avg:87.22ms +step:52/1670 train_time:4536ms step_avg:87.23ms +step:53/1670 train_time:4623ms step_avg:87.23ms +step:54/1670 train_time:4711ms step_avg:87.24ms +step:55/1670 train_time:4798ms step_avg:87.24ms +step:56/1670 train_time:4886ms step_avg:87.25ms +step:57/1670 train_time:4974ms step_avg:87.26ms +step:58/1670 train_time:5062ms step_avg:87.27ms +step:59/1670 train_time:5150ms step_avg:87.29ms +step:60/1670 train_time:5239ms step_avg:87.32ms +step:61/1670 train_time:5327ms step_avg:87.33ms +step:62/1670 train_time:5415ms step_avg:87.34ms +step:63/1670 train_time:5503ms step_avg:87.35ms +step:64/1670 train_time:5591ms step_avg:87.36ms +step:65/1670 train_time:5679ms step_avg:87.37ms +step:66/1670 train_time:5766ms step_avg:87.37ms +step:67/1670 train_time:5854ms step_avg:87.37ms +step:68/1670 train_time:5941ms step_avg:87.36ms +step:69/1670 train_time:6029ms step_avg:87.37ms +step:70/1670 train_time:6117ms step_avg:87.39ms +step:71/1670 train_time:6205ms step_avg:87.40ms +step:72/1670 train_time:6295ms step_avg:87.43ms +step:73/1670 train_time:6382ms step_avg:87.42ms +step:74/1670 train_time:6470ms step_avg:87.44ms +step:75/1670 train_time:6558ms step_avg:87.44ms +step:76/1670 train_time:6646ms step_avg:87.45ms +step:77/1670 train_time:6734ms step_avg:87.45ms +step:78/1670 train_time:6821ms step_avg:87.45ms +step:79/1670 train_time:6909ms step_avg:87.46ms +step:80/1670 train_time:6997ms step_avg:87.47ms +step:81/1670 train_time:7084ms step_avg:87.46ms +step:82/1670 train_time:7173ms step_avg:87.47ms +step:83/1670 train_time:7261ms step_avg:87.48ms +step:84/1670 train_time:7348ms step_avg:87.48ms +step:85/1670 train_time:7437ms step_avg:87.49ms +step:86/1670 train_time:7525ms step_avg:87.50ms +step:87/1670 train_time:7612ms step_avg:87.50ms +step:88/1670 train_time:7700ms step_avg:87.50ms +step:89/1670 train_time:7787ms step_avg:87.49ms +step:90/1670 train_time:7875ms step_avg:87.50ms +step:91/1670 train_time:7962ms step_avg:87.50ms +step:92/1670 train_time:8051ms step_avg:87.51ms +step:93/1670 train_time:8139ms step_avg:87.51ms +step:94/1670 train_time:8227ms step_avg:87.52ms +step:95/1670 train_time:8315ms step_avg:87.53ms +step:96/1670 train_time:8403ms step_avg:87.53ms +step:97/1670 train_time:8491ms step_avg:87.53ms +step:98/1670 train_time:8579ms step_avg:87.54ms +step:99/1670 train_time:8667ms step_avg:87.54ms +step:100/1670 train_time:8754ms step_avg:87.54ms +step:101/1670 train_time:8842ms step_avg:87.54ms +step:102/1670 train_time:8930ms step_avg:87.55ms +step:103/1670 train_time:9018ms step_avg:87.56ms +step:104/1670 train_time:9106ms step_avg:87.56ms +step:105/1670 train_time:9195ms step_avg:87.57ms +step:106/1670 train_time:9282ms step_avg:87.56ms +step:107/1670 train_time:9370ms step_avg:87.57ms +step:108/1670 train_time:9458ms step_avg:87.57ms +step:109/1670 train_time:9545ms step_avg:87.57ms +step:110/1670 train_time:9633ms step_avg:87.57ms +step:111/1670 train_time:9720ms step_avg:87.57ms +step:112/1670 train_time:9808ms step_avg:87.57ms +step:113/1670 train_time:9896ms step_avg:87.57ms +step:114/1670 train_time:9983ms step_avg:87.57ms +step:115/1670 train_time:10071ms step_avg:87.57ms +step:116/1670 train_time:10159ms step_avg:87.58ms +step:117/1670 train_time:10246ms step_avg:87.58ms +step:118/1670 train_time:10335ms step_avg:87.59ms +step:119/1670 train_time:10423ms step_avg:87.58ms +step:120/1670 train_time:10510ms step_avg:87.59ms +step:121/1670 train_time:10598ms step_avg:87.59ms +step:122/1670 train_time:10686ms step_avg:87.59ms +step:123/1670 train_time:10773ms step_avg:87.58ms +step:124/1670 train_time:10860ms step_avg:87.58ms +step:125/1670 train_time:10947ms step_avg:87.58ms +step:125/1670 val_loss:4.3607 train_time:11037ms step_avg:88.30ms +step:126/1670 train_time:11057ms step_avg:87.75ms +step:127/1670 train_time:11128ms step_avg:87.62ms +step:128/1670 train_time:11224ms step_avg:87.69ms +step:129/1670 train_time:11315ms step_avg:87.71ms +step:130/1670 train_time:11402ms step_avg:87.71ms +step:131/1670 train_time:11489ms step_avg:87.70ms +step:132/1670 train_time:11576ms step_avg:87.70ms +step:133/1670 train_time:11662ms step_avg:87.69ms +step:134/1670 train_time:11749ms step_avg:87.68ms +step:135/1670 train_time:11835ms step_avg:87.67ms +step:136/1670 train_time:11922ms step_avg:87.66ms +step:137/1670 train_time:12009ms step_avg:87.66ms +step:138/1670 train_time:12099ms step_avg:87.67ms +step:139/1670 train_time:12189ms step_avg:87.69ms +step:140/1670 train_time:12279ms step_avg:87.71ms +step:141/1670 train_time:12367ms step_avg:87.71ms +step:142/1670 train_time:12454ms step_avg:87.71ms +step:143/1670 train_time:12541ms step_avg:87.70ms +step:144/1670 train_time:12629ms step_avg:87.70ms +step:145/1670 train_time:12716ms step_avg:87.70ms +step:146/1670 train_time:12802ms step_avg:87.69ms +step:147/1670 train_time:12890ms step_avg:87.68ms +step:148/1670 train_time:12977ms step_avg:87.68ms +step:149/1670 train_time:13065ms step_avg:87.69ms +step:150/1670 train_time:13156ms step_avg:87.70ms +step:151/1670 train_time:13245ms step_avg:87.72ms +step:152/1670 train_time:13333ms step_avg:87.72ms +step:153/1670 train_time:13420ms step_avg:87.72ms +step:154/1670 train_time:13509ms step_avg:87.72ms +step:155/1670 train_time:13596ms step_avg:87.71ms +step:156/1670 train_time:13683ms step_avg:87.71ms +step:157/1670 train_time:13770ms step_avg:87.71ms +step:158/1670 train_time:13857ms step_avg:87.70ms +step:159/1670 train_time:13944ms step_avg:87.70ms +step:160/1670 train_time:14032ms step_avg:87.70ms +step:161/1670 train_time:14120ms step_avg:87.70ms +step:162/1670 train_time:14209ms step_avg:87.71ms +step:163/1670 train_time:14297ms step_avg:87.71ms +step:164/1670 train_time:14385ms step_avg:87.71ms +step:165/1670 train_time:14473ms step_avg:87.72ms +step:166/1670 train_time:14560ms step_avg:87.71ms +step:167/1670 train_time:14647ms step_avg:87.71ms +step:168/1670 train_time:14735ms step_avg:87.71ms +step:169/1670 train_time:14823ms step_avg:87.71ms +step:170/1670 train_time:14911ms step_avg:87.71ms +step:171/1670 train_time:14998ms step_avg:87.71ms +step:172/1670 train_time:15086ms step_avg:87.71ms +step:173/1670 train_time:15174ms step_avg:87.71ms +step:174/1670 train_time:15262ms step_avg:87.71ms +step:175/1670 train_time:15351ms step_avg:87.72ms +step:176/1670 train_time:15438ms step_avg:87.72ms +step:177/1670 train_time:15527ms step_avg:87.72ms +step:178/1670 train_time:15614ms step_avg:87.72ms +step:179/1670 train_time:15702ms step_avg:87.72ms +step:180/1670 train_time:15789ms step_avg:87.72ms +step:181/1670 train_time:15876ms step_avg:87.72ms +step:182/1670 train_time:15964ms step_avg:87.71ms +step:183/1670 train_time:16051ms step_avg:87.71ms +step:184/1670 train_time:16138ms step_avg:87.71ms +step:185/1670 train_time:16226ms step_avg:87.71ms +step:186/1670 train_time:16314ms step_avg:87.71ms +step:187/1670 train_time:16401ms step_avg:87.71ms +step:188/1670 train_time:16489ms step_avg:87.71ms +step:189/1670 train_time:16577ms step_avg:87.71ms +step:190/1670 train_time:16665ms step_avg:87.71ms +step:191/1670 train_time:16752ms step_avg:87.71ms +step:192/1670 train_time:16839ms step_avg:87.70ms +step:193/1670 train_time:16927ms step_avg:87.70ms +step:194/1670 train_time:17014ms step_avg:87.70ms +step:195/1670 train_time:17102ms step_avg:87.70ms +step:196/1670 train_time:17191ms step_avg:87.71ms +step:197/1670 train_time:17278ms step_avg:87.71ms +step:198/1670 train_time:17366ms step_avg:87.71ms +step:199/1670 train_time:17454ms step_avg:87.71ms +step:200/1670 train_time:17541ms step_avg:87.71ms +step:201/1670 train_time:17628ms step_avg:87.70ms +step:202/1670 train_time:17716ms step_avg:87.70ms +step:203/1670 train_time:17803ms step_avg:87.70ms +step:204/1670 train_time:17891ms step_avg:87.70ms +step:205/1670 train_time:17977ms step_avg:87.70ms +step:206/1670 train_time:18065ms step_avg:87.70ms +step:207/1670 train_time:18154ms step_avg:87.70ms +step:208/1670 train_time:18242ms step_avg:87.70ms +step:209/1670 train_time:18330ms step_avg:87.70ms +step:210/1670 train_time:18417ms step_avg:87.70ms +step:211/1670 train_time:18505ms step_avg:87.70ms +step:212/1670 train_time:18592ms step_avg:87.70ms +step:213/1670 train_time:18680ms step_avg:87.70ms +step:214/1670 train_time:18768ms step_avg:87.70ms +step:215/1670 train_time:18856ms step_avg:87.70ms +step:216/1670 train_time:18943ms step_avg:87.70ms +step:217/1670 train_time:19030ms step_avg:87.69ms +step:218/1670 train_time:19117ms step_avg:87.69ms +step:219/1670 train_time:19205ms step_avg:87.69ms +step:220/1670 train_time:19292ms step_avg:87.69ms +step:221/1670 train_time:19380ms step_avg:87.69ms +step:222/1670 train_time:19469ms step_avg:87.70ms +step:223/1670 train_time:19556ms step_avg:87.70ms +step:224/1670 train_time:19643ms step_avg:87.69ms +step:225/1670 train_time:19731ms step_avg:87.69ms +step:226/1670 train_time:19818ms step_avg:87.69ms +step:227/1670 train_time:19905ms step_avg:87.69ms +step:228/1670 train_time:19993ms step_avg:87.69ms +step:229/1670 train_time:20080ms step_avg:87.68ms +step:230/1670 train_time:20167ms step_avg:87.68ms +step:231/1670 train_time:20255ms step_avg:87.69ms +step:232/1670 train_time:20343ms step_avg:87.69ms +step:233/1670 train_time:20431ms step_avg:87.69ms +step:234/1670 train_time:20518ms step_avg:87.68ms +step:235/1670 train_time:20605ms step_avg:87.68ms +step:236/1670 train_time:20693ms step_avg:87.68ms +step:237/1670 train_time:20780ms step_avg:87.68ms +step:238/1670 train_time:20868ms step_avg:87.68ms +step:239/1670 train_time:20956ms step_avg:87.68ms +step:240/1670 train_time:21043ms step_avg:87.68ms +step:241/1670 train_time:21130ms step_avg:87.68ms +step:242/1670 train_time:21217ms step_avg:87.67ms +step:243/1670 train_time:21305ms step_avg:87.67ms +step:244/1670 train_time:21393ms step_avg:87.67ms +step:245/1670 train_time:21480ms step_avg:87.67ms +step:246/1670 train_time:21567ms step_avg:87.67ms +step:247/1670 train_time:21656ms step_avg:87.67ms +step:248/1670 train_time:21744ms step_avg:87.68ms +step:249/1670 train_time:21833ms step_avg:87.68ms +step:250/1670 train_time:21919ms step_avg:87.68ms +step:250/1670 val_loss:3.9859 train_time:22009ms step_avg:88.04ms +step:251/1670 train_time:22028ms step_avg:87.76ms +step:252/1670 train_time:22098ms step_avg:87.69ms +step:253/1670 train_time:22192ms step_avg:87.72ms +step:254/1670 train_time:22281ms step_avg:87.72ms +step:255/1670 train_time:22368ms step_avg:87.72ms +step:256/1670 train_time:22455ms step_avg:87.72ms +step:257/1670 train_time:22541ms step_avg:87.71ms +step:258/1670 train_time:22629ms step_avg:87.71ms +step:259/1670 train_time:22715ms step_avg:87.70ms +step:260/1670 train_time:22801ms step_avg:87.70ms +step:261/1670 train_time:22887ms step_avg:87.69ms +step:262/1670 train_time:22976ms step_avg:87.69ms +step:263/1670 train_time:23066ms step_avg:87.70ms +step:264/1670 train_time:23157ms step_avg:87.72ms +step:265/1670 train_time:23245ms step_avg:87.72ms +step:266/1670 train_time:23333ms step_avg:87.72ms +step:267/1670 train_time:23420ms step_avg:87.72ms +step:268/1670 train_time:23507ms step_avg:87.71ms +step:269/1670 train_time:23594ms step_avg:87.71ms +step:270/1670 train_time:23680ms step_avg:87.71ms +step:271/1670 train_time:23768ms step_avg:87.70ms +step:272/1670 train_time:23854ms step_avg:87.70ms +step:273/1670 train_time:23942ms step_avg:87.70ms +step:274/1670 train_time:24031ms step_avg:87.70ms +step:275/1670 train_time:24120ms step_avg:87.71ms +step:276/1670 train_time:24210ms step_avg:87.72ms +step:277/1670 train_time:24298ms step_avg:87.72ms +step:278/1670 train_time:24385ms step_avg:87.72ms +step:279/1670 train_time:24472ms step_avg:87.71ms +step:280/1670 train_time:24560ms step_avg:87.71ms +step:281/1670 train_time:24647ms step_avg:87.71ms +step:282/1670 train_time:24733ms step_avg:87.71ms +step:283/1670 train_time:24820ms step_avg:87.70ms +step:284/1670 train_time:24908ms step_avg:87.70ms +step:285/1670 train_time:24995ms step_avg:87.70ms +step:286/1670 train_time:25082ms step_avg:87.70ms +step:287/1670 train_time:25170ms step_avg:87.70ms +step:288/1670 train_time:25259ms step_avg:87.70ms +step:289/1670 train_time:25347ms step_avg:87.71ms +step:290/1670 train_time:25435ms step_avg:87.71ms +step:291/1670 train_time:25522ms step_avg:87.70ms +step:292/1670 train_time:25610ms step_avg:87.70ms +step:293/1670 train_time:25696ms step_avg:87.70ms +step:294/1670 train_time:25784ms step_avg:87.70ms +step:295/1670 train_time:25871ms step_avg:87.70ms +step:296/1670 train_time:25958ms step_avg:87.70ms +step:297/1670 train_time:26046ms step_avg:87.70ms +step:298/1670 train_time:26134ms step_avg:87.70ms +step:299/1670 train_time:26222ms step_avg:87.70ms +step:300/1670 train_time:26311ms step_avg:87.70ms +step:301/1670 train_time:26399ms step_avg:87.71ms +step:302/1670 train_time:26488ms step_avg:87.71ms +step:303/1670 train_time:26575ms step_avg:87.71ms +step:304/1670 train_time:26663ms step_avg:87.71ms +step:305/1670 train_time:26750ms step_avg:87.70ms +step:306/1670 train_time:26837ms step_avg:87.70ms +step:307/1670 train_time:26923ms step_avg:87.70ms +step:308/1670 train_time:27012ms step_avg:87.70ms +step:309/1670 train_time:27100ms step_avg:87.70ms +step:310/1670 train_time:27189ms step_avg:87.71ms +step:311/1670 train_time:27276ms step_avg:87.71ms +step:312/1670 train_time:27364ms step_avg:87.70ms +step:313/1670 train_time:27451ms step_avg:87.70ms +step:314/1670 train_time:27539ms step_avg:87.70ms +step:315/1670 train_time:27626ms step_avg:87.70ms +step:316/1670 train_time:27713ms step_avg:87.70ms +step:317/1670 train_time:27800ms step_avg:87.70ms +step:318/1670 train_time:27887ms step_avg:87.69ms +step:319/1670 train_time:27974ms step_avg:87.69ms +step:320/1670 train_time:28061ms step_avg:87.69ms +step:321/1670 train_time:28151ms step_avg:87.70ms +step:322/1670 train_time:28238ms step_avg:87.70ms +step:323/1670 train_time:28327ms step_avg:87.70ms +step:324/1670 train_time:28414ms step_avg:87.70ms +step:325/1670 train_time:28501ms step_avg:87.69ms +step:326/1670 train_time:28590ms step_avg:87.70ms +step:327/1670 train_time:28677ms step_avg:87.70ms +step:328/1670 train_time:28764ms step_avg:87.69ms +step:329/1670 train_time:28851ms step_avg:87.69ms +step:330/1670 train_time:28938ms step_avg:87.69ms +step:331/1670 train_time:29026ms step_avg:87.69ms +step:332/1670 train_time:29113ms step_avg:87.69ms +step:333/1670 train_time:29202ms step_avg:87.69ms +step:334/1670 train_time:29290ms step_avg:87.70ms +step:335/1670 train_time:29378ms step_avg:87.70ms +step:336/1670 train_time:29465ms step_avg:87.69ms +step:337/1670 train_time:29552ms step_avg:87.69ms +step:338/1670 train_time:29639ms step_avg:87.69ms +step:339/1670 train_time:29727ms step_avg:87.69ms +step:340/1670 train_time:29814ms step_avg:87.69ms +step:341/1670 train_time:29901ms step_avg:87.69ms +step:342/1670 train_time:29989ms step_avg:87.69ms +step:343/1670 train_time:30077ms step_avg:87.69ms +step:344/1670 train_time:30165ms step_avg:87.69ms +step:345/1670 train_time:30253ms step_avg:87.69ms +step:346/1670 train_time:30340ms step_avg:87.69ms +step:347/1670 train_time:30428ms step_avg:87.69ms +step:348/1670 train_time:30515ms step_avg:87.69ms +step:349/1670 train_time:30602ms step_avg:87.68ms +step:350/1670 train_time:30690ms step_avg:87.68ms +step:351/1670 train_time:30777ms step_avg:87.68ms +step:352/1670 train_time:30864ms step_avg:87.68ms +step:353/1670 train_time:30951ms step_avg:87.68ms +step:354/1670 train_time:31038ms step_avg:87.68ms +step:355/1670 train_time:31126ms step_avg:87.68ms +step:356/1670 train_time:31214ms step_avg:87.68ms +step:357/1670 train_time:31302ms step_avg:87.68ms +step:358/1670 train_time:31391ms step_avg:87.68ms +step:359/1670 train_time:31479ms step_avg:87.68ms +step:360/1670 train_time:31566ms step_avg:87.68ms +step:361/1670 train_time:31653ms step_avg:87.68ms +step:362/1670 train_time:31740ms step_avg:87.68ms +step:363/1670 train_time:31828ms step_avg:87.68ms +step:364/1670 train_time:31915ms step_avg:87.68ms +step:365/1670 train_time:32002ms step_avg:87.68ms +step:366/1670 train_time:32090ms step_avg:87.68ms +step:367/1670 train_time:32177ms step_avg:87.68ms +step:368/1670 train_time:32265ms step_avg:87.68ms +step:369/1670 train_time:32353ms step_avg:87.68ms +step:370/1670 train_time:32440ms step_avg:87.68ms +step:371/1670 train_time:32528ms step_avg:87.68ms +step:372/1670 train_time:32615ms step_avg:87.67ms +step:373/1670 train_time:32702ms step_avg:87.67ms +step:374/1670 train_time:32791ms step_avg:87.68ms +step:375/1670 train_time:32877ms step_avg:87.67ms +step:375/1670 val_loss:3.8308 train_time:32966ms step_avg:87.91ms +step:376/1670 train_time:32986ms step_avg:87.73ms +step:377/1670 train_time:33057ms step_avg:87.68ms +step:378/1670 train_time:33147ms step_avg:87.69ms +step:379/1670 train_time:33236ms step_avg:87.69ms +step:380/1670 train_time:33323ms step_avg:87.69ms +step:381/1670 train_time:33410ms step_avg:87.69ms +step:382/1670 train_time:33496ms step_avg:87.69ms +step:383/1670 train_time:33582ms step_avg:87.68ms +step:384/1670 train_time:33669ms step_avg:87.68ms +step:385/1670 train_time:33757ms step_avg:87.68ms +step:386/1670 train_time:33844ms step_avg:87.68ms +step:387/1670 train_time:33933ms step_avg:87.68ms +step:388/1670 train_time:34023ms step_avg:87.69ms +step:389/1670 train_time:34112ms step_avg:87.69ms +step:390/1670 train_time:34200ms step_avg:87.69ms +step:391/1670 train_time:34287ms step_avg:87.69ms +step:392/1670 train_time:34375ms step_avg:87.69ms +step:393/1670 train_time:34462ms step_avg:87.69ms +step:394/1670 train_time:34548ms step_avg:87.69ms +step:395/1670 train_time:34636ms step_avg:87.69ms +step:396/1670 train_time:34722ms step_avg:87.68ms +step:397/1670 train_time:34809ms step_avg:87.68ms +step:398/1670 train_time:34896ms step_avg:87.68ms +step:399/1670 train_time:34985ms step_avg:87.68ms +step:400/1670 train_time:35075ms step_avg:87.69ms +step:401/1670 train_time:35162ms step_avg:87.69ms +step:402/1670 train_time:35250ms step_avg:87.69ms +step:403/1670 train_time:35338ms step_avg:87.69ms +step:404/1670 train_time:35425ms step_avg:87.68ms +step:405/1670 train_time:35512ms step_avg:87.68ms +step:406/1670 train_time:35598ms step_avg:87.68ms +step:407/1670 train_time:35685ms step_avg:87.68ms +step:408/1670 train_time:35772ms step_avg:87.68ms +step:409/1670 train_time:35859ms step_avg:87.68ms +step:410/1670 train_time:35947ms step_avg:87.68ms +step:411/1670 train_time:36036ms step_avg:87.68ms +step:412/1670 train_time:36125ms step_avg:87.68ms +step:413/1670 train_time:36212ms step_avg:87.68ms +step:414/1670 train_time:36300ms step_avg:87.68ms +step:415/1670 train_time:36388ms step_avg:87.68ms +step:416/1670 train_time:36476ms step_avg:87.68ms +step:417/1670 train_time:36563ms step_avg:87.68ms +step:418/1670 train_time:36650ms step_avg:87.68ms +step:419/1670 train_time:36737ms step_avg:87.68ms +step:420/1670 train_time:36826ms step_avg:87.68ms +step:421/1670 train_time:36914ms step_avg:87.68ms +step:422/1670 train_time:37001ms step_avg:87.68ms +step:423/1670 train_time:37090ms step_avg:87.68ms +step:424/1670 train_time:37178ms step_avg:87.68ms +step:425/1670 train_time:37266ms step_avg:87.68ms +step:426/1670 train_time:37353ms step_avg:87.68ms +step:427/1670 train_time:37440ms step_avg:87.68ms +step:428/1670 train_time:37528ms step_avg:87.68ms +step:429/1670 train_time:37615ms step_avg:87.68ms +step:430/1670 train_time:37702ms step_avg:87.68ms +step:431/1670 train_time:37790ms step_avg:87.68ms +step:432/1670 train_time:37878ms step_avg:87.68ms +step:433/1670 train_time:37965ms step_avg:87.68ms +step:434/1670 train_time:38053ms step_avg:87.68ms +step:435/1670 train_time:38140ms step_avg:87.68ms +step:436/1670 train_time:38228ms step_avg:87.68ms +step:437/1670 train_time:38316ms step_avg:87.68ms +step:438/1670 train_time:38404ms step_avg:87.68ms +step:439/1670 train_time:38492ms step_avg:87.68ms +step:440/1670 train_time:38579ms step_avg:87.68ms +step:441/1670 train_time:38667ms step_avg:87.68ms +step:442/1670 train_time:38756ms step_avg:87.68ms +step:443/1670 train_time:38843ms step_avg:87.68ms +step:444/1670 train_time:38931ms step_avg:87.68ms +step:445/1670 train_time:39018ms step_avg:87.68ms +step:446/1670 train_time:39106ms step_avg:87.68ms +step:447/1670 train_time:39193ms step_avg:87.68ms +step:448/1670 train_time:39281ms step_avg:87.68ms +step:449/1670 train_time:39369ms step_avg:87.68ms +step:450/1670 train_time:39458ms step_avg:87.68ms +step:451/1670 train_time:39545ms step_avg:87.68ms +step:452/1670 train_time:39633ms step_avg:87.68ms +step:453/1670 train_time:39720ms step_avg:87.68ms +step:454/1670 train_time:39807ms step_avg:87.68ms +step:455/1670 train_time:39895ms step_avg:87.68ms +step:456/1670 train_time:39982ms step_avg:87.68ms +step:457/1670 train_time:40071ms step_avg:87.68ms +step:458/1670 train_time:40158ms step_avg:87.68ms +step:459/1670 train_time:40247ms step_avg:87.68ms +step:460/1670 train_time:40335ms step_avg:87.68ms +step:461/1670 train_time:40422ms step_avg:87.68ms +step:462/1670 train_time:40510ms step_avg:87.68ms +step:463/1670 train_time:40597ms step_avg:87.68ms +step:464/1670 train_time:40685ms step_avg:87.68ms +step:465/1670 train_time:40772ms step_avg:87.68ms +step:466/1670 train_time:40860ms step_avg:87.68ms +step:467/1670 train_time:40947ms step_avg:87.68ms +step:468/1670 train_time:41035ms step_avg:87.68ms +step:469/1670 train_time:41122ms step_avg:87.68ms +step:470/1670 train_time:41210ms step_avg:87.68ms +step:471/1670 train_time:41298ms step_avg:87.68ms +step:472/1670 train_time:41385ms step_avg:87.68ms +step:473/1670 train_time:41474ms step_avg:87.68ms +step:474/1670 train_time:41561ms step_avg:87.68ms +step:475/1670 train_time:41649ms step_avg:87.68ms +step:476/1670 train_time:41738ms step_avg:87.68ms +step:477/1670 train_time:41825ms step_avg:87.68ms +step:478/1670 train_time:41912ms step_avg:87.68ms +step:479/1670 train_time:42000ms step_avg:87.68ms +step:480/1670 train_time:42087ms step_avg:87.68ms +step:481/1670 train_time:42175ms step_avg:87.68ms +step:482/1670 train_time:42263ms step_avg:87.68ms +step:483/1670 train_time:42351ms step_avg:87.68ms +step:484/1670 train_time:42438ms step_avg:87.68ms +step:485/1670 train_time:42526ms step_avg:87.68ms +step:486/1670 train_time:42614ms step_avg:87.68ms +step:487/1670 train_time:42701ms step_avg:87.68ms +step:488/1670 train_time:42789ms step_avg:87.68ms +step:489/1670 train_time:42877ms step_avg:87.68ms +step:490/1670 train_time:42964ms step_avg:87.68ms +step:491/1670 train_time:43051ms step_avg:87.68ms +step:492/1670 train_time:43139ms step_avg:87.68ms +step:493/1670 train_time:43227ms step_avg:87.68ms +step:494/1670 train_time:43315ms step_avg:87.68ms +step:495/1670 train_time:43403ms step_avg:87.68ms +step:496/1670 train_time:43490ms step_avg:87.68ms +step:497/1670 train_time:43578ms step_avg:87.68ms +step:498/1670 train_time:43665ms step_avg:87.68ms +step:499/1670 train_time:43753ms step_avg:87.68ms +step:500/1670 train_time:43841ms step_avg:87.68ms +step:500/1670 val_loss:3.7255 train_time:43930ms step_avg:87.86ms +step:501/1670 train_time:43949ms step_avg:87.72ms +step:502/1670 train_time:44021ms step_avg:87.69ms +step:503/1670 train_time:44112ms step_avg:87.70ms +step:504/1670 train_time:44200ms step_avg:87.70ms +step:505/1670 train_time:44287ms step_avg:87.70ms +step:506/1670 train_time:44374ms step_avg:87.70ms +step:507/1670 train_time:44461ms step_avg:87.69ms +step:508/1670 train_time:44548ms step_avg:87.69ms +step:509/1670 train_time:44634ms step_avg:87.69ms +step:510/1670 train_time:44722ms step_avg:87.69ms +step:511/1670 train_time:44808ms step_avg:87.69ms +step:512/1670 train_time:44896ms step_avg:87.69ms +step:513/1670 train_time:44987ms step_avg:87.69ms +step:514/1670 train_time:45077ms step_avg:87.70ms +step:515/1670 train_time:45166ms step_avg:87.70ms +step:516/1670 train_time:45254ms step_avg:87.70ms +step:517/1670 train_time:45341ms step_avg:87.70ms +step:518/1670 train_time:45428ms step_avg:87.70ms +step:519/1670 train_time:45515ms step_avg:87.70ms +step:520/1670 train_time:45603ms step_avg:87.70ms +step:521/1670 train_time:45691ms step_avg:87.70ms +step:522/1670 train_time:45778ms step_avg:87.70ms +step:523/1670 train_time:45866ms step_avg:87.70ms +step:524/1670 train_time:45954ms step_avg:87.70ms +step:525/1670 train_time:46044ms step_avg:87.70ms +step:526/1670 train_time:46132ms step_avg:87.70ms +step:527/1670 train_time:46220ms step_avg:87.70ms +step:528/1670 train_time:46308ms step_avg:87.70ms +step:529/1670 train_time:46395ms step_avg:87.70ms +step:530/1670 train_time:46483ms step_avg:87.70ms +step:531/1670 train_time:46569ms step_avg:87.70ms +step:532/1670 train_time:46657ms step_avg:87.70ms +step:533/1670 train_time:46744ms step_avg:87.70ms +step:534/1670 train_time:46831ms step_avg:87.70ms +step:535/1670 train_time:46920ms step_avg:87.70ms +step:536/1670 train_time:47008ms step_avg:87.70ms +step:537/1670 train_time:47096ms step_avg:87.70ms +step:538/1670 train_time:47184ms step_avg:87.70ms +step:539/1670 train_time:47272ms step_avg:87.70ms +step:540/1670 train_time:47359ms step_avg:87.70ms +step:541/1670 train_time:47447ms step_avg:87.70ms +step:542/1670 train_time:47533ms step_avg:87.70ms +step:543/1670 train_time:47620ms step_avg:87.70ms +step:544/1670 train_time:47707ms step_avg:87.70ms +step:545/1670 train_time:47795ms step_avg:87.70ms +step:546/1670 train_time:47885ms step_avg:87.70ms +step:547/1670 train_time:47974ms step_avg:87.70ms +step:548/1670 train_time:48063ms step_avg:87.71ms +step:549/1670 train_time:48151ms step_avg:87.71ms +step:550/1670 train_time:48240ms step_avg:87.71ms +step:551/1670 train_time:48329ms step_avg:87.71ms +step:552/1670 train_time:48418ms step_avg:87.71ms +step:553/1670 train_time:48506ms step_avg:87.71ms +step:554/1670 train_time:48595ms step_avg:87.72ms +step:555/1670 train_time:48684ms step_avg:87.72ms +step:556/1670 train_time:48772ms step_avg:87.72ms +step:557/1670 train_time:48861ms step_avg:87.72ms +step:558/1670 train_time:48950ms step_avg:87.72ms +step:559/1670 train_time:49038ms step_avg:87.72ms +step:560/1670 train_time:49128ms step_avg:87.73ms +step:561/1670 train_time:49217ms step_avg:87.73ms +step:562/1670 train_time:49307ms step_avg:87.73ms +step:563/1670 train_time:49396ms step_avg:87.74ms +step:564/1670 train_time:49486ms step_avg:87.74ms +step:565/1670 train_time:49575ms step_avg:87.74ms +step:566/1670 train_time:49664ms step_avg:87.75ms +step:567/1670 train_time:49752ms step_avg:87.75ms +step:568/1670 train_time:49843ms step_avg:87.75ms +step:569/1670 train_time:49930ms step_avg:87.75ms +step:570/1670 train_time:50019ms step_avg:87.75ms +step:571/1670 train_time:50109ms step_avg:87.76ms +step:572/1670 train_time:50199ms step_avg:87.76ms +step:573/1670 train_time:50287ms step_avg:87.76ms +step:574/1670 train_time:50376ms step_avg:87.76ms +step:575/1670 train_time:50465ms step_avg:87.76ms +step:576/1670 train_time:50554ms step_avg:87.77ms +step:577/1670 train_time:50643ms step_avg:87.77ms +step:578/1670 train_time:50731ms step_avg:87.77ms +step:579/1670 train_time:50820ms step_avg:87.77ms +step:580/1670 train_time:50908ms step_avg:87.77ms +step:581/1670 train_time:50998ms step_avg:87.78ms +step:582/1670 train_time:51087ms step_avg:87.78ms +step:583/1670 train_time:51177ms step_avg:87.78ms +step:584/1670 train_time:51265ms step_avg:87.78ms +step:585/1670 train_time:51354ms step_avg:87.78ms +step:586/1670 train_time:51443ms step_avg:87.79ms +step:587/1670 train_time:51531ms step_avg:87.79ms +step:588/1670 train_time:51620ms step_avg:87.79ms +step:589/1670 train_time:51709ms step_avg:87.79ms +step:590/1670 train_time:51799ms step_avg:87.80ms +step:591/1670 train_time:51888ms step_avg:87.80ms +step:592/1670 train_time:51977ms step_avg:87.80ms +step:593/1670 train_time:52067ms step_avg:87.80ms +step:594/1670 train_time:52156ms step_avg:87.80ms +step:595/1670 train_time:52245ms step_avg:87.81ms +step:596/1670 train_time:52333ms step_avg:87.81ms +step:597/1670 train_time:52423ms step_avg:87.81ms +step:598/1670 train_time:52512ms step_avg:87.81ms +step:599/1670 train_time:52601ms step_avg:87.81ms +step:600/1670 train_time:52690ms step_avg:87.82ms +step:601/1670 train_time:52779ms step_avg:87.82ms +step:602/1670 train_time:52867ms step_avg:87.82ms +step:603/1670 train_time:52956ms step_avg:87.82ms +step:604/1670 train_time:53045ms step_avg:87.82ms +step:605/1670 train_time:53135ms step_avg:87.83ms +step:606/1670 train_time:53224ms step_avg:87.83ms +step:607/1670 train_time:53313ms step_avg:87.83ms +step:608/1670 train_time:53402ms step_avg:87.83ms +step:609/1670 train_time:53491ms step_avg:87.83ms +step:610/1670 train_time:53581ms step_avg:87.84ms +step:611/1670 train_time:53670ms step_avg:87.84ms +step:612/1670 train_time:53759ms step_avg:87.84ms +step:613/1670 train_time:53847ms step_avg:87.84ms +step:614/1670 train_time:53936ms step_avg:87.84ms +step:615/1670 train_time:54025ms step_avg:87.85ms +step:616/1670 train_time:54115ms step_avg:87.85ms +step:617/1670 train_time:54205ms step_avg:87.85ms +step:618/1670 train_time:54294ms step_avg:87.85ms +step:619/1670 train_time:54383ms step_avg:87.86ms +step:620/1670 train_time:54472ms step_avg:87.86ms +step:621/1670 train_time:54561ms step_avg:87.86ms +step:622/1670 train_time:54649ms step_avg:87.86ms +step:623/1670 train_time:54738ms step_avg:87.86ms +step:624/1670 train_time:54827ms step_avg:87.86ms +step:625/1670 train_time:54916ms step_avg:87.87ms +step:625/1670 val_loss:3.6211 train_time:55006ms step_avg:88.01ms +step:626/1670 train_time:55026ms step_avg:87.90ms +step:627/1670 train_time:55097ms step_avg:87.87ms +step:628/1670 train_time:55186ms step_avg:87.88ms +step:629/1670 train_time:55275ms step_avg:87.88ms +step:630/1670 train_time:55363ms step_avg:87.88ms +step:631/1670 train_time:55450ms step_avg:87.88ms +step:632/1670 train_time:55538ms step_avg:87.88ms +step:633/1670 train_time:55626ms step_avg:87.88ms +step:634/1670 train_time:55714ms step_avg:87.88ms +step:635/1670 train_time:55803ms step_avg:87.88ms +step:636/1670 train_time:55892ms step_avg:87.88ms +step:637/1670 train_time:55985ms step_avg:87.89ms +step:638/1670 train_time:56074ms step_avg:87.89ms +step:639/1670 train_time:56164ms step_avg:87.89ms +step:640/1670 train_time:56252ms step_avg:87.89ms +step:641/1670 train_time:56341ms step_avg:87.90ms +step:642/1670 train_time:56428ms step_avg:87.89ms +step:643/1670 train_time:56516ms step_avg:87.89ms +step:644/1670 train_time:56605ms step_avg:87.90ms +step:645/1670 train_time:56692ms step_avg:87.89ms +step:646/1670 train_time:56781ms step_avg:87.90ms +step:647/1670 train_time:56871ms step_avg:87.90ms +step:648/1670 train_time:56961ms step_avg:87.90ms +step:649/1670 train_time:57051ms step_avg:87.91ms +step:650/1670 train_time:57141ms step_avg:87.91ms +step:651/1670 train_time:57230ms step_avg:87.91ms +step:652/1670 train_time:57318ms step_avg:87.91ms +step:653/1670 train_time:57407ms step_avg:87.91ms +step:654/1670 train_time:57495ms step_avg:87.91ms +step:655/1670 train_time:57583ms step_avg:87.91ms +step:656/1670 train_time:57671ms step_avg:87.91ms +step:657/1670 train_time:57759ms step_avg:87.91ms +step:658/1670 train_time:57850ms step_avg:87.92ms +step:659/1670 train_time:57941ms step_avg:87.92ms +step:660/1670 train_time:58031ms step_avg:87.93ms +step:661/1670 train_time:58121ms step_avg:87.93ms +step:662/1670 train_time:58210ms step_avg:87.93ms +step:663/1670 train_time:58299ms step_avg:87.93ms +step:664/1670 train_time:58387ms step_avg:87.93ms +step:665/1670 train_time:58476ms step_avg:87.93ms +step:666/1670 train_time:58564ms step_avg:87.93ms +step:667/1670 train_time:58652ms step_avg:87.93ms +step:668/1670 train_time:58740ms step_avg:87.93ms +step:669/1670 train_time:58829ms step_avg:87.94ms +step:670/1670 train_time:58920ms step_avg:87.94ms +step:671/1670 train_time:59009ms step_avg:87.94ms +step:672/1670 train_time:59099ms step_avg:87.94ms +step:673/1670 train_time:59189ms step_avg:87.95ms +step:674/1670 train_time:59278ms step_avg:87.95ms +step:675/1670 train_time:59367ms step_avg:87.95ms +step:676/1670 train_time:59455ms step_avg:87.95ms +step:677/1670 train_time:59544ms step_avg:87.95ms +step:678/1670 train_time:59632ms step_avg:87.95ms +step:679/1670 train_time:59720ms step_avg:87.95ms +step:680/1670 train_time:59809ms step_avg:87.95ms +step:681/1670 train_time:59898ms step_avg:87.96ms +step:682/1670 train_time:59988ms step_avg:87.96ms +step:683/1670 train_time:60078ms step_avg:87.96ms +step:684/1670 train_time:60168ms step_avg:87.97ms +step:685/1670 train_time:60257ms step_avg:87.97ms +step:686/1670 train_time:60346ms step_avg:87.97ms +step:687/1670 train_time:60434ms step_avg:87.97ms +step:688/1670 train_time:60523ms step_avg:87.97ms +step:689/1670 train_time:60611ms step_avg:87.97ms +step:690/1670 train_time:60701ms step_avg:87.97ms +step:691/1670 train_time:60789ms step_avg:87.97ms +step:692/1670 train_time:60878ms step_avg:87.97ms +step:693/1670 train_time:60967ms step_avg:87.98ms +step:694/1670 train_time:61056ms step_avg:87.98ms +step:695/1670 train_time:61146ms step_avg:87.98ms +step:696/1670 train_time:61234ms step_avg:87.98ms +step:697/1670 train_time:61323ms step_avg:87.98ms +step:698/1670 train_time:61412ms step_avg:87.98ms +step:699/1670 train_time:61502ms step_avg:87.98ms +step:700/1670 train_time:61590ms step_avg:87.99ms +step:701/1670 train_time:61679ms step_avg:87.99ms +step:702/1670 train_time:61767ms step_avg:87.99ms +step:703/1670 train_time:61855ms step_avg:87.99ms +step:704/1670 train_time:61945ms step_avg:87.99ms +step:705/1670 train_time:62034ms step_avg:87.99ms +step:706/1670 train_time:62123ms step_avg:87.99ms +step:707/1670 train_time:62211ms step_avg:87.99ms +step:708/1670 train_time:62301ms step_avg:88.00ms +step:709/1670 train_time:62390ms step_avg:88.00ms +step:710/1670 train_time:62479ms step_avg:88.00ms +step:711/1670 train_time:62567ms step_avg:88.00ms +step:712/1670 train_time:62656ms step_avg:88.00ms +step:713/1670 train_time:62746ms step_avg:88.00ms +step:714/1670 train_time:62834ms step_avg:88.00ms +step:715/1670 train_time:62923ms step_avg:88.00ms +step:716/1670 train_time:63012ms step_avg:88.01ms +step:717/1670 train_time:63102ms step_avg:88.01ms +step:718/1670 train_time:63190ms step_avg:88.01ms +step:719/1670 train_time:63279ms step_avg:88.01ms +step:720/1670 train_time:63368ms step_avg:88.01ms +step:721/1670 train_time:63456ms step_avg:88.01ms +step:722/1670 train_time:63546ms step_avg:88.01ms +step:723/1670 train_time:63635ms step_avg:88.01ms +step:724/1670 train_time:63724ms step_avg:88.02ms +step:725/1670 train_time:63813ms step_avg:88.02ms +step:726/1670 train_time:63902ms step_avg:88.02ms +step:727/1670 train_time:63991ms step_avg:88.02ms +step:728/1670 train_time:64079ms step_avg:88.02ms +step:729/1670 train_time:64168ms step_avg:88.02ms +step:730/1670 train_time:64257ms step_avg:88.02ms +step:731/1670 train_time:64346ms step_avg:88.02ms +step:732/1670 train_time:64434ms step_avg:88.02ms +step:733/1670 train_time:64523ms step_avg:88.03ms +step:734/1670 train_time:64612ms step_avg:88.03ms +step:735/1670 train_time:64702ms step_avg:88.03ms +step:736/1670 train_time:64790ms step_avg:88.03ms +step:737/1670 train_time:64879ms step_avg:88.03ms +step:738/1670 train_time:64968ms step_avg:88.03ms +step:739/1670 train_time:65057ms step_avg:88.03ms +step:740/1670 train_time:65147ms step_avg:88.04ms +step:741/1670 train_time:65236ms step_avg:88.04ms +step:742/1670 train_time:65325ms step_avg:88.04ms +step:743/1670 train_time:65413ms step_avg:88.04ms +step:744/1670 train_time:65502ms step_avg:88.04ms +step:745/1670 train_time:65590ms step_avg:88.04ms +step:746/1670 train_time:65679ms step_avg:88.04ms +step:747/1670 train_time:65769ms step_avg:88.04ms +step:748/1670 train_time:65858ms step_avg:88.05ms +step:749/1670 train_time:65948ms step_avg:88.05ms +step:750/1670 train_time:66038ms step_avg:88.05ms +step:750/1670 val_loss:3.5708 train_time:66130ms step_avg:88.17ms +step:751/1670 train_time:66150ms step_avg:88.08ms +step:752/1670 train_time:66223ms step_avg:88.06ms +step:753/1670 train_time:66316ms step_avg:88.07ms +step:754/1670 train_time:66406ms step_avg:88.07ms +step:755/1670 train_time:66494ms step_avg:88.07ms +step:756/1670 train_time:66582ms step_avg:88.07ms +step:757/1670 train_time:66670ms step_avg:88.07ms +step:758/1670 train_time:66757ms step_avg:88.07ms +step:759/1670 train_time:66845ms step_avg:88.07ms +step:760/1670 train_time:66933ms step_avg:88.07ms +step:761/1670 train_time:67021ms step_avg:88.07ms +step:762/1670 train_time:67113ms step_avg:88.07ms +step:763/1670 train_time:67204ms step_avg:88.08ms +step:764/1670 train_time:67294ms step_avg:88.08ms +step:765/1670 train_time:67386ms step_avg:88.09ms +step:766/1670 train_time:67475ms step_avg:88.09ms +step:767/1670 train_time:67564ms step_avg:88.09ms +step:768/1670 train_time:67652ms step_avg:88.09ms +step:769/1670 train_time:67740ms step_avg:88.09ms +step:770/1670 train_time:67828ms step_avg:88.09ms +step:771/1670 train_time:67916ms step_avg:88.09ms +step:772/1670 train_time:68005ms step_avg:88.09ms +step:773/1670 train_time:68094ms step_avg:88.09ms +step:774/1670 train_time:68184ms step_avg:88.09ms +step:775/1670 train_time:68274ms step_avg:88.10ms +step:776/1670 train_time:68365ms step_avg:88.10ms +step:777/1670 train_time:68454ms step_avg:88.10ms +step:778/1670 train_time:68543ms step_avg:88.10ms +step:779/1670 train_time:68631ms step_avg:88.10ms +step:780/1670 train_time:68719ms step_avg:88.10ms +step:781/1670 train_time:68807ms step_avg:88.10ms +step:782/1670 train_time:68896ms step_avg:88.10ms +step:783/1670 train_time:68985ms step_avg:88.10ms +step:784/1670 train_time:69073ms step_avg:88.10ms +step:785/1670 train_time:69163ms step_avg:88.11ms +step:786/1670 train_time:69252ms step_avg:88.11ms +step:787/1670 train_time:69341ms step_avg:88.11ms +step:788/1670 train_time:69430ms step_avg:88.11ms +step:789/1670 train_time:69520ms step_avg:88.11ms +step:790/1670 train_time:69609ms step_avg:88.11ms +step:791/1670 train_time:69699ms step_avg:88.11ms +step:792/1670 train_time:69786ms step_avg:88.11ms +step:793/1670 train_time:69875ms step_avg:88.11ms +step:794/1670 train_time:69963ms step_avg:88.11ms +step:795/1670 train_time:70051ms step_avg:88.12ms +step:796/1670 train_time:70141ms step_avg:88.12ms +step:797/1670 train_time:70231ms step_avg:88.12ms +step:798/1670 train_time:70320ms step_avg:88.12ms +step:799/1670 train_time:70409ms step_avg:88.12ms +step:800/1670 train_time:70497ms step_avg:88.12ms +step:801/1670 train_time:70587ms step_avg:88.12ms +step:802/1670 train_time:70676ms step_avg:88.12ms +step:803/1670 train_time:70764ms step_avg:88.12ms +step:804/1670 train_time:70853ms step_avg:88.13ms +step:805/1670 train_time:70942ms step_avg:88.13ms +step:806/1670 train_time:71030ms step_avg:88.13ms +step:807/1670 train_time:71120ms step_avg:88.13ms +step:808/1670 train_time:71209ms step_avg:88.13ms +step:809/1670 train_time:71298ms step_avg:88.13ms +step:810/1670 train_time:71388ms step_avg:88.13ms +step:811/1670 train_time:71476ms step_avg:88.13ms +step:812/1670 train_time:71565ms step_avg:88.13ms +step:813/1670 train_time:71654ms step_avg:88.14ms +step:814/1670 train_time:71743ms step_avg:88.14ms +step:815/1670 train_time:71831ms step_avg:88.14ms +step:816/1670 train_time:71921ms step_avg:88.14ms +step:817/1670 train_time:72010ms step_avg:88.14ms +step:818/1670 train_time:72099ms step_avg:88.14ms +step:819/1670 train_time:72187ms step_avg:88.14ms +step:820/1670 train_time:72276ms step_avg:88.14ms +step:821/1670 train_time:72366ms step_avg:88.14ms +step:822/1670 train_time:72455ms step_avg:88.15ms +step:823/1670 train_time:72545ms step_avg:88.15ms +step:824/1670 train_time:72634ms step_avg:88.15ms +step:825/1670 train_time:72722ms step_avg:88.15ms +step:826/1670 train_time:72811ms step_avg:88.15ms +step:827/1670 train_time:72900ms step_avg:88.15ms +step:828/1670 train_time:72989ms step_avg:88.15ms +step:829/1670 train_time:73079ms step_avg:88.15ms +step:830/1670 train_time:73168ms step_avg:88.15ms +step:831/1670 train_time:73258ms step_avg:88.16ms +step:832/1670 train_time:73347ms step_avg:88.16ms +step:833/1670 train_time:73437ms step_avg:88.16ms +step:834/1670 train_time:73525ms step_avg:88.16ms +step:835/1670 train_time:73614ms step_avg:88.16ms +step:836/1670 train_time:73703ms step_avg:88.16ms +step:837/1670 train_time:73792ms step_avg:88.16ms +step:838/1670 train_time:73880ms step_avg:88.16ms +step:839/1670 train_time:73969ms step_avg:88.16ms +step:840/1670 train_time:74059ms step_avg:88.17ms +step:841/1670 train_time:74149ms step_avg:88.17ms +step:842/1670 train_time:74238ms step_avg:88.17ms +step:843/1670 train_time:74327ms step_avg:88.17ms +step:844/1670 train_time:74417ms step_avg:88.17ms +step:845/1670 train_time:74506ms step_avg:88.17ms +step:846/1670 train_time:74595ms step_avg:88.17ms +step:847/1670 train_time:74685ms step_avg:88.18ms +step:848/1670 train_time:74772ms step_avg:88.17ms +step:849/1670 train_time:74861ms step_avg:88.18ms +step:850/1670 train_time:74950ms step_avg:88.18ms +step:851/1670 train_time:75039ms step_avg:88.18ms +step:852/1670 train_time:75128ms step_avg:88.18ms +step:853/1670 train_time:75218ms step_avg:88.18ms +step:854/1670 train_time:75308ms step_avg:88.18ms +step:855/1670 train_time:75397ms step_avg:88.18ms +step:856/1670 train_time:75486ms step_avg:88.18ms +step:857/1670 train_time:75576ms step_avg:88.19ms +step:858/1670 train_time:75665ms step_avg:88.19ms +step:859/1670 train_time:75753ms step_avg:88.19ms +step:860/1670 train_time:75843ms step_avg:88.19ms +step:861/1670 train_time:75931ms step_avg:88.19ms +step:862/1670 train_time:76020ms step_avg:88.19ms +step:863/1670 train_time:76109ms step_avg:88.19ms +step:864/1670 train_time:76198ms step_avg:88.19ms +step:865/1670 train_time:76287ms step_avg:88.19ms +step:866/1670 train_time:76375ms step_avg:88.19ms +step:867/1670 train_time:76464ms step_avg:88.19ms +step:868/1670 train_time:76553ms step_avg:88.19ms +step:869/1670 train_time:76643ms step_avg:88.20ms +step:870/1670 train_time:76731ms step_avg:88.20ms +step:871/1670 train_time:76820ms step_avg:88.20ms +step:872/1670 train_time:76908ms step_avg:88.20ms +step:873/1670 train_time:76997ms step_avg:88.20ms +step:874/1670 train_time:77086ms step_avg:88.20ms +step:875/1670 train_time:77175ms step_avg:88.20ms +step:875/1670 val_loss:3.5191 train_time:77266ms step_avg:88.30ms +step:876/1670 train_time:77285ms step_avg:88.22ms +step:877/1670 train_time:77358ms step_avg:88.21ms +step:878/1670 train_time:77452ms step_avg:88.21ms +step:879/1670 train_time:77542ms step_avg:88.22ms +step:880/1670 train_time:77631ms step_avg:88.22ms +step:881/1670 train_time:77719ms step_avg:88.22ms +step:882/1670 train_time:77807ms step_avg:88.22ms +step:883/1670 train_time:77895ms step_avg:88.22ms +step:884/1670 train_time:77982ms step_avg:88.22ms +step:885/1670 train_time:78070ms step_avg:88.21ms +step:886/1670 train_time:78158ms step_avg:88.21ms +step:887/1670 train_time:78248ms step_avg:88.22ms +step:888/1670 train_time:78339ms step_avg:88.22ms +step:889/1670 train_time:78430ms step_avg:88.22ms +step:890/1670 train_time:78521ms step_avg:88.23ms +step:891/1670 train_time:78610ms step_avg:88.23ms +step:892/1670 train_time:78699ms step_avg:88.23ms +step:893/1670 train_time:78789ms step_avg:88.23ms +step:894/1670 train_time:78877ms step_avg:88.23ms +step:895/1670 train_time:78965ms step_avg:88.23ms +step:896/1670 train_time:79053ms step_avg:88.23ms +step:897/1670 train_time:79141ms step_avg:88.23ms +step:898/1670 train_time:79230ms step_avg:88.23ms +step:899/1670 train_time:79319ms step_avg:88.23ms +step:900/1670 train_time:79409ms step_avg:88.23ms +step:901/1670 train_time:79499ms step_avg:88.23ms +step:902/1670 train_time:79590ms step_avg:88.24ms +step:903/1670 train_time:79679ms step_avg:88.24ms +step:904/1670 train_time:79768ms step_avg:88.24ms +step:905/1670 train_time:79856ms step_avg:88.24ms +step:906/1670 train_time:79944ms step_avg:88.24ms +step:907/1670 train_time:80032ms step_avg:88.24ms +step:908/1670 train_time:80120ms step_avg:88.24ms +step:909/1670 train_time:80209ms step_avg:88.24ms +step:910/1670 train_time:80299ms step_avg:88.24ms +step:911/1670 train_time:80389ms step_avg:88.24ms +step:912/1670 train_time:80479ms step_avg:88.24ms +step:913/1670 train_time:80569ms step_avg:88.25ms +step:914/1670 train_time:80658ms step_avg:88.25ms +step:915/1670 train_time:80747ms step_avg:88.25ms +step:916/1670 train_time:80835ms step_avg:88.25ms +step:917/1670 train_time:80924ms step_avg:88.25ms +step:918/1670 train_time:81012ms step_avg:88.25ms +step:919/1670 train_time:81101ms step_avg:88.25ms +step:920/1670 train_time:81190ms step_avg:88.25ms +step:921/1670 train_time:81278ms step_avg:88.25ms +step:922/1670 train_time:81367ms step_avg:88.25ms +step:923/1670 train_time:81456ms step_avg:88.25ms +step:924/1670 train_time:81546ms step_avg:88.25ms +step:925/1670 train_time:81634ms step_avg:88.25ms +step:926/1670 train_time:81724ms step_avg:88.25ms +step:927/1670 train_time:81813ms step_avg:88.26ms +step:928/1670 train_time:81902ms step_avg:88.26ms +step:929/1670 train_time:81991ms step_avg:88.26ms +step:930/1670 train_time:82080ms step_avg:88.26ms +step:931/1670 train_time:82169ms step_avg:88.26ms +step:932/1670 train_time:82258ms step_avg:88.26ms +step:933/1670 train_time:82347ms step_avg:88.26ms +step:934/1670 train_time:82436ms step_avg:88.26ms +step:935/1670 train_time:82526ms step_avg:88.26ms +step:936/1670 train_time:82614ms step_avg:88.26ms +step:937/1670 train_time:82703ms step_avg:88.26ms +step:938/1670 train_time:82791ms step_avg:88.26ms +step:939/1670 train_time:82881ms step_avg:88.26ms +step:940/1670 train_time:82970ms step_avg:88.27ms +step:941/1670 train_time:83058ms step_avg:88.27ms +step:942/1670 train_time:83147ms step_avg:88.27ms +step:943/1670 train_time:83236ms step_avg:88.27ms +step:944/1670 train_time:83325ms step_avg:88.27ms +step:945/1670 train_time:83414ms step_avg:88.27ms +step:946/1670 train_time:83503ms step_avg:88.27ms +step:947/1670 train_time:83592ms step_avg:88.27ms +step:948/1670 train_time:83681ms step_avg:88.27ms +step:949/1670 train_time:83770ms step_avg:88.27ms +step:950/1670 train_time:83858ms step_avg:88.27ms +step:951/1670 train_time:83948ms step_avg:88.27ms +step:952/1670 train_time:84037ms step_avg:88.27ms +step:953/1670 train_time:84126ms step_avg:88.27ms +step:954/1670 train_time:84215ms step_avg:88.28ms +step:955/1670 train_time:84304ms step_avg:88.28ms +step:956/1670 train_time:84392ms step_avg:88.28ms +step:957/1670 train_time:84480ms step_avg:88.28ms +step:958/1670 train_time:84569ms step_avg:88.28ms +step:959/1670 train_time:84658ms step_avg:88.28ms +step:960/1670 train_time:84747ms step_avg:88.28ms +step:961/1670 train_time:84835ms step_avg:88.28ms +step:962/1670 train_time:84925ms step_avg:88.28ms +step:963/1670 train_time:85013ms step_avg:88.28ms +step:964/1670 train_time:85102ms step_avg:88.28ms +step:965/1670 train_time:85191ms step_avg:88.28ms +step:966/1670 train_time:85280ms step_avg:88.28ms +step:967/1670 train_time:85368ms step_avg:88.28ms +step:968/1670 train_time:85457ms step_avg:88.28ms +step:969/1670 train_time:85546ms step_avg:88.28ms +step:970/1670 train_time:85634ms step_avg:88.28ms +step:971/1670 train_time:85724ms step_avg:88.28ms +step:972/1670 train_time:85812ms step_avg:88.28ms +step:973/1670 train_time:85902ms step_avg:88.29ms +step:974/1670 train_time:85991ms step_avg:88.29ms +step:975/1670 train_time:86080ms step_avg:88.29ms +step:976/1670 train_time:86169ms step_avg:88.29ms +step:977/1670 train_time:86258ms step_avg:88.29ms +step:978/1670 train_time:86347ms step_avg:88.29ms +step:979/1670 train_time:86435ms step_avg:88.29ms +step:980/1670 train_time:86524ms step_avg:88.29ms +step:981/1670 train_time:86613ms step_avg:88.29ms +step:982/1670 train_time:86702ms step_avg:88.29ms +step:983/1670 train_time:86791ms step_avg:88.29ms +step:984/1670 train_time:86881ms step_avg:88.29ms +step:985/1670 train_time:86970ms step_avg:88.29ms +step:986/1670 train_time:87060ms step_avg:88.30ms +step:987/1670 train_time:87148ms step_avg:88.30ms +step:988/1670 train_time:87237ms step_avg:88.30ms +step:989/1670 train_time:87327ms step_avg:88.30ms +step:990/1670 train_time:87415ms step_avg:88.30ms +step:991/1670 train_time:87505ms step_avg:88.30ms +step:992/1670 train_time:87593ms step_avg:88.30ms +step:993/1670 train_time:87683ms step_avg:88.30ms +step:994/1670 train_time:87771ms step_avg:88.30ms +step:995/1670 train_time:87861ms step_avg:88.30ms +step:996/1670 train_time:87950ms step_avg:88.30ms +step:997/1670 train_time:88038ms step_avg:88.30ms +step:998/1670 train_time:88128ms step_avg:88.30ms +step:999/1670 train_time:88217ms step_avg:88.31ms +step:1000/1670 train_time:88307ms step_avg:88.31ms +step:1000/1670 val_loss:3.4687 train_time:88396ms step_avg:88.40ms +step:1001/1670 train_time:88416ms step_avg:88.33ms +step:1002/1670 train_time:88488ms step_avg:88.31ms +step:1003/1670 train_time:88582ms step_avg:88.32ms +step:1004/1670 train_time:88670ms step_avg:88.32ms +step:1005/1670 train_time:88759ms step_avg:88.32ms +step:1006/1670 train_time:88848ms step_avg:88.32ms +step:1007/1670 train_time:88936ms step_avg:88.32ms +step:1008/1670 train_time:89024ms step_avg:88.32ms +step:1009/1670 train_time:89112ms step_avg:88.32ms +step:1010/1670 train_time:89201ms step_avg:88.32ms +step:1011/1670 train_time:89289ms step_avg:88.32ms +step:1012/1670 train_time:89378ms step_avg:88.32ms +step:1013/1670 train_time:89470ms step_avg:88.32ms +step:1014/1670 train_time:89561ms step_avg:88.32ms +step:1015/1670 train_time:89651ms step_avg:88.33ms +step:1016/1670 train_time:89740ms step_avg:88.33ms +step:1017/1670 train_time:89829ms step_avg:88.33ms +step:1018/1670 train_time:89919ms step_avg:88.33ms +step:1019/1670 train_time:90007ms step_avg:88.33ms +step:1020/1670 train_time:90095ms step_avg:88.33ms +step:1021/1670 train_time:90183ms step_avg:88.33ms +step:1022/1670 train_time:90271ms step_avg:88.33ms +step:1023/1670 train_time:90361ms step_avg:88.33ms +step:1024/1670 train_time:90451ms step_avg:88.33ms +step:1025/1670 train_time:90542ms step_avg:88.33ms +step:1026/1670 train_time:90631ms step_avg:88.33ms +step:1027/1670 train_time:90721ms step_avg:88.34ms +step:1028/1670 train_time:90809ms step_avg:88.34ms +step:1029/1670 train_time:90898ms step_avg:88.34ms +step:1030/1670 train_time:90987ms step_avg:88.34ms +step:1031/1670 train_time:91074ms step_avg:88.34ms +step:1032/1670 train_time:91163ms step_avg:88.34ms +step:1033/1670 train_time:91251ms step_avg:88.34ms +step:1034/1670 train_time:91340ms step_avg:88.34ms +step:1035/1670 train_time:91430ms step_avg:88.34ms +step:1036/1670 train_time:91519ms step_avg:88.34ms +step:1037/1670 train_time:91609ms step_avg:88.34ms +step:1038/1670 train_time:91699ms step_avg:88.34ms +step:1039/1670 train_time:91787ms step_avg:88.34ms +step:1040/1670 train_time:91876ms step_avg:88.34ms +step:1041/1670 train_time:91965ms step_avg:88.34ms +step:1042/1670 train_time:92054ms step_avg:88.34ms +step:1043/1670 train_time:92143ms step_avg:88.34ms +step:1044/1670 train_time:92232ms step_avg:88.34ms +step:1045/1670 train_time:92322ms step_avg:88.35ms +step:1046/1670 train_time:92410ms step_avg:88.35ms +step:1047/1670 train_time:92500ms step_avg:88.35ms +step:1048/1670 train_time:92589ms step_avg:88.35ms +step:1049/1670 train_time:92679ms step_avg:88.35ms +step:1050/1670 train_time:92768ms step_avg:88.35ms +step:1051/1670 train_time:92857ms step_avg:88.35ms +step:1052/1670 train_time:92946ms step_avg:88.35ms +step:1053/1670 train_time:93036ms step_avg:88.35ms +step:1054/1670 train_time:93125ms step_avg:88.35ms +step:1055/1670 train_time:93213ms step_avg:88.35ms +step:1056/1670 train_time:93302ms step_avg:88.35ms +step:1057/1670 train_time:93391ms step_avg:88.35ms +step:1058/1670 train_time:93480ms step_avg:88.36ms +step:1059/1670 train_time:93569ms step_avg:88.36ms +step:1060/1670 train_time:93658ms step_avg:88.36ms +step:1061/1670 train_time:93747ms step_avg:88.36ms +step:1062/1670 train_time:93837ms step_avg:88.36ms +step:1063/1670 train_time:93925ms step_avg:88.36ms +step:1064/1670 train_time:94014ms step_avg:88.36ms +step:1065/1670 train_time:94104ms step_avg:88.36ms +step:1066/1670 train_time:94192ms step_avg:88.36ms +step:1067/1670 train_time:94281ms step_avg:88.36ms +step:1068/1670 train_time:94369ms step_avg:88.36ms +step:1069/1670 train_time:94460ms step_avg:88.36ms +step:1070/1670 train_time:94549ms step_avg:88.36ms +step:1071/1670 train_time:94637ms step_avg:88.36ms +step:1072/1670 train_time:94726ms step_avg:88.36ms +step:1073/1670 train_time:94816ms step_avg:88.37ms +step:1074/1670 train_time:94905ms step_avg:88.37ms +step:1075/1670 train_time:94993ms step_avg:88.37ms +step:1076/1670 train_time:95082ms step_avg:88.37ms +step:1077/1670 train_time:95171ms step_avg:88.37ms +step:1078/1670 train_time:95260ms step_avg:88.37ms +step:1079/1670 train_time:95349ms step_avg:88.37ms +step:1080/1670 train_time:95438ms step_avg:88.37ms +step:1081/1670 train_time:95528ms step_avg:88.37ms +step:1082/1670 train_time:95617ms step_avg:88.37ms +step:1083/1670 train_time:95706ms step_avg:88.37ms +step:1084/1670 train_time:95795ms step_avg:88.37ms +step:1085/1670 train_time:95884ms step_avg:88.37ms +step:1086/1670 train_time:95973ms step_avg:88.37ms +step:1087/1670 train_time:96062ms step_avg:88.37ms +step:1088/1670 train_time:96150ms step_avg:88.37ms +step:1089/1670 train_time:96239ms step_avg:88.37ms +step:1090/1670 train_time:96329ms step_avg:88.38ms +step:1091/1670 train_time:96420ms step_avg:88.38ms +step:1092/1670 train_time:96510ms step_avg:88.38ms +step:1093/1670 train_time:96599ms step_avg:88.38ms +step:1094/1670 train_time:96688ms step_avg:88.38ms +step:1095/1670 train_time:96778ms step_avg:88.38ms +step:1096/1670 train_time:96867ms step_avg:88.38ms +step:1097/1670 train_time:96957ms step_avg:88.38ms +step:1098/1670 train_time:97047ms step_avg:88.39ms +step:1099/1670 train_time:97138ms step_avg:88.39ms +step:1100/1670 train_time:97228ms step_avg:88.39ms +step:1101/1670 train_time:97319ms step_avg:88.39ms +step:1102/1670 train_time:97409ms step_avg:88.39ms +step:1103/1670 train_time:97499ms step_avg:88.39ms +step:1104/1670 train_time:97589ms step_avg:88.40ms +step:1105/1670 train_time:97678ms step_avg:88.40ms +step:1106/1670 train_time:97767ms step_avg:88.40ms +step:1107/1670 train_time:97857ms step_avg:88.40ms +step:1108/1670 train_time:97947ms step_avg:88.40ms +step:1109/1670 train_time:98037ms step_avg:88.40ms +step:1110/1670 train_time:98126ms step_avg:88.40ms +step:1111/1670 train_time:98216ms step_avg:88.40ms +step:1112/1670 train_time:98306ms step_avg:88.40ms +step:1113/1670 train_time:98395ms step_avg:88.41ms +step:1114/1670 train_time:98486ms step_avg:88.41ms +step:1115/1670 train_time:98576ms step_avg:88.41ms +step:1116/1670 train_time:98666ms step_avg:88.41ms +step:1117/1670 train_time:98756ms step_avg:88.41ms +step:1118/1670 train_time:98846ms step_avg:88.41ms +step:1119/1670 train_time:98935ms step_avg:88.41ms +step:1120/1670 train_time:99026ms step_avg:88.42ms +step:1121/1670 train_time:99115ms step_avg:88.42ms +step:1122/1670 train_time:99204ms step_avg:88.42ms +step:1123/1670 train_time:99294ms step_avg:88.42ms +step:1124/1670 train_time:99384ms step_avg:88.42ms +step:1125/1670 train_time:99473ms step_avg:88.42ms +step:1125/1670 val_loss:3.4156 train_time:99565ms step_avg:88.50ms +step:1126/1670 train_time:99585ms step_avg:88.44ms +step:1127/1670 train_time:99655ms step_avg:88.42ms +step:1128/1670 train_time:99745ms step_avg:88.43ms +step:1129/1670 train_time:99836ms step_avg:88.43ms +step:1130/1670 train_time:99925ms step_avg:88.43ms +step:1131/1670 train_time:100014ms step_avg:88.43ms +step:1132/1670 train_time:100102ms step_avg:88.43ms +step:1133/1670 train_time:100190ms step_avg:88.43ms +step:1134/1670 train_time:100279ms step_avg:88.43ms +step:1135/1670 train_time:100369ms step_avg:88.43ms +step:1136/1670 train_time:100458ms step_avg:88.43ms +step:1137/1670 train_time:100550ms step_avg:88.43ms +step:1138/1670 train_time:100641ms step_avg:88.44ms +step:1139/1670 train_time:100732ms step_avg:88.44ms +step:1140/1670 train_time:100822ms step_avg:88.44ms +step:1141/1670 train_time:100911ms step_avg:88.44ms +step:1142/1670 train_time:101002ms step_avg:88.44ms +step:1143/1670 train_time:101090ms step_avg:88.44ms +step:1144/1670 train_time:101179ms step_avg:88.44ms +step:1145/1670 train_time:101268ms step_avg:88.44ms +step:1146/1670 train_time:101358ms step_avg:88.45ms +step:1147/1670 train_time:101447ms step_avg:88.45ms +step:1148/1670 train_time:101538ms step_avg:88.45ms +step:1149/1670 train_time:101627ms step_avg:88.45ms +step:1150/1670 train_time:101718ms step_avg:88.45ms +step:1151/1670 train_time:101807ms step_avg:88.45ms +step:1152/1670 train_time:101897ms step_avg:88.45ms +step:1153/1670 train_time:101986ms step_avg:88.45ms +step:1154/1670 train_time:102076ms step_avg:88.45ms +step:1155/1670 train_time:102164ms step_avg:88.45ms +step:1156/1670 train_time:102253ms step_avg:88.45ms +step:1157/1670 train_time:102342ms step_avg:88.45ms +step:1158/1670 train_time:102431ms step_avg:88.46ms +step:1159/1670 train_time:102521ms step_avg:88.46ms +step:1160/1670 train_time:102611ms step_avg:88.46ms +step:1161/1670 train_time:102702ms step_avg:88.46ms +step:1162/1670 train_time:102793ms step_avg:88.46ms +step:1163/1670 train_time:102882ms step_avg:88.46ms +step:1164/1670 train_time:102972ms step_avg:88.46ms +step:1165/1670 train_time:103061ms step_avg:88.46ms +step:1166/1670 train_time:103151ms step_avg:88.47ms +step:1167/1670 train_time:103240ms step_avg:88.47ms +step:1168/1670 train_time:103330ms step_avg:88.47ms +step:1169/1670 train_time:103419ms step_avg:88.47ms +step:1170/1670 train_time:103508ms step_avg:88.47ms +step:1171/1670 train_time:103600ms step_avg:88.47ms +step:1172/1670 train_time:103690ms step_avg:88.47ms +step:1173/1670 train_time:103780ms step_avg:88.47ms +step:1174/1670 train_time:103870ms step_avg:88.48ms +step:1175/1670 train_time:103959ms step_avg:88.48ms +step:1176/1670 train_time:104049ms step_avg:88.48ms +step:1177/1670 train_time:104139ms step_avg:88.48ms +step:1178/1670 train_time:104228ms step_avg:88.48ms +step:1179/1670 train_time:104318ms step_avg:88.48ms +step:1180/1670 train_time:104406ms step_avg:88.48ms +step:1181/1670 train_time:104496ms step_avg:88.48ms +step:1182/1670 train_time:104586ms step_avg:88.48ms +step:1183/1670 train_time:104676ms step_avg:88.48ms +step:1184/1670 train_time:104765ms step_avg:88.48ms +step:1185/1670 train_time:104856ms step_avg:88.49ms +step:1186/1670 train_time:104945ms step_avg:88.49ms +step:1187/1670 train_time:105036ms step_avg:88.49ms +step:1188/1670 train_time:105125ms step_avg:88.49ms +step:1189/1670 train_time:105214ms step_avg:88.49ms +step:1190/1670 train_time:105303ms step_avg:88.49ms +step:1191/1670 train_time:105394ms step_avg:88.49ms +step:1192/1670 train_time:105483ms step_avg:88.49ms +step:1193/1670 train_time:105575ms step_avg:88.50ms +step:1194/1670 train_time:105665ms step_avg:88.50ms +step:1195/1670 train_time:105755ms step_avg:88.50ms +step:1196/1670 train_time:105844ms step_avg:88.50ms +step:1197/1670 train_time:105933ms step_avg:88.50ms +step:1198/1670 train_time:106023ms step_avg:88.50ms +step:1199/1670 train_time:106112ms step_avg:88.50ms +step:1200/1670 train_time:106201ms step_avg:88.50ms +step:1201/1670 train_time:106292ms step_avg:88.50ms +step:1202/1670 train_time:106382ms step_avg:88.50ms +step:1203/1670 train_time:106471ms step_avg:88.50ms +step:1204/1670 train_time:106561ms step_avg:88.51ms +step:1205/1670 train_time:106651ms step_avg:88.51ms +step:1206/1670 train_time:106741ms step_avg:88.51ms +step:1207/1670 train_time:106831ms step_avg:88.51ms +step:1208/1670 train_time:106921ms step_avg:88.51ms +step:1209/1670 train_time:107011ms step_avg:88.51ms +step:1210/1670 train_time:107101ms step_avg:88.51ms +step:1211/1670 train_time:107191ms step_avg:88.51ms +step:1212/1670 train_time:107281ms step_avg:88.52ms +step:1213/1670 train_time:107372ms step_avg:88.52ms +step:1214/1670 train_time:107462ms step_avg:88.52ms +step:1215/1670 train_time:107552ms step_avg:88.52ms +step:1216/1670 train_time:107642ms step_avg:88.52ms +step:1217/1670 train_time:107731ms step_avg:88.52ms +step:1218/1670 train_time:107822ms step_avg:88.52ms +step:1219/1670 train_time:107911ms step_avg:88.52ms +step:1220/1670 train_time:108002ms step_avg:88.53ms +step:1221/1670 train_time:108092ms step_avg:88.53ms +step:1222/1670 train_time:108181ms step_avg:88.53ms +step:1223/1670 train_time:108271ms step_avg:88.53ms +step:1224/1670 train_time:108361ms step_avg:88.53ms +step:1225/1670 train_time:108451ms step_avg:88.53ms +step:1226/1670 train_time:108542ms step_avg:88.53ms +step:1227/1670 train_time:108631ms step_avg:88.53ms +step:1228/1670 train_time:108721ms step_avg:88.54ms +step:1229/1670 train_time:108810ms step_avg:88.54ms +step:1230/1670 train_time:108900ms step_avg:88.54ms +step:1231/1670 train_time:108991ms step_avg:88.54ms +step:1232/1670 train_time:109080ms step_avg:88.54ms +step:1233/1670 train_time:109171ms step_avg:88.54ms +step:1234/1670 train_time:109262ms step_avg:88.54ms +step:1235/1670 train_time:109351ms step_avg:88.54ms +step:1236/1670 train_time:109441ms step_avg:88.54ms +step:1237/1670 train_time:109532ms step_avg:88.55ms +step:1238/1670 train_time:109621ms step_avg:88.55ms +step:1239/1670 train_time:109710ms step_avg:88.55ms +step:1240/1670 train_time:109800ms step_avg:88.55ms +step:1241/1670 train_time:109889ms step_avg:88.55ms +step:1242/1670 train_time:109980ms step_avg:88.55ms +step:1243/1670 train_time:110069ms step_avg:88.55ms +step:1244/1670 train_time:110160ms step_avg:88.55ms +step:1245/1670 train_time:110249ms step_avg:88.55ms +step:1246/1670 train_time:110339ms step_avg:88.55ms +step:1247/1670 train_time:110428ms step_avg:88.55ms +step:1248/1670 train_time:110519ms step_avg:88.56ms +step:1249/1670 train_time:110607ms step_avg:88.56ms +step:1250/1670 train_time:110697ms step_avg:88.56ms +step:1250/1670 val_loss:3.3774 train_time:110788ms step_avg:88.63ms +step:1251/1670 train_time:110807ms step_avg:88.58ms +step:1252/1670 train_time:110881ms step_avg:88.56ms +step:1253/1670 train_time:110973ms step_avg:88.57ms +step:1254/1670 train_time:111063ms step_avg:88.57ms +step:1255/1670 train_time:111151ms step_avg:88.57ms +step:1256/1670 train_time:111240ms step_avg:88.57ms +step:1257/1670 train_time:111329ms step_avg:88.57ms +step:1258/1670 train_time:111417ms step_avg:88.57ms +step:1259/1670 train_time:111506ms step_avg:88.57ms +step:1260/1670 train_time:111595ms step_avg:88.57ms +step:1261/1670 train_time:111684ms step_avg:88.57ms +step:1262/1670 train_time:111775ms step_avg:88.57ms +step:1263/1670 train_time:111867ms step_avg:88.57ms +step:1264/1670 train_time:111959ms step_avg:88.58ms +step:1265/1670 train_time:112049ms step_avg:88.58ms +step:1266/1670 train_time:112138ms step_avg:88.58ms +step:1267/1670 train_time:112228ms step_avg:88.58ms +step:1268/1670 train_time:112316ms step_avg:88.58ms +step:1269/1670 train_time:112405ms step_avg:88.58ms +step:1270/1670 train_time:112493ms step_avg:88.58ms +step:1271/1670 train_time:112582ms step_avg:88.58ms +step:1272/1670 train_time:112672ms step_avg:88.58ms +step:1273/1670 train_time:112763ms step_avg:88.58ms +step:1274/1670 train_time:112854ms step_avg:88.58ms +step:1275/1670 train_time:112946ms step_avg:88.58ms +step:1276/1670 train_time:113036ms step_avg:88.59ms +step:1277/1670 train_time:113125ms step_avg:88.59ms +step:1278/1670 train_time:113214ms step_avg:88.59ms +step:1279/1670 train_time:113304ms step_avg:88.59ms +step:1280/1670 train_time:113393ms step_avg:88.59ms +step:1281/1670 train_time:113483ms step_avg:88.59ms +step:1282/1670 train_time:113573ms step_avg:88.59ms +step:1283/1670 train_time:113662ms step_avg:88.59ms +step:1284/1670 train_time:113752ms step_avg:88.59ms +step:1285/1670 train_time:113843ms step_avg:88.59ms +step:1286/1670 train_time:113934ms step_avg:88.60ms +step:1287/1670 train_time:114025ms step_avg:88.60ms +step:1288/1670 train_time:114115ms step_avg:88.60ms +step:1289/1670 train_time:114205ms step_avg:88.60ms +step:1290/1670 train_time:114294ms step_avg:88.60ms +step:1291/1670 train_time:114384ms step_avg:88.60ms +step:1292/1670 train_time:114473ms step_avg:88.60ms +step:1293/1670 train_time:114563ms step_avg:88.60ms +step:1294/1670 train_time:114652ms step_avg:88.60ms +step:1295/1670 train_time:114741ms step_avg:88.60ms +step:1296/1670 train_time:114832ms step_avg:88.60ms +step:1297/1670 train_time:114922ms step_avg:88.61ms +step:1298/1670 train_time:115013ms step_avg:88.61ms +step:1299/1670 train_time:115103ms step_avg:88.61ms +step:1300/1670 train_time:115194ms step_avg:88.61ms +step:1301/1670 train_time:115284ms step_avg:88.61ms +step:1302/1670 train_time:115373ms step_avg:88.61ms +step:1303/1670 train_time:115462ms step_avg:88.61ms +step:1304/1670 train_time:115551ms step_avg:88.61ms +step:1305/1670 train_time:115640ms step_avg:88.61ms +step:1306/1670 train_time:115729ms step_avg:88.61ms +step:1307/1670 train_time:115819ms step_avg:88.61ms +step:1308/1670 train_time:115910ms step_avg:88.62ms +step:1309/1670 train_time:116000ms step_avg:88.62ms +step:1310/1670 train_time:116090ms step_avg:88.62ms +step:1311/1670 train_time:116181ms step_avg:88.62ms +step:1312/1670 train_time:116271ms step_avg:88.62ms +step:1313/1670 train_time:116360ms step_avg:88.62ms +step:1314/1670 train_time:116451ms step_avg:88.62ms +step:1315/1670 train_time:116540ms step_avg:88.62ms +step:1316/1670 train_time:116630ms step_avg:88.62ms +step:1317/1670 train_time:116720ms step_avg:88.63ms +step:1318/1670 train_time:116811ms step_avg:88.63ms +step:1319/1670 train_time:116902ms step_avg:88.63ms +step:1320/1670 train_time:116993ms step_avg:88.63ms +step:1321/1670 train_time:117085ms step_avg:88.63ms +step:1322/1670 train_time:117175ms step_avg:88.63ms +step:1323/1670 train_time:117266ms step_avg:88.64ms +step:1324/1670 train_time:117354ms step_avg:88.64ms +step:1325/1670 train_time:117443ms step_avg:88.64ms +step:1326/1670 train_time:117534ms step_avg:88.64ms +step:1327/1670 train_time:117623ms step_avg:88.64ms +step:1328/1670 train_time:117713ms step_avg:88.64ms +step:1329/1670 train_time:117803ms step_avg:88.64ms +step:1330/1670 train_time:117893ms step_avg:88.64ms +step:1331/1670 train_time:117982ms step_avg:88.64ms +step:1332/1670 train_time:118072ms step_avg:88.64ms +step:1333/1670 train_time:118162ms step_avg:88.64ms +step:1334/1670 train_time:118252ms step_avg:88.64ms +step:1335/1670 train_time:118341ms step_avg:88.65ms +step:1336/1670 train_time:118431ms step_avg:88.65ms +step:1337/1670 train_time:118521ms step_avg:88.65ms +step:1338/1670 train_time:118610ms step_avg:88.65ms +step:1339/1670 train_time:118699ms step_avg:88.65ms +step:1340/1670 train_time:118789ms step_avg:88.65ms +step:1341/1670 train_time:118879ms step_avg:88.65ms +step:1342/1670 train_time:118969ms step_avg:88.65ms +step:1343/1670 train_time:119059ms step_avg:88.65ms +step:1344/1670 train_time:119149ms step_avg:88.65ms +step:1345/1670 train_time:119238ms step_avg:88.65ms +step:1346/1670 train_time:119328ms step_avg:88.65ms +step:1347/1670 train_time:119417ms step_avg:88.65ms +step:1348/1670 train_time:119507ms step_avg:88.66ms +step:1349/1670 train_time:119597ms step_avg:88.66ms +step:1350/1670 train_time:119687ms step_avg:88.66ms +step:1351/1670 train_time:119777ms step_avg:88.66ms +step:1352/1670 train_time:119867ms step_avg:88.66ms +step:1353/1670 train_time:119957ms step_avg:88.66ms +step:1354/1670 train_time:120047ms step_avg:88.66ms +step:1355/1670 train_time:120136ms step_avg:88.66ms +step:1356/1670 train_time:120226ms step_avg:88.66ms +step:1357/1670 train_time:120315ms step_avg:88.66ms +step:1358/1670 train_time:120406ms step_avg:88.66ms +step:1359/1670 train_time:120495ms step_avg:88.66ms +step:1360/1670 train_time:120584ms step_avg:88.66ms +step:1361/1670 train_time:120673ms step_avg:88.66ms +step:1362/1670 train_time:120762ms step_avg:88.67ms +step:1363/1670 train_time:120852ms step_avg:88.67ms +step:1364/1670 train_time:120942ms step_avg:88.67ms +step:1365/1670 train_time:121033ms step_avg:88.67ms +step:1366/1670 train_time:121123ms step_avg:88.67ms +step:1367/1670 train_time:121213ms step_avg:88.67ms +step:1368/1670 train_time:121303ms step_avg:88.67ms +step:1369/1670 train_time:121393ms step_avg:88.67ms +step:1370/1670 train_time:121483ms step_avg:88.67ms +step:1371/1670 train_time:121572ms step_avg:88.67ms +step:1372/1670 train_time:121662ms step_avg:88.68ms +step:1373/1670 train_time:121751ms step_avg:88.68ms +step:1374/1670 train_time:121841ms step_avg:88.68ms +step:1375/1670 train_time:121930ms step_avg:88.68ms +step:1375/1670 val_loss:3.3423 train_time:122022ms step_avg:88.74ms +step:1376/1670 train_time:122041ms step_avg:88.69ms +step:1377/1670 train_time:122116ms step_avg:88.68ms +step:1378/1670 train_time:122207ms step_avg:88.68ms +step:1379/1670 train_time:122297ms step_avg:88.69ms +step:1380/1670 train_time:122385ms step_avg:88.69ms +step:1381/1670 train_time:122475ms step_avg:88.69ms +step:1382/1670 train_time:122563ms step_avg:88.69ms +step:1383/1670 train_time:122652ms step_avg:88.69ms +step:1384/1670 train_time:122741ms step_avg:88.69ms +step:1385/1670 train_time:122830ms step_avg:88.69ms +step:1386/1670 train_time:122920ms step_avg:88.69ms +step:1387/1670 train_time:123011ms step_avg:88.69ms +step:1388/1670 train_time:123104ms step_avg:88.69ms +step:1389/1670 train_time:123195ms step_avg:88.69ms +step:1390/1670 train_time:123285ms step_avg:88.69ms +step:1391/1670 train_time:123374ms step_avg:88.69ms +step:1392/1670 train_time:123463ms step_avg:88.69ms +step:1393/1670 train_time:123553ms step_avg:88.70ms +step:1394/1670 train_time:123641ms step_avg:88.70ms +step:1395/1670 train_time:123731ms step_avg:88.70ms +step:1396/1670 train_time:123821ms step_avg:88.70ms +step:1397/1670 train_time:123911ms step_avg:88.70ms +step:1398/1670 train_time:124002ms step_avg:88.70ms +step:1399/1670 train_time:124094ms step_avg:88.70ms +step:1400/1670 train_time:124184ms step_avg:88.70ms +step:1401/1670 train_time:124275ms step_avg:88.70ms +step:1402/1670 train_time:124364ms step_avg:88.70ms +step:1403/1670 train_time:124455ms step_avg:88.71ms +step:1404/1670 train_time:124543ms step_avg:88.71ms +step:1405/1670 train_time:124633ms step_avg:88.71ms +step:1406/1670 train_time:124722ms step_avg:88.71ms +step:1407/1670 train_time:124811ms step_avg:88.71ms +step:1408/1670 train_time:124900ms step_avg:88.71ms +step:1409/1670 train_time:124991ms step_avg:88.71ms +step:1410/1670 train_time:125082ms step_avg:88.71ms +step:1411/1670 train_time:125174ms step_avg:88.71ms +step:1412/1670 train_time:125264ms step_avg:88.71ms +step:1413/1670 train_time:125356ms step_avg:88.72ms +step:1414/1670 train_time:125446ms step_avg:88.72ms +step:1415/1670 train_time:125537ms step_avg:88.72ms +step:1416/1670 train_time:125625ms step_avg:88.72ms +step:1417/1670 train_time:125714ms step_avg:88.72ms +step:1418/1670 train_time:125803ms step_avg:88.72ms +step:1419/1670 train_time:125893ms step_avg:88.72ms +step:1420/1670 train_time:125983ms step_avg:88.72ms +step:1421/1670 train_time:126073ms step_avg:88.72ms +step:1422/1670 train_time:126163ms step_avg:88.72ms +step:1423/1670 train_time:126253ms step_avg:88.72ms +step:1424/1670 train_time:126343ms step_avg:88.72ms +step:1425/1670 train_time:126434ms step_avg:88.73ms +step:1426/1670 train_time:126524ms step_avg:88.73ms +step:1427/1670 train_time:126613ms step_avg:88.73ms +step:1428/1670 train_time:126702ms step_avg:88.73ms +step:1429/1670 train_time:126792ms step_avg:88.73ms +step:1430/1670 train_time:126882ms step_avg:88.73ms +step:1431/1670 train_time:126971ms step_avg:88.73ms +step:1432/1670 train_time:127062ms step_avg:88.73ms +step:1433/1670 train_time:127152ms step_avg:88.73ms +step:1434/1670 train_time:127241ms step_avg:88.73ms +step:1435/1670 train_time:127332ms step_avg:88.73ms +step:1436/1670 train_time:127421ms step_avg:88.73ms +step:1437/1670 train_time:127511ms step_avg:88.73ms +step:1438/1670 train_time:127601ms step_avg:88.73ms +step:1439/1670 train_time:127690ms step_avg:88.74ms +step:1440/1670 train_time:127780ms step_avg:88.74ms +step:1441/1670 train_time:127870ms step_avg:88.74ms +step:1442/1670 train_time:127960ms step_avg:88.74ms +step:1443/1670 train_time:128050ms step_avg:88.74ms +step:1444/1670 train_time:128140ms step_avg:88.74ms +step:1445/1670 train_time:128230ms step_avg:88.74ms +step:1446/1670 train_time:128320ms step_avg:88.74ms +step:1447/1670 train_time:128410ms step_avg:88.74ms +step:1448/1670 train_time:128499ms step_avg:88.74ms +step:1449/1670 train_time:128589ms step_avg:88.74ms +step:1450/1670 train_time:128678ms step_avg:88.74ms +step:1451/1670 train_time:128768ms step_avg:88.74ms +step:1452/1670 train_time:128858ms step_avg:88.74ms +step:1453/1670 train_time:128946ms step_avg:88.74ms +step:1454/1670 train_time:129037ms step_avg:88.75ms +step:1455/1670 train_time:129127ms step_avg:88.75ms +step:1456/1670 train_time:129218ms step_avg:88.75ms +step:1457/1670 train_time:129308ms step_avg:88.75ms +step:1458/1670 train_time:129398ms step_avg:88.75ms +step:1459/1670 train_time:129487ms step_avg:88.75ms +step:1460/1670 train_time:129578ms step_avg:88.75ms +step:1461/1670 train_time:129667ms step_avg:88.75ms +step:1462/1670 train_time:129757ms step_avg:88.75ms +step:1463/1670 train_time:129846ms step_avg:88.75ms +step:1464/1670 train_time:129936ms step_avg:88.75ms +step:1465/1670 train_time:130026ms step_avg:88.75ms +step:1466/1670 train_time:130115ms step_avg:88.76ms +step:1467/1670 train_time:130204ms step_avg:88.76ms +step:1468/1670 train_time:130295ms step_avg:88.76ms +step:1469/1670 train_time:130384ms step_avg:88.76ms +step:1470/1670 train_time:130474ms step_avg:88.76ms +step:1471/1670 train_time:130564ms step_avg:88.76ms +step:1472/1670 train_time:130653ms step_avg:88.76ms +step:1473/1670 train_time:130743ms step_avg:88.76ms +step:1474/1670 train_time:130833ms step_avg:88.76ms +step:1475/1670 train_time:130923ms step_avg:88.76ms +step:1476/1670 train_time:131013ms step_avg:88.76ms +step:1477/1670 train_time:131103ms step_avg:88.76ms +step:1478/1670 train_time:131193ms step_avg:88.76ms +step:1479/1670 train_time:131283ms step_avg:88.76ms +step:1480/1670 train_time:131373ms step_avg:88.77ms +step:1481/1670 train_time:131462ms step_avg:88.77ms +step:1482/1670 train_time:131553ms step_avg:88.77ms +step:1483/1670 train_time:131642ms step_avg:88.77ms +step:1484/1670 train_time:131732ms step_avg:88.77ms +step:1485/1670 train_time:131822ms step_avg:88.77ms +step:1486/1670 train_time:131912ms step_avg:88.77ms +step:1487/1670 train_time:132002ms step_avg:88.77ms +step:1488/1670 train_time:132093ms step_avg:88.77ms +step:1489/1670 train_time:132183ms step_avg:88.77ms +step:1490/1670 train_time:132272ms step_avg:88.77ms +step:1491/1670 train_time:132362ms step_avg:88.77ms +step:1492/1670 train_time:132452ms step_avg:88.77ms +step:1493/1670 train_time:132542ms step_avg:88.78ms +step:1494/1670 train_time:132632ms step_avg:88.78ms +step:1495/1670 train_time:132723ms step_avg:88.78ms +step:1496/1670 train_time:132813ms step_avg:88.78ms +step:1497/1670 train_time:132902ms step_avg:88.78ms +step:1498/1670 train_time:132992ms step_avg:88.78ms +step:1499/1670 train_time:133082ms step_avg:88.78ms +step:1500/1670 train_time:133170ms step_avg:88.78ms +step:1500/1670 val_loss:3.3122 train_time:133262ms step_avg:88.84ms +step:1501/1670 train_time:133281ms step_avg:88.79ms +step:1502/1670 train_time:133355ms step_avg:88.79ms +step:1503/1670 train_time:133449ms step_avg:88.79ms +step:1504/1670 train_time:133539ms step_avg:88.79ms +step:1505/1670 train_time:133628ms step_avg:88.79ms +step:1506/1670 train_time:133717ms step_avg:88.79ms +step:1507/1670 train_time:133805ms step_avg:88.79ms +step:1508/1670 train_time:133893ms step_avg:88.79ms +step:1509/1670 train_time:133982ms step_avg:88.79ms +step:1510/1670 train_time:134070ms step_avg:88.79ms +step:1511/1670 train_time:134161ms step_avg:88.79ms +step:1512/1670 train_time:134253ms step_avg:88.79ms +step:1513/1670 train_time:134346ms step_avg:88.79ms +step:1514/1670 train_time:134437ms step_avg:88.80ms +step:1515/1670 train_time:134529ms step_avg:88.80ms +step:1516/1670 train_time:134618ms step_avg:88.80ms +step:1517/1670 train_time:134708ms step_avg:88.80ms +step:1518/1670 train_time:134796ms step_avg:88.80ms +step:1519/1670 train_time:134885ms step_avg:88.80ms +step:1520/1670 train_time:134973ms step_avg:88.80ms +step:1521/1670 train_time:135062ms step_avg:88.80ms +step:1522/1670 train_time:135153ms step_avg:88.80ms +step:1523/1670 train_time:135242ms step_avg:88.80ms +step:1524/1670 train_time:135333ms step_avg:88.80ms +step:1525/1670 train_time:135425ms step_avg:88.80ms +step:1526/1670 train_time:135514ms step_avg:88.80ms +step:1527/1670 train_time:135604ms step_avg:88.80ms +step:1528/1670 train_time:135693ms step_avg:88.80ms +step:1529/1670 train_time:135782ms step_avg:88.80ms +step:1530/1670 train_time:135871ms step_avg:88.80ms +step:1531/1670 train_time:135961ms step_avg:88.81ms +step:1532/1670 train_time:136051ms step_avg:88.81ms +step:1533/1670 train_time:136141ms step_avg:88.81ms +step:1534/1670 train_time:136230ms step_avg:88.81ms +step:1535/1670 train_time:136321ms step_avg:88.81ms +step:1536/1670 train_time:136412ms step_avg:88.81ms +step:1537/1670 train_time:136503ms step_avg:88.81ms +step:1538/1670 train_time:136592ms step_avg:88.81ms +step:1539/1670 train_time:136682ms step_avg:88.81ms +step:1540/1670 train_time:136772ms step_avg:88.81ms +step:1541/1670 train_time:136861ms step_avg:88.81ms +step:1542/1670 train_time:136951ms step_avg:88.81ms +step:1543/1670 train_time:137041ms step_avg:88.81ms +step:1544/1670 train_time:137130ms step_avg:88.81ms +step:1545/1670 train_time:137220ms step_avg:88.82ms +step:1546/1670 train_time:137310ms step_avg:88.82ms +step:1547/1670 train_time:137401ms step_avg:88.82ms +step:1548/1670 train_time:137491ms step_avg:88.82ms +step:1549/1670 train_time:137581ms step_avg:88.82ms +step:1550/1670 train_time:137671ms step_avg:88.82ms +step:1551/1670 train_time:137761ms step_avg:88.82ms +step:1552/1670 train_time:137851ms step_avg:88.82ms +step:1553/1670 train_time:137940ms step_avg:88.82ms +step:1554/1670 train_time:138030ms step_avg:88.82ms +step:1555/1670 train_time:138120ms step_avg:88.82ms +step:1556/1670 train_time:138210ms step_avg:88.82ms +step:1557/1670 train_time:138301ms step_avg:88.83ms +step:1558/1670 train_time:138391ms step_avg:88.83ms +step:1559/1670 train_time:138482ms step_avg:88.83ms +step:1560/1670 train_time:138571ms step_avg:88.83ms +step:1561/1670 train_time:138661ms step_avg:88.83ms +step:1562/1670 train_time:138751ms step_avg:88.83ms +step:1563/1670 train_time:138841ms step_avg:88.83ms +step:1564/1670 train_time:138930ms step_avg:88.83ms +step:1565/1670 train_time:139019ms step_avg:88.83ms +step:1566/1670 train_time:139109ms step_avg:88.83ms +step:1567/1670 train_time:139198ms step_avg:88.83ms +step:1568/1670 train_time:139289ms step_avg:88.83ms +step:1569/1670 train_time:139380ms step_avg:88.83ms +step:1570/1670 train_time:139470ms step_avg:88.83ms +step:1571/1670 train_time:139561ms step_avg:88.84ms +step:1572/1670 train_time:139651ms step_avg:88.84ms +step:1573/1670 train_time:139741ms step_avg:88.84ms +step:1574/1670 train_time:139830ms step_avg:88.84ms +step:1575/1670 train_time:139921ms step_avg:88.84ms +step:1576/1670 train_time:140010ms step_avg:88.84ms +step:1577/1670 train_time:140099ms step_avg:88.84ms +step:1578/1670 train_time:140190ms step_avg:88.84ms +step:1579/1670 train_time:140279ms step_avg:88.84ms +step:1580/1670 train_time:140369ms step_avg:88.84ms +step:1581/1670 train_time:140459ms step_avg:88.84ms +step:1582/1670 train_time:140550ms step_avg:88.84ms +step:1583/1670 train_time:140639ms step_avg:88.84ms +step:1584/1670 train_time:140729ms step_avg:88.84ms +step:1585/1670 train_time:140819ms step_avg:88.84ms +step:1586/1670 train_time:140910ms step_avg:88.85ms +step:1587/1670 train_time:140999ms step_avg:88.85ms +step:1588/1670 train_time:141089ms step_avg:88.85ms +step:1589/1670 train_time:141179ms step_avg:88.85ms +step:1590/1670 train_time:141269ms step_avg:88.85ms +step:1591/1670 train_time:141359ms step_avg:88.85ms +step:1592/1670 train_time:141449ms step_avg:88.85ms +step:1593/1670 train_time:141540ms step_avg:88.85ms +step:1594/1670 train_time:141630ms step_avg:88.85ms +step:1595/1670 train_time:141719ms step_avg:88.85ms +step:1596/1670 train_time:141809ms step_avg:88.85ms +step:1597/1670 train_time:141898ms step_avg:88.85ms +step:1598/1670 train_time:141988ms step_avg:88.85ms +step:1599/1670 train_time:142078ms step_avg:88.85ms +step:1600/1670 train_time:142169ms step_avg:88.86ms +step:1601/1670 train_time:142260ms step_avg:88.86ms +step:1602/1670 train_time:142350ms step_avg:88.86ms +step:1603/1670 train_time:142440ms step_avg:88.86ms +step:1604/1670 train_time:142530ms step_avg:88.86ms +step:1605/1670 train_time:142621ms step_avg:88.86ms +step:1606/1670 train_time:142710ms step_avg:88.86ms +step:1607/1670 train_time:142800ms step_avg:88.86ms +step:1608/1670 train_time:142890ms step_avg:88.86ms +step:1609/1670 train_time:142980ms step_avg:88.86ms +step:1610/1670 train_time:143071ms step_avg:88.86ms +step:1611/1670 train_time:143162ms step_avg:88.87ms +step:1612/1670 train_time:143252ms step_avg:88.87ms +step:1613/1670 train_time:143342ms step_avg:88.87ms +step:1614/1670 train_time:143431ms step_avg:88.87ms +step:1615/1670 train_time:143521ms step_avg:88.87ms +step:1616/1670 train_time:143611ms step_avg:88.87ms +step:1617/1670 train_time:143701ms step_avg:88.87ms +step:1618/1670 train_time:143792ms step_avg:88.87ms +step:1619/1670 train_time:143882ms step_avg:88.87ms +step:1620/1670 train_time:143971ms step_avg:88.87ms +step:1621/1670 train_time:144062ms step_avg:88.87ms +step:1622/1670 train_time:144152ms step_avg:88.87ms +step:1623/1670 train_time:144242ms step_avg:88.87ms +step:1624/1670 train_time:144332ms step_avg:88.87ms +step:1625/1670 train_time:144421ms step_avg:88.87ms +step:1625/1670 val_loss:3.2890 train_time:144512ms step_avg:88.93ms +step:1626/1670 train_time:144531ms step_avg:88.89ms +step:1627/1670 train_time:144604ms step_avg:88.88ms +step:1628/1670 train_time:144696ms step_avg:88.88ms +step:1629/1670 train_time:144787ms step_avg:88.88ms +step:1630/1670 train_time:144876ms step_avg:88.88ms +step:1631/1670 train_time:144964ms step_avg:88.88ms +step:1632/1670 train_time:145053ms step_avg:88.88ms +step:1633/1670 train_time:145142ms step_avg:88.88ms +step:1634/1670 train_time:145232ms step_avg:88.88ms +step:1635/1670 train_time:145320ms step_avg:88.88ms +step:1636/1670 train_time:145411ms step_avg:88.88ms +step:1637/1670 train_time:145501ms step_avg:88.88ms +step:1638/1670 train_time:145594ms step_avg:88.89ms +step:1639/1670 train_time:145687ms step_avg:88.89ms +step:1640/1670 train_time:145777ms step_avg:88.89ms +step:1641/1670 train_time:145867ms step_avg:88.89ms +step:1642/1670 train_time:145956ms step_avg:88.89ms +step:1643/1670 train_time:146045ms step_avg:88.89ms +step:1644/1670 train_time:146134ms step_avg:88.89ms +step:1645/1670 train_time:146223ms step_avg:88.89ms +step:1646/1670 train_time:146312ms step_avg:88.89ms +step:1647/1670 train_time:146402ms step_avg:88.89ms +step:1648/1670 train_time:146494ms step_avg:88.89ms +step:1649/1670 train_time:146587ms step_avg:88.89ms +step:1650/1670 train_time:146678ms step_avg:88.90ms +step:1651/1670 train_time:146768ms step_avg:88.90ms +step:1652/1670 train_time:146857ms step_avg:88.90ms +step:1653/1670 train_time:146947ms step_avg:88.90ms +step:1654/1670 train_time:147036ms step_avg:88.90ms +step:1655/1670 train_time:147125ms step_avg:88.90ms +step:1656/1670 train_time:147214ms step_avg:88.90ms +step:1657/1670 train_time:147304ms step_avg:88.90ms +step:1658/1670 train_time:147394ms step_avg:88.90ms +step:1659/1670 train_time:147485ms step_avg:88.90ms +step:1660/1670 train_time:147576ms step_avg:88.90ms +step:1661/1670 train_time:147666ms step_avg:88.90ms +step:1662/1670 train_time:147756ms step_avg:88.90ms +step:1663/1670 train_time:147846ms step_avg:88.90ms +step:1664/1670 train_time:147934ms step_avg:88.90ms +step:1665/1670 train_time:148024ms step_avg:88.90ms +step:1666/1670 train_time:148114ms step_avg:88.90ms +step:1667/1670 train_time:148204ms step_avg:88.90ms +step:1668/1670 train_time:148293ms step_avg:88.90ms +step:1669/1670 train_time:148384ms step_avg:88.91ms +step:1670/1670 train_time:148476ms step_avg:88.91ms +step:1670/1670 val_loss:3.2796 train_time:148569ms step_avg:88.96ms +peak memory allocated: 30760 MiB reserved: 45994 MiB diff --git a/records/092925_PolarExpress/188c5c21-a850-4b45-ab17-d168a5bec7e7.txt b/records/092925_PolarExpress/188c5c21-a850-4b45-ab17-d168a5bec7e7.txt new file mode 100644 index 000000000..1e1ca18b1 --- /dev/null +++ b/records/092925_PolarExpress/188c5c21-a850-4b45-ab17-d168a5bec7e7.txt @@ -0,0 +1,3252 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from kernels import get_kernel +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[ + Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size() == 8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1, 10, 16, 16] + assert len(params) == sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx + size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx < len(params) else 0 + if getattr(params[module_idx], 'module', 'none') == 'attn': + for p in params[module_idx:module_idx + chunk_size]: + assert getattr(params[module_idx], 'module', 'none') == 'attn' + batch = 4 * original_shape[0] + d1 = original_shape[1] + d2 = original_shape[2] // 4 + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = polar_express(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = polar_express(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim // 4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim // 4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim * 4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module = 'attn' + with torch.no_grad(): + self.qkvo_w.view(4, self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4, self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, + 3 * self.num_heads, + self.head_dim).chunk( + 3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module = 'mlp' + self.c_proj.module = 'mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, + max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim ** 0.5) / 448, w_s=2 ** -9, + grad_s=1 / 448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), # extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, + long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1, len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i < 11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens = tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter == 5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter += 1 + return starts, ends + + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, + align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1630 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if + p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999, step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations + args.iteration_extension: + return args.ws_validate // 2, args.ws_validate + x = min(step / (1 + args.num_iterations), 0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[ + step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long < ws_long: + model.yarn.reset() + ws_long = new_ws_long + model(inputs, targets, cum_seqlens, ws_long // 2, ws_long).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.yarn.reset() +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations + args.iteration_extension +ws_short, ws_long = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws_short, new_ws_long = get_ws(step) + if new_ws_long != ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, + grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.5.0 +Mon Sep 29 06:31:01 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 40C P0 129W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:153ms step_avg:152.58ms +step:2/1670 train_time:175ms step_avg:87.36ms +step:3/1670 train_time:237ms step_avg:78.90ms +step:4/1670 train_time:322ms step_avg:80.62ms +step:5/1670 train_time:409ms step_avg:81.70ms +step:6/1670 train_time:495ms step_avg:82.52ms +step:7/1670 train_time:582ms step_avg:83.15ms +step:8/1670 train_time:669ms step_avg:83.63ms +step:9/1670 train_time:756ms step_avg:83.98ms +step:10/1670 train_time:843ms step_avg:84.28ms +step:11/1670 train_time:929ms step_avg:84.50ms +step:12/1670 train_time:1019ms step_avg:84.89ms +step:13/1670 train_time:1110ms step_avg:85.37ms +step:14/1670 train_time:1201ms step_avg:85.75ms +step:15/1670 train_time:1289ms step_avg:85.92ms +step:16/1670 train_time:1377ms step_avg:86.05ms +step:17/1670 train_time:1464ms step_avg:86.14ms +step:18/1670 train_time:1551ms step_avg:86.18ms +step:19/1670 train_time:1638ms step_avg:86.24ms +step:20/1670 train_time:1725ms step_avg:86.26ms +step:21/1670 train_time:1813ms step_avg:86.32ms +step:22/1670 train_time:1900ms step_avg:86.38ms +step:23/1670 train_time:1988ms step_avg:86.43ms +step:24/1670 train_time:2077ms step_avg:86.56ms +step:25/1670 train_time:2167ms step_avg:86.67ms +step:26/1670 train_time:2256ms step_avg:86.77ms +step:27/1670 train_time:2344ms step_avg:86.83ms +step:28/1670 train_time:2433ms step_avg:86.90ms +step:29/1670 train_time:2521ms step_avg:86.92ms +step:30/1670 train_time:2608ms step_avg:86.94ms +step:31/1670 train_time:2695ms step_avg:86.93ms +step:32/1670 train_time:2782ms step_avg:86.94ms +step:33/1670 train_time:2869ms step_avg:86.95ms +step:34/1670 train_time:2957ms step_avg:86.97ms +step:35/1670 train_time:3045ms step_avg:87.01ms +step:36/1670 train_time:3133ms step_avg:87.04ms +step:37/1670 train_time:3223ms step_avg:87.10ms +step:38/1670 train_time:3311ms step_avg:87.12ms +step:39/1670 train_time:3400ms step_avg:87.19ms +step:40/1670 train_time:3488ms step_avg:87.20ms +step:41/1670 train_time:3577ms step_avg:87.24ms +step:42/1670 train_time:3664ms step_avg:87.24ms +step:43/1670 train_time:3752ms step_avg:87.25ms +step:44/1670 train_time:3839ms step_avg:87.26ms +step:45/1670 train_time:3926ms step_avg:87.26ms +step:46/1670 train_time:4014ms step_avg:87.27ms +step:47/1670 train_time:4103ms step_avg:87.30ms +step:48/1670 train_time:4192ms step_avg:87.33ms +step:49/1670 train_time:4280ms step_avg:87.35ms +step:50/1670 train_time:4367ms step_avg:87.35ms +step:51/1670 train_time:4455ms step_avg:87.36ms +step:52/1670 train_time:4544ms step_avg:87.38ms +step:53/1670 train_time:4632ms step_avg:87.39ms +step:54/1670 train_time:4719ms step_avg:87.40ms +step:55/1670 train_time:4807ms step_avg:87.39ms +step:56/1670 train_time:4894ms step_avg:87.40ms +step:57/1670 train_time:4983ms step_avg:87.41ms +step:58/1670 train_time:5070ms step_avg:87.42ms +step:59/1670 train_time:5158ms step_avg:87.42ms +step:60/1670 train_time:5246ms step_avg:87.43ms +step:61/1670 train_time:5333ms step_avg:87.43ms +step:62/1670 train_time:5421ms step_avg:87.44ms +step:63/1670 train_time:5509ms step_avg:87.45ms +step:64/1670 train_time:5598ms step_avg:87.46ms +step:65/1670 train_time:5685ms step_avg:87.47ms +step:66/1670 train_time:5773ms step_avg:87.47ms +step:67/1670 train_time:5861ms step_avg:87.47ms +step:68/1670 train_time:5948ms step_avg:87.47ms +step:69/1670 train_time:6035ms step_avg:87.47ms +step:70/1670 train_time:6123ms step_avg:87.47ms +step:71/1670 train_time:6211ms step_avg:87.48ms +step:72/1670 train_time:6299ms step_avg:87.48ms +step:73/1670 train_time:6387ms step_avg:87.49ms +step:74/1670 train_time:6475ms step_avg:87.50ms +step:75/1670 train_time:6563ms step_avg:87.51ms +step:76/1670 train_time:6651ms step_avg:87.52ms +step:77/1670 train_time:6740ms step_avg:87.53ms +step:78/1670 train_time:6827ms step_avg:87.53ms +step:79/1670 train_time:6916ms step_avg:87.54ms +step:80/1670 train_time:7004ms step_avg:87.55ms +step:81/1670 train_time:7092ms step_avg:87.55ms +step:82/1670 train_time:7180ms step_avg:87.56ms +step:83/1670 train_time:7267ms step_avg:87.56ms +step:84/1670 train_time:7355ms step_avg:87.56ms +step:85/1670 train_time:7443ms step_avg:87.57ms +step:86/1670 train_time:7531ms step_avg:87.57ms +step:87/1670 train_time:7619ms step_avg:87.57ms +step:88/1670 train_time:7707ms step_avg:87.58ms +step:89/1670 train_time:7795ms step_avg:87.58ms +step:90/1670 train_time:7883ms step_avg:87.59ms +step:91/1670 train_time:7970ms step_avg:87.59ms +step:92/1670 train_time:8059ms step_avg:87.59ms +step:93/1670 train_time:8146ms step_avg:87.59ms +step:94/1670 train_time:8235ms step_avg:87.60ms +step:95/1670 train_time:8323ms step_avg:87.61ms +step:96/1670 train_time:8411ms step_avg:87.61ms +step:97/1670 train_time:8499ms step_avg:87.62ms +step:98/1670 train_time:8587ms step_avg:87.62ms +step:99/1670 train_time:8675ms step_avg:87.62ms +step:100/1670 train_time:8762ms step_avg:87.62ms +step:101/1670 train_time:8850ms step_avg:87.62ms +step:102/1670 train_time:8938ms step_avg:87.63ms +step:103/1670 train_time:9026ms step_avg:87.63ms +step:104/1670 train_time:9114ms step_avg:87.63ms +step:105/1670 train_time:9202ms step_avg:87.63ms +step:106/1670 train_time:9289ms step_avg:87.63ms +step:107/1670 train_time:9377ms step_avg:87.64ms +step:108/1670 train_time:9464ms step_avg:87.63ms +step:109/1670 train_time:9552ms step_avg:87.64ms +step:110/1670 train_time:9641ms step_avg:87.64ms +step:111/1670 train_time:9728ms step_avg:87.64ms +step:112/1670 train_time:9815ms step_avg:87.64ms +step:113/1670 train_time:9903ms step_avg:87.64ms +step:114/1670 train_time:9991ms step_avg:87.64ms +step:115/1670 train_time:10079ms step_avg:87.64ms +step:116/1670 train_time:10167ms step_avg:87.64ms +step:117/1670 train_time:10254ms step_avg:87.64ms +step:118/1670 train_time:10342ms step_avg:87.64ms +step:119/1670 train_time:10429ms step_avg:87.64ms +step:120/1670 train_time:10517ms step_avg:87.64ms +step:121/1670 train_time:10605ms step_avg:87.64ms +step:122/1670 train_time:10692ms step_avg:87.64ms +step:123/1670 train_time:10780ms step_avg:87.64ms +step:124/1670 train_time:10867ms step_avg:87.64ms +step:125/1670 train_time:10955ms step_avg:87.64ms +step:125/1670 val_loss:4.3461 train_time:11046ms step_avg:88.37ms +step:126/1670 train_time:11066ms step_avg:87.83ms +step:127/1670 train_time:11134ms step_avg:87.67ms +step:128/1670 train_time:11230ms step_avg:87.73ms +step:129/1670 train_time:11322ms step_avg:87.77ms +step:130/1670 train_time:11410ms step_avg:87.77ms +step:131/1670 train_time:11497ms step_avg:87.76ms +step:132/1670 train_time:11584ms step_avg:87.76ms +step:133/1670 train_time:11671ms step_avg:87.75ms +step:134/1670 train_time:11757ms step_avg:87.74ms +step:135/1670 train_time:11844ms step_avg:87.73ms +step:136/1670 train_time:11930ms step_avg:87.72ms +step:137/1670 train_time:12017ms step_avg:87.72ms +step:138/1670 train_time:12105ms step_avg:87.72ms +step:139/1670 train_time:12195ms step_avg:87.73ms +step:140/1670 train_time:12285ms step_avg:87.75ms +step:141/1670 train_time:12374ms step_avg:87.76ms +step:142/1670 train_time:12461ms step_avg:87.75ms +step:143/1670 train_time:12549ms step_avg:87.76ms +step:144/1670 train_time:12636ms step_avg:87.75ms +step:145/1670 train_time:12723ms step_avg:87.74ms +step:146/1670 train_time:12810ms step_avg:87.74ms +step:147/1670 train_time:12897ms step_avg:87.73ms +step:148/1670 train_time:12984ms step_avg:87.73ms +step:149/1670 train_time:13071ms step_avg:87.72ms +step:150/1670 train_time:13159ms step_avg:87.73ms +step:151/1670 train_time:13247ms step_avg:87.73ms +step:152/1670 train_time:13336ms step_avg:87.73ms +step:153/1670 train_time:13424ms step_avg:87.74ms +step:154/1670 train_time:13510ms step_avg:87.73ms +step:155/1670 train_time:13598ms step_avg:87.73ms +step:156/1670 train_time:13685ms step_avg:87.72ms +step:157/1670 train_time:13772ms step_avg:87.72ms +step:158/1670 train_time:13859ms step_avg:87.72ms +step:159/1670 train_time:13946ms step_avg:87.71ms +step:160/1670 train_time:14034ms step_avg:87.71ms +step:161/1670 train_time:14122ms step_avg:87.71ms +step:162/1670 train_time:14209ms step_avg:87.71ms +step:163/1670 train_time:14299ms step_avg:87.72ms +step:164/1670 train_time:14386ms step_avg:87.72ms +step:165/1670 train_time:14474ms step_avg:87.72ms +step:166/1670 train_time:14562ms step_avg:87.72ms +step:167/1670 train_time:14649ms step_avg:87.72ms +step:168/1670 train_time:14736ms step_avg:87.72ms +step:169/1670 train_time:14823ms step_avg:87.71ms +step:170/1670 train_time:14910ms step_avg:87.70ms +step:171/1670 train_time:14998ms step_avg:87.71ms +step:172/1670 train_time:15085ms step_avg:87.70ms +step:173/1670 train_time:15174ms step_avg:87.71ms +step:174/1670 train_time:15262ms step_avg:87.71ms +step:175/1670 train_time:15350ms step_avg:87.71ms +step:176/1670 train_time:15439ms step_avg:87.72ms +step:177/1670 train_time:15526ms step_avg:87.72ms +step:178/1670 train_time:15614ms step_avg:87.72ms +step:179/1670 train_time:15701ms step_avg:87.72ms +step:180/1670 train_time:15789ms step_avg:87.72ms +step:181/1670 train_time:15877ms step_avg:87.72ms +step:182/1670 train_time:15964ms step_avg:87.71ms +step:183/1670 train_time:16051ms step_avg:87.71ms +step:184/1670 train_time:16139ms step_avg:87.71ms +step:185/1670 train_time:16226ms step_avg:87.71ms +step:186/1670 train_time:16313ms step_avg:87.71ms +step:187/1670 train_time:16402ms step_avg:87.71ms +step:188/1670 train_time:16490ms step_avg:87.71ms +step:189/1670 train_time:16578ms step_avg:87.71ms +step:190/1670 train_time:16665ms step_avg:87.71ms +step:191/1670 train_time:16753ms step_avg:87.71ms +step:192/1670 train_time:16840ms step_avg:87.71ms +step:193/1670 train_time:16928ms step_avg:87.71ms +step:194/1670 train_time:17015ms step_avg:87.71ms +step:195/1670 train_time:17102ms step_avg:87.70ms +step:196/1670 train_time:17189ms step_avg:87.70ms +step:197/1670 train_time:17277ms step_avg:87.70ms +step:198/1670 train_time:17364ms step_avg:87.70ms +step:199/1670 train_time:17452ms step_avg:87.70ms +step:200/1670 train_time:17541ms step_avg:87.70ms +step:201/1670 train_time:17628ms step_avg:87.70ms +step:202/1670 train_time:17716ms step_avg:87.70ms +step:203/1670 train_time:17803ms step_avg:87.70ms +step:204/1670 train_time:17891ms step_avg:87.70ms +step:205/1670 train_time:17979ms step_avg:87.70ms +step:206/1670 train_time:18066ms step_avg:87.70ms +step:207/1670 train_time:18154ms step_avg:87.70ms +step:208/1670 train_time:18241ms step_avg:87.70ms +step:209/1670 train_time:18329ms step_avg:87.70ms +step:210/1670 train_time:18416ms step_avg:87.70ms +step:211/1670 train_time:18504ms step_avg:87.70ms +step:212/1670 train_time:18591ms step_avg:87.69ms +step:213/1670 train_time:18679ms step_avg:87.69ms +step:214/1670 train_time:18766ms step_avg:87.69ms +step:215/1670 train_time:18853ms step_avg:87.69ms +step:216/1670 train_time:18940ms step_avg:87.69ms +step:217/1670 train_time:19028ms step_avg:87.69ms +step:218/1670 train_time:19115ms step_avg:87.69ms +step:219/1670 train_time:19203ms step_avg:87.68ms +step:220/1670 train_time:19290ms step_avg:87.68ms +step:221/1670 train_time:19379ms step_avg:87.69ms +step:222/1670 train_time:19465ms step_avg:87.68ms +step:223/1670 train_time:19553ms step_avg:87.68ms +step:224/1670 train_time:19642ms step_avg:87.69ms +step:225/1670 train_time:19729ms step_avg:87.69ms +step:226/1670 train_time:19817ms step_avg:87.68ms +step:227/1670 train_time:19904ms step_avg:87.68ms +step:228/1670 train_time:19991ms step_avg:87.68ms +step:229/1670 train_time:20079ms step_avg:87.68ms +step:230/1670 train_time:20166ms step_avg:87.68ms +step:231/1670 train_time:20254ms step_avg:87.68ms +step:232/1670 train_time:20342ms step_avg:87.68ms +step:233/1670 train_time:20429ms step_avg:87.68ms +step:234/1670 train_time:20517ms step_avg:87.68ms +step:235/1670 train_time:20604ms step_avg:87.68ms +step:236/1670 train_time:20692ms step_avg:87.68ms +step:237/1670 train_time:20779ms step_avg:87.68ms +step:238/1670 train_time:20866ms step_avg:87.67ms +step:239/1670 train_time:20953ms step_avg:87.67ms +step:240/1670 train_time:21041ms step_avg:87.67ms +step:241/1670 train_time:21129ms step_avg:87.67ms +step:242/1670 train_time:21217ms step_avg:87.67ms +step:243/1670 train_time:21304ms step_avg:87.67ms +step:244/1670 train_time:21391ms step_avg:87.67ms +step:245/1670 train_time:21479ms step_avg:87.67ms +step:246/1670 train_time:21566ms step_avg:87.67ms +step:247/1670 train_time:21654ms step_avg:87.67ms +step:248/1670 train_time:21742ms step_avg:87.67ms +step:249/1670 train_time:21829ms step_avg:87.67ms +step:250/1670 train_time:21916ms step_avg:87.66ms +step:250/1670 val_loss:3.9850 train_time:22005ms step_avg:88.02ms +step:251/1670 train_time:22026ms step_avg:87.75ms +step:252/1670 train_time:22097ms step_avg:87.69ms +step:253/1670 train_time:22193ms step_avg:87.72ms +step:254/1670 train_time:22281ms step_avg:87.72ms +step:255/1670 train_time:22368ms step_avg:87.72ms +step:256/1670 train_time:22454ms step_avg:87.71ms +step:257/1670 train_time:22541ms step_avg:87.71ms +step:258/1670 train_time:22628ms step_avg:87.70ms +step:259/1670 train_time:22714ms step_avg:87.70ms +step:260/1670 train_time:22801ms step_avg:87.70ms +step:261/1670 train_time:22888ms step_avg:87.69ms +step:262/1670 train_time:22976ms step_avg:87.69ms +step:263/1670 train_time:23065ms step_avg:87.70ms +step:264/1670 train_time:23155ms step_avg:87.71ms +step:265/1670 train_time:23245ms step_avg:87.72ms +step:266/1670 train_time:23333ms step_avg:87.72ms +step:267/1670 train_time:23420ms step_avg:87.71ms +step:268/1670 train_time:23506ms step_avg:87.71ms +step:269/1670 train_time:23593ms step_avg:87.71ms +step:270/1670 train_time:23679ms step_avg:87.70ms +step:271/1670 train_time:23766ms step_avg:87.70ms +step:272/1670 train_time:23853ms step_avg:87.69ms +step:273/1670 train_time:23940ms step_avg:87.69ms +step:274/1670 train_time:24030ms step_avg:87.70ms +step:275/1670 train_time:24117ms step_avg:87.70ms +step:276/1670 train_time:24207ms step_avg:87.71ms +step:277/1670 train_time:24295ms step_avg:87.71ms +step:278/1670 train_time:24383ms step_avg:87.71ms +step:279/1670 train_time:24471ms step_avg:87.71ms +step:280/1670 train_time:24558ms step_avg:87.71ms +step:281/1670 train_time:24645ms step_avg:87.70ms +step:282/1670 train_time:24731ms step_avg:87.70ms +step:283/1670 train_time:24818ms step_avg:87.70ms +step:284/1670 train_time:24905ms step_avg:87.69ms +step:285/1670 train_time:24993ms step_avg:87.69ms +step:286/1670 train_time:25081ms step_avg:87.69ms +step:287/1670 train_time:25169ms step_avg:87.70ms +step:288/1670 train_time:25257ms step_avg:87.70ms +step:289/1670 train_time:25345ms step_avg:87.70ms +step:290/1670 train_time:25433ms step_avg:87.70ms +step:291/1670 train_time:25520ms step_avg:87.70ms +step:292/1670 train_time:25607ms step_avg:87.70ms +step:293/1670 train_time:25695ms step_avg:87.70ms +step:294/1670 train_time:25782ms step_avg:87.69ms +step:295/1670 train_time:25869ms step_avg:87.69ms +step:296/1670 train_time:25957ms step_avg:87.69ms +step:297/1670 train_time:26045ms step_avg:87.69ms +step:298/1670 train_time:26132ms step_avg:87.69ms +step:299/1670 train_time:26221ms step_avg:87.70ms +step:300/1670 train_time:26310ms step_avg:87.70ms +step:301/1670 train_time:26397ms step_avg:87.70ms +step:302/1670 train_time:26485ms step_avg:87.70ms +step:303/1670 train_time:26573ms step_avg:87.70ms +step:304/1670 train_time:26660ms step_avg:87.70ms +step:305/1670 train_time:26747ms step_avg:87.70ms +step:306/1670 train_time:26834ms step_avg:87.69ms +step:307/1670 train_time:26922ms step_avg:87.69ms +step:308/1670 train_time:27009ms step_avg:87.69ms +step:309/1670 train_time:27097ms step_avg:87.69ms +step:310/1670 train_time:27185ms step_avg:87.69ms +step:311/1670 train_time:27273ms step_avg:87.69ms +step:312/1670 train_time:27360ms step_avg:87.69ms +step:313/1670 train_time:27448ms step_avg:87.69ms +step:314/1670 train_time:27535ms step_avg:87.69ms +step:315/1670 train_time:27624ms step_avg:87.69ms +step:316/1670 train_time:27710ms step_avg:87.69ms +step:317/1670 train_time:27798ms step_avg:87.69ms +step:318/1670 train_time:27886ms step_avg:87.69ms +step:319/1670 train_time:27973ms step_avg:87.69ms +step:320/1670 train_time:28060ms step_avg:87.69ms +step:321/1670 train_time:28149ms step_avg:87.69ms +step:322/1670 train_time:28236ms step_avg:87.69ms +step:323/1670 train_time:28325ms step_avg:87.69ms +step:324/1670 train_time:28413ms step_avg:87.69ms +step:325/1670 train_time:28500ms step_avg:87.69ms +step:326/1670 train_time:28589ms step_avg:87.69ms +step:327/1670 train_time:28676ms step_avg:87.69ms +step:328/1670 train_time:28764ms step_avg:87.69ms +step:329/1670 train_time:28851ms step_avg:87.69ms +step:330/1670 train_time:28939ms step_avg:87.69ms +step:331/1670 train_time:29026ms step_avg:87.69ms +step:332/1670 train_time:29114ms step_avg:87.69ms +step:333/1670 train_time:29201ms step_avg:87.69ms +step:334/1670 train_time:29289ms step_avg:87.69ms +step:335/1670 train_time:29376ms step_avg:87.69ms +step:336/1670 train_time:29464ms step_avg:87.69ms +step:337/1670 train_time:29552ms step_avg:87.69ms +step:338/1670 train_time:29639ms step_avg:87.69ms +step:339/1670 train_time:29727ms step_avg:87.69ms +step:340/1670 train_time:29814ms step_avg:87.69ms +step:341/1670 train_time:29901ms step_avg:87.69ms +step:342/1670 train_time:29989ms step_avg:87.69ms +step:343/1670 train_time:30077ms step_avg:87.69ms +step:344/1670 train_time:30165ms step_avg:87.69ms +step:345/1670 train_time:30252ms step_avg:87.69ms +step:346/1670 train_time:30339ms step_avg:87.68ms +step:347/1670 train_time:30428ms step_avg:87.69ms +step:348/1670 train_time:30515ms step_avg:87.69ms +step:349/1670 train_time:30602ms step_avg:87.68ms +step:350/1670 train_time:30690ms step_avg:87.69ms +step:351/1670 train_time:30777ms step_avg:87.68ms +step:352/1670 train_time:30865ms step_avg:87.68ms +step:353/1670 train_time:30953ms step_avg:87.68ms +step:354/1670 train_time:31040ms step_avg:87.68ms +step:355/1670 train_time:31128ms step_avg:87.68ms +step:356/1670 train_time:31216ms step_avg:87.68ms +step:357/1670 train_time:31303ms step_avg:87.68ms +step:358/1670 train_time:31392ms step_avg:87.69ms +step:359/1670 train_time:31480ms step_avg:87.69ms +step:360/1670 train_time:31568ms step_avg:87.69ms +step:361/1670 train_time:31655ms step_avg:87.69ms +step:362/1670 train_time:31742ms step_avg:87.69ms +step:363/1670 train_time:31831ms step_avg:87.69ms +step:364/1670 train_time:31918ms step_avg:87.69ms +step:365/1670 train_time:32006ms step_avg:87.69ms +step:366/1670 train_time:32093ms step_avg:87.69ms +step:367/1670 train_time:32181ms step_avg:87.69ms +step:368/1670 train_time:32269ms step_avg:87.69ms +step:369/1670 train_time:32357ms step_avg:87.69ms +step:370/1670 train_time:32444ms step_avg:87.69ms +step:371/1670 train_time:32532ms step_avg:87.69ms +step:372/1670 train_time:32620ms step_avg:87.69ms +step:373/1670 train_time:32708ms step_avg:87.69ms +step:374/1670 train_time:32796ms step_avg:87.69ms +step:375/1670 train_time:32885ms step_avg:87.69ms +step:375/1670 val_loss:3.8254 train_time:32974ms step_avg:87.93ms +step:376/1670 train_time:32993ms step_avg:87.75ms +step:377/1670 train_time:33065ms step_avg:87.71ms +step:378/1670 train_time:33158ms step_avg:87.72ms +step:379/1670 train_time:33246ms step_avg:87.72ms +step:380/1670 train_time:33334ms step_avg:87.72ms +step:381/1670 train_time:33420ms step_avg:87.72ms +step:382/1670 train_time:33506ms step_avg:87.71ms +step:383/1670 train_time:33593ms step_avg:87.71ms +step:384/1670 train_time:33679ms step_avg:87.71ms +step:385/1670 train_time:33766ms step_avg:87.70ms +step:386/1670 train_time:33852ms step_avg:87.70ms +step:387/1670 train_time:33940ms step_avg:87.70ms +step:388/1670 train_time:34030ms step_avg:87.71ms +step:389/1670 train_time:34119ms step_avg:87.71ms +step:390/1670 train_time:34208ms step_avg:87.71ms +step:391/1670 train_time:34297ms step_avg:87.72ms +step:392/1670 train_time:34384ms step_avg:87.71ms +step:393/1670 train_time:34472ms step_avg:87.71ms +step:394/1670 train_time:34558ms step_avg:87.71ms +step:395/1670 train_time:34645ms step_avg:87.71ms +step:396/1670 train_time:34733ms step_avg:87.71ms +step:397/1670 train_time:34819ms step_avg:87.71ms +step:398/1670 train_time:34907ms step_avg:87.71ms +step:399/1670 train_time:34995ms step_avg:87.71ms +step:400/1670 train_time:35084ms step_avg:87.71ms +step:401/1670 train_time:35174ms step_avg:87.72ms +step:402/1670 train_time:35262ms step_avg:87.72ms +step:403/1670 train_time:35350ms step_avg:87.72ms +step:404/1670 train_time:35438ms step_avg:87.72ms +step:405/1670 train_time:35525ms step_avg:87.72ms +step:406/1670 train_time:35612ms step_avg:87.71ms +step:407/1670 train_time:35699ms step_avg:87.71ms +step:408/1670 train_time:35786ms step_avg:87.71ms +step:409/1670 train_time:35874ms step_avg:87.71ms +step:410/1670 train_time:35962ms step_avg:87.71ms +step:411/1670 train_time:36051ms step_avg:87.72ms +step:412/1670 train_time:36139ms step_avg:87.72ms +step:413/1670 train_time:36227ms step_avg:87.72ms +step:414/1670 train_time:36315ms step_avg:87.72ms +step:415/1670 train_time:36403ms step_avg:87.72ms +step:416/1670 train_time:36490ms step_avg:87.72ms +step:417/1670 train_time:36577ms step_avg:87.72ms +step:418/1670 train_time:36664ms step_avg:87.71ms +step:419/1670 train_time:36752ms step_avg:87.71ms +step:420/1670 train_time:36839ms step_avg:87.71ms +step:421/1670 train_time:36927ms step_avg:87.71ms +step:422/1670 train_time:37015ms step_avg:87.71ms +step:423/1670 train_time:37103ms step_avg:87.71ms +step:424/1670 train_time:37192ms step_avg:87.72ms +step:425/1670 train_time:37280ms step_avg:87.72ms +step:426/1670 train_time:37367ms step_avg:87.72ms +step:427/1670 train_time:37455ms step_avg:87.72ms +step:428/1670 train_time:37542ms step_avg:87.72ms +step:429/1670 train_time:37630ms step_avg:87.72ms +step:430/1670 train_time:37717ms step_avg:87.71ms +step:431/1670 train_time:37806ms step_avg:87.72ms +step:432/1670 train_time:37893ms step_avg:87.72ms +step:433/1670 train_time:37981ms step_avg:87.71ms +step:434/1670 train_time:38068ms step_avg:87.72ms +step:435/1670 train_time:38156ms step_avg:87.72ms +step:436/1670 train_time:38244ms step_avg:87.72ms +step:437/1670 train_time:38334ms step_avg:87.72ms +step:438/1670 train_time:38422ms step_avg:87.72ms +step:439/1670 train_time:38510ms step_avg:87.72ms +step:440/1670 train_time:38597ms step_avg:87.72ms +step:441/1670 train_time:38684ms step_avg:87.72ms +step:442/1670 train_time:38773ms step_avg:87.72ms +step:443/1670 train_time:38860ms step_avg:87.72ms +step:444/1670 train_time:38947ms step_avg:87.72ms +step:445/1670 train_time:39036ms step_avg:87.72ms +step:446/1670 train_time:39124ms step_avg:87.72ms +step:447/1670 train_time:39212ms step_avg:87.72ms +step:448/1670 train_time:39299ms step_avg:87.72ms +step:449/1670 train_time:39387ms step_avg:87.72ms +step:450/1670 train_time:39475ms step_avg:87.72ms +step:451/1670 train_time:39562ms step_avg:87.72ms +step:452/1670 train_time:39650ms step_avg:87.72ms +step:453/1670 train_time:39738ms step_avg:87.72ms +step:454/1670 train_time:39826ms step_avg:87.72ms +step:455/1670 train_time:39913ms step_avg:87.72ms +step:456/1670 train_time:40001ms step_avg:87.72ms +step:457/1670 train_time:40088ms step_avg:87.72ms +step:458/1670 train_time:40176ms step_avg:87.72ms +step:459/1670 train_time:40266ms step_avg:87.72ms +step:460/1670 train_time:40354ms step_avg:87.73ms +step:461/1670 train_time:40441ms step_avg:87.73ms +step:462/1670 train_time:40529ms step_avg:87.73ms +step:463/1670 train_time:40617ms step_avg:87.72ms +step:464/1670 train_time:40705ms step_avg:87.73ms +step:465/1670 train_time:40793ms step_avg:87.73ms +step:466/1670 train_time:40880ms step_avg:87.73ms +step:467/1670 train_time:40969ms step_avg:87.73ms +step:468/1670 train_time:41056ms step_avg:87.73ms +step:469/1670 train_time:41144ms step_avg:87.73ms +step:470/1670 train_time:41233ms step_avg:87.73ms +step:471/1670 train_time:41320ms step_avg:87.73ms +step:472/1670 train_time:41408ms step_avg:87.73ms +step:473/1670 train_time:41496ms step_avg:87.73ms +step:474/1670 train_time:41583ms step_avg:87.73ms +step:475/1670 train_time:41672ms step_avg:87.73ms +step:476/1670 train_time:41759ms step_avg:87.73ms +step:477/1670 train_time:41848ms step_avg:87.73ms +step:478/1670 train_time:41935ms step_avg:87.73ms +step:479/1670 train_time:42023ms step_avg:87.73ms +step:480/1670 train_time:42111ms step_avg:87.73ms +step:481/1670 train_time:42199ms step_avg:87.73ms +step:482/1670 train_time:42287ms step_avg:87.73ms +step:483/1670 train_time:42375ms step_avg:87.73ms +step:484/1670 train_time:42462ms step_avg:87.73ms +step:485/1670 train_time:42550ms step_avg:87.73ms +step:486/1670 train_time:42637ms step_avg:87.73ms +step:487/1670 train_time:42725ms step_avg:87.73ms +step:488/1670 train_time:42814ms step_avg:87.73ms +step:489/1670 train_time:42901ms step_avg:87.73ms +step:490/1670 train_time:42989ms step_avg:87.73ms +step:491/1670 train_time:43077ms step_avg:87.73ms +step:492/1670 train_time:43164ms step_avg:87.73ms +step:493/1670 train_time:43252ms step_avg:87.73ms +step:494/1670 train_time:43339ms step_avg:87.73ms +step:495/1670 train_time:43427ms step_avg:87.73ms +step:496/1670 train_time:43515ms step_avg:87.73ms +step:497/1670 train_time:43603ms step_avg:87.73ms +step:498/1670 train_time:43690ms step_avg:87.73ms +step:499/1670 train_time:43778ms step_avg:87.73ms +step:500/1670 train_time:43865ms step_avg:87.73ms +step:500/1670 val_loss:3.7207 train_time:43954ms step_avg:87.91ms +step:501/1670 train_time:43976ms step_avg:87.78ms +step:502/1670 train_time:44046ms step_avg:87.74ms +step:503/1670 train_time:44140ms step_avg:87.75ms +step:504/1670 train_time:44227ms step_avg:87.75ms +step:505/1670 train_time:44314ms step_avg:87.75ms +step:506/1670 train_time:44401ms step_avg:87.75ms +step:507/1670 train_time:44487ms step_avg:87.75ms +step:508/1670 train_time:44574ms step_avg:87.74ms +step:509/1670 train_time:44660ms step_avg:87.74ms +step:510/1670 train_time:44746ms step_avg:87.74ms +step:511/1670 train_time:44833ms step_avg:87.74ms +step:512/1670 train_time:44921ms step_avg:87.74ms +step:513/1670 train_time:45011ms step_avg:87.74ms +step:514/1670 train_time:45102ms step_avg:87.75ms +step:515/1670 train_time:45191ms step_avg:87.75ms +step:516/1670 train_time:45279ms step_avg:87.75ms +step:517/1670 train_time:45366ms step_avg:87.75ms +step:518/1670 train_time:45453ms step_avg:87.75ms +step:519/1670 train_time:45540ms step_avg:87.75ms +step:520/1670 train_time:45626ms step_avg:87.74ms +step:521/1670 train_time:45713ms step_avg:87.74ms +step:522/1670 train_time:45801ms step_avg:87.74ms +step:523/1670 train_time:45888ms step_avg:87.74ms +step:524/1670 train_time:45976ms step_avg:87.74ms +step:525/1670 train_time:46065ms step_avg:87.74ms +step:526/1670 train_time:46155ms step_avg:87.75ms +step:527/1670 train_time:46243ms step_avg:87.75ms +step:528/1670 train_time:46332ms step_avg:87.75ms +step:529/1670 train_time:46420ms step_avg:87.75ms +step:530/1670 train_time:46507ms step_avg:87.75ms +step:531/1670 train_time:46594ms step_avg:87.75ms +step:532/1670 train_time:46681ms step_avg:87.75ms +step:533/1670 train_time:46768ms step_avg:87.74ms +step:534/1670 train_time:46855ms step_avg:87.74ms +step:535/1670 train_time:46942ms step_avg:87.74ms +step:536/1670 train_time:47030ms step_avg:87.74ms +step:537/1670 train_time:47119ms step_avg:87.74ms +step:538/1670 train_time:47207ms step_avg:87.74ms +step:539/1670 train_time:47295ms step_avg:87.75ms +step:540/1670 train_time:47383ms step_avg:87.75ms +step:541/1670 train_time:47471ms step_avg:87.75ms +step:542/1670 train_time:47558ms step_avg:87.75ms +step:543/1670 train_time:47645ms step_avg:87.74ms +step:544/1670 train_time:47733ms step_avg:87.74ms +step:545/1670 train_time:47821ms step_avg:87.75ms +step:546/1670 train_time:47910ms step_avg:87.75ms +step:547/1670 train_time:47999ms step_avg:87.75ms +step:548/1670 train_time:48087ms step_avg:87.75ms +step:549/1670 train_time:48177ms step_avg:87.75ms +step:550/1670 train_time:48266ms step_avg:87.76ms +step:551/1670 train_time:48355ms step_avg:87.76ms +step:552/1670 train_time:48444ms step_avg:87.76ms +step:553/1670 train_time:48532ms step_avg:87.76ms +step:554/1670 train_time:48621ms step_avg:87.76ms +step:555/1670 train_time:48711ms step_avg:87.77ms +step:556/1670 train_time:48800ms step_avg:87.77ms +step:557/1670 train_time:48890ms step_avg:87.77ms +step:558/1670 train_time:48980ms step_avg:87.78ms +step:559/1670 train_time:49069ms step_avg:87.78ms +step:560/1670 train_time:49159ms step_avg:87.78ms +step:561/1670 train_time:49249ms step_avg:87.79ms +step:562/1670 train_time:49339ms step_avg:87.79ms +step:563/1670 train_time:49427ms step_avg:87.79ms +step:564/1670 train_time:49516ms step_avg:87.79ms +step:565/1670 train_time:49604ms step_avg:87.80ms +step:566/1670 train_time:49693ms step_avg:87.80ms +step:567/1670 train_time:49783ms step_avg:87.80ms +step:568/1670 train_time:49872ms step_avg:87.80ms +step:569/1670 train_time:49962ms step_avg:87.81ms +step:570/1670 train_time:50051ms step_avg:87.81ms +step:571/1670 train_time:50140ms step_avg:87.81ms +step:572/1670 train_time:50228ms step_avg:87.81ms +step:573/1670 train_time:50318ms step_avg:87.81ms +step:574/1670 train_time:50406ms step_avg:87.82ms +step:575/1670 train_time:50495ms step_avg:87.82ms +step:576/1670 train_time:50583ms step_avg:87.82ms +step:577/1670 train_time:50673ms step_avg:87.82ms +step:578/1670 train_time:50762ms step_avg:87.82ms +step:579/1670 train_time:50852ms step_avg:87.83ms +step:580/1670 train_time:50941ms step_avg:87.83ms +step:581/1670 train_time:51030ms step_avg:87.83ms +step:582/1670 train_time:51119ms step_avg:87.83ms +step:583/1670 train_time:51209ms step_avg:87.84ms +step:584/1670 train_time:51298ms step_avg:87.84ms +step:585/1670 train_time:51386ms step_avg:87.84ms +step:586/1670 train_time:51476ms step_avg:87.84ms +step:587/1670 train_time:51564ms step_avg:87.84ms +step:588/1670 train_time:51653ms step_avg:87.85ms +step:589/1670 train_time:51742ms step_avg:87.85ms +step:590/1670 train_time:51831ms step_avg:87.85ms +step:591/1670 train_time:51920ms step_avg:87.85ms +step:592/1670 train_time:52009ms step_avg:87.85ms +step:593/1670 train_time:52099ms step_avg:87.86ms +step:594/1670 train_time:52188ms step_avg:87.86ms +step:595/1670 train_time:52278ms step_avg:87.86ms +step:596/1670 train_time:52367ms step_avg:87.86ms +step:597/1670 train_time:52456ms step_avg:87.87ms +step:598/1670 train_time:52546ms step_avg:87.87ms +step:599/1670 train_time:52635ms step_avg:87.87ms +step:600/1670 train_time:52724ms step_avg:87.87ms +step:601/1670 train_time:52813ms step_avg:87.87ms +step:602/1670 train_time:52901ms step_avg:87.88ms +step:603/1670 train_time:52991ms step_avg:87.88ms +step:604/1670 train_time:53081ms step_avg:87.88ms +step:605/1670 train_time:53169ms step_avg:87.88ms +step:606/1670 train_time:53258ms step_avg:87.88ms +step:607/1670 train_time:53347ms step_avg:87.89ms +step:608/1670 train_time:53436ms step_avg:87.89ms +step:609/1670 train_time:53525ms step_avg:87.89ms +step:610/1670 train_time:53615ms step_avg:87.89ms +step:611/1670 train_time:53704ms step_avg:87.89ms +step:612/1670 train_time:53793ms step_avg:87.90ms +step:613/1670 train_time:53882ms step_avg:87.90ms +step:614/1670 train_time:53972ms step_avg:87.90ms +step:615/1670 train_time:54061ms step_avg:87.90ms +step:616/1670 train_time:54149ms step_avg:87.90ms +step:617/1670 train_time:54238ms step_avg:87.91ms +step:618/1670 train_time:54327ms step_avg:87.91ms +step:619/1670 train_time:54417ms step_avg:87.91ms +step:620/1670 train_time:54505ms step_avg:87.91ms +step:621/1670 train_time:54595ms step_avg:87.92ms +step:622/1670 train_time:54684ms step_avg:87.92ms +step:623/1670 train_time:54772ms step_avg:87.92ms +step:624/1670 train_time:54861ms step_avg:87.92ms +step:625/1670 train_time:54949ms step_avg:87.92ms +step:625/1670 val_loss:3.6177 train_time:55040ms step_avg:88.06ms +step:626/1670 train_time:55061ms step_avg:87.96ms +step:627/1670 train_time:55131ms step_avg:87.93ms +step:628/1670 train_time:55220ms step_avg:87.93ms +step:629/1670 train_time:55311ms step_avg:87.94ms +step:630/1670 train_time:55399ms step_avg:87.93ms +step:631/1670 train_time:55487ms step_avg:87.93ms +step:632/1670 train_time:55574ms step_avg:87.93ms +step:633/1670 train_time:55662ms step_avg:87.93ms +step:634/1670 train_time:55751ms step_avg:87.94ms +step:635/1670 train_time:55842ms step_avg:87.94ms +step:636/1670 train_time:55931ms step_avg:87.94ms +step:637/1670 train_time:56022ms step_avg:87.95ms +step:638/1670 train_time:56112ms step_avg:87.95ms +step:639/1670 train_time:56202ms step_avg:87.95ms +step:640/1670 train_time:56291ms step_avg:87.95ms +step:641/1670 train_time:56380ms step_avg:87.96ms +step:642/1670 train_time:56469ms step_avg:87.96ms +step:643/1670 train_time:56557ms step_avg:87.96ms +step:644/1670 train_time:56645ms step_avg:87.96ms +step:645/1670 train_time:56733ms step_avg:87.96ms +step:646/1670 train_time:56823ms step_avg:87.96ms +step:647/1670 train_time:56912ms step_avg:87.96ms +step:648/1670 train_time:57003ms step_avg:87.97ms +step:649/1670 train_time:57092ms step_avg:87.97ms +step:650/1670 train_time:57184ms step_avg:87.97ms +step:651/1670 train_time:57273ms step_avg:87.98ms +step:652/1670 train_time:57363ms step_avg:87.98ms +step:653/1670 train_time:57451ms step_avg:87.98ms +step:654/1670 train_time:57539ms step_avg:87.98ms +step:655/1670 train_time:57627ms step_avg:87.98ms +step:656/1670 train_time:57715ms step_avg:87.98ms +step:657/1670 train_time:57803ms step_avg:87.98ms +step:658/1670 train_time:57892ms step_avg:87.98ms +step:659/1670 train_time:57982ms step_avg:87.98ms +step:660/1670 train_time:58071ms step_avg:87.99ms +step:661/1670 train_time:58161ms step_avg:87.99ms +step:662/1670 train_time:58249ms step_avg:87.99ms +step:663/1670 train_time:58339ms step_avg:87.99ms +step:664/1670 train_time:58427ms step_avg:87.99ms +step:665/1670 train_time:58516ms step_avg:87.99ms +step:666/1670 train_time:58604ms step_avg:87.99ms +step:667/1670 train_time:58692ms step_avg:87.99ms +step:668/1670 train_time:58781ms step_avg:88.00ms +step:669/1670 train_time:58869ms step_avg:88.00ms +step:670/1670 train_time:58958ms step_avg:88.00ms +step:671/1670 train_time:59046ms step_avg:88.00ms +step:672/1670 train_time:59136ms step_avg:88.00ms +step:673/1670 train_time:59226ms step_avg:88.00ms +step:674/1670 train_time:59315ms step_avg:88.00ms +step:675/1670 train_time:59405ms step_avg:88.01ms +step:676/1670 train_time:59494ms step_avg:88.01ms +step:677/1670 train_time:59583ms step_avg:88.01ms +step:678/1670 train_time:59671ms step_avg:88.01ms +step:679/1670 train_time:59759ms step_avg:88.01ms +step:680/1670 train_time:59848ms step_avg:88.01ms +step:681/1670 train_time:59937ms step_avg:88.01ms +step:682/1670 train_time:60027ms step_avg:88.02ms +step:683/1670 train_time:60116ms step_avg:88.02ms +step:684/1670 train_time:60205ms step_avg:88.02ms +step:685/1670 train_time:60295ms step_avg:88.02ms +step:686/1670 train_time:60384ms step_avg:88.02ms +step:687/1670 train_time:60472ms step_avg:88.02ms +step:688/1670 train_time:60560ms step_avg:88.02ms +step:689/1670 train_time:60649ms step_avg:88.02ms +step:690/1670 train_time:60738ms step_avg:88.03ms +step:691/1670 train_time:60827ms step_avg:88.03ms +step:692/1670 train_time:60915ms step_avg:88.03ms +step:693/1670 train_time:61004ms step_avg:88.03ms +step:694/1670 train_time:61094ms step_avg:88.03ms +step:695/1670 train_time:61183ms step_avg:88.03ms +step:696/1670 train_time:61272ms step_avg:88.03ms +step:697/1670 train_time:61362ms step_avg:88.04ms +step:698/1670 train_time:61451ms step_avg:88.04ms +step:699/1670 train_time:61539ms step_avg:88.04ms +step:700/1670 train_time:61628ms step_avg:88.04ms +step:701/1670 train_time:61717ms step_avg:88.04ms +step:702/1670 train_time:61806ms step_avg:88.04ms +step:703/1670 train_time:61895ms step_avg:88.04ms +step:704/1670 train_time:61984ms step_avg:88.05ms +step:705/1670 train_time:62073ms step_avg:88.05ms +step:706/1670 train_time:62163ms step_avg:88.05ms +step:707/1670 train_time:62252ms step_avg:88.05ms +step:708/1670 train_time:62341ms step_avg:88.05ms +step:709/1670 train_time:62430ms step_avg:88.05ms +step:710/1670 train_time:62519ms step_avg:88.05ms +step:711/1670 train_time:62607ms step_avg:88.06ms +step:712/1670 train_time:62697ms step_avg:88.06ms +step:713/1670 train_time:62785ms step_avg:88.06ms +step:714/1670 train_time:62874ms step_avg:88.06ms +step:715/1670 train_time:62964ms step_avg:88.06ms +step:716/1670 train_time:63052ms step_avg:88.06ms +step:717/1670 train_time:63142ms step_avg:88.06ms +step:718/1670 train_time:63231ms step_avg:88.07ms +step:719/1670 train_time:63322ms step_avg:88.07ms +step:720/1670 train_time:63410ms step_avg:88.07ms +step:721/1670 train_time:63500ms step_avg:88.07ms +step:722/1670 train_time:63588ms step_avg:88.07ms +step:723/1670 train_time:63677ms step_avg:88.07ms +step:724/1670 train_time:63765ms step_avg:88.07ms +step:725/1670 train_time:63855ms step_avg:88.08ms +step:726/1670 train_time:63944ms step_avg:88.08ms +step:727/1670 train_time:64033ms step_avg:88.08ms +step:728/1670 train_time:64123ms step_avg:88.08ms +step:729/1670 train_time:64212ms step_avg:88.08ms +step:730/1670 train_time:64301ms step_avg:88.08ms +step:731/1670 train_time:64389ms step_avg:88.08ms +step:732/1670 train_time:64478ms step_avg:88.08ms +step:733/1670 train_time:64567ms step_avg:88.09ms +step:734/1670 train_time:64656ms step_avg:88.09ms +step:735/1670 train_time:64744ms step_avg:88.09ms +step:736/1670 train_time:64834ms step_avg:88.09ms +step:737/1670 train_time:64923ms step_avg:88.09ms +step:738/1670 train_time:65011ms step_avg:88.09ms +step:739/1670 train_time:65101ms step_avg:88.09ms +step:740/1670 train_time:65189ms step_avg:88.09ms +step:741/1670 train_time:65279ms step_avg:88.10ms +step:742/1670 train_time:65368ms step_avg:88.10ms +step:743/1670 train_time:65456ms step_avg:88.10ms +step:744/1670 train_time:65545ms step_avg:88.10ms +step:745/1670 train_time:65635ms step_avg:88.10ms +step:746/1670 train_time:65723ms step_avg:88.10ms +step:747/1670 train_time:65812ms step_avg:88.10ms +step:748/1670 train_time:65902ms step_avg:88.10ms +step:749/1670 train_time:65990ms step_avg:88.10ms +step:750/1670 train_time:66080ms step_avg:88.11ms +step:750/1670 val_loss:3.5670 train_time:66170ms step_avg:88.23ms +step:751/1670 train_time:66189ms step_avg:88.13ms +step:752/1670 train_time:66264ms step_avg:88.12ms +step:753/1670 train_time:66356ms step_avg:88.12ms +step:754/1670 train_time:66444ms step_avg:88.12ms +step:755/1670 train_time:66534ms step_avg:88.12ms +step:756/1670 train_time:66621ms step_avg:88.12ms +step:757/1670 train_time:66709ms step_avg:88.12ms +step:758/1670 train_time:66797ms step_avg:88.12ms +step:759/1670 train_time:66885ms step_avg:88.12ms +step:760/1670 train_time:66973ms step_avg:88.12ms +step:761/1670 train_time:67061ms step_avg:88.12ms +step:762/1670 train_time:67152ms step_avg:88.13ms +step:763/1670 train_time:67244ms step_avg:88.13ms +step:764/1670 train_time:67337ms step_avg:88.14ms +step:765/1670 train_time:67426ms step_avg:88.14ms +step:766/1670 train_time:67515ms step_avg:88.14ms +step:767/1670 train_time:67603ms step_avg:88.14ms +step:768/1670 train_time:67691ms step_avg:88.14ms +step:769/1670 train_time:67779ms step_avg:88.14ms +step:770/1670 train_time:67867ms step_avg:88.14ms +step:771/1670 train_time:67955ms step_avg:88.14ms +step:772/1670 train_time:68043ms step_avg:88.14ms +step:773/1670 train_time:68134ms step_avg:88.14ms +step:774/1670 train_time:68224ms step_avg:88.14ms +step:775/1670 train_time:68314ms step_avg:88.15ms +step:776/1670 train_time:68403ms step_avg:88.15ms +step:777/1670 train_time:68493ms step_avg:88.15ms +step:778/1670 train_time:68581ms step_avg:88.15ms +step:779/1670 train_time:68670ms step_avg:88.15ms +step:780/1670 train_time:68759ms step_avg:88.15ms +step:781/1670 train_time:68847ms step_avg:88.15ms +step:782/1670 train_time:68935ms step_avg:88.15ms +step:783/1670 train_time:69023ms step_avg:88.15ms +step:784/1670 train_time:69113ms step_avg:88.15ms +step:785/1670 train_time:69202ms step_avg:88.16ms +step:786/1670 train_time:69292ms step_avg:88.16ms +step:787/1670 train_time:69381ms step_avg:88.16ms +step:788/1670 train_time:69471ms step_avg:88.16ms +step:789/1670 train_time:69559ms step_avg:88.16ms +step:790/1670 train_time:69648ms step_avg:88.16ms +step:791/1670 train_time:69737ms step_avg:88.16ms +step:792/1670 train_time:69825ms step_avg:88.16ms +step:793/1670 train_time:69914ms step_avg:88.16ms +step:794/1670 train_time:70002ms step_avg:88.16ms +step:795/1670 train_time:70092ms step_avg:88.17ms +step:796/1670 train_time:70181ms step_avg:88.17ms +step:797/1670 train_time:70271ms step_avg:88.17ms +step:798/1670 train_time:70361ms step_avg:88.17ms +step:799/1670 train_time:70451ms step_avg:88.17ms +step:800/1670 train_time:70540ms step_avg:88.17ms +step:801/1670 train_time:70628ms step_avg:88.18ms +step:802/1670 train_time:70717ms step_avg:88.18ms +step:803/1670 train_time:70805ms step_avg:88.18ms +step:804/1670 train_time:70894ms step_avg:88.18ms +step:805/1670 train_time:70982ms step_avg:88.18ms +step:806/1670 train_time:71071ms step_avg:88.18ms +step:807/1670 train_time:71161ms step_avg:88.18ms +step:808/1670 train_time:71250ms step_avg:88.18ms +step:809/1670 train_time:71339ms step_avg:88.18ms +step:810/1670 train_time:71429ms step_avg:88.18ms +step:811/1670 train_time:71518ms step_avg:88.18ms +step:812/1670 train_time:71606ms step_avg:88.19ms +step:813/1670 train_time:71696ms step_avg:88.19ms +step:814/1670 train_time:71784ms step_avg:88.19ms +step:815/1670 train_time:71873ms step_avg:88.19ms +step:816/1670 train_time:71962ms step_avg:88.19ms +step:817/1670 train_time:72051ms step_avg:88.19ms +step:818/1670 train_time:72140ms step_avg:88.19ms +step:819/1670 train_time:72229ms step_avg:88.19ms +step:820/1670 train_time:72318ms step_avg:88.19ms +step:821/1670 train_time:72407ms step_avg:88.19ms +step:822/1670 train_time:72496ms step_avg:88.19ms +step:823/1670 train_time:72584ms step_avg:88.19ms +step:824/1670 train_time:72673ms step_avg:88.20ms +step:825/1670 train_time:72762ms step_avg:88.20ms +step:826/1670 train_time:72851ms step_avg:88.20ms +step:827/1670 train_time:72940ms step_avg:88.20ms +step:828/1670 train_time:73029ms step_avg:88.20ms +step:829/1670 train_time:73118ms step_avg:88.20ms +step:830/1670 train_time:73207ms step_avg:88.20ms +step:831/1670 train_time:73297ms step_avg:88.20ms +step:832/1670 train_time:73386ms step_avg:88.20ms +step:833/1670 train_time:73476ms step_avg:88.21ms +step:834/1670 train_time:73564ms step_avg:88.21ms +step:835/1670 train_time:73654ms step_avg:88.21ms +step:836/1670 train_time:73742ms step_avg:88.21ms +step:837/1670 train_time:73831ms step_avg:88.21ms +step:838/1670 train_time:73920ms step_avg:88.21ms +step:839/1670 train_time:74009ms step_avg:88.21ms +step:840/1670 train_time:74099ms step_avg:88.21ms +step:841/1670 train_time:74188ms step_avg:88.21ms +step:842/1670 train_time:74277ms step_avg:88.22ms +step:843/1670 train_time:74366ms step_avg:88.22ms +step:844/1670 train_time:74455ms step_avg:88.22ms +step:845/1670 train_time:74544ms step_avg:88.22ms +step:846/1670 train_time:74634ms step_avg:88.22ms +step:847/1670 train_time:74722ms step_avg:88.22ms +step:848/1670 train_time:74811ms step_avg:88.22ms +step:849/1670 train_time:74900ms step_avg:88.22ms +step:850/1670 train_time:74989ms step_avg:88.22ms +step:851/1670 train_time:75078ms step_avg:88.22ms +step:852/1670 train_time:75167ms step_avg:88.22ms +step:853/1670 train_time:75258ms step_avg:88.23ms +step:854/1670 train_time:75346ms step_avg:88.23ms +step:855/1670 train_time:75435ms step_avg:88.23ms +step:856/1670 train_time:75524ms step_avg:88.23ms +step:857/1670 train_time:75614ms step_avg:88.23ms +step:858/1670 train_time:75702ms step_avg:88.23ms +step:859/1670 train_time:75792ms step_avg:88.23ms +step:860/1670 train_time:75881ms step_avg:88.23ms +step:861/1670 train_time:75970ms step_avg:88.23ms +step:862/1670 train_time:76058ms step_avg:88.23ms +step:863/1670 train_time:76147ms step_avg:88.24ms +step:864/1670 train_time:76236ms step_avg:88.24ms +step:865/1670 train_time:76326ms step_avg:88.24ms +step:866/1670 train_time:76415ms step_avg:88.24ms +step:867/1670 train_time:76504ms step_avg:88.24ms +step:868/1670 train_time:76593ms step_avg:88.24ms +step:869/1670 train_time:76682ms step_avg:88.24ms +step:870/1670 train_time:76771ms step_avg:88.24ms +step:871/1670 train_time:76861ms step_avg:88.24ms +step:872/1670 train_time:76951ms step_avg:88.25ms +step:873/1670 train_time:77040ms step_avg:88.25ms +step:874/1670 train_time:77129ms step_avg:88.25ms +step:875/1670 train_time:77218ms step_avg:88.25ms +step:875/1670 val_loss:3.5175 train_time:77308ms step_avg:88.35ms +step:876/1670 train_time:77328ms step_avg:88.27ms +step:877/1670 train_time:77400ms step_avg:88.26ms +step:878/1670 train_time:77493ms step_avg:88.26ms +step:879/1670 train_time:77581ms step_avg:88.26ms +step:880/1670 train_time:77669ms step_avg:88.26ms +step:881/1670 train_time:77757ms step_avg:88.26ms +step:882/1670 train_time:77844ms step_avg:88.26ms +step:883/1670 train_time:77932ms step_avg:88.26ms +step:884/1670 train_time:78021ms step_avg:88.26ms +step:885/1670 train_time:78109ms step_avg:88.26ms +step:886/1670 train_time:78197ms step_avg:88.26ms +step:887/1670 train_time:78288ms step_avg:88.26ms +step:888/1670 train_time:78379ms step_avg:88.26ms +step:889/1670 train_time:78469ms step_avg:88.27ms +step:890/1670 train_time:78559ms step_avg:88.27ms +step:891/1670 train_time:78648ms step_avg:88.27ms +step:892/1670 train_time:78737ms step_avg:88.27ms +step:893/1670 train_time:78825ms step_avg:88.27ms +step:894/1670 train_time:78912ms step_avg:88.27ms +step:895/1670 train_time:79001ms step_avg:88.27ms +step:896/1670 train_time:79089ms step_avg:88.27ms +step:897/1670 train_time:79178ms step_avg:88.27ms +step:898/1670 train_time:79267ms step_avg:88.27ms +step:899/1670 train_time:79358ms step_avg:88.27ms +step:900/1670 train_time:79448ms step_avg:88.28ms +step:901/1670 train_time:79538ms step_avg:88.28ms +step:902/1670 train_time:79628ms step_avg:88.28ms +step:903/1670 train_time:79718ms step_avg:88.28ms +step:904/1670 train_time:79806ms step_avg:88.28ms +step:905/1670 train_time:79895ms step_avg:88.28ms +step:906/1670 train_time:79984ms step_avg:88.28ms +step:907/1670 train_time:80072ms step_avg:88.28ms +step:908/1670 train_time:80161ms step_avg:88.28ms +step:909/1670 train_time:80250ms step_avg:88.28ms +step:910/1670 train_time:80340ms step_avg:88.29ms +step:911/1670 train_time:80429ms step_avg:88.29ms +step:912/1670 train_time:80519ms step_avg:88.29ms +step:913/1670 train_time:80609ms step_avg:88.29ms +step:914/1670 train_time:80698ms step_avg:88.29ms +step:915/1670 train_time:80786ms step_avg:88.29ms +step:916/1670 train_time:80875ms step_avg:88.29ms +step:917/1670 train_time:80963ms step_avg:88.29ms +step:918/1670 train_time:81052ms step_avg:88.29ms +step:919/1670 train_time:81140ms step_avg:88.29ms +step:920/1670 train_time:81230ms step_avg:88.29ms +step:921/1670 train_time:81319ms step_avg:88.29ms +step:922/1670 train_time:81409ms step_avg:88.30ms +step:923/1670 train_time:81499ms step_avg:88.30ms +step:924/1670 train_time:81588ms step_avg:88.30ms +step:925/1670 train_time:81677ms step_avg:88.30ms +step:926/1670 train_time:81766ms step_avg:88.30ms +step:927/1670 train_time:81856ms step_avg:88.30ms +step:928/1670 train_time:81944ms step_avg:88.30ms +step:929/1670 train_time:82032ms step_avg:88.30ms +step:930/1670 train_time:82122ms step_avg:88.30ms +step:931/1670 train_time:82211ms step_avg:88.30ms +step:932/1670 train_time:82300ms step_avg:88.31ms +step:933/1670 train_time:82390ms step_avg:88.31ms +step:934/1670 train_time:82480ms step_avg:88.31ms +step:935/1670 train_time:82570ms step_avg:88.31ms +step:936/1670 train_time:82659ms step_avg:88.31ms +step:937/1670 train_time:82748ms step_avg:88.31ms +step:938/1670 train_time:82836ms step_avg:88.31ms +step:939/1670 train_time:82925ms step_avg:88.31ms +step:940/1670 train_time:83014ms step_avg:88.31ms +step:941/1670 train_time:83102ms step_avg:88.31ms +step:942/1670 train_time:83191ms step_avg:88.31ms +step:943/1670 train_time:83281ms step_avg:88.31ms +step:944/1670 train_time:83369ms step_avg:88.31ms +step:945/1670 train_time:83458ms step_avg:88.32ms +step:946/1670 train_time:83547ms step_avg:88.32ms +step:947/1670 train_time:83637ms step_avg:88.32ms +step:948/1670 train_time:83726ms step_avg:88.32ms +step:949/1670 train_time:83815ms step_avg:88.32ms +step:950/1670 train_time:83905ms step_avg:88.32ms +step:951/1670 train_time:83994ms step_avg:88.32ms +step:952/1670 train_time:84085ms step_avg:88.32ms +step:953/1670 train_time:84173ms step_avg:88.32ms +step:954/1670 train_time:84263ms step_avg:88.33ms +step:955/1670 train_time:84351ms step_avg:88.33ms +step:956/1670 train_time:84440ms step_avg:88.33ms +step:957/1670 train_time:84529ms step_avg:88.33ms +step:958/1670 train_time:84618ms step_avg:88.33ms +step:959/1670 train_time:84708ms step_avg:88.33ms +step:960/1670 train_time:84797ms step_avg:88.33ms +step:961/1670 train_time:84887ms step_avg:88.33ms +step:962/1670 train_time:84976ms step_avg:88.33ms +step:963/1670 train_time:85065ms step_avg:88.33ms +step:964/1670 train_time:85154ms step_avg:88.33ms +step:965/1670 train_time:85244ms step_avg:88.34ms +step:966/1670 train_time:85332ms step_avg:88.34ms +step:967/1670 train_time:85421ms step_avg:88.34ms +step:968/1670 train_time:85510ms step_avg:88.34ms +step:969/1670 train_time:85601ms step_avg:88.34ms +step:970/1670 train_time:85690ms step_avg:88.34ms +step:971/1670 train_time:85779ms step_avg:88.34ms +step:972/1670 train_time:85867ms step_avg:88.34ms +step:973/1670 train_time:85956ms step_avg:88.34ms +step:974/1670 train_time:86045ms step_avg:88.34ms +step:975/1670 train_time:86134ms step_avg:88.34ms +step:976/1670 train_time:86224ms step_avg:88.34ms +step:977/1670 train_time:86312ms step_avg:88.34ms +step:978/1670 train_time:86401ms step_avg:88.34ms +step:979/1670 train_time:86490ms step_avg:88.35ms +step:980/1670 train_time:86580ms step_avg:88.35ms +step:981/1670 train_time:86669ms step_avg:88.35ms +step:982/1670 train_time:86758ms step_avg:88.35ms +step:983/1670 train_time:86847ms step_avg:88.35ms +step:984/1670 train_time:86936ms step_avg:88.35ms +step:985/1670 train_time:87025ms step_avg:88.35ms +step:986/1670 train_time:87113ms step_avg:88.35ms +step:987/1670 train_time:87203ms step_avg:88.35ms +step:988/1670 train_time:87291ms step_avg:88.35ms +step:989/1670 train_time:87381ms step_avg:88.35ms +step:990/1670 train_time:87470ms step_avg:88.35ms +step:991/1670 train_time:87560ms step_avg:88.36ms +step:992/1670 train_time:87649ms step_avg:88.36ms +step:993/1670 train_time:87738ms step_avg:88.36ms +step:994/1670 train_time:87828ms step_avg:88.36ms +step:995/1670 train_time:87917ms step_avg:88.36ms +step:996/1670 train_time:88006ms step_avg:88.36ms +step:997/1670 train_time:88095ms step_avg:88.36ms +step:998/1670 train_time:88185ms step_avg:88.36ms +step:999/1670 train_time:88273ms step_avg:88.36ms +step:1000/1670 train_time:88363ms step_avg:88.36ms +step:1000/1670 val_loss:3.4676 train_time:88454ms step_avg:88.45ms +step:1001/1670 train_time:88473ms step_avg:88.38ms +step:1002/1670 train_time:88547ms step_avg:88.37ms +step:1003/1670 train_time:88641ms step_avg:88.38ms +step:1004/1670 train_time:88732ms step_avg:88.38ms +step:1005/1670 train_time:88819ms step_avg:88.38ms +step:1006/1670 train_time:88907ms step_avg:88.38ms +step:1007/1670 train_time:88995ms step_avg:88.38ms +step:1008/1670 train_time:89082ms step_avg:88.37ms +step:1009/1670 train_time:89169ms step_avg:88.37ms +step:1010/1670 train_time:89257ms step_avg:88.37ms +step:1011/1670 train_time:89345ms step_avg:88.37ms +step:1012/1670 train_time:89436ms step_avg:88.38ms +step:1013/1670 train_time:89528ms step_avg:88.38ms +step:1014/1670 train_time:89620ms step_avg:88.38ms +step:1015/1670 train_time:89711ms step_avg:88.38ms +step:1016/1670 train_time:89800ms step_avg:88.39ms +step:1017/1670 train_time:89888ms step_avg:88.39ms +step:1018/1670 train_time:89976ms step_avg:88.39ms +step:1019/1670 train_time:90064ms step_avg:88.38ms +step:1020/1670 train_time:90152ms step_avg:88.38ms +step:1021/1670 train_time:90239ms step_avg:88.38ms +step:1022/1670 train_time:90327ms step_avg:88.38ms +step:1023/1670 train_time:90416ms step_avg:88.38ms +step:1024/1670 train_time:90507ms step_avg:88.39ms +step:1025/1670 train_time:90599ms step_avg:88.39ms +step:1026/1670 train_time:90690ms step_avg:88.39ms +step:1027/1670 train_time:90780ms step_avg:88.39ms +step:1028/1670 train_time:90869ms step_avg:88.39ms +step:1029/1670 train_time:90958ms step_avg:88.39ms +step:1030/1670 train_time:91046ms step_avg:88.39ms +step:1031/1670 train_time:91134ms step_avg:88.39ms +step:1032/1670 train_time:91222ms step_avg:88.39ms +step:1033/1670 train_time:91310ms step_avg:88.39ms +step:1034/1670 train_time:91399ms step_avg:88.39ms +step:1035/1670 train_time:91489ms step_avg:88.39ms +step:1036/1670 train_time:91578ms step_avg:88.40ms +step:1037/1670 train_time:91668ms step_avg:88.40ms +step:1038/1670 train_time:91759ms step_avg:88.40ms +step:1039/1670 train_time:91849ms step_avg:88.40ms +step:1040/1670 train_time:91937ms step_avg:88.40ms +step:1041/1670 train_time:92026ms step_avg:88.40ms +step:1042/1670 train_time:92116ms step_avg:88.40ms +step:1043/1670 train_time:92204ms step_avg:88.40ms +step:1044/1670 train_time:92292ms step_avg:88.40ms +step:1045/1670 train_time:92381ms step_avg:88.40ms +step:1046/1670 train_time:92470ms step_avg:88.40ms +step:1047/1670 train_time:92561ms step_avg:88.41ms +step:1048/1670 train_time:92650ms step_avg:88.41ms +step:1049/1670 train_time:92740ms step_avg:88.41ms +step:1050/1670 train_time:92830ms step_avg:88.41ms +step:1051/1670 train_time:92919ms step_avg:88.41ms +step:1052/1670 train_time:93009ms step_avg:88.41ms +step:1053/1670 train_time:93097ms step_avg:88.41ms +step:1054/1670 train_time:93186ms step_avg:88.41ms +step:1055/1670 train_time:93274ms step_avg:88.41ms +step:1056/1670 train_time:93362ms step_avg:88.41ms +step:1057/1670 train_time:93452ms step_avg:88.41ms +step:1058/1670 train_time:93541ms step_avg:88.41ms +step:1059/1670 train_time:93631ms step_avg:88.41ms +step:1060/1670 train_time:93720ms step_avg:88.42ms +step:1061/1670 train_time:93810ms step_avg:88.42ms +step:1062/1670 train_time:93899ms step_avg:88.42ms +step:1063/1670 train_time:93989ms step_avg:88.42ms +step:1064/1670 train_time:94078ms step_avg:88.42ms +step:1065/1670 train_time:94167ms step_avg:88.42ms +step:1066/1670 train_time:94256ms step_avg:88.42ms +step:1067/1670 train_time:94344ms step_avg:88.42ms +step:1068/1670 train_time:94434ms step_avg:88.42ms +step:1069/1670 train_time:94524ms step_avg:88.42ms +step:1070/1670 train_time:94615ms step_avg:88.42ms +step:1071/1670 train_time:94702ms step_avg:88.42ms +step:1072/1670 train_time:94792ms step_avg:88.43ms +step:1073/1670 train_time:94881ms step_avg:88.43ms +step:1074/1670 train_time:94970ms step_avg:88.43ms +step:1075/1670 train_time:95059ms step_avg:88.43ms +step:1076/1670 train_time:95148ms step_avg:88.43ms +step:1077/1670 train_time:95237ms step_avg:88.43ms +step:1078/1670 train_time:95326ms step_avg:88.43ms +step:1079/1670 train_time:95416ms step_avg:88.43ms +step:1080/1670 train_time:95505ms step_avg:88.43ms +step:1081/1670 train_time:95595ms step_avg:88.43ms +step:1082/1670 train_time:95683ms step_avg:88.43ms +step:1083/1670 train_time:95773ms step_avg:88.43ms +step:1084/1670 train_time:95861ms step_avg:88.43ms +step:1085/1670 train_time:95951ms step_avg:88.43ms +step:1086/1670 train_time:96039ms step_avg:88.43ms +step:1087/1670 train_time:96128ms step_avg:88.43ms +step:1088/1670 train_time:96216ms step_avg:88.43ms +step:1089/1670 train_time:96306ms step_avg:88.43ms +step:1090/1670 train_time:96396ms step_avg:88.44ms +step:1091/1670 train_time:96486ms step_avg:88.44ms +step:1092/1670 train_time:96577ms step_avg:88.44ms +step:1093/1670 train_time:96666ms step_avg:88.44ms +step:1094/1670 train_time:96756ms step_avg:88.44ms +step:1095/1670 train_time:96846ms step_avg:88.44ms +step:1096/1670 train_time:96936ms step_avg:88.45ms +step:1097/1670 train_time:97026ms step_avg:88.45ms +step:1098/1670 train_time:97116ms step_avg:88.45ms +step:1099/1670 train_time:97205ms step_avg:88.45ms +step:1100/1670 train_time:97295ms step_avg:88.45ms +step:1101/1670 train_time:97385ms step_avg:88.45ms +step:1102/1670 train_time:97475ms step_avg:88.45ms +step:1103/1670 train_time:97564ms step_avg:88.45ms +step:1104/1670 train_time:97655ms step_avg:88.46ms +step:1105/1670 train_time:97744ms step_avg:88.46ms +step:1106/1670 train_time:97834ms step_avg:88.46ms +step:1107/1670 train_time:97924ms step_avg:88.46ms +step:1108/1670 train_time:98014ms step_avg:88.46ms +step:1109/1670 train_time:98103ms step_avg:88.46ms +step:1110/1670 train_time:98193ms step_avg:88.46ms +step:1111/1670 train_time:98282ms step_avg:88.46ms +step:1112/1670 train_time:98372ms step_avg:88.46ms +step:1113/1670 train_time:98462ms step_avg:88.46ms +step:1114/1670 train_time:98552ms step_avg:88.47ms +step:1115/1670 train_time:98641ms step_avg:88.47ms +step:1116/1670 train_time:98732ms step_avg:88.47ms +step:1117/1670 train_time:98822ms step_avg:88.47ms +step:1118/1670 train_time:98913ms step_avg:88.47ms +step:1119/1670 train_time:99002ms step_avg:88.47ms +step:1120/1670 train_time:99092ms step_avg:88.47ms +step:1121/1670 train_time:99180ms step_avg:88.47ms +step:1122/1670 train_time:99270ms step_avg:88.48ms +step:1123/1670 train_time:99360ms step_avg:88.48ms +step:1124/1670 train_time:99450ms step_avg:88.48ms +step:1125/1670 train_time:99540ms step_avg:88.48ms +step:1125/1670 val_loss:3.4131 train_time:99631ms step_avg:88.56ms +step:1126/1670 train_time:99650ms step_avg:88.50ms +step:1127/1670 train_time:99723ms step_avg:88.49ms +step:1128/1670 train_time:99812ms step_avg:88.49ms +step:1129/1670 train_time:99903ms step_avg:88.49ms +step:1130/1670 train_time:99992ms step_avg:88.49ms +step:1131/1670 train_time:100082ms step_avg:88.49ms +step:1132/1670 train_time:100170ms step_avg:88.49ms +step:1133/1670 train_time:100259ms step_avg:88.49ms +step:1134/1670 train_time:100348ms step_avg:88.49ms +step:1135/1670 train_time:100436ms step_avg:88.49ms +step:1136/1670 train_time:100527ms step_avg:88.49ms +step:1137/1670 train_time:100619ms step_avg:88.50ms +step:1138/1670 train_time:100711ms step_avg:88.50ms +step:1139/1670 train_time:100801ms step_avg:88.50ms +step:1140/1670 train_time:100891ms step_avg:88.50ms +step:1141/1670 train_time:100981ms step_avg:88.50ms +step:1142/1670 train_time:101071ms step_avg:88.50ms +step:1143/1670 train_time:101160ms step_avg:88.50ms +step:1144/1670 train_time:101249ms step_avg:88.50ms +step:1145/1670 train_time:101338ms step_avg:88.51ms +step:1146/1670 train_time:101427ms step_avg:88.51ms +step:1147/1670 train_time:101517ms step_avg:88.51ms +step:1148/1670 train_time:101607ms step_avg:88.51ms +step:1149/1670 train_time:101698ms step_avg:88.51ms +step:1150/1670 train_time:101788ms step_avg:88.51ms +step:1151/1670 train_time:101878ms step_avg:88.51ms +step:1152/1670 train_time:101968ms step_avg:88.51ms +step:1153/1670 train_time:102058ms step_avg:88.52ms +step:1154/1670 train_time:102148ms step_avg:88.52ms +step:1155/1670 train_time:102237ms step_avg:88.52ms +step:1156/1670 train_time:102326ms step_avg:88.52ms +step:1157/1670 train_time:102415ms step_avg:88.52ms +step:1158/1670 train_time:102505ms step_avg:88.52ms +step:1159/1670 train_time:102595ms step_avg:88.52ms +step:1160/1670 train_time:102686ms step_avg:88.52ms +step:1161/1670 train_time:102775ms step_avg:88.52ms +step:1162/1670 train_time:102867ms step_avg:88.53ms +step:1163/1670 train_time:102957ms step_avg:88.53ms +step:1164/1670 train_time:103047ms step_avg:88.53ms +step:1165/1670 train_time:103136ms step_avg:88.53ms +step:1166/1670 train_time:103226ms step_avg:88.53ms +step:1167/1670 train_time:103315ms step_avg:88.53ms +step:1168/1670 train_time:103404ms step_avg:88.53ms +step:1169/1670 train_time:103493ms step_avg:88.53ms +step:1170/1670 train_time:103583ms step_avg:88.53ms +step:1171/1670 train_time:103672ms step_avg:88.53ms +step:1172/1670 train_time:103763ms step_avg:88.53ms +step:1173/1670 train_time:103852ms step_avg:88.54ms +step:1174/1670 train_time:103943ms step_avg:88.54ms +step:1175/1670 train_time:104032ms step_avg:88.54ms +step:1176/1670 train_time:104123ms step_avg:88.54ms +step:1177/1670 train_time:104212ms step_avg:88.54ms +step:1178/1670 train_time:104301ms step_avg:88.54ms +step:1179/1670 train_time:104390ms step_avg:88.54ms +step:1180/1670 train_time:104479ms step_avg:88.54ms +step:1181/1670 train_time:104569ms step_avg:88.54ms +step:1182/1670 train_time:104659ms step_avg:88.54ms +step:1183/1670 train_time:104749ms step_avg:88.55ms +step:1184/1670 train_time:104841ms step_avg:88.55ms +step:1185/1670 train_time:104931ms step_avg:88.55ms +step:1186/1670 train_time:105020ms step_avg:88.55ms +step:1187/1670 train_time:105110ms step_avg:88.55ms +step:1188/1670 train_time:105200ms step_avg:88.55ms +step:1189/1670 train_time:105289ms step_avg:88.55ms +step:1190/1670 train_time:105380ms step_avg:88.55ms +step:1191/1670 train_time:105469ms step_avg:88.55ms +step:1192/1670 train_time:105559ms step_avg:88.56ms +step:1193/1670 train_time:105650ms step_avg:88.56ms +step:1194/1670 train_time:105741ms step_avg:88.56ms +step:1195/1670 train_time:105831ms step_avg:88.56ms +step:1196/1670 train_time:105920ms step_avg:88.56ms +step:1197/1670 train_time:106009ms step_avg:88.56ms +step:1198/1670 train_time:106098ms step_avg:88.56ms +step:1199/1670 train_time:106188ms step_avg:88.56ms +step:1200/1670 train_time:106278ms step_avg:88.57ms +step:1201/1670 train_time:106368ms step_avg:88.57ms +step:1202/1670 train_time:106457ms step_avg:88.57ms +step:1203/1670 train_time:106547ms step_avg:88.57ms +step:1204/1670 train_time:106637ms step_avg:88.57ms +step:1205/1670 train_time:106727ms step_avg:88.57ms +step:1206/1670 train_time:106817ms step_avg:88.57ms +step:1207/1670 train_time:106908ms step_avg:88.57ms +step:1208/1670 train_time:106997ms step_avg:88.57ms +step:1209/1670 train_time:107088ms step_avg:88.58ms +step:1210/1670 train_time:107178ms step_avg:88.58ms +step:1211/1670 train_time:107268ms step_avg:88.58ms +step:1212/1670 train_time:107359ms step_avg:88.58ms +step:1213/1670 train_time:107449ms step_avg:88.58ms +step:1214/1670 train_time:107539ms step_avg:88.58ms +step:1215/1670 train_time:107628ms step_avg:88.58ms +step:1216/1670 train_time:107718ms step_avg:88.58ms +step:1217/1670 train_time:107807ms step_avg:88.58ms +step:1218/1670 train_time:107896ms step_avg:88.58ms +step:1219/1670 train_time:107987ms step_avg:88.59ms +step:1220/1670 train_time:108077ms step_avg:88.59ms +step:1221/1670 train_time:108166ms step_avg:88.59ms +step:1222/1670 train_time:108256ms step_avg:88.59ms +step:1223/1670 train_time:108346ms step_avg:88.59ms +step:1224/1670 train_time:108436ms step_avg:88.59ms +step:1225/1670 train_time:108525ms step_avg:88.59ms +step:1226/1670 train_time:108614ms step_avg:88.59ms +step:1227/1670 train_time:108704ms step_avg:88.59ms +step:1228/1670 train_time:108793ms step_avg:88.59ms +step:1229/1670 train_time:108883ms step_avg:88.60ms +step:1230/1670 train_time:108972ms step_avg:88.60ms +step:1231/1670 train_time:109062ms step_avg:88.60ms +step:1232/1670 train_time:109151ms step_avg:88.60ms +step:1233/1670 train_time:109241ms step_avg:88.60ms +step:1234/1670 train_time:109331ms step_avg:88.60ms +step:1235/1670 train_time:109420ms step_avg:88.60ms +step:1236/1670 train_time:109509ms step_avg:88.60ms +step:1237/1670 train_time:109599ms step_avg:88.60ms +step:1238/1670 train_time:109689ms step_avg:88.60ms +step:1239/1670 train_time:109780ms step_avg:88.60ms +step:1240/1670 train_time:109870ms step_avg:88.60ms +step:1241/1670 train_time:109960ms step_avg:88.61ms +step:1242/1670 train_time:110050ms step_avg:88.61ms +step:1243/1670 train_time:110140ms step_avg:88.61ms +step:1244/1670 train_time:110229ms step_avg:88.61ms +step:1245/1670 train_time:110319ms step_avg:88.61ms +step:1246/1670 train_time:110408ms step_avg:88.61ms +step:1247/1670 train_time:110497ms step_avg:88.61ms +step:1248/1670 train_time:110587ms step_avg:88.61ms +step:1249/1670 train_time:110677ms step_avg:88.61ms +step:1250/1670 train_time:110767ms step_avg:88.61ms +step:1250/1670 val_loss:3.3752 train_time:110858ms step_avg:88.69ms +step:1251/1670 train_time:110879ms step_avg:88.63ms +step:1252/1670 train_time:110952ms step_avg:88.62ms +step:1253/1670 train_time:111043ms step_avg:88.62ms +step:1254/1670 train_time:111134ms step_avg:88.62ms +step:1255/1670 train_time:111222ms step_avg:88.62ms +step:1256/1670 train_time:111311ms step_avg:88.62ms +step:1257/1670 train_time:111399ms step_avg:88.62ms +step:1258/1670 train_time:111488ms step_avg:88.62ms +step:1259/1670 train_time:111576ms step_avg:88.62ms +step:1260/1670 train_time:111666ms step_avg:88.62ms +step:1261/1670 train_time:111754ms step_avg:88.62ms +step:1262/1670 train_time:111845ms step_avg:88.63ms +step:1263/1670 train_time:111936ms step_avg:88.63ms +step:1264/1670 train_time:112029ms step_avg:88.63ms +step:1265/1670 train_time:112120ms step_avg:88.63ms +step:1266/1670 train_time:112210ms step_avg:88.63ms +step:1267/1670 train_time:112300ms step_avg:88.63ms +step:1268/1670 train_time:112389ms step_avg:88.63ms +step:1269/1670 train_time:112478ms step_avg:88.64ms +step:1270/1670 train_time:112567ms step_avg:88.64ms +step:1271/1670 train_time:112656ms step_avg:88.64ms +step:1272/1670 train_time:112746ms step_avg:88.64ms +step:1273/1670 train_time:112835ms step_avg:88.64ms +step:1274/1670 train_time:112927ms step_avg:88.64ms +step:1275/1670 train_time:113017ms step_avg:88.64ms +step:1276/1670 train_time:113108ms step_avg:88.64ms +step:1277/1670 train_time:113197ms step_avg:88.64ms +step:1278/1670 train_time:113287ms step_avg:88.64ms +step:1279/1670 train_time:113376ms step_avg:88.64ms +step:1280/1670 train_time:113465ms step_avg:88.64ms +step:1281/1670 train_time:113553ms step_avg:88.64ms +step:1282/1670 train_time:113643ms step_avg:88.65ms +step:1283/1670 train_time:113732ms step_avg:88.65ms +step:1284/1670 train_time:113823ms step_avg:88.65ms +step:1285/1670 train_time:113913ms step_avg:88.65ms +step:1286/1670 train_time:114004ms step_avg:88.65ms +step:1287/1670 train_time:114093ms step_avg:88.65ms +step:1288/1670 train_time:114184ms step_avg:88.65ms +step:1289/1670 train_time:114275ms step_avg:88.65ms +step:1290/1670 train_time:114366ms step_avg:88.66ms +step:1291/1670 train_time:114454ms step_avg:88.66ms +step:1292/1670 train_time:114543ms step_avg:88.66ms +step:1293/1670 train_time:114632ms step_avg:88.66ms +step:1294/1670 train_time:114721ms step_avg:88.66ms +step:1295/1670 train_time:114812ms step_avg:88.66ms +step:1296/1670 train_time:114902ms step_avg:88.66ms +step:1297/1670 train_time:114992ms step_avg:88.66ms +step:1298/1670 train_time:115081ms step_avg:88.66ms +step:1299/1670 train_time:115171ms step_avg:88.66ms +step:1300/1670 train_time:115261ms step_avg:88.66ms +step:1301/1670 train_time:115351ms step_avg:88.66ms +step:1302/1670 train_time:115441ms step_avg:88.66ms +step:1303/1670 train_time:115531ms step_avg:88.67ms +step:1304/1670 train_time:115620ms step_avg:88.67ms +step:1305/1670 train_time:115709ms step_avg:88.67ms +step:1306/1670 train_time:115800ms step_avg:88.67ms +step:1307/1670 train_time:115890ms step_avg:88.67ms +step:1308/1670 train_time:115980ms step_avg:88.67ms +step:1309/1670 train_time:116069ms step_avg:88.67ms +step:1310/1670 train_time:116159ms step_avg:88.67ms +step:1311/1670 train_time:116249ms step_avg:88.67ms +step:1312/1670 train_time:116339ms step_avg:88.67ms +step:1313/1670 train_time:116429ms step_avg:88.67ms +step:1314/1670 train_time:116518ms step_avg:88.67ms +step:1315/1670 train_time:116608ms step_avg:88.68ms +step:1316/1670 train_time:116697ms step_avg:88.68ms +step:1317/1670 train_time:116788ms step_avg:88.68ms +step:1318/1670 train_time:116878ms step_avg:88.68ms +step:1319/1670 train_time:116969ms step_avg:88.68ms +step:1320/1670 train_time:117059ms step_avg:88.68ms +step:1321/1670 train_time:117148ms step_avg:88.68ms +step:1322/1670 train_time:117237ms step_avg:88.68ms +step:1323/1670 train_time:117328ms step_avg:88.68ms +step:1324/1670 train_time:117418ms step_avg:88.68ms +step:1325/1670 train_time:117508ms step_avg:88.69ms +step:1326/1670 train_time:117598ms step_avg:88.69ms +step:1327/1670 train_time:117689ms step_avg:88.69ms +step:1328/1670 train_time:117778ms step_avg:88.69ms +step:1329/1670 train_time:117870ms step_avg:88.69ms +step:1330/1670 train_time:117959ms step_avg:88.69ms +step:1331/1670 train_time:118049ms step_avg:88.69ms +step:1332/1670 train_time:118139ms step_avg:88.69ms +step:1333/1670 train_time:118229ms step_avg:88.69ms +step:1334/1670 train_time:118318ms step_avg:88.69ms +step:1335/1670 train_time:118409ms step_avg:88.70ms +step:1336/1670 train_time:118498ms step_avg:88.70ms +step:1337/1670 train_time:118588ms step_avg:88.70ms +step:1338/1670 train_time:118677ms step_avg:88.70ms +step:1339/1670 train_time:118768ms step_avg:88.70ms +step:1340/1670 train_time:118857ms step_avg:88.70ms +step:1341/1670 train_time:118947ms step_avg:88.70ms +step:1342/1670 train_time:119036ms step_avg:88.70ms +step:1343/1670 train_time:119127ms step_avg:88.70ms +step:1344/1670 train_time:119216ms step_avg:88.70ms +step:1345/1670 train_time:119306ms step_avg:88.70ms +step:1346/1670 train_time:119395ms step_avg:88.70ms +step:1347/1670 train_time:119485ms step_avg:88.70ms +step:1348/1670 train_time:119575ms step_avg:88.71ms +step:1349/1670 train_time:119666ms step_avg:88.71ms +step:1350/1670 train_time:119755ms step_avg:88.71ms +step:1351/1670 train_time:119846ms step_avg:88.71ms +step:1352/1670 train_time:119935ms step_avg:88.71ms +step:1353/1670 train_time:120026ms step_avg:88.71ms +step:1354/1670 train_time:120116ms step_avg:88.71ms +step:1355/1670 train_time:120206ms step_avg:88.71ms +step:1356/1670 train_time:120295ms step_avg:88.71ms +step:1357/1670 train_time:120385ms step_avg:88.71ms +step:1358/1670 train_time:120475ms step_avg:88.71ms +step:1359/1670 train_time:120565ms step_avg:88.72ms +step:1360/1670 train_time:120655ms step_avg:88.72ms +step:1361/1670 train_time:120745ms step_avg:88.72ms +step:1362/1670 train_time:120834ms step_avg:88.72ms +step:1363/1670 train_time:120924ms step_avg:88.72ms +step:1364/1670 train_time:121013ms step_avg:88.72ms +step:1365/1670 train_time:121103ms step_avg:88.72ms +step:1366/1670 train_time:121192ms step_avg:88.72ms +step:1367/1670 train_time:121282ms step_avg:88.72ms +step:1368/1670 train_time:121371ms step_avg:88.72ms +step:1369/1670 train_time:121463ms step_avg:88.72ms +step:1370/1670 train_time:121552ms step_avg:88.72ms +step:1371/1670 train_time:121642ms step_avg:88.73ms +step:1372/1670 train_time:121732ms step_avg:88.73ms +step:1373/1670 train_time:121822ms step_avg:88.73ms +step:1374/1670 train_time:121912ms step_avg:88.73ms +step:1375/1670 train_time:122001ms step_avg:88.73ms +step:1375/1670 val_loss:3.3403 train_time:122093ms step_avg:88.79ms +step:1376/1670 train_time:122112ms step_avg:88.74ms +step:1377/1670 train_time:122188ms step_avg:88.74ms +step:1378/1670 train_time:122282ms step_avg:88.74ms +step:1379/1670 train_time:122372ms step_avg:88.74ms +step:1380/1670 train_time:122460ms step_avg:88.74ms +step:1381/1670 train_time:122550ms step_avg:88.74ms +step:1382/1670 train_time:122637ms step_avg:88.74ms +step:1383/1670 train_time:122726ms step_avg:88.74ms +step:1384/1670 train_time:122815ms step_avg:88.74ms +step:1385/1670 train_time:122904ms step_avg:88.74ms +step:1386/1670 train_time:122994ms step_avg:88.74ms +step:1387/1670 train_time:123086ms step_avg:88.74ms +step:1388/1670 train_time:123178ms step_avg:88.74ms +step:1389/1670 train_time:123270ms step_avg:88.75ms +step:1390/1670 train_time:123361ms step_avg:88.75ms +step:1391/1670 train_time:123450ms step_avg:88.75ms +step:1392/1670 train_time:123539ms step_avg:88.75ms +step:1393/1670 train_time:123628ms step_avg:88.75ms +step:1394/1670 train_time:123717ms step_avg:88.75ms +step:1395/1670 train_time:123806ms step_avg:88.75ms +step:1396/1670 train_time:123895ms step_avg:88.75ms +step:1397/1670 train_time:123984ms step_avg:88.75ms +step:1398/1670 train_time:124074ms step_avg:88.75ms +step:1399/1670 train_time:124165ms step_avg:88.75ms +step:1400/1670 train_time:124256ms step_avg:88.75ms +step:1401/1670 train_time:124346ms step_avg:88.76ms +step:1402/1670 train_time:124436ms step_avg:88.76ms +step:1403/1670 train_time:124526ms step_avg:88.76ms +step:1404/1670 train_time:124616ms step_avg:88.76ms +step:1405/1670 train_time:124705ms step_avg:88.76ms +step:1406/1670 train_time:124794ms step_avg:88.76ms +step:1407/1670 train_time:124884ms step_avg:88.76ms +step:1408/1670 train_time:124973ms step_avg:88.76ms +step:1409/1670 train_time:125063ms step_avg:88.76ms +step:1410/1670 train_time:125153ms step_avg:88.76ms +step:1411/1670 train_time:125244ms step_avg:88.76ms +step:1412/1670 train_time:125335ms step_avg:88.76ms +step:1413/1670 train_time:125426ms step_avg:88.77ms +step:1414/1670 train_time:125516ms step_avg:88.77ms +step:1415/1670 train_time:125606ms step_avg:88.77ms +step:1416/1670 train_time:125695ms step_avg:88.77ms +step:1417/1670 train_time:125786ms step_avg:88.77ms +step:1418/1670 train_time:125874ms step_avg:88.77ms +step:1419/1670 train_time:125964ms step_avg:88.77ms +step:1420/1670 train_time:126053ms step_avg:88.77ms +step:1421/1670 train_time:126143ms step_avg:88.77ms +step:1422/1670 train_time:126233ms step_avg:88.77ms +step:1423/1670 train_time:126324ms step_avg:88.77ms +step:1424/1670 train_time:126415ms step_avg:88.77ms +step:1425/1670 train_time:126505ms step_avg:88.78ms +step:1426/1670 train_time:126594ms step_avg:88.78ms +step:1427/1670 train_time:126684ms step_avg:88.78ms +step:1428/1670 train_time:126773ms step_avg:88.78ms +step:1429/1670 train_time:126863ms step_avg:88.78ms +step:1430/1670 train_time:126952ms step_avg:88.78ms +step:1431/1670 train_time:127041ms step_avg:88.78ms +step:1432/1670 train_time:127132ms step_avg:88.78ms +step:1433/1670 train_time:127223ms step_avg:88.78ms +step:1434/1670 train_time:127314ms step_avg:88.78ms +step:1435/1670 train_time:127405ms step_avg:88.78ms +step:1436/1670 train_time:127496ms step_avg:88.79ms +step:1437/1670 train_time:127585ms step_avg:88.79ms +step:1438/1670 train_time:127674ms step_avg:88.79ms +step:1439/1670 train_time:127764ms step_avg:88.79ms +step:1440/1670 train_time:127854ms step_avg:88.79ms +step:1441/1670 train_time:127943ms step_avg:88.79ms +step:1442/1670 train_time:128033ms step_avg:88.79ms +step:1443/1670 train_time:128123ms step_avg:88.79ms +step:1444/1670 train_time:128214ms step_avg:88.79ms +step:1445/1670 train_time:128306ms step_avg:88.79ms +step:1446/1670 train_time:128395ms step_avg:88.79ms +step:1447/1670 train_time:128485ms step_avg:88.79ms +step:1448/1670 train_time:128575ms step_avg:88.79ms +step:1449/1670 train_time:128665ms step_avg:88.80ms +step:1450/1670 train_time:128754ms step_avg:88.80ms +step:1451/1670 train_time:128843ms step_avg:88.80ms +step:1452/1670 train_time:128933ms step_avg:88.80ms +step:1453/1670 train_time:129023ms step_avg:88.80ms +step:1454/1670 train_time:129114ms step_avg:88.80ms +step:1455/1670 train_time:129204ms step_avg:88.80ms +step:1456/1670 train_time:129296ms step_avg:88.80ms +step:1457/1670 train_time:129386ms step_avg:88.80ms +step:1458/1670 train_time:129475ms step_avg:88.80ms +step:1459/1670 train_time:129564ms step_avg:88.80ms +step:1460/1670 train_time:129654ms step_avg:88.80ms +step:1461/1670 train_time:129743ms step_avg:88.80ms +step:1462/1670 train_time:129833ms step_avg:88.80ms +step:1463/1670 train_time:129923ms step_avg:88.81ms +step:1464/1670 train_time:130014ms step_avg:88.81ms +step:1465/1670 train_time:130104ms step_avg:88.81ms +step:1466/1670 train_time:130195ms step_avg:88.81ms +step:1467/1670 train_time:130286ms step_avg:88.81ms +step:1468/1670 train_time:130376ms step_avg:88.81ms +step:1469/1670 train_time:130465ms step_avg:88.81ms +step:1470/1670 train_time:130554ms step_avg:88.81ms +step:1471/1670 train_time:130643ms step_avg:88.81ms +step:1472/1670 train_time:130734ms step_avg:88.81ms +step:1473/1670 train_time:130824ms step_avg:88.81ms +step:1474/1670 train_time:130913ms step_avg:88.81ms +step:1475/1670 train_time:131003ms step_avg:88.82ms +step:1476/1670 train_time:131094ms step_avg:88.82ms +step:1477/1670 train_time:131185ms step_avg:88.82ms +step:1478/1670 train_time:131275ms step_avg:88.82ms +step:1479/1670 train_time:131366ms step_avg:88.82ms +step:1480/1670 train_time:131455ms step_avg:88.82ms +step:1481/1670 train_time:131545ms step_avg:88.82ms +step:1482/1670 train_time:131634ms step_avg:88.82ms +step:1483/1670 train_time:131724ms step_avg:88.82ms +step:1484/1670 train_time:131814ms step_avg:88.82ms +step:1485/1670 train_time:131906ms step_avg:88.83ms +step:1486/1670 train_time:131995ms step_avg:88.83ms +step:1487/1670 train_time:132085ms step_avg:88.83ms +step:1488/1670 train_time:132176ms step_avg:88.83ms +step:1489/1670 train_time:132267ms step_avg:88.83ms +step:1490/1670 train_time:132356ms step_avg:88.83ms +step:1491/1670 train_time:132446ms step_avg:88.83ms +step:1492/1670 train_time:132535ms step_avg:88.83ms +step:1493/1670 train_time:132624ms step_avg:88.83ms +step:1494/1670 train_time:132713ms step_avg:88.83ms +step:1495/1670 train_time:132803ms step_avg:88.83ms +step:1496/1670 train_time:132894ms step_avg:88.83ms +step:1497/1670 train_time:132985ms step_avg:88.83ms +step:1498/1670 train_time:133074ms step_avg:88.83ms +step:1499/1670 train_time:133166ms step_avg:88.84ms +step:1500/1670 train_time:133255ms step_avg:88.84ms +step:1500/1670 val_loss:3.3102 train_time:133346ms step_avg:88.90ms +step:1501/1670 train_time:133365ms step_avg:88.85ms +step:1502/1670 train_time:133439ms step_avg:88.84ms +step:1503/1670 train_time:133533ms step_avg:88.84ms +step:1504/1670 train_time:133624ms step_avg:88.85ms +step:1505/1670 train_time:133713ms step_avg:88.85ms +step:1506/1670 train_time:133802ms step_avg:88.85ms +step:1507/1670 train_time:133891ms step_avg:88.85ms +step:1508/1670 train_time:133979ms step_avg:88.85ms +step:1509/1670 train_time:134067ms step_avg:88.84ms +step:1510/1670 train_time:134156ms step_avg:88.85ms +step:1511/1670 train_time:134245ms step_avg:88.85ms +step:1512/1670 train_time:134338ms step_avg:88.85ms +step:1513/1670 train_time:134430ms step_avg:88.85ms +step:1514/1670 train_time:134523ms step_avg:88.85ms +step:1515/1670 train_time:134613ms step_avg:88.85ms +step:1516/1670 train_time:134703ms step_avg:88.85ms +step:1517/1670 train_time:134792ms step_avg:88.85ms +step:1518/1670 train_time:134881ms step_avg:88.85ms +step:1519/1670 train_time:134969ms step_avg:88.85ms +step:1520/1670 train_time:135058ms step_avg:88.85ms +step:1521/1670 train_time:135146ms step_avg:88.85ms +step:1522/1670 train_time:135236ms step_avg:88.85ms +step:1523/1670 train_time:135327ms step_avg:88.86ms +step:1524/1670 train_time:135420ms step_avg:88.86ms +step:1525/1670 train_time:135510ms step_avg:88.86ms +step:1526/1670 train_time:135600ms step_avg:88.86ms +step:1527/1670 train_time:135690ms step_avg:88.86ms +step:1528/1670 train_time:135779ms step_avg:88.86ms +step:1529/1670 train_time:135868ms step_avg:88.86ms +step:1530/1670 train_time:135957ms step_avg:88.86ms +step:1531/1670 train_time:136046ms step_avg:88.86ms +step:1532/1670 train_time:136135ms step_avg:88.86ms +step:1533/1670 train_time:136226ms step_avg:88.86ms +step:1534/1670 train_time:136316ms step_avg:88.86ms +step:1535/1670 train_time:136407ms step_avg:88.86ms +step:1536/1670 train_time:136498ms step_avg:88.87ms +step:1537/1670 train_time:136589ms step_avg:88.87ms +step:1538/1670 train_time:136680ms step_avg:88.87ms +step:1539/1670 train_time:136769ms step_avg:88.87ms +step:1540/1670 train_time:136859ms step_avg:88.87ms +step:1541/1670 train_time:136949ms step_avg:88.87ms +step:1542/1670 train_time:137039ms step_avg:88.87ms +step:1543/1670 train_time:137128ms step_avg:88.87ms +step:1544/1670 train_time:137218ms step_avg:88.87ms +step:1545/1670 train_time:137308ms step_avg:88.87ms +step:1546/1670 train_time:137398ms step_avg:88.87ms +step:1547/1670 train_time:137488ms step_avg:88.87ms +step:1548/1670 train_time:137579ms step_avg:88.88ms +step:1549/1670 train_time:137668ms step_avg:88.88ms +step:1550/1670 train_time:137758ms step_avg:88.88ms +step:1551/1670 train_time:137847ms step_avg:88.88ms +step:1552/1670 train_time:137937ms step_avg:88.88ms +step:1553/1670 train_time:138027ms step_avg:88.88ms +step:1554/1670 train_time:138117ms step_avg:88.88ms +step:1555/1670 train_time:138207ms step_avg:88.88ms +step:1556/1670 train_time:138297ms step_avg:88.88ms +step:1557/1670 train_time:138387ms step_avg:88.88ms +step:1558/1670 train_time:138477ms step_avg:88.88ms +step:1559/1670 train_time:138567ms step_avg:88.88ms +step:1560/1670 train_time:138658ms step_avg:88.88ms +step:1561/1670 train_time:138747ms step_avg:88.88ms +step:1562/1670 train_time:138838ms step_avg:88.88ms +step:1563/1670 train_time:138927ms step_avg:88.88ms +step:1564/1670 train_time:139017ms step_avg:88.89ms +step:1565/1670 train_time:139105ms step_avg:88.89ms +step:1566/1670 train_time:139195ms step_avg:88.89ms +step:1567/1670 train_time:139285ms step_avg:88.89ms +step:1568/1670 train_time:139375ms step_avg:88.89ms +step:1569/1670 train_time:139466ms step_avg:88.89ms +step:1570/1670 train_time:139557ms step_avg:88.89ms +step:1571/1670 train_time:139647ms step_avg:88.89ms +step:1572/1670 train_time:139737ms step_avg:88.89ms +step:1573/1670 train_time:139826ms step_avg:88.89ms +step:1574/1670 train_time:139916ms step_avg:88.89ms +step:1575/1670 train_time:140006ms step_avg:88.89ms +step:1576/1670 train_time:140096ms step_avg:88.89ms +step:1577/1670 train_time:140186ms step_avg:88.89ms +step:1578/1670 train_time:140275ms step_avg:88.89ms +step:1579/1670 train_time:140365ms step_avg:88.89ms +step:1580/1670 train_time:140455ms step_avg:88.90ms +step:1581/1670 train_time:140545ms step_avg:88.90ms +step:1582/1670 train_time:140636ms step_avg:88.90ms +step:1583/1670 train_time:140726ms step_avg:88.90ms +step:1584/1670 train_time:140816ms step_avg:88.90ms +step:1585/1670 train_time:140905ms step_avg:88.90ms +step:1586/1670 train_time:140995ms step_avg:88.90ms +step:1587/1670 train_time:141085ms step_avg:88.90ms +step:1588/1670 train_time:141175ms step_avg:88.90ms +step:1589/1670 train_time:141264ms step_avg:88.90ms +step:1590/1670 train_time:141354ms step_avg:88.90ms +step:1591/1670 train_time:141444ms step_avg:88.90ms +step:1592/1670 train_time:141535ms step_avg:88.90ms +step:1593/1670 train_time:141625ms step_avg:88.90ms +step:1594/1670 train_time:141715ms step_avg:88.91ms +step:1595/1670 train_time:141805ms step_avg:88.91ms +step:1596/1670 train_time:141894ms step_avg:88.91ms +step:1597/1670 train_time:141984ms step_avg:88.91ms +step:1598/1670 train_time:142074ms step_avg:88.91ms +step:1599/1670 train_time:142165ms step_avg:88.91ms +step:1600/1670 train_time:142255ms step_avg:88.91ms +step:1601/1670 train_time:142345ms step_avg:88.91ms +step:1602/1670 train_time:142434ms step_avg:88.91ms +step:1603/1670 train_time:142525ms step_avg:88.91ms +step:1604/1670 train_time:142616ms step_avg:88.91ms +step:1605/1670 train_time:142705ms step_avg:88.91ms +step:1606/1670 train_time:142795ms step_avg:88.91ms +step:1607/1670 train_time:142885ms step_avg:88.91ms +step:1608/1670 train_time:142975ms step_avg:88.91ms +step:1609/1670 train_time:143064ms step_avg:88.91ms +step:1610/1670 train_time:143154ms step_avg:88.92ms +step:1611/1670 train_time:143244ms step_avg:88.92ms +step:1612/1670 train_time:143333ms step_avg:88.92ms +step:1613/1670 train_time:143424ms step_avg:88.92ms +step:1614/1670 train_time:143514ms step_avg:88.92ms +step:1615/1670 train_time:143604ms step_avg:88.92ms +step:1616/1670 train_time:143694ms step_avg:88.92ms +step:1617/1670 train_time:143783ms step_avg:88.92ms +step:1618/1670 train_time:143873ms step_avg:88.92ms +step:1619/1670 train_time:143963ms step_avg:88.92ms +step:1620/1670 train_time:144053ms step_avg:88.92ms +step:1621/1670 train_time:144144ms step_avg:88.92ms +step:1622/1670 train_time:144234ms step_avg:88.92ms +step:1623/1670 train_time:144324ms step_avg:88.92ms +step:1624/1670 train_time:144415ms step_avg:88.93ms +step:1625/1670 train_time:144505ms step_avg:88.93ms +step:1625/1670 val_loss:3.2869 train_time:144596ms step_avg:88.98ms +step:1626/1670 train_time:144615ms step_avg:88.94ms +step:1627/1670 train_time:144690ms step_avg:88.93ms +step:1628/1670 train_time:144787ms step_avg:88.94ms +step:1629/1670 train_time:144878ms step_avg:88.94ms +step:1630/1670 train_time:144967ms step_avg:88.94ms +step:1631/1670 train_time:145055ms step_avg:88.94ms +step:1632/1670 train_time:145145ms step_avg:88.94ms +step:1633/1670 train_time:145233ms step_avg:88.94ms +step:1634/1670 train_time:145321ms step_avg:88.94ms +step:1635/1670 train_time:145410ms step_avg:88.94ms +step:1636/1670 train_time:145499ms step_avg:88.94ms +step:1637/1670 train_time:145592ms step_avg:88.94ms +step:1638/1670 train_time:145685ms step_avg:88.94ms +step:1639/1670 train_time:145776ms step_avg:88.94ms +step:1640/1670 train_time:145867ms step_avg:88.94ms +step:1641/1670 train_time:145957ms step_avg:88.94ms +step:1642/1670 train_time:146046ms step_avg:88.94ms +step:1643/1670 train_time:146135ms step_avg:88.94ms +step:1644/1670 train_time:146224ms step_avg:88.94ms +step:1645/1670 train_time:146313ms step_avg:88.94ms +step:1646/1670 train_time:146402ms step_avg:88.94ms +step:1647/1670 train_time:146492ms step_avg:88.94ms +step:1648/1670 train_time:146583ms step_avg:88.95ms +step:1649/1670 train_time:146674ms step_avg:88.95ms +step:1650/1670 train_time:146765ms step_avg:88.95ms +step:1651/1670 train_time:146855ms step_avg:88.95ms +step:1652/1670 train_time:146947ms step_avg:88.95ms +step:1653/1670 train_time:147036ms step_avg:88.95ms +step:1654/1670 train_time:147125ms step_avg:88.95ms +step:1655/1670 train_time:147214ms step_avg:88.95ms +step:1656/1670 train_time:147303ms step_avg:88.95ms +step:1657/1670 train_time:147392ms step_avg:88.95ms +step:1658/1670 train_time:147481ms step_avg:88.95ms +step:1659/1670 train_time:147571ms step_avg:88.95ms +step:1660/1670 train_time:147661ms step_avg:88.95ms +step:1661/1670 train_time:147751ms step_avg:88.95ms +step:1662/1670 train_time:147842ms step_avg:88.95ms +step:1663/1670 train_time:147932ms step_avg:88.95ms +step:1664/1670 train_time:148022ms step_avg:88.96ms +step:1665/1670 train_time:148112ms step_avg:88.96ms +step:1666/1670 train_time:148201ms step_avg:88.96ms +step:1667/1670 train_time:148290ms step_avg:88.96ms +step:1668/1670 train_time:148379ms step_avg:88.96ms +step:1669/1670 train_time:148469ms step_avg:88.96ms +step:1670/1670 train_time:148559ms step_avg:88.96ms +step:1670/1670 val_loss:3.2776 train_time:148653ms step_avg:89.01ms +peak memory allocated: 30760 MiB reserved: 45914 MiB diff --git a/records/092925_PolarExpress/3bb6c2eb-1935-46d5-9f07-40b98223cfaa.txt b/records/092925_PolarExpress/3bb6c2eb-1935-46d5-9f07-40b98223cfaa.txt new file mode 100644 index 000000000..7abff2e73 --- /dev/null +++ b/records/092925_PolarExpress/3bb6c2eb-1935-46d5-9f07-40b98223cfaa.txt @@ -0,0 +1,3252 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from kernels import get_kernel +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[ + Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size() == 8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1, 10, 16, 16] + assert len(params) == sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx + size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx < len(params) else 0 + if getattr(params[module_idx], 'module', 'none') == 'attn': + for p in params[module_idx:module_idx + chunk_size]: + assert getattr(params[module_idx], 'module', 'none') == 'attn' + batch = 4 * original_shape[0] + d1 = original_shape[1] + d2 = original_shape[2] // 4 + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = polar_express(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = polar_express(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim // 4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim // 4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim * 4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module = 'attn' + with torch.no_grad(): + self.qkvo_w.view(4, self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4, self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, + 3 * self.num_heads, + self.head_dim).chunk( + 3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module = 'mlp' + self.c_proj.module = 'mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, + max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim ** 0.5) / 448, w_s=2 ** -9, + grad_s=1 / 448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), # extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, + long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1, len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i < 11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens = tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter == 5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter += 1 + return starts, ends + + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, + align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1630 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if + p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999, step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations + args.iteration_extension: + return args.ws_validate // 2, args.ws_validate + x = min(step / (1 + args.num_iterations), 0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[ + step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long < ws_long: + model.yarn.reset() + ws_long = new_ws_long + model(inputs, targets, cum_seqlens, ws_long // 2, ws_long).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.yarn.reset() +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations + args.iteration_extension +ws_short, ws_long = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws_short, new_ws_long = get_ws(step) + if new_ws_long != ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, + grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.5.0 +Mon Sep 29 06:43:39 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 40C P0 129W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 37C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:158ms step_avg:157.66ms +step:2/1670 train_time:179ms step_avg:89.40ms +step:3/1670 train_time:242ms step_avg:80.66ms +step:4/1670 train_time:327ms step_avg:81.86ms +step:5/1670 train_time:414ms step_avg:82.75ms +step:6/1670 train_time:500ms step_avg:83.40ms +step:7/1670 train_time:587ms step_avg:83.88ms +step:8/1670 train_time:674ms step_avg:84.23ms +step:9/1670 train_time:761ms step_avg:84.53ms +step:10/1670 train_time:847ms step_avg:84.73ms +step:11/1670 train_time:934ms step_avg:84.90ms +step:12/1670 train_time:1025ms step_avg:85.40ms +step:13/1670 train_time:1116ms step_avg:85.85ms +step:14/1670 train_time:1209ms step_avg:86.37ms +step:15/1670 train_time:1297ms step_avg:86.48ms +step:16/1670 train_time:1384ms step_avg:86.52ms +step:17/1670 train_time:1472ms step_avg:86.59ms +step:18/1670 train_time:1559ms step_avg:86.60ms +step:19/1670 train_time:1646ms step_avg:86.65ms +step:20/1670 train_time:1734ms step_avg:86.68ms +step:21/1670 train_time:1821ms step_avg:86.69ms +step:22/1670 train_time:1909ms step_avg:86.76ms +step:23/1670 train_time:1997ms step_avg:86.81ms +step:24/1670 train_time:2086ms step_avg:86.92ms +step:25/1670 train_time:2177ms step_avg:87.07ms +step:26/1670 train_time:2267ms step_avg:87.19ms +step:27/1670 train_time:2355ms step_avg:87.22ms +step:28/1670 train_time:2442ms step_avg:87.22ms +step:29/1670 train_time:2530ms step_avg:87.23ms +step:30/1670 train_time:2617ms step_avg:87.22ms +step:31/1670 train_time:2704ms step_avg:87.23ms +step:32/1670 train_time:2791ms step_avg:87.22ms +step:33/1670 train_time:2878ms step_avg:87.22ms +step:34/1670 train_time:2966ms step_avg:87.23ms +step:35/1670 train_time:3054ms step_avg:87.26ms +step:36/1670 train_time:3144ms step_avg:87.32ms +step:37/1670 train_time:3232ms step_avg:87.36ms +step:38/1670 train_time:3321ms step_avg:87.41ms +step:39/1670 train_time:3409ms step_avg:87.42ms +step:40/1670 train_time:3497ms step_avg:87.42ms +step:41/1670 train_time:3585ms step_avg:87.44ms +step:42/1670 train_time:3673ms step_avg:87.45ms +step:43/1670 train_time:3760ms step_avg:87.44ms +step:44/1670 train_time:3848ms step_avg:87.44ms +step:45/1670 train_time:3935ms step_avg:87.44ms +step:46/1670 train_time:4022ms step_avg:87.44ms +step:47/1670 train_time:4110ms step_avg:87.46ms +step:48/1670 train_time:4199ms step_avg:87.48ms +step:49/1670 train_time:4288ms step_avg:87.50ms +step:50/1670 train_time:4376ms step_avg:87.52ms +step:51/1670 train_time:4463ms step_avg:87.52ms +step:52/1670 train_time:4551ms step_avg:87.53ms +step:53/1670 train_time:4639ms step_avg:87.52ms +step:54/1670 train_time:4726ms step_avg:87.52ms +step:55/1670 train_time:4813ms step_avg:87.51ms +step:56/1670 train_time:4901ms step_avg:87.52ms +step:57/1670 train_time:4989ms step_avg:87.52ms +step:58/1670 train_time:5076ms step_avg:87.52ms +step:59/1670 train_time:5164ms step_avg:87.53ms +step:60/1670 train_time:5252ms step_avg:87.53ms +step:61/1670 train_time:5340ms step_avg:87.54ms +step:62/1670 train_time:5428ms step_avg:87.55ms +step:63/1670 train_time:5516ms step_avg:87.56ms +step:64/1670 train_time:5603ms step_avg:87.55ms +step:65/1670 train_time:5691ms step_avg:87.55ms +step:66/1670 train_time:5778ms step_avg:87.54ms +step:67/1670 train_time:5866ms step_avg:87.55ms +step:68/1670 train_time:5953ms step_avg:87.54ms +step:69/1670 train_time:6041ms step_avg:87.55ms +step:70/1670 train_time:6130ms step_avg:87.57ms +step:71/1670 train_time:6217ms step_avg:87.57ms +step:72/1670 train_time:6305ms step_avg:87.58ms +step:73/1670 train_time:6393ms step_avg:87.58ms +step:74/1670 train_time:6481ms step_avg:87.58ms +step:75/1670 train_time:6569ms step_avg:87.58ms +step:76/1670 train_time:6656ms step_avg:87.58ms +step:77/1670 train_time:6744ms step_avg:87.58ms +step:78/1670 train_time:6832ms step_avg:87.59ms +step:79/1670 train_time:6920ms step_avg:87.59ms +step:80/1670 train_time:7009ms step_avg:87.61ms +step:81/1670 train_time:7096ms step_avg:87.61ms +step:82/1670 train_time:7184ms step_avg:87.61ms +step:83/1670 train_time:7272ms step_avg:87.62ms +step:84/1670 train_time:7360ms step_avg:87.62ms +step:85/1670 train_time:7448ms step_avg:87.62ms +step:86/1670 train_time:7535ms step_avg:87.61ms +step:87/1670 train_time:7622ms step_avg:87.61ms +step:88/1670 train_time:7710ms step_avg:87.62ms +step:89/1670 train_time:7798ms step_avg:87.62ms +step:90/1670 train_time:7886ms step_avg:87.62ms +step:91/1670 train_time:7973ms step_avg:87.62ms +step:92/1670 train_time:8061ms step_avg:87.62ms +step:93/1670 train_time:8150ms step_avg:87.63ms +step:94/1670 train_time:8238ms step_avg:87.63ms +step:95/1670 train_time:8326ms step_avg:87.64ms +step:96/1670 train_time:8414ms step_avg:87.64ms +step:97/1670 train_time:8502ms step_avg:87.65ms +step:98/1670 train_time:8591ms step_avg:87.66ms +step:99/1670 train_time:8678ms step_avg:87.66ms +step:100/1670 train_time:8765ms step_avg:87.65ms +step:101/1670 train_time:8853ms step_avg:87.65ms +step:102/1670 train_time:8941ms step_avg:87.65ms +step:103/1670 train_time:9029ms step_avg:87.66ms +step:104/1670 train_time:9117ms step_avg:87.66ms +step:105/1670 train_time:9205ms step_avg:87.67ms +step:106/1670 train_time:9292ms step_avg:87.66ms +step:107/1670 train_time:9381ms step_avg:87.67ms +step:108/1670 train_time:9469ms step_avg:87.67ms +step:109/1670 train_time:9556ms step_avg:87.67ms +step:110/1670 train_time:9644ms step_avg:87.67ms +step:111/1670 train_time:9731ms step_avg:87.67ms +step:112/1670 train_time:9819ms step_avg:87.67ms +step:113/1670 train_time:9907ms step_avg:87.68ms +step:114/1670 train_time:9994ms step_avg:87.67ms +step:115/1670 train_time:10082ms step_avg:87.67ms +step:116/1670 train_time:10169ms step_avg:87.67ms +step:117/1670 train_time:10257ms step_avg:87.67ms +step:118/1670 train_time:10345ms step_avg:87.67ms +step:119/1670 train_time:10433ms step_avg:87.67ms +step:120/1670 train_time:10522ms step_avg:87.68ms +step:121/1670 train_time:10609ms step_avg:87.68ms +step:122/1670 train_time:10696ms step_avg:87.68ms +step:123/1670 train_time:10785ms step_avg:87.68ms +step:124/1670 train_time:10872ms step_avg:87.68ms +step:125/1670 train_time:10960ms step_avg:87.68ms +step:125/1670 val_loss:4.3301 train_time:11050ms step_avg:88.40ms +step:126/1670 train_time:11070ms step_avg:87.86ms +step:127/1670 train_time:11142ms step_avg:87.74ms +step:128/1670 train_time:11238ms step_avg:87.79ms +step:129/1670 train_time:11327ms step_avg:87.81ms +step:130/1670 train_time:11414ms step_avg:87.80ms +step:131/1670 train_time:11501ms step_avg:87.79ms +step:132/1670 train_time:11587ms step_avg:87.78ms +step:133/1670 train_time:11673ms step_avg:87.77ms +step:134/1670 train_time:11760ms step_avg:87.76ms +step:135/1670 train_time:11846ms step_avg:87.75ms +step:136/1670 train_time:11933ms step_avg:87.74ms +step:137/1670 train_time:12020ms step_avg:87.74ms +step:138/1670 train_time:12109ms step_avg:87.75ms +step:139/1670 train_time:12201ms step_avg:87.78ms +step:140/1670 train_time:12291ms step_avg:87.79ms +step:141/1670 train_time:12381ms step_avg:87.81ms +step:142/1670 train_time:12468ms step_avg:87.80ms +step:143/1670 train_time:12555ms step_avg:87.80ms +step:144/1670 train_time:12642ms step_avg:87.79ms +step:145/1670 train_time:12729ms step_avg:87.79ms +step:146/1670 train_time:12815ms step_avg:87.78ms +step:147/1670 train_time:12903ms step_avg:87.77ms +step:148/1670 train_time:12989ms step_avg:87.76ms +step:149/1670 train_time:13077ms step_avg:87.77ms +step:150/1670 train_time:13166ms step_avg:87.77ms +step:151/1670 train_time:13255ms step_avg:87.78ms +step:152/1670 train_time:13343ms step_avg:87.79ms +step:153/1670 train_time:13431ms step_avg:87.79ms +step:154/1670 train_time:13520ms step_avg:87.79ms +step:155/1670 train_time:13607ms step_avg:87.78ms +step:156/1670 train_time:13694ms step_avg:87.78ms +step:157/1670 train_time:13782ms step_avg:87.78ms +step:158/1670 train_time:13868ms step_avg:87.78ms +step:159/1670 train_time:13956ms step_avg:87.78ms +step:160/1670 train_time:14044ms step_avg:87.77ms +step:161/1670 train_time:14132ms step_avg:87.78ms +step:162/1670 train_time:14221ms step_avg:87.78ms +step:163/1670 train_time:14309ms step_avg:87.78ms +step:164/1670 train_time:14397ms step_avg:87.79ms +step:165/1670 train_time:14485ms step_avg:87.79ms +step:166/1670 train_time:14573ms step_avg:87.79ms +step:167/1670 train_time:14661ms step_avg:87.79ms +step:168/1670 train_time:14748ms step_avg:87.79ms +step:169/1670 train_time:14836ms step_avg:87.78ms +step:170/1670 train_time:14922ms step_avg:87.78ms +step:171/1670 train_time:15010ms step_avg:87.78ms +step:172/1670 train_time:15098ms step_avg:87.78ms +step:173/1670 train_time:15186ms step_avg:87.78ms +step:174/1670 train_time:15275ms step_avg:87.79ms +step:175/1670 train_time:15364ms step_avg:87.79ms +step:176/1670 train_time:15452ms step_avg:87.80ms +step:177/1670 train_time:15541ms step_avg:87.80ms +step:178/1670 train_time:15629ms step_avg:87.80ms +step:179/1670 train_time:15716ms step_avg:87.80ms +step:180/1670 train_time:15803ms step_avg:87.79ms +step:181/1670 train_time:15890ms step_avg:87.79ms +step:182/1670 train_time:15978ms step_avg:87.79ms +step:183/1670 train_time:16065ms step_avg:87.79ms +step:184/1670 train_time:16154ms step_avg:87.79ms +step:185/1670 train_time:16242ms step_avg:87.79ms +step:186/1670 train_time:16330ms step_avg:87.80ms +step:187/1670 train_time:16418ms step_avg:87.80ms +step:188/1670 train_time:16505ms step_avg:87.79ms +step:189/1670 train_time:16593ms step_avg:87.79ms +step:190/1670 train_time:16681ms step_avg:87.79ms +step:191/1670 train_time:16767ms step_avg:87.79ms +step:192/1670 train_time:16855ms step_avg:87.79ms +step:193/1670 train_time:16943ms step_avg:87.79ms +step:194/1670 train_time:17030ms step_avg:87.78ms +step:195/1670 train_time:17117ms step_avg:87.78ms +step:196/1670 train_time:17205ms step_avg:87.78ms +step:197/1670 train_time:17293ms step_avg:87.78ms +step:198/1670 train_time:17381ms step_avg:87.78ms +step:199/1670 train_time:17469ms step_avg:87.78ms +step:200/1670 train_time:17556ms step_avg:87.78ms +step:201/1670 train_time:17644ms step_avg:87.78ms +step:202/1670 train_time:17731ms step_avg:87.78ms +step:203/1670 train_time:17819ms step_avg:87.78ms +step:204/1670 train_time:17906ms step_avg:87.78ms +step:205/1670 train_time:17994ms step_avg:87.77ms +step:206/1670 train_time:18081ms step_avg:87.77ms +step:207/1670 train_time:18168ms step_avg:87.77ms +step:208/1670 train_time:18255ms step_avg:87.77ms +step:209/1670 train_time:18343ms step_avg:87.77ms +step:210/1670 train_time:18430ms step_avg:87.76ms +step:211/1670 train_time:18518ms step_avg:87.76ms +step:212/1670 train_time:18606ms step_avg:87.76ms +step:213/1670 train_time:18693ms step_avg:87.76ms +step:214/1670 train_time:18781ms step_avg:87.76ms +step:215/1670 train_time:18868ms step_avg:87.76ms +step:216/1670 train_time:18955ms step_avg:87.76ms +step:217/1670 train_time:19043ms step_avg:87.75ms +step:218/1670 train_time:19130ms step_avg:87.75ms +step:219/1670 train_time:19218ms step_avg:87.75ms +step:220/1670 train_time:19306ms step_avg:87.75ms +step:221/1670 train_time:19393ms step_avg:87.75ms +step:222/1670 train_time:19482ms step_avg:87.76ms +step:223/1670 train_time:19569ms step_avg:87.76ms +step:224/1670 train_time:19658ms step_avg:87.76ms +step:225/1670 train_time:19745ms step_avg:87.76ms +step:226/1670 train_time:19832ms step_avg:87.75ms +step:227/1670 train_time:19919ms step_avg:87.75ms +step:228/1670 train_time:20007ms step_avg:87.75ms +step:229/1670 train_time:20094ms step_avg:87.75ms +step:230/1670 train_time:20182ms step_avg:87.75ms +step:231/1670 train_time:20269ms step_avg:87.74ms +step:232/1670 train_time:20357ms step_avg:87.75ms +step:233/1670 train_time:20444ms step_avg:87.74ms +step:234/1670 train_time:20532ms step_avg:87.74ms +step:235/1670 train_time:20620ms step_avg:87.75ms +step:236/1670 train_time:20707ms step_avg:87.74ms +step:237/1670 train_time:20795ms step_avg:87.74ms +step:238/1670 train_time:20883ms step_avg:87.74ms +step:239/1670 train_time:20970ms step_avg:87.74ms +step:240/1670 train_time:21058ms step_avg:87.74ms +step:241/1670 train_time:21144ms step_avg:87.74ms +step:242/1670 train_time:21232ms step_avg:87.74ms +step:243/1670 train_time:21320ms step_avg:87.74ms +step:244/1670 train_time:21407ms step_avg:87.73ms +step:245/1670 train_time:21495ms step_avg:87.74ms +step:246/1670 train_time:21582ms step_avg:87.73ms +step:247/1670 train_time:21669ms step_avg:87.73ms +step:248/1670 train_time:21758ms step_avg:87.73ms +step:249/1670 train_time:21845ms step_avg:87.73ms +step:250/1670 train_time:21933ms step_avg:87.73ms +step:250/1670 val_loss:3.9793 train_time:22022ms step_avg:88.09ms +step:251/1670 train_time:22041ms step_avg:87.81ms +step:252/1670 train_time:22113ms step_avg:87.75ms +step:253/1670 train_time:22205ms step_avg:87.77ms +step:254/1670 train_time:22292ms step_avg:87.76ms +step:255/1670 train_time:22378ms step_avg:87.76ms +step:256/1670 train_time:22465ms step_avg:87.75ms +step:257/1670 train_time:22552ms step_avg:87.75ms +step:258/1670 train_time:22638ms step_avg:87.74ms +step:259/1670 train_time:22724ms step_avg:87.74ms +step:260/1670 train_time:22812ms step_avg:87.74ms +step:261/1670 train_time:22898ms step_avg:87.73ms +step:262/1670 train_time:22987ms step_avg:87.74ms +step:263/1670 train_time:23077ms step_avg:87.75ms +step:264/1670 train_time:23167ms step_avg:87.75ms +step:265/1670 train_time:23256ms step_avg:87.76ms +step:266/1670 train_time:23343ms step_avg:87.75ms +step:267/1670 train_time:23429ms step_avg:87.75ms +step:268/1670 train_time:23517ms step_avg:87.75ms +step:269/1670 train_time:23603ms step_avg:87.74ms +step:270/1670 train_time:23689ms step_avg:87.74ms +step:271/1670 train_time:23776ms step_avg:87.73ms +step:272/1670 train_time:23863ms step_avg:87.73ms +step:273/1670 train_time:23950ms step_avg:87.73ms +step:274/1670 train_time:24039ms step_avg:87.73ms +step:275/1670 train_time:24127ms step_avg:87.74ms +step:276/1670 train_time:24216ms step_avg:87.74ms +step:277/1670 train_time:24304ms step_avg:87.74ms +step:278/1670 train_time:24392ms step_avg:87.74ms +step:279/1670 train_time:24479ms step_avg:87.74ms +step:280/1670 train_time:24566ms step_avg:87.73ms +step:281/1670 train_time:24653ms step_avg:87.73ms +step:282/1670 train_time:24740ms step_avg:87.73ms +step:283/1670 train_time:24826ms step_avg:87.72ms +step:284/1670 train_time:24914ms step_avg:87.72ms +step:285/1670 train_time:25001ms step_avg:87.72ms +step:286/1670 train_time:25090ms step_avg:87.73ms +step:287/1670 train_time:25179ms step_avg:87.73ms +step:288/1670 train_time:25267ms step_avg:87.73ms +step:289/1670 train_time:25355ms step_avg:87.74ms +step:290/1670 train_time:25443ms step_avg:87.73ms +step:291/1670 train_time:25530ms step_avg:87.73ms +step:292/1670 train_time:25617ms step_avg:87.73ms +step:293/1670 train_time:25704ms step_avg:87.73ms +step:294/1670 train_time:25791ms step_avg:87.73ms +step:295/1670 train_time:25879ms step_avg:87.72ms +step:296/1670 train_time:25966ms step_avg:87.72ms +step:297/1670 train_time:26055ms step_avg:87.73ms +step:298/1670 train_time:26142ms step_avg:87.72ms +step:299/1670 train_time:26230ms step_avg:87.72ms +step:300/1670 train_time:26318ms step_avg:87.73ms +step:301/1670 train_time:26405ms step_avg:87.73ms +step:302/1670 train_time:26493ms step_avg:87.73ms +step:303/1670 train_time:26581ms step_avg:87.73ms +step:304/1670 train_time:26668ms step_avg:87.72ms +step:305/1670 train_time:26755ms step_avg:87.72ms +step:306/1670 train_time:26842ms step_avg:87.72ms +step:307/1670 train_time:26929ms step_avg:87.72ms +step:308/1670 train_time:27017ms step_avg:87.72ms +step:309/1670 train_time:27105ms step_avg:87.72ms +step:310/1670 train_time:27194ms step_avg:87.72ms +step:311/1670 train_time:27280ms step_avg:87.72ms +step:312/1670 train_time:27369ms step_avg:87.72ms +step:313/1670 train_time:27456ms step_avg:87.72ms +step:314/1670 train_time:27543ms step_avg:87.72ms +step:315/1670 train_time:27630ms step_avg:87.72ms +step:316/1670 train_time:27718ms step_avg:87.72ms +step:317/1670 train_time:27805ms step_avg:87.71ms +step:318/1670 train_time:27893ms step_avg:87.71ms +step:319/1670 train_time:27981ms step_avg:87.71ms +step:320/1670 train_time:28068ms step_avg:87.71ms +step:321/1670 train_time:28156ms step_avg:87.71ms +step:322/1670 train_time:28244ms step_avg:87.71ms +step:323/1670 train_time:28332ms step_avg:87.72ms +step:324/1670 train_time:28420ms step_avg:87.72ms +step:325/1670 train_time:28508ms step_avg:87.72ms +step:326/1670 train_time:28596ms step_avg:87.72ms +step:327/1670 train_time:28683ms step_avg:87.71ms +step:328/1670 train_time:28770ms step_avg:87.71ms +step:329/1670 train_time:28857ms step_avg:87.71ms +step:330/1670 train_time:28945ms step_avg:87.71ms +step:331/1670 train_time:29032ms step_avg:87.71ms +step:332/1670 train_time:29119ms step_avg:87.71ms +step:333/1670 train_time:29207ms step_avg:87.71ms +step:334/1670 train_time:29295ms step_avg:87.71ms +step:335/1670 train_time:29382ms step_avg:87.71ms +step:336/1670 train_time:29470ms step_avg:87.71ms +step:337/1670 train_time:29558ms step_avg:87.71ms +step:338/1670 train_time:29645ms step_avg:87.71ms +step:339/1670 train_time:29733ms step_avg:87.71ms +step:340/1670 train_time:29820ms step_avg:87.71ms +step:341/1670 train_time:29907ms step_avg:87.71ms +step:342/1670 train_time:29995ms step_avg:87.70ms +step:343/1670 train_time:30082ms step_avg:87.70ms +step:344/1670 train_time:30170ms step_avg:87.70ms +step:345/1670 train_time:30258ms step_avg:87.70ms +step:346/1670 train_time:30344ms step_avg:87.70ms +step:347/1670 train_time:30432ms step_avg:87.70ms +step:348/1670 train_time:30519ms step_avg:87.70ms +step:349/1670 train_time:30607ms step_avg:87.70ms +step:350/1670 train_time:30694ms step_avg:87.70ms +step:351/1670 train_time:30781ms step_avg:87.70ms +step:352/1670 train_time:30869ms step_avg:87.70ms +step:353/1670 train_time:30958ms step_avg:87.70ms +step:354/1670 train_time:31045ms step_avg:87.70ms +step:355/1670 train_time:31133ms step_avg:87.70ms +step:356/1670 train_time:31220ms step_avg:87.70ms +step:357/1670 train_time:31308ms step_avg:87.70ms +step:358/1670 train_time:31396ms step_avg:87.70ms +step:359/1670 train_time:31484ms step_avg:87.70ms +step:360/1670 train_time:31573ms step_avg:87.70ms +step:361/1670 train_time:31660ms step_avg:87.70ms +step:362/1670 train_time:31747ms step_avg:87.70ms +step:363/1670 train_time:31835ms step_avg:87.70ms +step:364/1670 train_time:31922ms step_avg:87.70ms +step:365/1670 train_time:32010ms step_avg:87.70ms +step:366/1670 train_time:32097ms step_avg:87.70ms +step:367/1670 train_time:32185ms step_avg:87.70ms +step:368/1670 train_time:32272ms step_avg:87.70ms +step:369/1670 train_time:32360ms step_avg:87.70ms +step:370/1670 train_time:32447ms step_avg:87.70ms +step:371/1670 train_time:32535ms step_avg:87.70ms +step:372/1670 train_time:32623ms step_avg:87.70ms +step:373/1670 train_time:32711ms step_avg:87.70ms +step:374/1670 train_time:32798ms step_avg:87.70ms +step:375/1670 train_time:32885ms step_avg:87.69ms +step:375/1670 val_loss:3.8259 train_time:32975ms step_avg:87.93ms +step:376/1670 train_time:32995ms step_avg:87.75ms +step:377/1670 train_time:33065ms step_avg:87.71ms +step:378/1670 train_time:33155ms step_avg:87.71ms +step:379/1670 train_time:33243ms step_avg:87.71ms +step:380/1670 train_time:33330ms step_avg:87.71ms +step:381/1670 train_time:33417ms step_avg:87.71ms +step:382/1670 train_time:33503ms step_avg:87.70ms +step:383/1670 train_time:33590ms step_avg:87.70ms +step:384/1670 train_time:33677ms step_avg:87.70ms +step:385/1670 train_time:33765ms step_avg:87.70ms +step:386/1670 train_time:33852ms step_avg:87.70ms +step:387/1670 train_time:33940ms step_avg:87.70ms +step:388/1670 train_time:34031ms step_avg:87.71ms +step:389/1670 train_time:34121ms step_avg:87.72ms +step:390/1670 train_time:34209ms step_avg:87.72ms +step:391/1670 train_time:34297ms step_avg:87.72ms +step:392/1670 train_time:34385ms step_avg:87.72ms +step:393/1670 train_time:34471ms step_avg:87.71ms +step:394/1670 train_time:34558ms step_avg:87.71ms +step:395/1670 train_time:34645ms step_avg:87.71ms +step:396/1670 train_time:34732ms step_avg:87.71ms +step:397/1670 train_time:34819ms step_avg:87.71ms +step:398/1670 train_time:34908ms step_avg:87.71ms +step:399/1670 train_time:34997ms step_avg:87.71ms +step:400/1670 train_time:35087ms step_avg:87.72ms +step:401/1670 train_time:35175ms step_avg:87.72ms +step:402/1670 train_time:35263ms step_avg:87.72ms +step:403/1670 train_time:35351ms step_avg:87.72ms +step:404/1670 train_time:35438ms step_avg:87.72ms +step:405/1670 train_time:35525ms step_avg:87.72ms +step:406/1670 train_time:35612ms step_avg:87.72ms +step:407/1670 train_time:35700ms step_avg:87.72ms +step:408/1670 train_time:35787ms step_avg:87.71ms +step:409/1670 train_time:35875ms step_avg:87.71ms +step:410/1670 train_time:35962ms step_avg:87.71ms +step:411/1670 train_time:36050ms step_avg:87.71ms +step:412/1670 train_time:36139ms step_avg:87.72ms +step:413/1670 train_time:36227ms step_avg:87.72ms +step:414/1670 train_time:36315ms step_avg:87.72ms +step:415/1670 train_time:36402ms step_avg:87.72ms +step:416/1670 train_time:36489ms step_avg:87.71ms +step:417/1670 train_time:36576ms step_avg:87.71ms +step:418/1670 train_time:36663ms step_avg:87.71ms +step:419/1670 train_time:36751ms step_avg:87.71ms +step:420/1670 train_time:36839ms step_avg:87.71ms +step:421/1670 train_time:36928ms step_avg:87.71ms +step:422/1670 train_time:37015ms step_avg:87.71ms +step:423/1670 train_time:37104ms step_avg:87.72ms +step:424/1670 train_time:37192ms step_avg:87.72ms +step:425/1670 train_time:37280ms step_avg:87.72ms +step:426/1670 train_time:37369ms step_avg:87.72ms +step:427/1670 train_time:37455ms step_avg:87.72ms +step:428/1670 train_time:37542ms step_avg:87.71ms +step:429/1670 train_time:37629ms step_avg:87.71ms +step:430/1670 train_time:37716ms step_avg:87.71ms +step:431/1670 train_time:37805ms step_avg:87.71ms +step:432/1670 train_time:37892ms step_avg:87.71ms +step:433/1670 train_time:37981ms step_avg:87.72ms +step:434/1670 train_time:38069ms step_avg:87.72ms +step:435/1670 train_time:38157ms step_avg:87.72ms +step:436/1670 train_time:38246ms step_avg:87.72ms +step:437/1670 train_time:38333ms step_avg:87.72ms +step:438/1670 train_time:38421ms step_avg:87.72ms +step:439/1670 train_time:38509ms step_avg:87.72ms +step:440/1670 train_time:38596ms step_avg:87.72ms +step:441/1670 train_time:38684ms step_avg:87.72ms +step:442/1670 train_time:38771ms step_avg:87.72ms +step:443/1670 train_time:38859ms step_avg:87.72ms +step:444/1670 train_time:38947ms step_avg:87.72ms +step:445/1670 train_time:39035ms step_avg:87.72ms +step:446/1670 train_time:39123ms step_avg:87.72ms +step:447/1670 train_time:39211ms step_avg:87.72ms +step:448/1670 train_time:39298ms step_avg:87.72ms +step:449/1670 train_time:39386ms step_avg:87.72ms +step:450/1670 train_time:39473ms step_avg:87.72ms +step:451/1670 train_time:39560ms step_avg:87.72ms +step:452/1670 train_time:39647ms step_avg:87.72ms +step:453/1670 train_time:39734ms step_avg:87.71ms +step:454/1670 train_time:39821ms step_avg:87.71ms +step:455/1670 train_time:39909ms step_avg:87.71ms +step:456/1670 train_time:39997ms step_avg:87.71ms +step:457/1670 train_time:40085ms step_avg:87.71ms +step:458/1670 train_time:40173ms step_avg:87.71ms +step:459/1670 train_time:40260ms step_avg:87.71ms +step:460/1670 train_time:40348ms step_avg:87.71ms +step:461/1670 train_time:40436ms step_avg:87.71ms +step:462/1670 train_time:40523ms step_avg:87.71ms +step:463/1670 train_time:40611ms step_avg:87.71ms +step:464/1670 train_time:40698ms step_avg:87.71ms +step:465/1670 train_time:40786ms step_avg:87.71ms +step:466/1670 train_time:40874ms step_avg:87.71ms +step:467/1670 train_time:40962ms step_avg:87.71ms +step:468/1670 train_time:41050ms step_avg:87.71ms +step:469/1670 train_time:41138ms step_avg:87.72ms +step:470/1670 train_time:41226ms step_avg:87.72ms +step:471/1670 train_time:41314ms step_avg:87.71ms +step:472/1670 train_time:41401ms step_avg:87.71ms +step:473/1670 train_time:41489ms step_avg:87.71ms +step:474/1670 train_time:41576ms step_avg:87.71ms +step:475/1670 train_time:41664ms step_avg:87.71ms +step:476/1670 train_time:41751ms step_avg:87.71ms +step:477/1670 train_time:41839ms step_avg:87.71ms +step:478/1670 train_time:41927ms step_avg:87.71ms +step:479/1670 train_time:42014ms step_avg:87.71ms +step:480/1670 train_time:42102ms step_avg:87.71ms +step:481/1670 train_time:42190ms step_avg:87.71ms +step:482/1670 train_time:42278ms step_avg:87.71ms +step:483/1670 train_time:42366ms step_avg:87.71ms +step:484/1670 train_time:42454ms step_avg:87.71ms +step:485/1670 train_time:42541ms step_avg:87.71ms +step:486/1670 train_time:42628ms step_avg:87.71ms +step:487/1670 train_time:42716ms step_avg:87.71ms +step:488/1670 train_time:42804ms step_avg:87.71ms +step:489/1670 train_time:42892ms step_avg:87.71ms +step:490/1670 train_time:42979ms step_avg:87.71ms +step:491/1670 train_time:43068ms step_avg:87.71ms +step:492/1670 train_time:43156ms step_avg:87.71ms +step:493/1670 train_time:43244ms step_avg:87.72ms +step:494/1670 train_time:43331ms step_avg:87.72ms +step:495/1670 train_time:43419ms step_avg:87.71ms +step:496/1670 train_time:43507ms step_avg:87.72ms +step:497/1670 train_time:43594ms step_avg:87.71ms +step:498/1670 train_time:43681ms step_avg:87.71ms +step:499/1670 train_time:43769ms step_avg:87.71ms +step:500/1670 train_time:43857ms step_avg:87.71ms +step:500/1670 val_loss:3.7194 train_time:43946ms step_avg:87.89ms +step:501/1670 train_time:43968ms step_avg:87.76ms +step:502/1670 train_time:44038ms step_avg:87.72ms +step:503/1670 train_time:44131ms step_avg:87.74ms +step:504/1670 train_time:44219ms step_avg:87.74ms +step:505/1670 train_time:44307ms step_avg:87.74ms +step:506/1670 train_time:44393ms step_avg:87.73ms +step:507/1670 train_time:44479ms step_avg:87.73ms +step:508/1670 train_time:44566ms step_avg:87.73ms +step:509/1670 train_time:44653ms step_avg:87.73ms +step:510/1670 train_time:44739ms step_avg:87.72ms +step:511/1670 train_time:44826ms step_avg:87.72ms +step:512/1670 train_time:44915ms step_avg:87.72ms +step:513/1670 train_time:45004ms step_avg:87.73ms +step:514/1670 train_time:45096ms step_avg:87.73ms +step:515/1670 train_time:45185ms step_avg:87.74ms +step:516/1670 train_time:45273ms step_avg:87.74ms +step:517/1670 train_time:45360ms step_avg:87.74ms +step:518/1670 train_time:45446ms step_avg:87.73ms +step:519/1670 train_time:45534ms step_avg:87.73ms +step:520/1670 train_time:45620ms step_avg:87.73ms +step:521/1670 train_time:45707ms step_avg:87.73ms +step:522/1670 train_time:45794ms step_avg:87.73ms +step:523/1670 train_time:45882ms step_avg:87.73ms +step:524/1670 train_time:45970ms step_avg:87.73ms +step:525/1670 train_time:46059ms step_avg:87.73ms +step:526/1670 train_time:46148ms step_avg:87.73ms +step:527/1670 train_time:46237ms step_avg:87.74ms +step:528/1670 train_time:46325ms step_avg:87.74ms +step:529/1670 train_time:46412ms step_avg:87.74ms +step:530/1670 train_time:46499ms step_avg:87.73ms +step:531/1670 train_time:46586ms step_avg:87.73ms +step:532/1670 train_time:46674ms step_avg:87.73ms +step:533/1670 train_time:46760ms step_avg:87.73ms +step:534/1670 train_time:46849ms step_avg:87.73ms +step:535/1670 train_time:46937ms step_avg:87.73ms +step:536/1670 train_time:47026ms step_avg:87.74ms +step:537/1670 train_time:47116ms step_avg:87.74ms +step:538/1670 train_time:47204ms step_avg:87.74ms +step:539/1670 train_time:47291ms step_avg:87.74ms +step:540/1670 train_time:47378ms step_avg:87.74ms +step:541/1670 train_time:47465ms step_avg:87.74ms +step:542/1670 train_time:47553ms step_avg:87.74ms +step:543/1670 train_time:47640ms step_avg:87.73ms +step:544/1670 train_time:47727ms step_avg:87.73ms +step:545/1670 train_time:47816ms step_avg:87.74ms +step:546/1670 train_time:47905ms step_avg:87.74ms +step:547/1670 train_time:47994ms step_avg:87.74ms +step:548/1670 train_time:48082ms step_avg:87.74ms +step:549/1670 train_time:48172ms step_avg:87.74ms +step:550/1670 train_time:48261ms step_avg:87.75ms +step:551/1670 train_time:48350ms step_avg:87.75ms +step:552/1670 train_time:48439ms step_avg:87.75ms +step:553/1670 train_time:48527ms step_avg:87.75ms +step:554/1670 train_time:48615ms step_avg:87.75ms +step:555/1670 train_time:48704ms step_avg:87.75ms +step:556/1670 train_time:48793ms step_avg:87.76ms +step:557/1670 train_time:48882ms step_avg:87.76ms +step:558/1670 train_time:48971ms step_avg:87.76ms +step:559/1670 train_time:49059ms step_avg:87.76ms +step:560/1670 train_time:49148ms step_avg:87.76ms +step:561/1670 train_time:49238ms step_avg:87.77ms +step:562/1670 train_time:49328ms step_avg:87.77ms +step:563/1670 train_time:49418ms step_avg:87.78ms +step:564/1670 train_time:49507ms step_avg:87.78ms +step:565/1670 train_time:49596ms step_avg:87.78ms +step:566/1670 train_time:49683ms step_avg:87.78ms +step:567/1670 train_time:49772ms step_avg:87.78ms +step:568/1670 train_time:49861ms step_avg:87.78ms +step:569/1670 train_time:49951ms step_avg:87.79ms +step:570/1670 train_time:50039ms step_avg:87.79ms +step:571/1670 train_time:50129ms step_avg:87.79ms +step:572/1670 train_time:50218ms step_avg:87.79ms +step:573/1670 train_time:50308ms step_avg:87.80ms +step:574/1670 train_time:50397ms step_avg:87.80ms +step:575/1670 train_time:50485ms step_avg:87.80ms +step:576/1670 train_time:50574ms step_avg:87.80ms +step:577/1670 train_time:50662ms step_avg:87.80ms +step:578/1670 train_time:50750ms step_avg:87.80ms +step:579/1670 train_time:50839ms step_avg:87.81ms +step:580/1670 train_time:50928ms step_avg:87.81ms +step:581/1670 train_time:51017ms step_avg:87.81ms +step:582/1670 train_time:51106ms step_avg:87.81ms +step:583/1670 train_time:51196ms step_avg:87.81ms +step:584/1670 train_time:51286ms step_avg:87.82ms +step:585/1670 train_time:51375ms step_avg:87.82ms +step:586/1670 train_time:51464ms step_avg:87.82ms +step:587/1670 train_time:51553ms step_avg:87.82ms +step:588/1670 train_time:51641ms step_avg:87.82ms +step:589/1670 train_time:51730ms step_avg:87.83ms +step:590/1670 train_time:51819ms step_avg:87.83ms +step:591/1670 train_time:51909ms step_avg:87.83ms +step:592/1670 train_time:51998ms step_avg:87.83ms +step:593/1670 train_time:52087ms step_avg:87.84ms +step:594/1670 train_time:52176ms step_avg:87.84ms +step:595/1670 train_time:52265ms step_avg:87.84ms +step:596/1670 train_time:52355ms step_avg:87.84ms +step:597/1670 train_time:52442ms step_avg:87.84ms +step:598/1670 train_time:52531ms step_avg:87.84ms +step:599/1670 train_time:52620ms step_avg:87.85ms +step:600/1670 train_time:52709ms step_avg:87.85ms +step:601/1670 train_time:52797ms step_avg:87.85ms +step:602/1670 train_time:52886ms step_avg:87.85ms +step:603/1670 train_time:52976ms step_avg:87.85ms +step:604/1670 train_time:53065ms step_avg:87.86ms +step:605/1670 train_time:53155ms step_avg:87.86ms +step:606/1670 train_time:53244ms step_avg:87.86ms +step:607/1670 train_time:53334ms step_avg:87.86ms +step:608/1670 train_time:53422ms step_avg:87.86ms +step:609/1670 train_time:53511ms step_avg:87.87ms +step:610/1670 train_time:53600ms step_avg:87.87ms +step:611/1670 train_time:53689ms step_avg:87.87ms +step:612/1670 train_time:53778ms step_avg:87.87ms +step:613/1670 train_time:53866ms step_avg:87.87ms +step:614/1670 train_time:53955ms step_avg:87.87ms +step:615/1670 train_time:54043ms step_avg:87.88ms +step:616/1670 train_time:54132ms step_avg:87.88ms +step:617/1670 train_time:54221ms step_avg:87.88ms +step:618/1670 train_time:54309ms step_avg:87.88ms +step:619/1670 train_time:54398ms step_avg:87.88ms +step:620/1670 train_time:54487ms step_avg:87.88ms +step:621/1670 train_time:54576ms step_avg:87.88ms +step:622/1670 train_time:54664ms step_avg:87.88ms +step:623/1670 train_time:54754ms step_avg:87.89ms +step:624/1670 train_time:54842ms step_avg:87.89ms +step:625/1670 train_time:54931ms step_avg:87.89ms +step:625/1670 val_loss:3.6188 train_time:55020ms step_avg:88.03ms +step:626/1670 train_time:55040ms step_avg:87.92ms +step:627/1670 train_time:55110ms step_avg:87.89ms +step:628/1670 train_time:55200ms step_avg:87.90ms +step:629/1670 train_time:55290ms step_avg:87.90ms +step:630/1670 train_time:55378ms step_avg:87.90ms +step:631/1670 train_time:55465ms step_avg:87.90ms +step:632/1670 train_time:55552ms step_avg:87.90ms +step:633/1670 train_time:55640ms step_avg:87.90ms +step:634/1670 train_time:55728ms step_avg:87.90ms +step:635/1670 train_time:55819ms step_avg:87.90ms +step:636/1670 train_time:55907ms step_avg:87.90ms +step:637/1670 train_time:55999ms step_avg:87.91ms +step:638/1670 train_time:56090ms step_avg:87.92ms +step:639/1670 train_time:56180ms step_avg:87.92ms +step:640/1670 train_time:56270ms step_avg:87.92ms +step:641/1670 train_time:56359ms step_avg:87.92ms +step:642/1670 train_time:56446ms step_avg:87.92ms +step:643/1670 train_time:56535ms step_avg:87.92ms +step:644/1670 train_time:56622ms step_avg:87.92ms +step:645/1670 train_time:56711ms step_avg:87.92ms +step:646/1670 train_time:56800ms step_avg:87.93ms +step:647/1670 train_time:56889ms step_avg:87.93ms +step:648/1670 train_time:56979ms step_avg:87.93ms +step:649/1670 train_time:57069ms step_avg:87.93ms +step:650/1670 train_time:57159ms step_avg:87.94ms +step:651/1670 train_time:57249ms step_avg:87.94ms +step:652/1670 train_time:57338ms step_avg:87.94ms +step:653/1670 train_time:57426ms step_avg:87.94ms +step:654/1670 train_time:57515ms step_avg:87.94ms +step:655/1670 train_time:57602ms step_avg:87.94ms +step:656/1670 train_time:57691ms step_avg:87.94ms +step:657/1670 train_time:57778ms step_avg:87.94ms +step:658/1670 train_time:57868ms step_avg:87.94ms +step:659/1670 train_time:57958ms step_avg:87.95ms +step:660/1670 train_time:58048ms step_avg:87.95ms +step:661/1670 train_time:58138ms step_avg:87.96ms +step:662/1670 train_time:58228ms step_avg:87.96ms +step:663/1670 train_time:58317ms step_avg:87.96ms +step:664/1670 train_time:58405ms step_avg:87.96ms +step:665/1670 train_time:58494ms step_avg:87.96ms +step:666/1670 train_time:58582ms step_avg:87.96ms +step:667/1670 train_time:58671ms step_avg:87.96ms +step:668/1670 train_time:58760ms step_avg:87.96ms +step:669/1670 train_time:58848ms step_avg:87.96ms +step:670/1670 train_time:58938ms step_avg:87.97ms +step:671/1670 train_time:59027ms step_avg:87.97ms +step:672/1670 train_time:59118ms step_avg:87.97ms +step:673/1670 train_time:59206ms step_avg:87.97ms +step:674/1670 train_time:59295ms step_avg:87.97ms +step:675/1670 train_time:59383ms step_avg:87.98ms +step:676/1670 train_time:59473ms step_avg:87.98ms +step:677/1670 train_time:59562ms step_avg:87.98ms +step:678/1670 train_time:59650ms step_avg:87.98ms +step:679/1670 train_time:59739ms step_avg:87.98ms +step:680/1670 train_time:59827ms step_avg:87.98ms +step:681/1670 train_time:59916ms step_avg:87.98ms +step:682/1670 train_time:60004ms step_avg:87.98ms +step:683/1670 train_time:60095ms step_avg:87.99ms +step:684/1670 train_time:60182ms step_avg:87.99ms +step:685/1670 train_time:60272ms step_avg:87.99ms +step:686/1670 train_time:60360ms step_avg:87.99ms +step:687/1670 train_time:60449ms step_avg:87.99ms +step:688/1670 train_time:60537ms step_avg:87.99ms +step:689/1670 train_time:60626ms step_avg:87.99ms +step:690/1670 train_time:60714ms step_avg:87.99ms +step:691/1670 train_time:60802ms step_avg:87.99ms +step:692/1670 train_time:60892ms step_avg:87.99ms +step:693/1670 train_time:60981ms step_avg:88.00ms +step:694/1670 train_time:61070ms step_avg:88.00ms +step:695/1670 train_time:61159ms step_avg:88.00ms +step:696/1670 train_time:61248ms step_avg:88.00ms +step:697/1670 train_time:61337ms step_avg:88.00ms +step:698/1670 train_time:61426ms step_avg:88.00ms +step:699/1670 train_time:61515ms step_avg:88.00ms +step:700/1670 train_time:61603ms step_avg:88.00ms +step:701/1670 train_time:61692ms step_avg:88.01ms +step:702/1670 train_time:61780ms step_avg:88.01ms +step:703/1670 train_time:61869ms step_avg:88.01ms +step:704/1670 train_time:61959ms step_avg:88.01ms +step:705/1670 train_time:62048ms step_avg:88.01ms +step:706/1670 train_time:62137ms step_avg:88.01ms +step:707/1670 train_time:62226ms step_avg:88.01ms +step:708/1670 train_time:62315ms step_avg:88.02ms +step:709/1670 train_time:62403ms step_avg:88.02ms +step:710/1670 train_time:62493ms step_avg:88.02ms +step:711/1670 train_time:62582ms step_avg:88.02ms +step:712/1670 train_time:62671ms step_avg:88.02ms +step:713/1670 train_time:62759ms step_avg:88.02ms +step:714/1670 train_time:62849ms step_avg:88.02ms +step:715/1670 train_time:62938ms step_avg:88.02ms +step:716/1670 train_time:63027ms step_avg:88.03ms +step:717/1670 train_time:63116ms step_avg:88.03ms +step:718/1670 train_time:63204ms step_avg:88.03ms +step:719/1670 train_time:63294ms step_avg:88.03ms +step:720/1670 train_time:63382ms step_avg:88.03ms +step:721/1670 train_time:63471ms step_avg:88.03ms +step:722/1670 train_time:63560ms step_avg:88.03ms +step:723/1670 train_time:63649ms step_avg:88.03ms +step:724/1670 train_time:63737ms step_avg:88.03ms +step:725/1670 train_time:63825ms step_avg:88.04ms +step:726/1670 train_time:63914ms step_avg:88.04ms +step:727/1670 train_time:64003ms step_avg:88.04ms +step:728/1670 train_time:64092ms step_avg:88.04ms +step:729/1670 train_time:64180ms step_avg:88.04ms +step:730/1670 train_time:64270ms step_avg:88.04ms +step:731/1670 train_time:64359ms step_avg:88.04ms +step:732/1670 train_time:64449ms step_avg:88.05ms +step:733/1670 train_time:64538ms step_avg:88.05ms +step:734/1670 train_time:64627ms step_avg:88.05ms +step:735/1670 train_time:64716ms step_avg:88.05ms +step:736/1670 train_time:64804ms step_avg:88.05ms +step:737/1670 train_time:64892ms step_avg:88.05ms +step:738/1670 train_time:64981ms step_avg:88.05ms +step:739/1670 train_time:65070ms step_avg:88.05ms +step:740/1670 train_time:65159ms step_avg:88.05ms +step:741/1670 train_time:65249ms step_avg:88.05ms +step:742/1670 train_time:65338ms step_avg:88.06ms +step:743/1670 train_time:65426ms step_avg:88.06ms +step:744/1670 train_time:65515ms step_avg:88.06ms +step:745/1670 train_time:65603ms step_avg:88.06ms +step:746/1670 train_time:65693ms step_avg:88.06ms +step:747/1670 train_time:65781ms step_avg:88.06ms +step:748/1670 train_time:65870ms step_avg:88.06ms +step:749/1670 train_time:65960ms step_avg:88.06ms +step:750/1670 train_time:66049ms step_avg:88.07ms +step:750/1670 val_loss:3.5666 train_time:66140ms step_avg:88.19ms +step:751/1670 train_time:66161ms step_avg:88.10ms +step:752/1670 train_time:66231ms step_avg:88.07ms +step:753/1670 train_time:66320ms step_avg:88.07ms +step:754/1670 train_time:66410ms step_avg:88.08ms +step:755/1670 train_time:66498ms step_avg:88.08ms +step:756/1670 train_time:66587ms step_avg:88.08ms +step:757/1670 train_time:66676ms step_avg:88.08ms +step:758/1670 train_time:66763ms step_avg:88.08ms +step:759/1670 train_time:66852ms step_avg:88.08ms +step:760/1670 train_time:66940ms step_avg:88.08ms +step:761/1670 train_time:67028ms step_avg:88.08ms +step:762/1670 train_time:67118ms step_avg:88.08ms +step:763/1670 train_time:67207ms step_avg:88.08ms +step:764/1670 train_time:67299ms step_avg:88.09ms +step:765/1670 train_time:67388ms step_avg:88.09ms +step:766/1670 train_time:67476ms step_avg:88.09ms +step:767/1670 train_time:67565ms step_avg:88.09ms +step:768/1670 train_time:67653ms step_avg:88.09ms +step:769/1670 train_time:67741ms step_avg:88.09ms +step:770/1670 train_time:67830ms step_avg:88.09ms +step:771/1670 train_time:67919ms step_avg:88.09ms +step:772/1670 train_time:68008ms step_avg:88.09ms +step:773/1670 train_time:68098ms step_avg:88.10ms +step:774/1670 train_time:68187ms step_avg:88.10ms +step:775/1670 train_time:68278ms step_avg:88.10ms +step:776/1670 train_time:68367ms step_avg:88.10ms +step:777/1670 train_time:68456ms step_avg:88.10ms +step:778/1670 train_time:68544ms step_avg:88.10ms +step:779/1670 train_time:68633ms step_avg:88.10ms +step:780/1670 train_time:68721ms step_avg:88.10ms +step:781/1670 train_time:68811ms step_avg:88.11ms +step:782/1670 train_time:68899ms step_avg:88.11ms +step:783/1670 train_time:68989ms step_avg:88.11ms +step:784/1670 train_time:69078ms step_avg:88.11ms +step:785/1670 train_time:69166ms step_avg:88.11ms +step:786/1670 train_time:69255ms step_avg:88.11ms +step:787/1670 train_time:69344ms step_avg:88.11ms +step:788/1670 train_time:69434ms step_avg:88.11ms +step:789/1670 train_time:69522ms step_avg:88.11ms +step:790/1670 train_time:69612ms step_avg:88.12ms +step:791/1670 train_time:69700ms step_avg:88.12ms +step:792/1670 train_time:69789ms step_avg:88.12ms +step:793/1670 train_time:69878ms step_avg:88.12ms +step:794/1670 train_time:69966ms step_avg:88.12ms +step:795/1670 train_time:70054ms step_avg:88.12ms +step:796/1670 train_time:70143ms step_avg:88.12ms +step:797/1670 train_time:70232ms step_avg:88.12ms +step:798/1670 train_time:70321ms step_avg:88.12ms +step:799/1670 train_time:70410ms step_avg:88.12ms +step:800/1670 train_time:70501ms step_avg:88.13ms +step:801/1670 train_time:70589ms step_avg:88.13ms +step:802/1670 train_time:70678ms step_avg:88.13ms +step:803/1670 train_time:70766ms step_avg:88.13ms +step:804/1670 train_time:70856ms step_avg:88.13ms +step:805/1670 train_time:70943ms step_avg:88.13ms +step:806/1670 train_time:71033ms step_avg:88.13ms +step:807/1670 train_time:71122ms step_avg:88.13ms +step:808/1670 train_time:71211ms step_avg:88.13ms +step:809/1670 train_time:71301ms step_avg:88.13ms +step:810/1670 train_time:71390ms step_avg:88.14ms +step:811/1670 train_time:71480ms step_avg:88.14ms +step:812/1670 train_time:71568ms step_avg:88.14ms +step:813/1670 train_time:71657ms step_avg:88.14ms +step:814/1670 train_time:71745ms step_avg:88.14ms +step:815/1670 train_time:71835ms step_avg:88.14ms +step:816/1670 train_time:71923ms step_avg:88.14ms +step:817/1670 train_time:72012ms step_avg:88.14ms +step:818/1670 train_time:72100ms step_avg:88.14ms +step:819/1670 train_time:72190ms step_avg:88.14ms +step:820/1670 train_time:72279ms step_avg:88.14ms +step:821/1670 train_time:72368ms step_avg:88.15ms +step:822/1670 train_time:72457ms step_avg:88.15ms +step:823/1670 train_time:72545ms step_avg:88.15ms +step:824/1670 train_time:72634ms step_avg:88.15ms +step:825/1670 train_time:72722ms step_avg:88.15ms +step:826/1670 train_time:72811ms step_avg:88.15ms +step:827/1670 train_time:72900ms step_avg:88.15ms +step:828/1670 train_time:72989ms step_avg:88.15ms +step:829/1670 train_time:73078ms step_avg:88.15ms +step:830/1670 train_time:73166ms step_avg:88.15ms +step:831/1670 train_time:73255ms step_avg:88.15ms +step:832/1670 train_time:73343ms step_avg:88.15ms +step:833/1670 train_time:73433ms step_avg:88.15ms +step:834/1670 train_time:73521ms step_avg:88.16ms +step:835/1670 train_time:73611ms step_avg:88.16ms +step:836/1670 train_time:73699ms step_avg:88.16ms +step:837/1670 train_time:73789ms step_avg:88.16ms +step:838/1670 train_time:73878ms step_avg:88.16ms +step:839/1670 train_time:73968ms step_avg:88.16ms +step:840/1670 train_time:74056ms step_avg:88.16ms +step:841/1670 train_time:74145ms step_avg:88.16ms +step:842/1670 train_time:74234ms step_avg:88.16ms +step:843/1670 train_time:74322ms step_avg:88.16ms +step:844/1670 train_time:74412ms step_avg:88.17ms +step:845/1670 train_time:74502ms step_avg:88.17ms +step:846/1670 train_time:74590ms step_avg:88.17ms +step:847/1670 train_time:74679ms step_avg:88.17ms +step:848/1670 train_time:74767ms step_avg:88.17ms +step:849/1670 train_time:74856ms step_avg:88.17ms +step:850/1670 train_time:74945ms step_avg:88.17ms +step:851/1670 train_time:75034ms step_avg:88.17ms +step:852/1670 train_time:75123ms step_avg:88.17ms +step:853/1670 train_time:75211ms step_avg:88.17ms +step:854/1670 train_time:75301ms step_avg:88.17ms +step:855/1670 train_time:75390ms step_avg:88.18ms +step:856/1670 train_time:75479ms step_avg:88.18ms +step:857/1670 train_time:75567ms step_avg:88.18ms +step:858/1670 train_time:75656ms step_avg:88.18ms +step:859/1670 train_time:75744ms step_avg:88.18ms +step:860/1670 train_time:75834ms step_avg:88.18ms +step:861/1670 train_time:75922ms step_avg:88.18ms +step:862/1670 train_time:76011ms step_avg:88.18ms +step:863/1670 train_time:76100ms step_avg:88.18ms +step:864/1670 train_time:76190ms step_avg:88.18ms +step:865/1670 train_time:76279ms step_avg:88.18ms +step:866/1670 train_time:76368ms step_avg:88.19ms +step:867/1670 train_time:76458ms step_avg:88.19ms +step:868/1670 train_time:76547ms step_avg:88.19ms +step:869/1670 train_time:76636ms step_avg:88.19ms +step:870/1670 train_time:76725ms step_avg:88.19ms +step:871/1670 train_time:76814ms step_avg:88.19ms +step:872/1670 train_time:76903ms step_avg:88.19ms +step:873/1670 train_time:76992ms step_avg:88.19ms +step:874/1670 train_time:77081ms step_avg:88.19ms +step:875/1670 train_time:77170ms step_avg:88.19ms +step:875/1670 val_loss:3.5199 train_time:77260ms step_avg:88.30ms +step:876/1670 train_time:77289ms step_avg:88.23ms +step:877/1670 train_time:77355ms step_avg:88.20ms +step:878/1670 train_time:77449ms step_avg:88.21ms +step:879/1670 train_time:77538ms step_avg:88.21ms +step:880/1670 train_time:77627ms step_avg:88.21ms +step:881/1670 train_time:77713ms step_avg:88.21ms +step:882/1670 train_time:77801ms step_avg:88.21ms +step:883/1670 train_time:77889ms step_avg:88.21ms +step:884/1670 train_time:77976ms step_avg:88.21ms +step:885/1670 train_time:78064ms step_avg:88.21ms +step:886/1670 train_time:78153ms step_avg:88.21ms +step:887/1670 train_time:78245ms step_avg:88.21ms +step:888/1670 train_time:78335ms step_avg:88.22ms +step:889/1670 train_time:78428ms step_avg:88.22ms +step:890/1670 train_time:78517ms step_avg:88.22ms +step:891/1670 train_time:78606ms step_avg:88.22ms +step:892/1670 train_time:78694ms step_avg:88.22ms +step:893/1670 train_time:78783ms step_avg:88.22ms +step:894/1670 train_time:78871ms step_avg:88.22ms +step:895/1670 train_time:78959ms step_avg:88.22ms +step:896/1670 train_time:79047ms step_avg:88.22ms +step:897/1670 train_time:79135ms step_avg:88.22ms +step:898/1670 train_time:79226ms step_avg:88.22ms +step:899/1670 train_time:79316ms step_avg:88.23ms +step:900/1670 train_time:79408ms step_avg:88.23ms +step:901/1670 train_time:79498ms step_avg:88.23ms +step:902/1670 train_time:79588ms step_avg:88.23ms +step:903/1670 train_time:79677ms step_avg:88.24ms +step:904/1670 train_time:79766ms step_avg:88.24ms +step:905/1670 train_time:79854ms step_avg:88.24ms +step:906/1670 train_time:79942ms step_avg:88.24ms +step:907/1670 train_time:80030ms step_avg:88.24ms +step:908/1670 train_time:80118ms step_avg:88.24ms +step:909/1670 train_time:80208ms step_avg:88.24ms +step:910/1670 train_time:80298ms step_avg:88.24ms +step:911/1670 train_time:80388ms step_avg:88.24ms +step:912/1670 train_time:80478ms step_avg:88.24ms +step:913/1670 train_time:80567ms step_avg:88.24ms +step:914/1670 train_time:80657ms step_avg:88.25ms +step:915/1670 train_time:80746ms step_avg:88.25ms +step:916/1670 train_time:80835ms step_avg:88.25ms +step:917/1670 train_time:80925ms step_avg:88.25ms +step:918/1670 train_time:81013ms step_avg:88.25ms +step:919/1670 train_time:81102ms step_avg:88.25ms +step:920/1670 train_time:81191ms step_avg:88.25ms +step:921/1670 train_time:81279ms step_avg:88.25ms +step:922/1670 train_time:81369ms step_avg:88.25ms +step:923/1670 train_time:81458ms step_avg:88.25ms +step:924/1670 train_time:81548ms step_avg:88.26ms +step:925/1670 train_time:81637ms step_avg:88.26ms +step:926/1670 train_time:81727ms step_avg:88.26ms +step:927/1670 train_time:81816ms step_avg:88.26ms +step:928/1670 train_time:81906ms step_avg:88.26ms +step:929/1670 train_time:81995ms step_avg:88.26ms +step:930/1670 train_time:82084ms step_avg:88.26ms +step:931/1670 train_time:82172ms step_avg:88.26ms +step:932/1670 train_time:82261ms step_avg:88.26ms +step:933/1670 train_time:82349ms step_avg:88.26ms +step:934/1670 train_time:82439ms step_avg:88.26ms +step:935/1670 train_time:82529ms step_avg:88.27ms +step:936/1670 train_time:82619ms step_avg:88.27ms +step:937/1670 train_time:82707ms step_avg:88.27ms +step:938/1670 train_time:82796ms step_avg:88.27ms +step:939/1670 train_time:82885ms step_avg:88.27ms +step:940/1670 train_time:82973ms step_avg:88.27ms +step:941/1670 train_time:83062ms step_avg:88.27ms +step:942/1670 train_time:83151ms step_avg:88.27ms +step:943/1670 train_time:83240ms step_avg:88.27ms +step:944/1670 train_time:83328ms step_avg:88.27ms +step:945/1670 train_time:83417ms step_avg:88.27ms +step:946/1670 train_time:83507ms step_avg:88.27ms +step:947/1670 train_time:83597ms step_avg:88.28ms +step:948/1670 train_time:83686ms step_avg:88.28ms +step:949/1670 train_time:83774ms step_avg:88.28ms +step:950/1670 train_time:83864ms step_avg:88.28ms +step:951/1670 train_time:83951ms step_avg:88.28ms +step:952/1670 train_time:84040ms step_avg:88.28ms +step:953/1670 train_time:84129ms step_avg:88.28ms +step:954/1670 train_time:84218ms step_avg:88.28ms +step:955/1670 train_time:84307ms step_avg:88.28ms +step:956/1670 train_time:84395ms step_avg:88.28ms +step:957/1670 train_time:84485ms step_avg:88.28ms +step:958/1670 train_time:84573ms step_avg:88.28ms +step:959/1670 train_time:84662ms step_avg:88.28ms +step:960/1670 train_time:84751ms step_avg:88.28ms +step:961/1670 train_time:84841ms step_avg:88.28ms +step:962/1670 train_time:84929ms step_avg:88.28ms +step:963/1670 train_time:85018ms step_avg:88.28ms +step:964/1670 train_time:85107ms step_avg:88.29ms +step:965/1670 train_time:85195ms step_avg:88.29ms +step:966/1670 train_time:85285ms step_avg:88.29ms +step:967/1670 train_time:85374ms step_avg:88.29ms +step:968/1670 train_time:85463ms step_avg:88.29ms +step:969/1670 train_time:85552ms step_avg:88.29ms +step:970/1670 train_time:85641ms step_avg:88.29ms +step:971/1670 train_time:85729ms step_avg:88.29ms +step:972/1670 train_time:85818ms step_avg:88.29ms +step:973/1670 train_time:85907ms step_avg:88.29ms +step:974/1670 train_time:85996ms step_avg:88.29ms +step:975/1670 train_time:86085ms step_avg:88.29ms +step:976/1670 train_time:86173ms step_avg:88.29ms +step:977/1670 train_time:86262ms step_avg:88.29ms +step:978/1670 train_time:86351ms step_avg:88.29ms +step:979/1670 train_time:86440ms step_avg:88.29ms +step:980/1670 train_time:86529ms step_avg:88.29ms +step:981/1670 train_time:86618ms step_avg:88.30ms +step:982/1670 train_time:86708ms step_avg:88.30ms +step:983/1670 train_time:86798ms step_avg:88.30ms +step:984/1670 train_time:86887ms step_avg:88.30ms +step:985/1670 train_time:86975ms step_avg:88.30ms +step:986/1670 train_time:87065ms step_avg:88.30ms +step:987/1670 train_time:87152ms step_avg:88.30ms +step:988/1670 train_time:87241ms step_avg:88.30ms +step:989/1670 train_time:87331ms step_avg:88.30ms +step:990/1670 train_time:87419ms step_avg:88.30ms +step:991/1670 train_time:87509ms step_avg:88.30ms +step:992/1670 train_time:87599ms step_avg:88.31ms +step:993/1670 train_time:87688ms step_avg:88.31ms +step:994/1670 train_time:87777ms step_avg:88.31ms +step:995/1670 train_time:87867ms step_avg:88.31ms +step:996/1670 train_time:87955ms step_avg:88.31ms +step:997/1670 train_time:88045ms step_avg:88.31ms +step:998/1670 train_time:88133ms step_avg:88.31ms +step:999/1670 train_time:88222ms step_avg:88.31ms +step:1000/1670 train_time:88311ms step_avg:88.31ms +step:1000/1670 val_loss:3.4692 train_time:88401ms step_avg:88.40ms +step:1001/1670 train_time:88424ms step_avg:88.34ms +step:1002/1670 train_time:88493ms step_avg:88.32ms +step:1003/1670 train_time:88585ms step_avg:88.32ms +step:1004/1670 train_time:88673ms step_avg:88.32ms +step:1005/1670 train_time:88762ms step_avg:88.32ms +step:1006/1670 train_time:88849ms step_avg:88.32ms +step:1007/1670 train_time:88937ms step_avg:88.32ms +step:1008/1670 train_time:89024ms step_avg:88.32ms +step:1009/1670 train_time:89112ms step_avg:88.32ms +step:1010/1670 train_time:89201ms step_avg:88.32ms +step:1011/1670 train_time:89289ms step_avg:88.32ms +step:1012/1670 train_time:89379ms step_avg:88.32ms +step:1013/1670 train_time:89472ms step_avg:88.32ms +step:1014/1670 train_time:89564ms step_avg:88.33ms +step:1015/1670 train_time:89653ms step_avg:88.33ms +step:1016/1670 train_time:89744ms step_avg:88.33ms +step:1017/1670 train_time:89832ms step_avg:88.33ms +step:1018/1670 train_time:89920ms step_avg:88.33ms +step:1019/1670 train_time:90009ms step_avg:88.33ms +step:1020/1670 train_time:90096ms step_avg:88.33ms +step:1021/1670 train_time:90185ms step_avg:88.33ms +step:1022/1670 train_time:90272ms step_avg:88.33ms +step:1023/1670 train_time:90362ms step_avg:88.33ms +step:1024/1670 train_time:90452ms step_avg:88.33ms +step:1025/1670 train_time:90543ms step_avg:88.33ms +step:1026/1670 train_time:90633ms step_avg:88.34ms +step:1027/1670 train_time:90722ms step_avg:88.34ms +step:1028/1670 train_time:90811ms step_avg:88.34ms +step:1029/1670 train_time:90899ms step_avg:88.34ms +step:1030/1670 train_time:90988ms step_avg:88.34ms +step:1031/1670 train_time:91076ms step_avg:88.34ms +step:1032/1670 train_time:91166ms step_avg:88.34ms +step:1033/1670 train_time:91254ms step_avg:88.34ms +step:1034/1670 train_time:91343ms step_avg:88.34ms +step:1035/1670 train_time:91432ms step_avg:88.34ms +step:1036/1670 train_time:91522ms step_avg:88.34ms +step:1037/1670 train_time:91613ms step_avg:88.34ms +step:1038/1670 train_time:91702ms step_avg:88.34ms +step:1039/1670 train_time:91791ms step_avg:88.35ms +step:1040/1670 train_time:91879ms step_avg:88.35ms +step:1041/1670 train_time:91968ms step_avg:88.35ms +step:1042/1670 train_time:92056ms step_avg:88.35ms +step:1043/1670 train_time:92145ms step_avg:88.35ms +step:1044/1670 train_time:92234ms step_avg:88.35ms +step:1045/1670 train_time:92322ms step_avg:88.35ms +step:1046/1670 train_time:92411ms step_avg:88.35ms +step:1047/1670 train_time:92501ms step_avg:88.35ms +step:1048/1670 train_time:92591ms step_avg:88.35ms +step:1049/1670 train_time:92681ms step_avg:88.35ms +step:1050/1670 train_time:92770ms step_avg:88.35ms +step:1051/1670 train_time:92859ms step_avg:88.35ms +step:1052/1670 train_time:92947ms step_avg:88.35ms +step:1053/1670 train_time:93035ms step_avg:88.35ms +step:1054/1670 train_time:93123ms step_avg:88.35ms +step:1055/1670 train_time:93212ms step_avg:88.35ms +step:1056/1670 train_time:93301ms step_avg:88.35ms +step:1057/1670 train_time:93391ms step_avg:88.35ms +step:1058/1670 train_time:93480ms step_avg:88.36ms +step:1059/1670 train_time:93571ms step_avg:88.36ms +step:1060/1670 train_time:93660ms step_avg:88.36ms +step:1061/1670 train_time:93749ms step_avg:88.36ms +step:1062/1670 train_time:93837ms step_avg:88.36ms +step:1063/1670 train_time:93927ms step_avg:88.36ms +step:1064/1670 train_time:94015ms step_avg:88.36ms +step:1065/1670 train_time:94103ms step_avg:88.36ms +step:1066/1670 train_time:94192ms step_avg:88.36ms +step:1067/1670 train_time:94280ms step_avg:88.36ms +step:1068/1670 train_time:94369ms step_avg:88.36ms +step:1069/1670 train_time:94459ms step_avg:88.36ms +step:1070/1670 train_time:94549ms step_avg:88.36ms +step:1071/1670 train_time:94638ms step_avg:88.36ms +step:1072/1670 train_time:94727ms step_avg:88.37ms +step:1073/1670 train_time:94816ms step_avg:88.37ms +step:1074/1670 train_time:94906ms step_avg:88.37ms +step:1075/1670 train_time:94994ms step_avg:88.37ms +step:1076/1670 train_time:95083ms step_avg:88.37ms +step:1077/1670 train_time:95172ms step_avg:88.37ms +step:1078/1670 train_time:95261ms step_avg:88.37ms +step:1079/1670 train_time:95350ms step_avg:88.37ms +step:1080/1670 train_time:95439ms step_avg:88.37ms +step:1081/1670 train_time:95528ms step_avg:88.37ms +step:1082/1670 train_time:95617ms step_avg:88.37ms +step:1083/1670 train_time:95706ms step_avg:88.37ms +step:1084/1670 train_time:95795ms step_avg:88.37ms +step:1085/1670 train_time:95883ms step_avg:88.37ms +step:1086/1670 train_time:95971ms step_avg:88.37ms +step:1087/1670 train_time:96060ms step_avg:88.37ms +step:1088/1670 train_time:96150ms step_avg:88.37ms +step:1089/1670 train_time:96238ms step_avg:88.37ms +step:1090/1670 train_time:96328ms step_avg:88.37ms +step:1091/1670 train_time:96418ms step_avg:88.38ms +step:1092/1670 train_time:96508ms step_avg:88.38ms +step:1093/1670 train_time:96597ms step_avg:88.38ms +step:1094/1670 train_time:96687ms step_avg:88.38ms +step:1095/1670 train_time:96776ms step_avg:88.38ms +step:1096/1670 train_time:96867ms step_avg:88.38ms +step:1097/1670 train_time:96955ms step_avg:88.38ms +step:1098/1670 train_time:97045ms step_avg:88.38ms +step:1099/1670 train_time:97134ms step_avg:88.38ms +step:1100/1670 train_time:97224ms step_avg:88.39ms +step:1101/1670 train_time:97313ms step_avg:88.39ms +step:1102/1670 train_time:97403ms step_avg:88.39ms +step:1103/1670 train_time:97491ms step_avg:88.39ms +step:1104/1670 train_time:97581ms step_avg:88.39ms +step:1105/1670 train_time:97671ms step_avg:88.39ms +step:1106/1670 train_time:97762ms step_avg:88.39ms +step:1107/1670 train_time:97851ms step_avg:88.39ms +step:1108/1670 train_time:97940ms step_avg:88.39ms +step:1109/1670 train_time:98030ms step_avg:88.40ms +step:1110/1670 train_time:98120ms step_avg:88.40ms +step:1111/1670 train_time:98210ms step_avg:88.40ms +step:1112/1670 train_time:98300ms step_avg:88.40ms +step:1113/1670 train_time:98390ms step_avg:88.40ms +step:1114/1670 train_time:98480ms step_avg:88.40ms +step:1115/1670 train_time:98570ms step_avg:88.40ms +step:1116/1670 train_time:98660ms step_avg:88.40ms +step:1117/1670 train_time:98750ms step_avg:88.41ms +step:1118/1670 train_time:98840ms step_avg:88.41ms +step:1119/1670 train_time:98930ms step_avg:88.41ms +step:1120/1670 train_time:99020ms step_avg:88.41ms +step:1121/1670 train_time:99109ms step_avg:88.41ms +step:1122/1670 train_time:99199ms step_avg:88.41ms +step:1123/1670 train_time:99290ms step_avg:88.41ms +step:1124/1670 train_time:99380ms step_avg:88.42ms +step:1125/1670 train_time:99470ms step_avg:88.42ms +step:1125/1670 val_loss:3.4150 train_time:99561ms step_avg:88.50ms +step:1126/1670 train_time:99585ms step_avg:88.44ms +step:1127/1670 train_time:99652ms step_avg:88.42ms +step:1128/1670 train_time:99742ms step_avg:88.42ms +step:1129/1670 train_time:99833ms step_avg:88.43ms +step:1130/1670 train_time:99922ms step_avg:88.43ms +step:1131/1670 train_time:100010ms step_avg:88.43ms +step:1132/1670 train_time:100099ms step_avg:88.43ms +step:1133/1670 train_time:100188ms step_avg:88.43ms +step:1134/1670 train_time:100276ms step_avg:88.43ms +step:1135/1670 train_time:100366ms step_avg:88.43ms +step:1136/1670 train_time:100457ms step_avg:88.43ms +step:1137/1670 train_time:100551ms step_avg:88.44ms +step:1138/1670 train_time:100641ms step_avg:88.44ms +step:1139/1670 train_time:100733ms step_avg:88.44ms +step:1140/1670 train_time:100823ms step_avg:88.44ms +step:1141/1670 train_time:100912ms step_avg:88.44ms +step:1142/1670 train_time:101001ms step_avg:88.44ms +step:1143/1670 train_time:101090ms step_avg:88.44ms +step:1144/1670 train_time:101178ms step_avg:88.44ms +step:1145/1670 train_time:101267ms step_avg:88.44ms +step:1146/1670 train_time:101355ms step_avg:88.44ms +step:1147/1670 train_time:101447ms step_avg:88.45ms +step:1148/1670 train_time:101538ms step_avg:88.45ms +step:1149/1670 train_time:101629ms step_avg:88.45ms +step:1150/1670 train_time:101721ms step_avg:88.45ms +step:1151/1670 train_time:101812ms step_avg:88.46ms +step:1152/1670 train_time:101902ms step_avg:88.46ms +step:1153/1670 train_time:101991ms step_avg:88.46ms +step:1154/1670 train_time:102080ms step_avg:88.46ms +step:1155/1670 train_time:102169ms step_avg:88.46ms +step:1156/1670 train_time:102258ms step_avg:88.46ms +step:1157/1670 train_time:102347ms step_avg:88.46ms +step:1158/1670 train_time:102437ms step_avg:88.46ms +step:1159/1670 train_time:102528ms step_avg:88.46ms +step:1160/1670 train_time:102617ms step_avg:88.46ms +step:1161/1670 train_time:102708ms step_avg:88.47ms +step:1162/1670 train_time:102798ms step_avg:88.47ms +step:1163/1670 train_time:102889ms step_avg:88.47ms +step:1164/1670 train_time:102978ms step_avg:88.47ms +step:1165/1670 train_time:103067ms step_avg:88.47ms +step:1166/1670 train_time:103156ms step_avg:88.47ms +step:1167/1670 train_time:103244ms step_avg:88.47ms +step:1168/1670 train_time:103333ms step_avg:88.47ms +step:1169/1670 train_time:103423ms step_avg:88.47ms +step:1170/1670 train_time:103513ms step_avg:88.47ms +step:1171/1670 train_time:103603ms step_avg:88.47ms +step:1172/1670 train_time:103693ms step_avg:88.48ms +step:1173/1670 train_time:103784ms step_avg:88.48ms +step:1174/1670 train_time:103873ms step_avg:88.48ms +step:1175/1670 train_time:103963ms step_avg:88.48ms +step:1176/1670 train_time:104053ms step_avg:88.48ms +step:1177/1670 train_time:104142ms step_avg:88.48ms +step:1178/1670 train_time:104231ms step_avg:88.48ms +step:1179/1670 train_time:104320ms step_avg:88.48ms +step:1180/1670 train_time:104410ms step_avg:88.48ms +step:1181/1670 train_time:104500ms step_avg:88.48ms +step:1182/1670 train_time:104590ms step_avg:88.49ms +step:1183/1670 train_time:104680ms step_avg:88.49ms +step:1184/1670 train_time:104771ms step_avg:88.49ms +step:1185/1670 train_time:104862ms step_avg:88.49ms +step:1186/1670 train_time:104951ms step_avg:88.49ms +step:1187/1670 train_time:105040ms step_avg:88.49ms +step:1188/1670 train_time:105131ms step_avg:88.49ms +step:1189/1670 train_time:105220ms step_avg:88.49ms +step:1190/1670 train_time:105309ms step_avg:88.50ms +step:1191/1670 train_time:105399ms step_avg:88.50ms +step:1192/1670 train_time:105490ms step_avg:88.50ms +step:1193/1670 train_time:105581ms step_avg:88.50ms +step:1194/1670 train_time:105671ms step_avg:88.50ms +step:1195/1670 train_time:105761ms step_avg:88.50ms +step:1196/1670 train_time:105851ms step_avg:88.50ms +step:1197/1670 train_time:105940ms step_avg:88.50ms +step:1198/1670 train_time:106030ms step_avg:88.51ms +step:1199/1670 train_time:106119ms step_avg:88.51ms +step:1200/1670 train_time:106209ms step_avg:88.51ms +step:1201/1670 train_time:106298ms step_avg:88.51ms +step:1202/1670 train_time:106389ms step_avg:88.51ms +step:1203/1670 train_time:106478ms step_avg:88.51ms +step:1204/1670 train_time:106567ms step_avg:88.51ms +step:1205/1670 train_time:106657ms step_avg:88.51ms +step:1206/1670 train_time:106747ms step_avg:88.51ms +step:1207/1670 train_time:106837ms step_avg:88.51ms +step:1208/1670 train_time:106929ms step_avg:88.52ms +step:1209/1670 train_time:107019ms step_avg:88.52ms +step:1210/1670 train_time:107108ms step_avg:88.52ms +step:1211/1670 train_time:107198ms step_avg:88.52ms +step:1212/1670 train_time:107288ms step_avg:88.52ms +step:1213/1670 train_time:107377ms step_avg:88.52ms +step:1214/1670 train_time:107467ms step_avg:88.52ms +step:1215/1670 train_time:107556ms step_avg:88.52ms +step:1216/1670 train_time:107646ms step_avg:88.52ms +step:1217/1670 train_time:107735ms step_avg:88.53ms +step:1218/1670 train_time:107825ms step_avg:88.53ms +step:1219/1670 train_time:107915ms step_avg:88.53ms +step:1220/1670 train_time:108004ms step_avg:88.53ms +step:1221/1670 train_time:108093ms step_avg:88.53ms +step:1222/1670 train_time:108183ms step_avg:88.53ms +step:1223/1670 train_time:108272ms step_avg:88.53ms +step:1224/1670 train_time:108362ms step_avg:88.53ms +step:1225/1670 train_time:108452ms step_avg:88.53ms +step:1226/1670 train_time:108541ms step_avg:88.53ms +step:1227/1670 train_time:108630ms step_avg:88.53ms +step:1228/1670 train_time:108720ms step_avg:88.53ms +step:1229/1670 train_time:108811ms step_avg:88.54ms +step:1230/1670 train_time:108901ms step_avg:88.54ms +step:1231/1670 train_time:108990ms step_avg:88.54ms +step:1232/1670 train_time:109080ms step_avg:88.54ms +step:1233/1670 train_time:109170ms step_avg:88.54ms +step:1234/1670 train_time:109260ms step_avg:88.54ms +step:1235/1670 train_time:109350ms step_avg:88.54ms +step:1236/1670 train_time:109439ms step_avg:88.54ms +step:1237/1670 train_time:109530ms step_avg:88.54ms +step:1238/1670 train_time:109619ms step_avg:88.55ms +step:1239/1670 train_time:109709ms step_avg:88.55ms +step:1240/1670 train_time:109798ms step_avg:88.55ms +step:1241/1670 train_time:109889ms step_avg:88.55ms +step:1242/1670 train_time:109978ms step_avg:88.55ms +step:1243/1670 train_time:110068ms step_avg:88.55ms +step:1244/1670 train_time:110157ms step_avg:88.55ms +step:1245/1670 train_time:110247ms step_avg:88.55ms +step:1246/1670 train_time:110336ms step_avg:88.55ms +step:1247/1670 train_time:110426ms step_avg:88.55ms +step:1248/1670 train_time:110515ms step_avg:88.55ms +step:1249/1670 train_time:110604ms step_avg:88.55ms +step:1250/1670 train_time:110694ms step_avg:88.56ms +step:1250/1670 val_loss:3.3760 train_time:110785ms step_avg:88.63ms +step:1251/1670 train_time:110805ms step_avg:88.57ms +step:1252/1670 train_time:110880ms step_avg:88.56ms +step:1253/1670 train_time:110970ms step_avg:88.56ms +step:1254/1670 train_time:111062ms step_avg:88.57ms +step:1255/1670 train_time:111152ms step_avg:88.57ms +step:1256/1670 train_time:111240ms step_avg:88.57ms +step:1257/1670 train_time:111329ms step_avg:88.57ms +step:1258/1670 train_time:111420ms step_avg:88.57ms +step:1259/1670 train_time:111509ms step_avg:88.57ms +step:1260/1670 train_time:111599ms step_avg:88.57ms +step:1261/1670 train_time:111688ms step_avg:88.57ms +step:1262/1670 train_time:111779ms step_avg:88.57ms +step:1263/1670 train_time:111870ms step_avg:88.57ms +step:1264/1670 train_time:111961ms step_avg:88.58ms +step:1265/1670 train_time:112052ms step_avg:88.58ms +step:1266/1670 train_time:112142ms step_avg:88.58ms +step:1267/1670 train_time:112232ms step_avg:88.58ms +step:1268/1670 train_time:112321ms step_avg:88.58ms +step:1269/1670 train_time:112410ms step_avg:88.58ms +step:1270/1670 train_time:112499ms step_avg:88.58ms +step:1271/1670 train_time:112588ms step_avg:88.58ms +step:1272/1670 train_time:112678ms step_avg:88.58ms +step:1273/1670 train_time:112768ms step_avg:88.58ms +step:1274/1670 train_time:112860ms step_avg:88.59ms +step:1275/1670 train_time:112951ms step_avg:88.59ms +step:1276/1670 train_time:113041ms step_avg:88.59ms +step:1277/1670 train_time:113132ms step_avg:88.59ms +step:1278/1670 train_time:113222ms step_avg:88.59ms +step:1279/1670 train_time:113312ms step_avg:88.59ms +step:1280/1670 train_time:113402ms step_avg:88.59ms +step:1281/1670 train_time:113491ms step_avg:88.60ms +step:1282/1670 train_time:113581ms step_avg:88.60ms +step:1283/1670 train_time:113671ms step_avg:88.60ms +step:1284/1670 train_time:113760ms step_avg:88.60ms +step:1285/1670 train_time:113851ms step_avg:88.60ms +step:1286/1670 train_time:113940ms step_avg:88.60ms +step:1287/1670 train_time:114031ms step_avg:88.60ms +step:1288/1670 train_time:114121ms step_avg:88.60ms +step:1289/1670 train_time:114212ms step_avg:88.61ms +step:1290/1670 train_time:114301ms step_avg:88.61ms +step:1291/1670 train_time:114391ms step_avg:88.61ms +step:1292/1670 train_time:114480ms step_avg:88.61ms +step:1293/1670 train_time:114570ms step_avg:88.61ms +step:1294/1670 train_time:114660ms step_avg:88.61ms +step:1295/1670 train_time:114750ms step_avg:88.61ms +step:1296/1670 train_time:114841ms step_avg:88.61ms +step:1297/1670 train_time:114931ms step_avg:88.61ms +step:1298/1670 train_time:115021ms step_avg:88.61ms +step:1299/1670 train_time:115111ms step_avg:88.61ms +step:1300/1670 train_time:115201ms step_avg:88.62ms +step:1301/1670 train_time:115292ms step_avg:88.62ms +step:1302/1670 train_time:115381ms step_avg:88.62ms +step:1303/1670 train_time:115470ms step_avg:88.62ms +step:1304/1670 train_time:115559ms step_avg:88.62ms +step:1305/1670 train_time:115649ms step_avg:88.62ms +step:1306/1670 train_time:115739ms step_avg:88.62ms +step:1307/1670 train_time:115829ms step_avg:88.62ms +step:1308/1670 train_time:115920ms step_avg:88.62ms +step:1309/1670 train_time:116009ms step_avg:88.62ms +step:1310/1670 train_time:116100ms step_avg:88.63ms +step:1311/1670 train_time:116191ms step_avg:88.63ms +step:1312/1670 train_time:116281ms step_avg:88.63ms +step:1313/1670 train_time:116371ms step_avg:88.63ms +step:1314/1670 train_time:116461ms step_avg:88.63ms +step:1315/1670 train_time:116551ms step_avg:88.63ms +step:1316/1670 train_time:116640ms step_avg:88.63ms +step:1317/1670 train_time:116730ms step_avg:88.63ms +step:1318/1670 train_time:116821ms step_avg:88.64ms +step:1319/1670 train_time:116912ms step_avg:88.64ms +step:1320/1670 train_time:117002ms step_avg:88.64ms +step:1321/1670 train_time:117092ms step_avg:88.64ms +step:1322/1670 train_time:117181ms step_avg:88.64ms +step:1323/1670 train_time:117271ms step_avg:88.64ms +step:1324/1670 train_time:117361ms step_avg:88.64ms +step:1325/1670 train_time:117451ms step_avg:88.64ms +step:1326/1670 train_time:117541ms step_avg:88.64ms +step:1327/1670 train_time:117631ms step_avg:88.64ms +step:1328/1670 train_time:117721ms step_avg:88.65ms +step:1329/1670 train_time:117811ms step_avg:88.65ms +step:1330/1670 train_time:117901ms step_avg:88.65ms +step:1331/1670 train_time:117990ms step_avg:88.65ms +step:1332/1670 train_time:118080ms step_avg:88.65ms +step:1333/1670 train_time:118170ms step_avg:88.65ms +step:1334/1670 train_time:118260ms step_avg:88.65ms +step:1335/1670 train_time:118350ms step_avg:88.65ms +step:1336/1670 train_time:118442ms step_avg:88.65ms +step:1337/1670 train_time:118531ms step_avg:88.65ms +step:1338/1670 train_time:118621ms step_avg:88.66ms +step:1339/1670 train_time:118711ms step_avg:88.66ms +step:1340/1670 train_time:118800ms step_avg:88.66ms +step:1341/1670 train_time:118889ms step_avg:88.66ms +step:1342/1670 train_time:118979ms step_avg:88.66ms +step:1343/1670 train_time:119069ms step_avg:88.66ms +step:1344/1670 train_time:119159ms step_avg:88.66ms +step:1345/1670 train_time:119249ms step_avg:88.66ms +step:1346/1670 train_time:119339ms step_avg:88.66ms +step:1347/1670 train_time:119430ms step_avg:88.66ms +step:1348/1670 train_time:119522ms step_avg:88.67ms +step:1349/1670 train_time:119612ms step_avg:88.67ms +step:1350/1670 train_time:119702ms step_avg:88.67ms +step:1351/1670 train_time:119792ms step_avg:88.67ms +step:1352/1670 train_time:119881ms step_avg:88.67ms +step:1353/1670 train_time:119971ms step_avg:88.67ms +step:1354/1670 train_time:120061ms step_avg:88.67ms +step:1355/1670 train_time:120150ms step_avg:88.67ms +step:1356/1670 train_time:120240ms step_avg:88.67ms +step:1357/1670 train_time:120330ms step_avg:88.67ms +step:1358/1670 train_time:120421ms step_avg:88.68ms +step:1359/1670 train_time:120511ms step_avg:88.68ms +step:1360/1670 train_time:120600ms step_avg:88.68ms +step:1361/1670 train_time:120690ms step_avg:88.68ms +step:1362/1670 train_time:120779ms step_avg:88.68ms +step:1363/1670 train_time:120869ms step_avg:88.68ms +step:1364/1670 train_time:120958ms step_avg:88.68ms +step:1365/1670 train_time:121048ms step_avg:88.68ms +step:1366/1670 train_time:121138ms step_avg:88.68ms +step:1367/1670 train_time:121227ms step_avg:88.68ms +step:1368/1670 train_time:121317ms step_avg:88.68ms +step:1369/1670 train_time:121407ms step_avg:88.68ms +step:1370/1670 train_time:121498ms step_avg:88.68ms +step:1371/1670 train_time:121589ms step_avg:88.69ms +step:1372/1670 train_time:121678ms step_avg:88.69ms +step:1373/1670 train_time:121768ms step_avg:88.69ms +step:1374/1670 train_time:121859ms step_avg:88.69ms +step:1375/1670 train_time:121948ms step_avg:88.69ms +step:1375/1670 val_loss:3.3413 train_time:122040ms step_avg:88.76ms +step:1376/1670 train_time:122059ms step_avg:88.71ms +step:1377/1670 train_time:122133ms step_avg:88.69ms +step:1378/1670 train_time:122225ms step_avg:88.70ms +step:1379/1670 train_time:122314ms step_avg:88.70ms +step:1380/1670 train_time:122403ms step_avg:88.70ms +step:1381/1670 train_time:122492ms step_avg:88.70ms +step:1382/1670 train_time:122580ms step_avg:88.70ms +step:1383/1670 train_time:122669ms step_avg:88.70ms +step:1384/1670 train_time:122758ms step_avg:88.70ms +step:1385/1670 train_time:122847ms step_avg:88.70ms +step:1386/1670 train_time:122936ms step_avg:88.70ms +step:1387/1670 train_time:123028ms step_avg:88.70ms +step:1388/1670 train_time:123119ms step_avg:88.70ms +step:1389/1670 train_time:123210ms step_avg:88.70ms +step:1390/1670 train_time:123299ms step_avg:88.70ms +step:1391/1670 train_time:123389ms step_avg:88.71ms +step:1392/1670 train_time:123479ms step_avg:88.71ms +step:1393/1670 train_time:123568ms step_avg:88.71ms +step:1394/1670 train_time:123656ms step_avg:88.71ms +step:1395/1670 train_time:123746ms step_avg:88.71ms +step:1396/1670 train_time:123835ms step_avg:88.71ms +step:1397/1670 train_time:123925ms step_avg:88.71ms +step:1398/1670 train_time:124016ms step_avg:88.71ms +step:1399/1670 train_time:124109ms step_avg:88.71ms +step:1400/1670 train_time:124198ms step_avg:88.71ms +step:1401/1670 train_time:124289ms step_avg:88.71ms +step:1402/1670 train_time:124378ms step_avg:88.71ms +step:1403/1670 train_time:124468ms step_avg:88.72ms +step:1404/1670 train_time:124557ms step_avg:88.72ms +step:1405/1670 train_time:124646ms step_avg:88.72ms +step:1406/1670 train_time:124735ms step_avg:88.72ms +step:1407/1670 train_time:124824ms step_avg:88.72ms +step:1408/1670 train_time:124915ms step_avg:88.72ms +step:1409/1670 train_time:125005ms step_avg:88.72ms +step:1410/1670 train_time:125096ms step_avg:88.72ms +step:1411/1670 train_time:125187ms step_avg:88.72ms +step:1412/1670 train_time:125277ms step_avg:88.72ms +step:1413/1670 train_time:125368ms step_avg:88.72ms +step:1414/1670 train_time:125458ms step_avg:88.73ms +step:1415/1670 train_time:125548ms step_avg:88.73ms +step:1416/1670 train_time:125637ms step_avg:88.73ms +step:1417/1670 train_time:125728ms step_avg:88.73ms +step:1418/1670 train_time:125816ms step_avg:88.73ms +step:1419/1670 train_time:125905ms step_avg:88.73ms +step:1420/1670 train_time:125996ms step_avg:88.73ms +step:1421/1670 train_time:126086ms step_avg:88.73ms +step:1422/1670 train_time:126176ms step_avg:88.73ms +step:1423/1670 train_time:126267ms step_avg:88.73ms +step:1424/1670 train_time:126358ms step_avg:88.73ms +step:1425/1670 train_time:126448ms step_avg:88.74ms +step:1426/1670 train_time:126538ms step_avg:88.74ms +step:1427/1670 train_time:126627ms step_avg:88.74ms +step:1428/1670 train_time:126716ms step_avg:88.74ms +step:1429/1670 train_time:126806ms step_avg:88.74ms +step:1430/1670 train_time:126895ms step_avg:88.74ms +step:1431/1670 train_time:126985ms step_avg:88.74ms +step:1432/1670 train_time:127074ms step_avg:88.74ms +step:1433/1670 train_time:127165ms step_avg:88.74ms +step:1434/1670 train_time:127255ms step_avg:88.74ms +step:1435/1670 train_time:127346ms step_avg:88.74ms +step:1436/1670 train_time:127436ms step_avg:88.74ms +step:1437/1670 train_time:127527ms step_avg:88.75ms +step:1438/1670 train_time:127616ms step_avg:88.75ms +step:1439/1670 train_time:127706ms step_avg:88.75ms +step:1440/1670 train_time:127795ms step_avg:88.75ms +step:1441/1670 train_time:127884ms step_avg:88.75ms +step:1442/1670 train_time:127975ms step_avg:88.75ms +step:1443/1670 train_time:128065ms step_avg:88.75ms +step:1444/1670 train_time:128156ms step_avg:88.75ms +step:1445/1670 train_time:128246ms step_avg:88.75ms +step:1446/1670 train_time:128336ms step_avg:88.75ms +step:1447/1670 train_time:128427ms step_avg:88.75ms +step:1448/1670 train_time:128517ms step_avg:88.75ms +step:1449/1670 train_time:128606ms step_avg:88.76ms +step:1450/1670 train_time:128696ms step_avg:88.76ms +step:1451/1670 train_time:128785ms step_avg:88.76ms +step:1452/1670 train_time:128874ms step_avg:88.76ms +step:1453/1670 train_time:128964ms step_avg:88.76ms +step:1454/1670 train_time:129054ms step_avg:88.76ms +step:1455/1670 train_time:129145ms step_avg:88.76ms +step:1456/1670 train_time:129234ms step_avg:88.76ms +step:1457/1670 train_time:129324ms step_avg:88.76ms +step:1458/1670 train_time:129415ms step_avg:88.76ms +step:1459/1670 train_time:129505ms step_avg:88.76ms +step:1460/1670 train_time:129595ms step_avg:88.76ms +step:1461/1670 train_time:129685ms step_avg:88.76ms +step:1462/1670 train_time:129774ms step_avg:88.76ms +step:1463/1670 train_time:129864ms step_avg:88.77ms +step:1464/1670 train_time:129954ms step_avg:88.77ms +step:1465/1670 train_time:130044ms step_avg:88.77ms +step:1466/1670 train_time:130133ms step_avg:88.77ms +step:1467/1670 train_time:130223ms step_avg:88.77ms +step:1468/1670 train_time:130315ms step_avg:88.77ms +step:1469/1670 train_time:130405ms step_avg:88.77ms +step:1470/1670 train_time:130495ms step_avg:88.77ms +step:1471/1670 train_time:130586ms step_avg:88.77ms +step:1472/1670 train_time:130675ms step_avg:88.77ms +step:1473/1670 train_time:130765ms step_avg:88.77ms +step:1474/1670 train_time:130855ms step_avg:88.78ms +step:1475/1670 train_time:130944ms step_avg:88.78ms +step:1476/1670 train_time:131034ms step_avg:88.78ms +step:1477/1670 train_time:131124ms step_avg:88.78ms +step:1478/1670 train_time:131214ms step_avg:88.78ms +step:1479/1670 train_time:131304ms step_avg:88.78ms +step:1480/1670 train_time:131395ms step_avg:88.78ms +step:1481/1670 train_time:131485ms step_avg:88.78ms +step:1482/1670 train_time:131576ms step_avg:88.78ms +step:1483/1670 train_time:131666ms step_avg:88.78ms +step:1484/1670 train_time:131757ms step_avg:88.79ms +step:1485/1670 train_time:131847ms step_avg:88.79ms +step:1486/1670 train_time:131937ms step_avg:88.79ms +step:1487/1670 train_time:132027ms step_avg:88.79ms +step:1488/1670 train_time:132115ms step_avg:88.79ms +step:1489/1670 train_time:132205ms step_avg:88.79ms +step:1490/1670 train_time:132295ms step_avg:88.79ms +step:1491/1670 train_time:132385ms step_avg:88.79ms +step:1492/1670 train_time:132474ms step_avg:88.79ms +step:1493/1670 train_time:132564ms step_avg:88.79ms +step:1494/1670 train_time:132654ms step_avg:88.79ms +step:1495/1670 train_time:132744ms step_avg:88.79ms +step:1496/1670 train_time:132834ms step_avg:88.79ms +step:1497/1670 train_time:132924ms step_avg:88.79ms +step:1498/1670 train_time:133014ms step_avg:88.79ms +step:1499/1670 train_time:133103ms step_avg:88.79ms +step:1500/1670 train_time:133193ms step_avg:88.80ms +step:1500/1670 val_loss:3.3118 train_time:133284ms step_avg:88.86ms +step:1501/1670 train_time:133304ms step_avg:88.81ms +step:1502/1670 train_time:133376ms step_avg:88.80ms +step:1503/1670 train_time:133469ms step_avg:88.80ms +step:1504/1670 train_time:133558ms step_avg:88.80ms +step:1505/1670 train_time:133647ms step_avg:88.80ms +step:1506/1670 train_time:133735ms step_avg:88.80ms +step:1507/1670 train_time:133824ms step_avg:88.80ms +step:1508/1670 train_time:133913ms step_avg:88.80ms +step:1509/1670 train_time:134002ms step_avg:88.80ms +step:1510/1670 train_time:134091ms step_avg:88.80ms +step:1511/1670 train_time:134181ms step_avg:88.80ms +step:1512/1670 train_time:134273ms step_avg:88.80ms +step:1513/1670 train_time:134365ms step_avg:88.81ms +step:1514/1670 train_time:134456ms step_avg:88.81ms +step:1515/1670 train_time:134547ms step_avg:88.81ms +step:1516/1670 train_time:134636ms step_avg:88.81ms +step:1517/1670 train_time:134725ms step_avg:88.81ms +step:1518/1670 train_time:134813ms step_avg:88.81ms +step:1519/1670 train_time:134903ms step_avg:88.81ms +step:1520/1670 train_time:134992ms step_avg:88.81ms +step:1521/1670 train_time:135083ms step_avg:88.81ms +step:1522/1670 train_time:135172ms step_avg:88.81ms +step:1523/1670 train_time:135263ms step_avg:88.81ms +step:1524/1670 train_time:135353ms step_avg:88.81ms +step:1525/1670 train_time:135443ms step_avg:88.82ms +step:1526/1670 train_time:135533ms step_avg:88.82ms +step:1527/1670 train_time:135624ms step_avg:88.82ms +step:1528/1670 train_time:135713ms step_avg:88.82ms +step:1529/1670 train_time:135803ms step_avg:88.82ms +step:1530/1670 train_time:135892ms step_avg:88.82ms +step:1531/1670 train_time:135982ms step_avg:88.82ms +step:1532/1670 train_time:136072ms step_avg:88.82ms +step:1533/1670 train_time:136162ms step_avg:88.82ms +step:1534/1670 train_time:136252ms step_avg:88.82ms +step:1535/1670 train_time:136342ms step_avg:88.82ms +step:1536/1670 train_time:136431ms step_avg:88.82ms +step:1537/1670 train_time:136522ms step_avg:88.82ms +step:1538/1670 train_time:136611ms step_avg:88.82ms +step:1539/1670 train_time:136702ms step_avg:88.82ms +step:1540/1670 train_time:136791ms step_avg:88.83ms +step:1541/1670 train_time:136880ms step_avg:88.83ms +step:1542/1670 train_time:136970ms step_avg:88.83ms +step:1543/1670 train_time:137060ms step_avg:88.83ms +step:1544/1670 train_time:137149ms step_avg:88.83ms +step:1545/1670 train_time:137239ms step_avg:88.83ms +step:1546/1670 train_time:137329ms step_avg:88.83ms +step:1547/1670 train_time:137419ms step_avg:88.83ms +step:1548/1670 train_time:137510ms step_avg:88.83ms +step:1549/1670 train_time:137600ms step_avg:88.83ms +step:1550/1670 train_time:137690ms step_avg:88.83ms +step:1551/1670 train_time:137780ms step_avg:88.83ms +step:1552/1670 train_time:137869ms step_avg:88.83ms +step:1553/1670 train_time:137958ms step_avg:88.83ms +step:1554/1670 train_time:138048ms step_avg:88.83ms +step:1555/1670 train_time:138138ms step_avg:88.83ms +step:1556/1670 train_time:138228ms step_avg:88.84ms +step:1557/1670 train_time:138318ms step_avg:88.84ms +step:1558/1670 train_time:138409ms step_avg:88.84ms +step:1559/1670 train_time:138499ms step_avg:88.84ms +step:1560/1670 train_time:138589ms step_avg:88.84ms +step:1561/1670 train_time:138679ms step_avg:88.84ms +step:1562/1670 train_time:138769ms step_avg:88.84ms +step:1563/1670 train_time:138859ms step_avg:88.84ms +step:1564/1670 train_time:138948ms step_avg:88.84ms +step:1565/1670 train_time:139038ms step_avg:88.84ms +step:1566/1670 train_time:139127ms step_avg:88.84ms +step:1567/1670 train_time:139217ms step_avg:88.84ms +step:1568/1670 train_time:139308ms step_avg:88.84ms +step:1569/1670 train_time:139399ms step_avg:88.85ms +step:1570/1670 train_time:139489ms step_avg:88.85ms +step:1571/1670 train_time:139579ms step_avg:88.85ms +step:1572/1670 train_time:139669ms step_avg:88.85ms +step:1573/1670 train_time:139759ms step_avg:88.85ms +step:1574/1670 train_time:139848ms step_avg:88.85ms +step:1575/1670 train_time:139937ms step_avg:88.85ms +step:1576/1670 train_time:140028ms step_avg:88.85ms +step:1577/1670 train_time:140118ms step_avg:88.85ms +step:1578/1670 train_time:140208ms step_avg:88.85ms +step:1579/1670 train_time:140298ms step_avg:88.85ms +step:1580/1670 train_time:140388ms step_avg:88.85ms +step:1581/1670 train_time:140478ms step_avg:88.85ms +step:1582/1670 train_time:140569ms step_avg:88.86ms +step:1583/1670 train_time:140659ms step_avg:88.86ms +step:1584/1670 train_time:140749ms step_avg:88.86ms +step:1585/1670 train_time:140838ms step_avg:88.86ms +step:1586/1670 train_time:140928ms step_avg:88.86ms +step:1587/1670 train_time:141017ms step_avg:88.86ms +step:1588/1670 train_time:141108ms step_avg:88.86ms +step:1589/1670 train_time:141199ms step_avg:88.86ms +step:1590/1670 train_time:141290ms step_avg:88.86ms +step:1591/1670 train_time:141381ms step_avg:88.86ms +step:1592/1670 train_time:141469ms step_avg:88.86ms +step:1593/1670 train_time:141559ms step_avg:88.86ms +step:1594/1670 train_time:141649ms step_avg:88.86ms +step:1595/1670 train_time:141739ms step_avg:88.86ms +step:1596/1670 train_time:141829ms step_avg:88.87ms +step:1597/1670 train_time:141918ms step_avg:88.87ms +step:1598/1670 train_time:142007ms step_avg:88.87ms +step:1599/1670 train_time:142097ms step_avg:88.87ms +step:1600/1670 train_time:142188ms step_avg:88.87ms +step:1601/1670 train_time:142279ms step_avg:88.87ms +step:1602/1670 train_time:142367ms step_avg:88.87ms +step:1603/1670 train_time:142457ms step_avg:88.87ms +step:1604/1670 train_time:142546ms step_avg:88.87ms +step:1605/1670 train_time:142635ms step_avg:88.87ms +step:1606/1670 train_time:142725ms step_avg:88.87ms +step:1607/1670 train_time:142815ms step_avg:88.87ms +step:1608/1670 train_time:142905ms step_avg:88.87ms +step:1609/1670 train_time:142994ms step_avg:88.87ms +step:1610/1670 train_time:143085ms step_avg:88.87ms +step:1611/1670 train_time:143174ms step_avg:88.87ms +step:1612/1670 train_time:143264ms step_avg:88.87ms +step:1613/1670 train_time:143353ms step_avg:88.87ms +step:1614/1670 train_time:143444ms step_avg:88.87ms +step:1615/1670 train_time:143533ms step_avg:88.87ms +step:1616/1670 train_time:143623ms step_avg:88.88ms +step:1617/1670 train_time:143712ms step_avg:88.88ms +step:1618/1670 train_time:143803ms step_avg:88.88ms +step:1619/1670 train_time:143892ms step_avg:88.88ms +step:1620/1670 train_time:143982ms step_avg:88.88ms +step:1621/1670 train_time:144072ms step_avg:88.88ms +step:1622/1670 train_time:144162ms step_avg:88.88ms +step:1623/1670 train_time:144251ms step_avg:88.88ms +step:1624/1670 train_time:144341ms step_avg:88.88ms +step:1625/1670 train_time:144430ms step_avg:88.88ms +step:1625/1670 val_loss:3.2885 train_time:144522ms step_avg:88.94ms +step:1626/1670 train_time:144543ms step_avg:88.89ms +step:1627/1670 train_time:144622ms step_avg:88.89ms +step:1628/1670 train_time:144714ms step_avg:88.89ms +step:1629/1670 train_time:144803ms step_avg:88.89ms +step:1630/1670 train_time:144891ms step_avg:88.89ms +step:1631/1670 train_time:144979ms step_avg:88.89ms +step:1632/1670 train_time:145067ms step_avg:88.89ms +step:1633/1670 train_time:145157ms step_avg:88.89ms +step:1634/1670 train_time:145246ms step_avg:88.89ms +step:1635/1670 train_time:145337ms step_avg:88.89ms +step:1636/1670 train_time:145427ms step_avg:88.89ms +step:1637/1670 train_time:145521ms step_avg:88.89ms +step:1638/1670 train_time:145615ms step_avg:88.90ms +step:1639/1670 train_time:145704ms step_avg:88.90ms +step:1640/1670 train_time:145794ms step_avg:88.90ms +step:1641/1670 train_time:145883ms step_avg:88.90ms +step:1642/1670 train_time:145972ms step_avg:88.90ms +step:1643/1670 train_time:146060ms step_avg:88.90ms +step:1644/1670 train_time:146149ms step_avg:88.90ms +step:1645/1670 train_time:146238ms step_avg:88.90ms +step:1646/1670 train_time:146328ms step_avg:88.90ms +step:1647/1670 train_time:146419ms step_avg:88.90ms +step:1648/1670 train_time:146512ms step_avg:88.90ms +step:1649/1670 train_time:146602ms step_avg:88.90ms +step:1650/1670 train_time:146692ms step_avg:88.90ms +step:1651/1670 train_time:146782ms step_avg:88.90ms +step:1652/1670 train_time:146871ms step_avg:88.91ms +step:1653/1670 train_time:146960ms step_avg:88.91ms +step:1654/1670 train_time:147049ms step_avg:88.91ms +step:1655/1670 train_time:147138ms step_avg:88.91ms +step:1656/1670 train_time:147227ms step_avg:88.91ms +step:1657/1670 train_time:147317ms step_avg:88.91ms +step:1658/1670 train_time:147408ms step_avg:88.91ms +step:1659/1670 train_time:147500ms step_avg:88.91ms +step:1660/1670 train_time:147590ms step_avg:88.91ms +step:1661/1670 train_time:147681ms step_avg:88.91ms +step:1662/1670 train_time:147770ms step_avg:88.91ms +step:1663/1670 train_time:147860ms step_avg:88.91ms +step:1664/1670 train_time:147950ms step_avg:88.91ms +step:1665/1670 train_time:148039ms step_avg:88.91ms +step:1666/1670 train_time:148128ms step_avg:88.91ms +step:1667/1670 train_time:148218ms step_avg:88.91ms +step:1668/1670 train_time:148307ms step_avg:88.91ms +step:1669/1670 train_time:148397ms step_avg:88.91ms +step:1670/1670 train_time:148488ms step_avg:88.91ms +step:1670/1670 val_loss:3.2792 train_time:148580ms step_avg:88.97ms +peak memory allocated: 30760 MiB reserved: 45434 MiB diff --git a/records/092925_PolarExpress/730671d8-2fca-498a-819a-0bdf0f3aa76c.txt b/records/092925_PolarExpress/730671d8-2fca-498a-819a-0bdf0f3aa76c.txt new file mode 100644 index 000000000..379186672 --- /dev/null +++ b/records/092925_PolarExpress/730671d8-2fca-498a-819a-0bdf0f3aa76c.txt @@ -0,0 +1,3252 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from kernels import get_kernel +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[ + Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size() == 8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1, 10, 16, 16] + assert len(params) == sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx + size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx < len(params) else 0 + if getattr(params[module_idx], 'module', 'none') == 'attn': + for p in params[module_idx:module_idx + chunk_size]: + assert getattr(params[module_idx], 'module', 'none') == 'attn' + batch = 4 * original_shape[0] + d1 = original_shape[1] + d2 = original_shape[2] // 4 + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = polar_express(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = polar_express(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim // 4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim // 4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim * 4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module = 'attn' + with torch.no_grad(): + self.qkvo_w.view(4, self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4, self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, + 3 * self.num_heads, + self.head_dim).chunk( + 3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module = 'mlp' + self.c_proj.module = 'mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, + max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim ** 0.5) / 448, w_s=2 ** -9, + grad_s=1 / 448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), # extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, + long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1, len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i < 11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens = tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter == 5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter += 1 + return starts, ends + + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, + align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1630 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if + p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999, step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations + args.iteration_extension: + return args.ws_validate // 2, args.ws_validate + x = min(step / (1 + args.num_iterations), 0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[ + step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long < ws_long: + model.yarn.reset() + ws_long = new_ws_long + model(inputs, targets, cum_seqlens, ws_long // 2, ws_long).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.yarn.reset() +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations + args.iteration_extension +ws_short, ws_long = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws_short, new_ws_long = get_ws(step) + if new_ws_long != ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, + grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.5.0 +Mon Sep 29 06:22:36 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 40C P0 129W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:135ms step_avg:135.10ms +step:2/1670 train_time:156ms step_avg:78.05ms +step:3/1670 train_time:222ms step_avg:74.08ms +step:4/1670 train_time:308ms step_avg:77.00ms +step:5/1670 train_time:395ms step_avg:78.99ms +step:6/1670 train_time:481ms step_avg:80.21ms +step:7/1670 train_time:568ms step_avg:81.12ms +step:8/1670 train_time:655ms step_avg:81.84ms +step:9/1670 train_time:742ms step_avg:82.42ms +step:10/1670 train_time:829ms step_avg:82.88ms +step:11/1670 train_time:915ms step_avg:83.21ms +step:12/1670 train_time:1005ms step_avg:83.75ms +step:13/1670 train_time:1096ms step_avg:84.30ms +step:14/1670 train_time:1188ms step_avg:84.84ms +step:15/1670 train_time:1277ms step_avg:85.11ms +step:16/1670 train_time:1365ms step_avg:85.29ms +step:17/1670 train_time:1452ms step_avg:85.41ms +step:18/1670 train_time:1539ms step_avg:85.52ms +step:19/1670 train_time:1627ms step_avg:85.62ms +step:20/1670 train_time:1714ms step_avg:85.68ms +step:21/1670 train_time:1801ms step_avg:85.75ms +step:22/1670 train_time:1888ms step_avg:85.83ms +step:23/1670 train_time:1976ms step_avg:85.89ms +step:24/1670 train_time:2067ms step_avg:86.11ms +step:25/1670 train_time:2155ms step_avg:86.20ms +step:26/1670 train_time:2244ms step_avg:86.32ms +step:27/1670 train_time:2332ms step_avg:86.36ms +step:28/1670 train_time:2421ms step_avg:86.45ms +step:29/1670 train_time:2508ms step_avg:86.49ms +step:30/1670 train_time:2596ms step_avg:86.53ms +step:31/1670 train_time:2683ms step_avg:86.54ms +step:32/1670 train_time:2770ms step_avg:86.55ms +step:33/1670 train_time:2857ms step_avg:86.57ms +step:34/1670 train_time:2944ms step_avg:86.60ms +step:35/1670 train_time:3032ms step_avg:86.64ms +step:36/1670 train_time:3121ms step_avg:86.68ms +step:37/1670 train_time:3210ms step_avg:86.76ms +step:38/1670 train_time:3299ms step_avg:86.82ms +step:39/1670 train_time:3388ms step_avg:86.88ms +step:40/1670 train_time:3477ms step_avg:86.92ms +step:41/1670 train_time:3566ms step_avg:86.97ms +step:42/1670 train_time:3653ms step_avg:86.98ms +step:43/1670 train_time:3740ms step_avg:86.99ms +step:44/1670 train_time:3828ms step_avg:86.99ms +step:45/1670 train_time:3915ms step_avg:87.00ms +step:46/1670 train_time:4003ms step_avg:87.01ms +step:47/1670 train_time:4092ms step_avg:87.05ms +step:48/1670 train_time:4179ms step_avg:87.07ms +step:49/1670 train_time:4268ms step_avg:87.11ms +step:50/1670 train_time:4357ms step_avg:87.13ms +step:51/1670 train_time:4445ms step_avg:87.16ms +step:52/1670 train_time:4533ms step_avg:87.18ms +step:53/1670 train_time:4622ms step_avg:87.20ms +step:54/1670 train_time:4710ms step_avg:87.22ms +step:55/1670 train_time:4798ms step_avg:87.23ms +step:56/1670 train_time:4886ms step_avg:87.25ms +step:57/1670 train_time:4973ms step_avg:87.25ms +step:58/1670 train_time:5061ms step_avg:87.27ms +step:59/1670 train_time:5149ms step_avg:87.27ms +step:60/1670 train_time:5238ms step_avg:87.29ms +step:61/1670 train_time:5327ms step_avg:87.33ms +step:62/1670 train_time:5415ms step_avg:87.34ms +step:63/1670 train_time:5503ms step_avg:87.35ms +step:64/1670 train_time:5591ms step_avg:87.36ms +step:65/1670 train_time:5679ms step_avg:87.36ms +step:66/1670 train_time:5768ms step_avg:87.39ms +step:67/1670 train_time:5856ms step_avg:87.41ms +step:68/1670 train_time:5944ms step_avg:87.42ms +step:69/1670 train_time:6031ms step_avg:87.41ms +step:70/1670 train_time:6119ms step_avg:87.41ms +step:71/1670 train_time:6206ms step_avg:87.41ms +step:72/1670 train_time:6295ms step_avg:87.43ms +step:73/1670 train_time:6383ms step_avg:87.43ms +step:74/1670 train_time:6471ms step_avg:87.44ms +step:75/1670 train_time:6559ms step_avg:87.45ms +step:76/1670 train_time:6647ms step_avg:87.47ms +step:77/1670 train_time:6736ms step_avg:87.48ms +step:78/1670 train_time:6825ms step_avg:87.50ms +step:79/1670 train_time:6912ms step_avg:87.49ms +step:80/1670 train_time:7000ms step_avg:87.49ms +step:81/1670 train_time:7087ms step_avg:87.49ms +step:82/1670 train_time:7174ms step_avg:87.48ms +step:83/1670 train_time:7261ms step_avg:87.48ms +step:84/1670 train_time:7349ms step_avg:87.49ms +step:85/1670 train_time:7437ms step_avg:87.50ms +step:86/1670 train_time:7526ms step_avg:87.51ms +step:87/1670 train_time:7614ms step_avg:87.51ms +step:88/1670 train_time:7702ms step_avg:87.52ms +step:89/1670 train_time:7790ms step_avg:87.53ms +step:90/1670 train_time:7878ms step_avg:87.53ms +step:91/1670 train_time:7966ms step_avg:87.54ms +step:92/1670 train_time:8052ms step_avg:87.53ms +step:93/1670 train_time:8139ms step_avg:87.52ms +step:94/1670 train_time:8228ms step_avg:87.53ms +step:95/1670 train_time:8315ms step_avg:87.53ms +step:96/1670 train_time:8403ms step_avg:87.53ms +step:97/1670 train_time:8491ms step_avg:87.53ms +step:98/1670 train_time:8579ms step_avg:87.54ms +step:99/1670 train_time:8668ms step_avg:87.56ms +step:100/1670 train_time:8756ms step_avg:87.56ms +step:101/1670 train_time:8843ms step_avg:87.56ms +step:102/1670 train_time:8931ms step_avg:87.56ms +step:103/1670 train_time:9019ms step_avg:87.57ms +step:104/1670 train_time:9107ms step_avg:87.57ms +step:105/1670 train_time:9195ms step_avg:87.57ms +step:106/1670 train_time:9282ms step_avg:87.57ms +step:107/1670 train_time:9370ms step_avg:87.57ms +step:108/1670 train_time:9459ms step_avg:87.58ms +step:109/1670 train_time:9546ms step_avg:87.58ms +step:110/1670 train_time:9634ms step_avg:87.58ms +step:111/1670 train_time:9722ms step_avg:87.58ms +step:112/1670 train_time:9809ms step_avg:87.58ms +step:113/1670 train_time:9898ms step_avg:87.59ms +step:114/1670 train_time:9986ms step_avg:87.60ms +step:115/1670 train_time:10074ms step_avg:87.60ms +step:116/1670 train_time:10162ms step_avg:87.61ms +step:117/1670 train_time:10249ms step_avg:87.60ms +step:118/1670 train_time:10337ms step_avg:87.60ms +step:119/1670 train_time:10426ms step_avg:87.61ms +step:120/1670 train_time:10513ms step_avg:87.61ms +step:121/1670 train_time:10601ms step_avg:87.61ms +step:122/1670 train_time:10689ms step_avg:87.61ms +step:123/1670 train_time:10776ms step_avg:87.61ms +step:124/1670 train_time:10864ms step_avg:87.61ms +step:125/1670 train_time:10951ms step_avg:87.61ms +step:125/1670 val_loss:4.3456 train_time:11041ms step_avg:88.32ms +step:126/1670 train_time:11061ms step_avg:87.78ms +step:127/1670 train_time:11133ms step_avg:87.66ms +step:128/1670 train_time:11229ms step_avg:87.73ms +step:129/1670 train_time:11319ms step_avg:87.74ms +step:130/1670 train_time:11407ms step_avg:87.75ms +step:131/1670 train_time:11494ms step_avg:87.74ms +step:132/1670 train_time:11581ms step_avg:87.73ms +step:133/1670 train_time:11667ms step_avg:87.72ms +step:134/1670 train_time:11754ms step_avg:87.72ms +step:135/1670 train_time:11841ms step_avg:87.71ms +step:136/1670 train_time:11927ms step_avg:87.70ms +step:137/1670 train_time:12014ms step_avg:87.69ms +step:138/1670 train_time:12104ms step_avg:87.71ms +step:139/1670 train_time:12195ms step_avg:87.73ms +step:140/1670 train_time:12283ms step_avg:87.74ms +step:141/1670 train_time:12372ms step_avg:87.74ms +step:142/1670 train_time:12460ms step_avg:87.74ms +step:143/1670 train_time:12547ms step_avg:87.74ms +step:144/1670 train_time:12634ms step_avg:87.74ms +step:145/1670 train_time:12720ms step_avg:87.73ms +step:146/1670 train_time:12807ms step_avg:87.72ms +step:147/1670 train_time:12894ms step_avg:87.71ms +step:148/1670 train_time:12981ms step_avg:87.71ms +step:149/1670 train_time:13069ms step_avg:87.71ms +step:150/1670 train_time:13159ms step_avg:87.73ms +step:151/1670 train_time:13247ms step_avg:87.73ms +step:152/1670 train_time:13335ms step_avg:87.73ms +step:153/1670 train_time:13423ms step_avg:87.73ms +step:154/1670 train_time:13510ms step_avg:87.73ms +step:155/1670 train_time:13598ms step_avg:87.73ms +step:156/1670 train_time:13684ms step_avg:87.72ms +step:157/1670 train_time:13772ms step_avg:87.72ms +step:158/1670 train_time:13859ms step_avg:87.71ms +step:159/1670 train_time:13945ms step_avg:87.70ms +step:160/1670 train_time:14032ms step_avg:87.70ms +step:161/1670 train_time:14120ms step_avg:87.70ms +step:162/1670 train_time:14208ms step_avg:87.70ms +step:163/1670 train_time:14296ms step_avg:87.71ms +step:164/1670 train_time:14384ms step_avg:87.71ms +step:165/1670 train_time:14472ms step_avg:87.71ms +step:166/1670 train_time:14560ms step_avg:87.71ms +step:167/1670 train_time:14647ms step_avg:87.71ms +step:168/1670 train_time:14733ms step_avg:87.70ms +step:169/1670 train_time:14821ms step_avg:87.70ms +step:170/1670 train_time:14907ms step_avg:87.69ms +step:171/1670 train_time:14995ms step_avg:87.69ms +step:172/1670 train_time:15083ms step_avg:87.69ms +step:173/1670 train_time:15171ms step_avg:87.69ms +step:174/1670 train_time:15259ms step_avg:87.69ms +step:175/1670 train_time:15346ms step_avg:87.69ms +step:176/1670 train_time:15434ms step_avg:87.69ms +step:177/1670 train_time:15522ms step_avg:87.70ms +step:178/1670 train_time:15609ms step_avg:87.69ms +step:179/1670 train_time:15697ms step_avg:87.69ms +step:180/1670 train_time:15783ms step_avg:87.69ms +step:181/1670 train_time:15870ms step_avg:87.68ms +step:182/1670 train_time:15957ms step_avg:87.68ms +step:183/1670 train_time:16045ms step_avg:87.68ms +step:184/1670 train_time:16133ms step_avg:87.68ms +step:185/1670 train_time:16220ms step_avg:87.68ms +step:186/1670 train_time:16308ms step_avg:87.68ms +step:187/1670 train_time:16396ms step_avg:87.68ms +step:188/1670 train_time:16483ms step_avg:87.68ms +step:189/1670 train_time:16571ms step_avg:87.68ms +step:190/1670 train_time:16659ms step_avg:87.68ms +step:191/1670 train_time:16746ms step_avg:87.68ms +step:192/1670 train_time:16834ms step_avg:87.67ms +step:193/1670 train_time:16921ms step_avg:87.67ms +step:194/1670 train_time:17008ms step_avg:87.67ms +step:195/1670 train_time:17097ms step_avg:87.67ms +step:196/1670 train_time:17184ms step_avg:87.67ms +step:197/1670 train_time:17272ms step_avg:87.68ms +step:198/1670 train_time:17360ms step_avg:87.68ms +step:199/1670 train_time:17447ms step_avg:87.67ms +step:200/1670 train_time:17535ms step_avg:87.67ms +step:201/1670 train_time:17622ms step_avg:87.67ms +step:202/1670 train_time:17709ms step_avg:87.67ms +step:203/1670 train_time:17798ms step_avg:87.68ms +step:204/1670 train_time:17885ms step_avg:87.67ms +step:205/1670 train_time:17973ms step_avg:87.67ms +step:206/1670 train_time:18061ms step_avg:87.67ms +step:207/1670 train_time:18149ms step_avg:87.67ms +step:208/1670 train_time:18236ms step_avg:87.67ms +step:209/1670 train_time:18323ms step_avg:87.67ms +step:210/1670 train_time:18411ms step_avg:87.67ms +step:211/1670 train_time:18498ms step_avg:87.67ms +step:212/1670 train_time:18585ms step_avg:87.67ms +step:213/1670 train_time:18673ms step_avg:87.67ms +step:214/1670 train_time:18761ms step_avg:87.67ms +step:215/1670 train_time:18848ms step_avg:87.66ms +step:216/1670 train_time:18935ms step_avg:87.66ms +step:217/1670 train_time:19022ms step_avg:87.66ms +step:218/1670 train_time:19110ms step_avg:87.66ms +step:219/1670 train_time:19198ms step_avg:87.66ms +step:220/1670 train_time:19285ms step_avg:87.66ms +step:221/1670 train_time:19372ms step_avg:87.66ms +step:222/1670 train_time:19460ms step_avg:87.66ms +step:223/1670 train_time:19547ms step_avg:87.65ms +step:224/1670 train_time:19635ms step_avg:87.66ms +step:225/1670 train_time:19722ms step_avg:87.65ms +step:226/1670 train_time:19809ms step_avg:87.65ms +step:227/1670 train_time:19897ms step_avg:87.65ms +step:228/1670 train_time:19984ms step_avg:87.65ms +step:229/1670 train_time:20072ms step_avg:87.65ms +step:230/1670 train_time:20161ms step_avg:87.66ms +step:231/1670 train_time:20249ms step_avg:87.66ms +step:232/1670 train_time:20336ms step_avg:87.66ms +step:233/1670 train_time:20423ms step_avg:87.65ms +step:234/1670 train_time:20510ms step_avg:87.65ms +step:235/1670 train_time:20599ms step_avg:87.66ms +step:236/1670 train_time:20686ms step_avg:87.65ms +step:237/1670 train_time:20774ms step_avg:87.66ms +step:238/1670 train_time:20862ms step_avg:87.65ms +step:239/1670 train_time:20949ms step_avg:87.65ms +step:240/1670 train_time:21037ms step_avg:87.65ms +step:241/1670 train_time:21125ms step_avg:87.65ms +step:242/1670 train_time:21212ms step_avg:87.65ms +step:243/1670 train_time:21301ms step_avg:87.66ms +step:244/1670 train_time:21388ms step_avg:87.66ms +step:245/1670 train_time:21475ms step_avg:87.65ms +step:246/1670 train_time:21563ms step_avg:87.65ms +step:247/1670 train_time:21650ms step_avg:87.65ms +step:248/1670 train_time:21738ms step_avg:87.65ms +step:249/1670 train_time:21826ms step_avg:87.65ms +step:250/1670 train_time:21914ms step_avg:87.65ms +step:250/1670 val_loss:3.9771 train_time:22003ms step_avg:88.01ms +step:251/1670 train_time:22024ms step_avg:87.74ms +step:252/1670 train_time:22094ms step_avg:87.67ms +step:253/1670 train_time:22183ms step_avg:87.68ms +step:254/1670 train_time:22271ms step_avg:87.68ms +step:255/1670 train_time:22357ms step_avg:87.67ms +step:256/1670 train_time:22444ms step_avg:87.67ms +step:257/1670 train_time:22531ms step_avg:87.67ms +step:258/1670 train_time:22617ms step_avg:87.66ms +step:259/1670 train_time:22703ms step_avg:87.66ms +step:260/1670 train_time:22791ms step_avg:87.66ms +step:261/1670 train_time:22877ms step_avg:87.65ms +step:262/1670 train_time:22965ms step_avg:87.65ms +step:263/1670 train_time:23057ms step_avg:87.67ms +step:264/1670 train_time:23146ms step_avg:87.68ms +step:265/1670 train_time:23236ms step_avg:87.68ms +step:266/1670 train_time:23323ms step_avg:87.68ms +step:267/1670 train_time:23411ms step_avg:87.68ms +step:268/1670 train_time:23498ms step_avg:87.68ms +step:269/1670 train_time:23585ms step_avg:87.67ms +step:270/1670 train_time:23671ms step_avg:87.67ms +step:271/1670 train_time:23758ms step_avg:87.67ms +step:272/1670 train_time:23845ms step_avg:87.67ms +step:273/1670 train_time:23933ms step_avg:87.67ms +step:274/1670 train_time:24022ms step_avg:87.67ms +step:275/1670 train_time:24111ms step_avg:87.67ms +step:276/1670 train_time:24199ms step_avg:87.68ms +step:277/1670 train_time:24287ms step_avg:87.68ms +step:278/1670 train_time:24374ms step_avg:87.68ms +step:279/1670 train_time:24461ms step_avg:87.67ms +step:280/1670 train_time:24548ms step_avg:87.67ms +step:281/1670 train_time:24635ms step_avg:87.67ms +step:282/1670 train_time:24722ms step_avg:87.67ms +step:283/1670 train_time:24809ms step_avg:87.66ms +step:284/1670 train_time:24896ms step_avg:87.66ms +step:285/1670 train_time:24984ms step_avg:87.66ms +step:286/1670 train_time:25074ms step_avg:87.67ms +step:287/1670 train_time:25161ms step_avg:87.67ms +step:288/1670 train_time:25249ms step_avg:87.67ms +step:289/1670 train_time:25336ms step_avg:87.67ms +step:290/1670 train_time:25423ms step_avg:87.67ms +step:291/1670 train_time:25511ms step_avg:87.67ms +step:292/1670 train_time:25598ms step_avg:87.66ms +step:293/1670 train_time:25685ms step_avg:87.66ms +step:294/1670 train_time:25773ms step_avg:87.66ms +step:295/1670 train_time:25860ms step_avg:87.66ms +step:296/1670 train_time:25947ms step_avg:87.66ms +step:297/1670 train_time:26035ms step_avg:87.66ms +step:298/1670 train_time:26125ms step_avg:87.67ms +step:299/1670 train_time:26214ms step_avg:87.67ms +step:300/1670 train_time:26301ms step_avg:87.67ms +step:301/1670 train_time:26388ms step_avg:87.67ms +step:302/1670 train_time:26475ms step_avg:87.67ms +step:303/1670 train_time:26563ms step_avg:87.67ms +step:304/1670 train_time:26651ms step_avg:87.67ms +step:305/1670 train_time:26738ms step_avg:87.66ms +step:306/1670 train_time:26825ms step_avg:87.66ms +step:307/1670 train_time:26913ms step_avg:87.67ms +step:308/1670 train_time:27000ms step_avg:87.66ms +step:309/1670 train_time:27088ms step_avg:87.66ms +step:310/1670 train_time:27176ms step_avg:87.66ms +step:311/1670 train_time:27264ms step_avg:87.66ms +step:312/1670 train_time:27353ms step_avg:87.67ms +step:313/1670 train_time:27440ms step_avg:87.67ms +step:314/1670 train_time:27528ms step_avg:87.67ms +step:315/1670 train_time:27616ms step_avg:87.67ms +step:316/1670 train_time:27703ms step_avg:87.67ms +step:317/1670 train_time:27790ms step_avg:87.67ms +step:318/1670 train_time:27877ms step_avg:87.66ms +step:319/1670 train_time:27965ms step_avg:87.66ms +step:320/1670 train_time:28054ms step_avg:87.67ms +step:321/1670 train_time:28141ms step_avg:87.67ms +step:322/1670 train_time:28229ms step_avg:87.67ms +step:323/1670 train_time:28317ms step_avg:87.67ms +step:324/1670 train_time:28404ms step_avg:87.67ms +step:325/1670 train_time:28492ms step_avg:87.67ms +step:326/1670 train_time:28579ms step_avg:87.66ms +step:327/1670 train_time:28666ms step_avg:87.66ms +step:328/1670 train_time:28754ms step_avg:87.66ms +step:329/1670 train_time:28841ms step_avg:87.66ms +step:330/1670 train_time:28928ms step_avg:87.66ms +step:331/1670 train_time:29017ms step_avg:87.66ms +step:332/1670 train_time:29104ms step_avg:87.66ms +step:333/1670 train_time:29192ms step_avg:87.66ms +step:334/1670 train_time:29280ms step_avg:87.67ms +step:335/1670 train_time:29368ms step_avg:87.66ms +step:336/1670 train_time:29456ms step_avg:87.67ms +step:337/1670 train_time:29543ms step_avg:87.67ms +step:338/1670 train_time:29631ms step_avg:87.67ms +step:339/1670 train_time:29718ms step_avg:87.66ms +step:340/1670 train_time:29805ms step_avg:87.66ms +step:341/1670 train_time:29893ms step_avg:87.66ms +step:342/1670 train_time:29981ms step_avg:87.66ms +step:343/1670 train_time:30069ms step_avg:87.67ms +step:344/1670 train_time:30157ms step_avg:87.67ms +step:345/1670 train_time:30245ms step_avg:87.67ms +step:346/1670 train_time:30332ms step_avg:87.67ms +step:347/1670 train_time:30419ms step_avg:87.66ms +step:348/1670 train_time:30507ms step_avg:87.66ms +step:349/1670 train_time:30595ms step_avg:87.66ms +step:350/1670 train_time:30682ms step_avg:87.66ms +step:351/1670 train_time:30770ms step_avg:87.66ms +step:352/1670 train_time:30857ms step_avg:87.66ms +step:353/1670 train_time:30945ms step_avg:87.66ms +step:354/1670 train_time:31033ms step_avg:87.66ms +step:355/1670 train_time:31120ms step_avg:87.66ms +step:356/1670 train_time:31208ms step_avg:87.66ms +step:357/1670 train_time:31296ms step_avg:87.66ms +step:358/1670 train_time:31383ms step_avg:87.66ms +step:359/1670 train_time:31471ms step_avg:87.66ms +step:360/1670 train_time:31558ms step_avg:87.66ms +step:361/1670 train_time:31645ms step_avg:87.66ms +step:362/1670 train_time:31733ms step_avg:87.66ms +step:363/1670 train_time:31820ms step_avg:87.66ms +step:364/1670 train_time:31908ms step_avg:87.66ms +step:365/1670 train_time:31996ms step_avg:87.66ms +step:366/1670 train_time:32084ms step_avg:87.66ms +step:367/1670 train_time:32172ms step_avg:87.66ms +step:368/1670 train_time:32259ms step_avg:87.66ms +step:369/1670 train_time:32347ms step_avg:87.66ms +step:370/1670 train_time:32435ms step_avg:87.66ms +step:371/1670 train_time:32523ms step_avg:87.66ms +step:372/1670 train_time:32610ms step_avg:87.66ms +step:373/1670 train_time:32698ms step_avg:87.66ms +step:374/1670 train_time:32785ms step_avg:87.66ms +step:375/1670 train_time:32873ms step_avg:87.66ms +step:375/1670 val_loss:3.8215 train_time:32961ms step_avg:87.90ms +step:376/1670 train_time:32983ms step_avg:87.72ms +step:377/1670 train_time:33051ms step_avg:87.67ms +step:378/1670 train_time:33142ms step_avg:87.68ms +step:379/1670 train_time:33229ms step_avg:87.67ms +step:380/1670 train_time:33315ms step_avg:87.67ms +step:381/1670 train_time:33403ms step_avg:87.67ms +step:382/1670 train_time:33489ms step_avg:87.67ms +step:383/1670 train_time:33576ms step_avg:87.67ms +step:384/1670 train_time:33662ms step_avg:87.66ms +step:385/1670 train_time:33750ms step_avg:87.66ms +step:386/1670 train_time:33837ms step_avg:87.66ms +step:387/1670 train_time:33927ms step_avg:87.67ms +step:388/1670 train_time:34016ms step_avg:87.67ms +step:389/1670 train_time:34107ms step_avg:87.68ms +step:390/1670 train_time:34195ms step_avg:87.68ms +step:391/1670 train_time:34282ms step_avg:87.68ms +step:392/1670 train_time:34369ms step_avg:87.68ms +step:393/1670 train_time:34456ms step_avg:87.68ms +step:394/1670 train_time:34543ms step_avg:87.67ms +step:395/1670 train_time:34630ms step_avg:87.67ms +step:396/1670 train_time:34716ms step_avg:87.67ms +step:397/1670 train_time:34803ms step_avg:87.67ms +step:398/1670 train_time:34890ms step_avg:87.66ms +step:399/1670 train_time:34979ms step_avg:87.67ms +step:400/1670 train_time:35068ms step_avg:87.67ms +step:401/1670 train_time:35156ms step_avg:87.67ms +step:402/1670 train_time:35245ms step_avg:87.68ms +step:403/1670 train_time:35333ms step_avg:87.67ms +step:404/1670 train_time:35420ms step_avg:87.67ms +step:405/1670 train_time:35508ms step_avg:87.67ms +step:406/1670 train_time:35595ms step_avg:87.67ms +step:407/1670 train_time:35682ms step_avg:87.67ms +step:408/1670 train_time:35769ms step_avg:87.67ms +step:409/1670 train_time:35857ms step_avg:87.67ms +step:410/1670 train_time:35945ms step_avg:87.67ms +step:411/1670 train_time:36033ms step_avg:87.67ms +step:412/1670 train_time:36122ms step_avg:87.67ms +step:413/1670 train_time:36210ms step_avg:87.67ms +step:414/1670 train_time:36298ms step_avg:87.68ms +step:415/1670 train_time:36386ms step_avg:87.68ms +step:416/1670 train_time:36473ms step_avg:87.68ms +step:417/1670 train_time:36561ms step_avg:87.68ms +step:418/1670 train_time:36648ms step_avg:87.67ms +step:419/1670 train_time:36735ms step_avg:87.67ms +step:420/1670 train_time:36822ms step_avg:87.67ms +step:421/1670 train_time:36911ms step_avg:87.67ms +step:422/1670 train_time:36998ms step_avg:87.67ms +step:423/1670 train_time:37087ms step_avg:87.68ms +step:424/1670 train_time:37175ms step_avg:87.68ms +step:425/1670 train_time:37263ms step_avg:87.68ms +step:426/1670 train_time:37350ms step_avg:87.68ms +step:427/1670 train_time:37438ms step_avg:87.68ms +step:428/1670 train_time:37525ms step_avg:87.68ms +step:429/1670 train_time:37613ms step_avg:87.68ms +step:430/1670 train_time:37701ms step_avg:87.68ms +step:431/1670 train_time:37788ms step_avg:87.68ms +step:432/1670 train_time:37876ms step_avg:87.68ms +step:433/1670 train_time:37965ms step_avg:87.68ms +step:434/1670 train_time:38052ms step_avg:87.68ms +step:435/1670 train_time:38140ms step_avg:87.68ms +step:436/1670 train_time:38227ms step_avg:87.68ms +step:437/1670 train_time:38315ms step_avg:87.68ms +step:438/1670 train_time:38403ms step_avg:87.68ms +step:439/1670 train_time:38490ms step_avg:87.68ms +step:440/1670 train_time:38577ms step_avg:87.68ms +step:441/1670 train_time:38665ms step_avg:87.68ms +step:442/1670 train_time:38752ms step_avg:87.67ms +step:443/1670 train_time:38840ms step_avg:87.68ms +step:444/1670 train_time:38928ms step_avg:87.68ms +step:445/1670 train_time:39016ms step_avg:87.68ms +step:446/1670 train_time:39104ms step_avg:87.68ms +step:447/1670 train_time:39191ms step_avg:87.68ms +step:448/1670 train_time:39279ms step_avg:87.68ms +step:449/1670 train_time:39367ms step_avg:87.68ms +step:450/1670 train_time:39454ms step_avg:87.68ms +step:451/1670 train_time:39542ms step_avg:87.68ms +step:452/1670 train_time:39629ms step_avg:87.68ms +step:453/1670 train_time:39717ms step_avg:87.67ms +step:454/1670 train_time:39805ms step_avg:87.68ms +step:455/1670 train_time:39893ms step_avg:87.68ms +step:456/1670 train_time:39980ms step_avg:87.68ms +step:457/1670 train_time:40068ms step_avg:87.68ms +step:458/1670 train_time:40155ms step_avg:87.67ms +step:459/1670 train_time:40243ms step_avg:87.68ms +step:460/1670 train_time:40330ms step_avg:87.67ms +step:461/1670 train_time:40418ms step_avg:87.67ms +step:462/1670 train_time:40506ms step_avg:87.68ms +step:463/1670 train_time:40594ms step_avg:87.68ms +step:464/1670 train_time:40682ms step_avg:87.68ms +step:465/1670 train_time:40769ms step_avg:87.68ms +step:466/1670 train_time:40857ms step_avg:87.68ms +step:467/1670 train_time:40945ms step_avg:87.68ms +step:468/1670 train_time:41032ms step_avg:87.68ms +step:469/1670 train_time:41119ms step_avg:87.67ms +step:470/1670 train_time:41207ms step_avg:87.67ms +step:471/1670 train_time:41294ms step_avg:87.67ms +step:472/1670 train_time:41381ms step_avg:87.67ms +step:473/1670 train_time:41468ms step_avg:87.67ms +step:474/1670 train_time:41556ms step_avg:87.67ms +step:475/1670 train_time:41646ms step_avg:87.68ms +step:476/1670 train_time:41733ms step_avg:87.68ms +step:477/1670 train_time:41821ms step_avg:87.68ms +step:478/1670 train_time:41909ms step_avg:87.68ms +step:479/1670 train_time:41997ms step_avg:87.68ms +step:480/1670 train_time:42086ms step_avg:87.68ms +step:481/1670 train_time:42174ms step_avg:87.68ms +step:482/1670 train_time:42261ms step_avg:87.68ms +step:483/1670 train_time:42348ms step_avg:87.68ms +step:484/1670 train_time:42436ms step_avg:87.68ms +step:485/1670 train_time:42524ms step_avg:87.68ms +step:486/1670 train_time:42611ms step_avg:87.68ms +step:487/1670 train_time:42700ms step_avg:87.68ms +step:488/1670 train_time:42787ms step_avg:87.68ms +step:489/1670 train_time:42875ms step_avg:87.68ms +step:490/1670 train_time:42962ms step_avg:87.68ms +step:491/1670 train_time:43049ms step_avg:87.68ms +step:492/1670 train_time:43138ms step_avg:87.68ms +step:493/1670 train_time:43225ms step_avg:87.68ms +step:494/1670 train_time:43312ms step_avg:87.68ms +step:495/1670 train_time:43400ms step_avg:87.68ms +step:496/1670 train_time:43488ms step_avg:87.68ms +step:497/1670 train_time:43576ms step_avg:87.68ms +step:498/1670 train_time:43665ms step_avg:87.68ms +step:499/1670 train_time:43752ms step_avg:87.68ms +step:500/1670 train_time:43840ms step_avg:87.68ms +step:500/1670 val_loss:3.7174 train_time:43929ms step_avg:87.86ms +step:501/1670 train_time:43948ms step_avg:87.72ms +step:502/1670 train_time:44021ms step_avg:87.69ms +step:503/1670 train_time:44114ms step_avg:87.70ms +step:504/1670 train_time:44201ms step_avg:87.70ms +step:505/1670 train_time:44288ms step_avg:87.70ms +step:506/1670 train_time:44375ms step_avg:87.70ms +step:507/1670 train_time:44461ms step_avg:87.69ms +step:508/1670 train_time:44548ms step_avg:87.69ms +step:509/1670 train_time:44635ms step_avg:87.69ms +step:510/1670 train_time:44722ms step_avg:87.69ms +step:511/1670 train_time:44809ms step_avg:87.69ms +step:512/1670 train_time:44899ms step_avg:87.69ms +step:513/1670 train_time:44988ms step_avg:87.70ms +step:514/1670 train_time:45078ms step_avg:87.70ms +step:515/1670 train_time:45167ms step_avg:87.70ms +step:516/1670 train_time:45255ms step_avg:87.70ms +step:517/1670 train_time:45342ms step_avg:87.70ms +step:518/1670 train_time:45429ms step_avg:87.70ms +step:519/1670 train_time:45517ms step_avg:87.70ms +step:520/1670 train_time:45603ms step_avg:87.70ms +step:521/1670 train_time:45690ms step_avg:87.70ms +step:522/1670 train_time:45777ms step_avg:87.70ms +step:523/1670 train_time:45865ms step_avg:87.70ms +step:524/1670 train_time:45955ms step_avg:87.70ms +step:525/1670 train_time:46043ms step_avg:87.70ms +step:526/1670 train_time:46132ms step_avg:87.70ms +step:527/1670 train_time:46220ms step_avg:87.70ms +step:528/1670 train_time:46307ms step_avg:87.70ms +step:529/1670 train_time:46395ms step_avg:87.70ms +step:530/1670 train_time:46481ms step_avg:87.70ms +step:531/1670 train_time:46568ms step_avg:87.70ms +step:532/1670 train_time:46656ms step_avg:87.70ms +step:533/1670 train_time:46742ms step_avg:87.70ms +step:534/1670 train_time:46830ms step_avg:87.70ms +step:535/1670 train_time:46919ms step_avg:87.70ms +step:536/1670 train_time:47007ms step_avg:87.70ms +step:537/1670 train_time:47096ms step_avg:87.70ms +step:538/1670 train_time:47183ms step_avg:87.70ms +step:539/1670 train_time:47271ms step_avg:87.70ms +step:540/1670 train_time:47359ms step_avg:87.70ms +step:541/1670 train_time:47447ms step_avg:87.70ms +step:542/1670 train_time:47534ms step_avg:87.70ms +step:543/1670 train_time:47621ms step_avg:87.70ms +step:544/1670 train_time:47708ms step_avg:87.70ms +step:545/1670 train_time:47796ms step_avg:87.70ms +step:546/1670 train_time:47885ms step_avg:87.70ms +step:547/1670 train_time:47974ms step_avg:87.70ms +step:548/1670 train_time:48063ms step_avg:87.71ms +step:549/1670 train_time:48152ms step_avg:87.71ms +step:550/1670 train_time:48241ms step_avg:87.71ms +step:551/1670 train_time:48330ms step_avg:87.71ms +step:552/1670 train_time:48419ms step_avg:87.72ms +step:553/1670 train_time:48508ms step_avg:87.72ms +step:554/1670 train_time:48597ms step_avg:87.72ms +step:555/1670 train_time:48687ms step_avg:87.72ms +step:556/1670 train_time:48775ms step_avg:87.73ms +step:557/1670 train_time:48863ms step_avg:87.73ms +step:558/1670 train_time:48953ms step_avg:87.73ms +step:559/1670 train_time:49041ms step_avg:87.73ms +step:560/1670 train_time:49131ms step_avg:87.73ms +step:561/1670 train_time:49220ms step_avg:87.74ms +step:562/1670 train_time:49310ms step_avg:87.74ms +step:563/1670 train_time:49399ms step_avg:87.74ms +step:564/1670 train_time:49488ms step_avg:87.74ms +step:565/1670 train_time:49576ms step_avg:87.75ms +step:566/1670 train_time:49666ms step_avg:87.75ms +step:567/1670 train_time:49755ms step_avg:87.75ms +step:568/1670 train_time:49844ms step_avg:87.75ms +step:569/1670 train_time:49933ms step_avg:87.76ms +step:570/1670 train_time:50021ms step_avg:87.76ms +step:571/1670 train_time:50111ms step_avg:87.76ms +step:572/1670 train_time:50200ms step_avg:87.76ms +step:573/1670 train_time:50290ms step_avg:87.77ms +step:574/1670 train_time:50378ms step_avg:87.77ms +step:575/1670 train_time:50467ms step_avg:87.77ms +step:576/1670 train_time:50557ms step_avg:87.77ms +step:577/1670 train_time:50645ms step_avg:87.77ms +step:578/1670 train_time:50734ms step_avg:87.77ms +step:579/1670 train_time:50822ms step_avg:87.78ms +step:580/1670 train_time:50911ms step_avg:87.78ms +step:581/1670 train_time:51001ms step_avg:87.78ms +step:582/1670 train_time:51090ms step_avg:87.78ms +step:583/1670 train_time:51179ms step_avg:87.78ms +step:584/1670 train_time:51268ms step_avg:87.79ms +step:585/1670 train_time:51357ms step_avg:87.79ms +step:586/1670 train_time:51445ms step_avg:87.79ms +step:587/1670 train_time:51535ms step_avg:87.79ms +step:588/1670 train_time:51623ms step_avg:87.79ms +step:589/1670 train_time:51713ms step_avg:87.80ms +step:590/1670 train_time:51801ms step_avg:87.80ms +step:591/1670 train_time:51891ms step_avg:87.80ms +step:592/1670 train_time:51980ms step_avg:87.80ms +step:593/1670 train_time:52069ms step_avg:87.81ms +step:594/1670 train_time:52158ms step_avg:87.81ms +step:595/1670 train_time:52247ms step_avg:87.81ms +step:596/1670 train_time:52336ms step_avg:87.81ms +step:597/1670 train_time:52426ms step_avg:87.82ms +step:598/1670 train_time:52516ms step_avg:87.82ms +step:599/1670 train_time:52605ms step_avg:87.82ms +step:600/1670 train_time:52694ms step_avg:87.82ms +step:601/1670 train_time:52783ms step_avg:87.82ms +step:602/1670 train_time:52872ms step_avg:87.83ms +step:603/1670 train_time:52960ms step_avg:87.83ms +step:604/1670 train_time:53049ms step_avg:87.83ms +step:605/1670 train_time:53138ms step_avg:87.83ms +step:606/1670 train_time:53227ms step_avg:87.83ms +step:607/1670 train_time:53316ms step_avg:87.83ms +step:608/1670 train_time:53404ms step_avg:87.84ms +step:609/1670 train_time:53495ms step_avg:87.84ms +step:610/1670 train_time:53584ms step_avg:87.84ms +step:611/1670 train_time:53673ms step_avg:87.84ms +step:612/1670 train_time:53761ms step_avg:87.85ms +step:613/1670 train_time:53851ms step_avg:87.85ms +step:614/1670 train_time:53939ms step_avg:87.85ms +step:615/1670 train_time:54028ms step_avg:87.85ms +step:616/1670 train_time:54117ms step_avg:87.85ms +step:617/1670 train_time:54206ms step_avg:87.85ms +step:618/1670 train_time:54295ms step_avg:87.86ms +step:619/1670 train_time:54384ms step_avg:87.86ms +step:620/1670 train_time:54474ms step_avg:87.86ms +step:621/1670 train_time:54562ms step_avg:87.86ms +step:622/1670 train_time:54651ms step_avg:87.86ms +step:623/1670 train_time:54740ms step_avg:87.86ms +step:624/1670 train_time:54829ms step_avg:87.87ms +step:625/1670 train_time:54918ms step_avg:87.87ms +step:625/1670 val_loss:3.6176 train_time:55009ms step_avg:88.01ms +step:626/1670 train_time:55029ms step_avg:87.91ms +step:627/1670 train_time:55099ms step_avg:87.88ms +step:628/1670 train_time:55188ms step_avg:87.88ms +step:629/1670 train_time:55277ms step_avg:87.88ms +step:630/1670 train_time:55366ms step_avg:87.88ms +step:631/1670 train_time:55454ms step_avg:87.88ms +step:632/1670 train_time:55541ms step_avg:87.88ms +step:633/1670 train_time:55629ms step_avg:87.88ms +step:634/1670 train_time:55716ms step_avg:87.88ms +step:635/1670 train_time:55804ms step_avg:87.88ms +step:636/1670 train_time:55896ms step_avg:87.89ms +step:637/1670 train_time:55991ms step_avg:87.90ms +step:638/1670 train_time:56081ms step_avg:87.90ms +step:639/1670 train_time:56172ms step_avg:87.91ms +step:640/1670 train_time:56260ms step_avg:87.91ms +step:641/1670 train_time:56350ms step_avg:87.91ms +step:642/1670 train_time:56438ms step_avg:87.91ms +step:643/1670 train_time:56525ms step_avg:87.91ms +step:644/1670 train_time:56613ms step_avg:87.91ms +step:645/1670 train_time:56701ms step_avg:87.91ms +step:646/1670 train_time:56790ms step_avg:87.91ms +step:647/1670 train_time:56880ms step_avg:87.91ms +step:648/1670 train_time:56972ms step_avg:87.92ms +step:649/1670 train_time:57063ms step_avg:87.92ms +step:650/1670 train_time:57153ms step_avg:87.93ms +step:651/1670 train_time:57243ms step_avg:87.93ms +step:652/1670 train_time:57332ms step_avg:87.93ms +step:653/1670 train_time:57421ms step_avg:87.93ms +step:654/1670 train_time:57509ms step_avg:87.93ms +step:655/1670 train_time:57597ms step_avg:87.93ms +step:656/1670 train_time:57685ms step_avg:87.93ms +step:657/1670 train_time:57773ms step_avg:87.93ms +step:658/1670 train_time:57863ms step_avg:87.94ms +step:659/1670 train_time:57953ms step_avg:87.94ms +step:660/1670 train_time:58042ms step_avg:87.94ms +step:661/1670 train_time:58132ms step_avg:87.95ms +step:662/1670 train_time:58222ms step_avg:87.95ms +step:663/1670 train_time:58310ms step_avg:87.95ms +step:664/1670 train_time:58399ms step_avg:87.95ms +step:665/1670 train_time:58487ms step_avg:87.95ms +step:666/1670 train_time:58575ms step_avg:87.95ms +step:667/1670 train_time:58664ms step_avg:87.95ms +step:668/1670 train_time:58753ms step_avg:87.95ms +step:669/1670 train_time:58842ms step_avg:87.96ms +step:670/1670 train_time:58932ms step_avg:87.96ms +step:671/1670 train_time:59020ms step_avg:87.96ms +step:672/1670 train_time:59110ms step_avg:87.96ms +step:673/1670 train_time:59199ms step_avg:87.96ms +step:674/1670 train_time:59287ms step_avg:87.96ms +step:675/1670 train_time:59375ms step_avg:87.96ms +step:676/1670 train_time:59464ms step_avg:87.97ms +step:677/1670 train_time:59553ms step_avg:87.97ms +step:678/1670 train_time:59641ms step_avg:87.97ms +step:679/1670 train_time:59729ms step_avg:87.97ms +step:680/1670 train_time:59818ms step_avg:87.97ms +step:681/1670 train_time:59907ms step_avg:87.97ms +step:682/1670 train_time:59996ms step_avg:87.97ms +step:683/1670 train_time:60085ms step_avg:87.97ms +step:684/1670 train_time:60173ms step_avg:87.97ms +step:685/1670 train_time:60262ms step_avg:87.97ms +step:686/1670 train_time:60352ms step_avg:87.98ms +step:687/1670 train_time:60442ms step_avg:87.98ms +step:688/1670 train_time:60531ms step_avg:87.98ms +step:689/1670 train_time:60619ms step_avg:87.98ms +step:690/1670 train_time:60709ms step_avg:87.98ms +step:691/1670 train_time:60797ms step_avg:87.98ms +step:692/1670 train_time:60886ms step_avg:87.99ms +step:693/1670 train_time:60975ms step_avg:87.99ms +step:694/1670 train_time:61065ms step_avg:87.99ms +step:695/1670 train_time:61155ms step_avg:87.99ms +step:696/1670 train_time:61244ms step_avg:87.99ms +step:697/1670 train_time:61333ms step_avg:88.00ms +step:698/1670 train_time:61422ms step_avg:88.00ms +step:699/1670 train_time:61511ms step_avg:88.00ms +step:700/1670 train_time:61600ms step_avg:88.00ms +step:701/1670 train_time:61689ms step_avg:88.00ms +step:702/1670 train_time:61777ms step_avg:88.00ms +step:703/1670 train_time:61867ms step_avg:88.00ms +step:704/1670 train_time:61956ms step_avg:88.01ms +step:705/1670 train_time:62045ms step_avg:88.01ms +step:706/1670 train_time:62134ms step_avg:88.01ms +step:707/1670 train_time:62223ms step_avg:88.01ms +step:708/1670 train_time:62312ms step_avg:88.01ms +step:709/1670 train_time:62401ms step_avg:88.01ms +step:710/1670 train_time:62491ms step_avg:88.01ms +step:711/1670 train_time:62579ms step_avg:88.02ms +step:712/1670 train_time:62668ms step_avg:88.02ms +step:713/1670 train_time:62757ms step_avg:88.02ms +step:714/1670 train_time:62846ms step_avg:88.02ms +step:715/1670 train_time:62934ms step_avg:88.02ms +step:716/1670 train_time:63023ms step_avg:88.02ms +step:717/1670 train_time:63111ms step_avg:88.02ms +step:718/1670 train_time:63200ms step_avg:88.02ms +step:719/1670 train_time:63290ms step_avg:88.02ms +step:720/1670 train_time:63378ms step_avg:88.02ms +step:721/1670 train_time:63467ms step_avg:88.03ms +step:722/1670 train_time:63555ms step_avg:88.03ms +step:723/1670 train_time:63644ms step_avg:88.03ms +step:724/1670 train_time:63733ms step_avg:88.03ms +step:725/1670 train_time:63822ms step_avg:88.03ms +step:726/1670 train_time:63912ms step_avg:88.03ms +step:727/1670 train_time:64001ms step_avg:88.03ms +step:728/1670 train_time:64090ms step_avg:88.04ms +step:729/1670 train_time:64178ms step_avg:88.04ms +step:730/1670 train_time:64268ms step_avg:88.04ms +step:731/1670 train_time:64356ms step_avg:88.04ms +step:732/1670 train_time:64445ms step_avg:88.04ms +step:733/1670 train_time:64534ms step_avg:88.04ms +step:734/1670 train_time:64623ms step_avg:88.04ms +step:735/1670 train_time:64711ms step_avg:88.04ms +step:736/1670 train_time:64800ms step_avg:88.04ms +step:737/1670 train_time:64889ms step_avg:88.05ms +step:738/1670 train_time:64978ms step_avg:88.05ms +step:739/1670 train_time:65068ms step_avg:88.05ms +step:740/1670 train_time:65157ms step_avg:88.05ms +step:741/1670 train_time:65247ms step_avg:88.05ms +step:742/1670 train_time:65335ms step_avg:88.05ms +step:743/1670 train_time:65424ms step_avg:88.05ms +step:744/1670 train_time:65513ms step_avg:88.06ms +step:745/1670 train_time:65602ms step_avg:88.06ms +step:746/1670 train_time:65693ms step_avg:88.06ms +step:747/1670 train_time:65782ms step_avg:88.06ms +step:748/1670 train_time:65871ms step_avg:88.06ms +step:749/1670 train_time:65959ms step_avg:88.06ms +step:750/1670 train_time:66049ms step_avg:88.07ms +step:750/1670 val_loss:3.5665 train_time:66139ms step_avg:88.19ms +step:751/1670 train_time:66159ms step_avg:88.09ms +step:752/1670 train_time:66233ms step_avg:88.08ms +step:753/1670 train_time:66323ms step_avg:88.08ms +step:754/1670 train_time:66414ms step_avg:88.08ms +step:755/1670 train_time:66502ms step_avg:88.08ms +step:756/1670 train_time:66591ms step_avg:88.08ms +step:757/1670 train_time:66680ms step_avg:88.08ms +step:758/1670 train_time:66768ms step_avg:88.08ms +step:759/1670 train_time:66856ms step_avg:88.08ms +step:760/1670 train_time:66943ms step_avg:88.08ms +step:761/1670 train_time:67032ms step_avg:88.08ms +step:762/1670 train_time:67122ms step_avg:88.09ms +step:763/1670 train_time:67213ms step_avg:88.09ms +step:764/1670 train_time:67304ms step_avg:88.09ms +step:765/1670 train_time:67393ms step_avg:88.10ms +step:766/1670 train_time:67482ms step_avg:88.10ms +step:767/1670 train_time:67573ms step_avg:88.10ms +step:768/1670 train_time:67660ms step_avg:88.10ms +step:769/1670 train_time:67749ms step_avg:88.10ms +step:770/1670 train_time:67837ms step_avg:88.10ms +step:771/1670 train_time:67925ms step_avg:88.10ms +step:772/1670 train_time:68014ms step_avg:88.10ms +step:773/1670 train_time:68103ms step_avg:88.10ms +step:774/1670 train_time:68194ms step_avg:88.11ms +step:775/1670 train_time:68283ms step_avg:88.11ms +step:776/1670 train_time:68374ms step_avg:88.11ms +step:777/1670 train_time:68463ms step_avg:88.11ms +step:778/1670 train_time:68553ms step_avg:88.11ms +step:779/1670 train_time:68642ms step_avg:88.12ms +step:780/1670 train_time:68730ms step_avg:88.12ms +step:781/1670 train_time:68818ms step_avg:88.12ms +step:782/1670 train_time:68906ms step_avg:88.12ms +step:783/1670 train_time:68996ms step_avg:88.12ms +step:784/1670 train_time:69084ms step_avg:88.12ms +step:785/1670 train_time:69174ms step_avg:88.12ms +step:786/1670 train_time:69263ms step_avg:88.12ms +step:787/1670 train_time:69353ms step_avg:88.12ms +step:788/1670 train_time:69441ms step_avg:88.12ms +step:789/1670 train_time:69531ms step_avg:88.13ms +step:790/1670 train_time:69619ms step_avg:88.13ms +step:791/1670 train_time:69708ms step_avg:88.13ms +step:792/1670 train_time:69796ms step_avg:88.13ms +step:793/1670 train_time:69884ms step_avg:88.13ms +step:794/1670 train_time:69974ms step_avg:88.13ms +step:795/1670 train_time:70063ms step_avg:88.13ms +step:796/1670 train_time:70153ms step_avg:88.13ms +step:797/1670 train_time:70241ms step_avg:88.13ms +step:798/1670 train_time:70330ms step_avg:88.13ms +step:799/1670 train_time:70419ms step_avg:88.13ms +step:800/1670 train_time:70510ms step_avg:88.14ms +step:801/1670 train_time:70598ms step_avg:88.14ms +step:802/1670 train_time:70687ms step_avg:88.14ms +step:803/1670 train_time:70775ms step_avg:88.14ms +step:804/1670 train_time:70863ms step_avg:88.14ms +step:805/1670 train_time:70952ms step_avg:88.14ms +step:806/1670 train_time:71041ms step_avg:88.14ms +step:807/1670 train_time:71131ms step_avg:88.14ms +step:808/1670 train_time:71219ms step_avg:88.14ms +step:809/1670 train_time:71308ms step_avg:88.14ms +step:810/1670 train_time:71398ms step_avg:88.15ms +step:811/1670 train_time:71487ms step_avg:88.15ms +step:812/1670 train_time:71576ms step_avg:88.15ms +step:813/1670 train_time:71665ms step_avg:88.15ms +step:814/1670 train_time:71754ms step_avg:88.15ms +step:815/1670 train_time:71842ms step_avg:88.15ms +step:816/1670 train_time:71931ms step_avg:88.15ms +step:817/1670 train_time:72019ms step_avg:88.15ms +step:818/1670 train_time:72108ms step_avg:88.15ms +step:819/1670 train_time:72197ms step_avg:88.15ms +step:820/1670 train_time:72286ms step_avg:88.15ms +step:821/1670 train_time:72375ms step_avg:88.15ms +step:822/1670 train_time:72464ms step_avg:88.16ms +step:823/1670 train_time:72553ms step_avg:88.16ms +step:824/1670 train_time:72641ms step_avg:88.16ms +step:825/1670 train_time:72731ms step_avg:88.16ms +step:826/1670 train_time:72819ms step_avg:88.16ms +step:827/1670 train_time:72908ms step_avg:88.16ms +step:828/1670 train_time:72997ms step_avg:88.16ms +step:829/1670 train_time:73086ms step_avg:88.16ms +step:830/1670 train_time:73174ms step_avg:88.16ms +step:831/1670 train_time:73263ms step_avg:88.16ms +step:832/1670 train_time:73352ms step_avg:88.16ms +step:833/1670 train_time:73440ms step_avg:88.16ms +step:834/1670 train_time:73530ms step_avg:88.17ms +step:835/1670 train_time:73618ms step_avg:88.17ms +step:836/1670 train_time:73708ms step_avg:88.17ms +step:837/1670 train_time:73797ms step_avg:88.17ms +step:838/1670 train_time:73886ms step_avg:88.17ms +step:839/1670 train_time:73975ms step_avg:88.17ms +step:840/1670 train_time:74063ms step_avg:88.17ms +step:841/1670 train_time:74153ms step_avg:88.17ms +step:842/1670 train_time:74241ms step_avg:88.17ms +step:843/1670 train_time:74331ms step_avg:88.17ms +step:844/1670 train_time:74419ms step_avg:88.17ms +step:845/1670 train_time:74508ms step_avg:88.18ms +step:846/1670 train_time:74597ms step_avg:88.18ms +step:847/1670 train_time:74686ms step_avg:88.18ms +step:848/1670 train_time:74775ms step_avg:88.18ms +step:849/1670 train_time:74864ms step_avg:88.18ms +step:850/1670 train_time:74953ms step_avg:88.18ms +step:851/1670 train_time:75042ms step_avg:88.18ms +step:852/1670 train_time:75131ms step_avg:88.18ms +step:853/1670 train_time:75220ms step_avg:88.18ms +step:854/1670 train_time:75309ms step_avg:88.18ms +step:855/1670 train_time:75397ms step_avg:88.18ms +step:856/1670 train_time:75486ms step_avg:88.18ms +step:857/1670 train_time:75576ms step_avg:88.19ms +step:858/1670 train_time:75666ms step_avg:88.19ms +step:859/1670 train_time:75754ms step_avg:88.19ms +step:860/1670 train_time:75843ms step_avg:88.19ms +step:861/1670 train_time:75932ms step_avg:88.19ms +step:862/1670 train_time:76021ms step_avg:88.19ms +step:863/1670 train_time:76111ms step_avg:88.19ms +step:864/1670 train_time:76200ms step_avg:88.19ms +step:865/1670 train_time:76290ms step_avg:88.20ms +step:866/1670 train_time:76379ms step_avg:88.20ms +step:867/1670 train_time:76468ms step_avg:88.20ms +step:868/1670 train_time:76556ms step_avg:88.20ms +step:869/1670 train_time:76645ms step_avg:88.20ms +step:870/1670 train_time:76735ms step_avg:88.20ms +step:871/1670 train_time:76823ms step_avg:88.20ms +step:872/1670 train_time:76912ms step_avg:88.20ms +step:873/1670 train_time:77000ms step_avg:88.20ms +step:874/1670 train_time:77090ms step_avg:88.20ms +step:875/1670 train_time:77178ms step_avg:88.20ms +step:875/1670 val_loss:3.5204 train_time:77269ms step_avg:88.31ms +step:876/1670 train_time:77288ms step_avg:88.23ms +step:877/1670 train_time:77359ms step_avg:88.21ms +step:878/1670 train_time:77451ms step_avg:88.21ms +step:879/1670 train_time:77541ms step_avg:88.21ms +step:880/1670 train_time:77629ms step_avg:88.21ms +step:881/1670 train_time:77716ms step_avg:88.21ms +step:882/1670 train_time:77804ms step_avg:88.21ms +step:883/1670 train_time:77892ms step_avg:88.21ms +step:884/1670 train_time:77980ms step_avg:88.21ms +step:885/1670 train_time:78068ms step_avg:88.21ms +step:886/1670 train_time:78156ms step_avg:88.21ms +step:887/1670 train_time:78247ms step_avg:88.21ms +step:888/1670 train_time:78338ms step_avg:88.22ms +step:889/1670 train_time:78430ms step_avg:88.22ms +step:890/1670 train_time:78520ms step_avg:88.22ms +step:891/1670 train_time:78608ms step_avg:88.22ms +step:892/1670 train_time:78697ms step_avg:88.23ms +step:893/1670 train_time:78785ms step_avg:88.23ms +step:894/1670 train_time:78874ms step_avg:88.23ms +step:895/1670 train_time:78962ms step_avg:88.23ms +step:896/1670 train_time:79050ms step_avg:88.23ms +step:897/1670 train_time:79138ms step_avg:88.22ms +step:898/1670 train_time:79227ms step_avg:88.23ms +step:899/1670 train_time:79316ms step_avg:88.23ms +step:900/1670 train_time:79406ms step_avg:88.23ms +step:901/1670 train_time:79496ms step_avg:88.23ms +step:902/1670 train_time:79585ms step_avg:88.23ms +step:903/1670 train_time:79674ms step_avg:88.23ms +step:904/1670 train_time:79762ms step_avg:88.23ms +step:905/1670 train_time:79851ms step_avg:88.23ms +step:906/1670 train_time:79940ms step_avg:88.23ms +step:907/1670 train_time:80029ms step_avg:88.23ms +step:908/1670 train_time:80117ms step_avg:88.23ms +step:909/1670 train_time:80206ms step_avg:88.24ms +step:910/1670 train_time:80296ms step_avg:88.24ms +step:911/1670 train_time:80385ms step_avg:88.24ms +step:912/1670 train_time:80475ms step_avg:88.24ms +step:913/1670 train_time:80564ms step_avg:88.24ms +step:914/1670 train_time:80654ms step_avg:88.24ms +step:915/1670 train_time:80743ms step_avg:88.24ms +step:916/1670 train_time:80832ms step_avg:88.24ms +step:917/1670 train_time:80920ms step_avg:88.24ms +step:918/1670 train_time:81009ms step_avg:88.24ms +step:919/1670 train_time:81097ms step_avg:88.25ms +step:920/1670 train_time:81186ms step_avg:88.25ms +step:921/1670 train_time:81276ms step_avg:88.25ms +step:922/1670 train_time:81366ms step_avg:88.25ms +step:923/1670 train_time:81456ms step_avg:88.25ms +step:924/1670 train_time:81546ms step_avg:88.25ms +step:925/1670 train_time:81636ms step_avg:88.25ms +step:926/1670 train_time:81724ms step_avg:88.25ms +step:927/1670 train_time:81813ms step_avg:88.26ms +step:928/1670 train_time:81901ms step_avg:88.26ms +step:929/1670 train_time:81990ms step_avg:88.26ms +step:930/1670 train_time:82078ms step_avg:88.26ms +step:931/1670 train_time:82167ms step_avg:88.26ms +step:932/1670 train_time:82256ms step_avg:88.26ms +step:933/1670 train_time:82345ms step_avg:88.26ms +step:934/1670 train_time:82435ms step_avg:88.26ms +step:935/1670 train_time:82525ms step_avg:88.26ms +step:936/1670 train_time:82614ms step_avg:88.26ms +step:937/1670 train_time:82702ms step_avg:88.26ms +step:938/1670 train_time:82792ms step_avg:88.26ms +step:939/1670 train_time:82880ms step_avg:88.26ms +step:940/1670 train_time:82969ms step_avg:88.27ms +step:941/1670 train_time:83058ms step_avg:88.27ms +step:942/1670 train_time:83147ms step_avg:88.27ms +step:943/1670 train_time:83237ms step_avg:88.27ms +step:944/1670 train_time:83326ms step_avg:88.27ms +step:945/1670 train_time:83415ms step_avg:88.27ms +step:946/1670 train_time:83504ms step_avg:88.27ms +step:947/1670 train_time:83594ms step_avg:88.27ms +step:948/1670 train_time:83682ms step_avg:88.27ms +step:949/1670 train_time:83771ms step_avg:88.27ms +step:950/1670 train_time:83860ms step_avg:88.27ms +step:951/1670 train_time:83950ms step_avg:88.28ms +step:952/1670 train_time:84039ms step_avg:88.28ms +step:953/1670 train_time:84128ms step_avg:88.28ms +step:954/1670 train_time:84217ms step_avg:88.28ms +step:955/1670 train_time:84306ms step_avg:88.28ms +step:956/1670 train_time:84395ms step_avg:88.28ms +step:957/1670 train_time:84484ms step_avg:88.28ms +step:958/1670 train_time:84573ms step_avg:88.28ms +step:959/1670 train_time:84661ms step_avg:88.28ms +step:960/1670 train_time:84750ms step_avg:88.28ms +step:961/1670 train_time:84839ms step_avg:88.28ms +step:962/1670 train_time:84929ms step_avg:88.28ms +step:963/1670 train_time:85017ms step_avg:88.28ms +step:964/1670 train_time:85106ms step_avg:88.28ms +step:965/1670 train_time:85196ms step_avg:88.29ms +step:966/1670 train_time:85284ms step_avg:88.29ms +step:967/1670 train_time:85373ms step_avg:88.29ms +step:968/1670 train_time:85462ms step_avg:88.29ms +step:969/1670 train_time:85551ms step_avg:88.29ms +step:970/1670 train_time:85640ms step_avg:88.29ms +step:971/1670 train_time:85729ms step_avg:88.29ms +step:972/1670 train_time:85818ms step_avg:88.29ms +step:973/1670 train_time:85907ms step_avg:88.29ms +step:974/1670 train_time:85996ms step_avg:88.29ms +step:975/1670 train_time:86084ms step_avg:88.29ms +step:976/1670 train_time:86173ms step_avg:88.29ms +step:977/1670 train_time:86261ms step_avg:88.29ms +step:978/1670 train_time:86351ms step_avg:88.29ms +step:979/1670 train_time:86440ms step_avg:88.29ms +step:980/1670 train_time:86529ms step_avg:88.29ms +step:981/1670 train_time:86618ms step_avg:88.30ms +step:982/1670 train_time:86707ms step_avg:88.30ms +step:983/1670 train_time:86796ms step_avg:88.30ms +step:984/1670 train_time:86884ms step_avg:88.30ms +step:985/1670 train_time:86974ms step_avg:88.30ms +step:986/1670 train_time:87062ms step_avg:88.30ms +step:987/1670 train_time:87151ms step_avg:88.30ms +step:988/1670 train_time:87240ms step_avg:88.30ms +step:989/1670 train_time:87329ms step_avg:88.30ms +step:990/1670 train_time:87419ms step_avg:88.30ms +step:991/1670 train_time:87508ms step_avg:88.30ms +step:992/1670 train_time:87597ms step_avg:88.30ms +step:993/1670 train_time:87686ms step_avg:88.30ms +step:994/1670 train_time:87775ms step_avg:88.31ms +step:995/1670 train_time:87864ms step_avg:88.31ms +step:996/1670 train_time:87953ms step_avg:88.31ms +step:997/1670 train_time:88041ms step_avg:88.31ms +step:998/1670 train_time:88130ms step_avg:88.31ms +step:999/1670 train_time:88219ms step_avg:88.31ms +step:1000/1670 train_time:88308ms step_avg:88.31ms +step:1000/1670 val_loss:3.4692 train_time:88399ms step_avg:88.40ms +step:1001/1670 train_time:88420ms step_avg:88.33ms +step:1002/1670 train_time:88492ms step_avg:88.32ms +step:1003/1670 train_time:88583ms step_avg:88.32ms +step:1004/1670 train_time:88672ms step_avg:88.32ms +step:1005/1670 train_time:88760ms step_avg:88.32ms +step:1006/1670 train_time:88847ms step_avg:88.32ms +step:1007/1670 train_time:88935ms step_avg:88.32ms +step:1008/1670 train_time:89024ms step_avg:88.32ms +step:1009/1670 train_time:89111ms step_avg:88.32ms +step:1010/1670 train_time:89200ms step_avg:88.32ms +step:1011/1670 train_time:89288ms step_avg:88.32ms +step:1012/1670 train_time:89379ms step_avg:88.32ms +step:1013/1670 train_time:89472ms step_avg:88.32ms +step:1014/1670 train_time:89562ms step_avg:88.33ms +step:1015/1670 train_time:89651ms step_avg:88.33ms +step:1016/1670 train_time:89740ms step_avg:88.33ms +step:1017/1670 train_time:89829ms step_avg:88.33ms +step:1018/1670 train_time:89916ms step_avg:88.33ms +step:1019/1670 train_time:90004ms step_avg:88.33ms +step:1020/1670 train_time:90093ms step_avg:88.33ms +step:1021/1670 train_time:90182ms step_avg:88.33ms +step:1022/1670 train_time:90271ms step_avg:88.33ms +step:1023/1670 train_time:90361ms step_avg:88.33ms +step:1024/1670 train_time:90452ms step_avg:88.33ms +step:1025/1670 train_time:90542ms step_avg:88.33ms +step:1026/1670 train_time:90631ms step_avg:88.33ms +step:1027/1670 train_time:90721ms step_avg:88.34ms +step:1028/1670 train_time:90809ms step_avg:88.34ms +step:1029/1670 train_time:90898ms step_avg:88.34ms +step:1030/1670 train_time:90986ms step_avg:88.34ms +step:1031/1670 train_time:91074ms step_avg:88.34ms +step:1032/1670 train_time:91162ms step_avg:88.34ms +step:1033/1670 train_time:91251ms step_avg:88.34ms +step:1034/1670 train_time:91341ms step_avg:88.34ms +step:1035/1670 train_time:91431ms step_avg:88.34ms +step:1036/1670 train_time:91521ms step_avg:88.34ms +step:1037/1670 train_time:91610ms step_avg:88.34ms +step:1038/1670 train_time:91700ms step_avg:88.34ms +step:1039/1670 train_time:91788ms step_avg:88.34ms +step:1040/1670 train_time:91878ms step_avg:88.34ms +step:1041/1670 train_time:91965ms step_avg:88.34ms +step:1042/1670 train_time:92054ms step_avg:88.34ms +step:1043/1670 train_time:92142ms step_avg:88.34ms +step:1044/1670 train_time:92231ms step_avg:88.34ms +step:1045/1670 train_time:92320ms step_avg:88.34ms +step:1046/1670 train_time:92409ms step_avg:88.35ms +step:1047/1670 train_time:92500ms step_avg:88.35ms +step:1048/1670 train_time:92589ms step_avg:88.35ms +step:1049/1670 train_time:92679ms step_avg:88.35ms +step:1050/1670 train_time:92767ms step_avg:88.35ms +step:1051/1670 train_time:92856ms step_avg:88.35ms +step:1052/1670 train_time:92944ms step_avg:88.35ms +step:1053/1670 train_time:93033ms step_avg:88.35ms +step:1054/1670 train_time:93122ms step_avg:88.35ms +step:1055/1670 train_time:93211ms step_avg:88.35ms +step:1056/1670 train_time:93300ms step_avg:88.35ms +step:1057/1670 train_time:93389ms step_avg:88.35ms +step:1058/1670 train_time:93478ms step_avg:88.35ms +step:1059/1670 train_time:93567ms step_avg:88.35ms +step:1060/1670 train_time:93656ms step_avg:88.36ms +step:1061/1670 train_time:93745ms step_avg:88.36ms +step:1062/1670 train_time:93833ms step_avg:88.36ms +step:1063/1670 train_time:93923ms step_avg:88.36ms +step:1064/1670 train_time:94013ms step_avg:88.36ms +step:1065/1670 train_time:94102ms step_avg:88.36ms +step:1066/1670 train_time:94191ms step_avg:88.36ms +step:1067/1670 train_time:94281ms step_avg:88.36ms +step:1068/1670 train_time:94370ms step_avg:88.36ms +step:1069/1670 train_time:94459ms step_avg:88.36ms +step:1070/1670 train_time:94548ms step_avg:88.36ms +step:1071/1670 train_time:94637ms step_avg:88.36ms +step:1072/1670 train_time:94726ms step_avg:88.36ms +step:1073/1670 train_time:94815ms step_avg:88.36ms +step:1074/1670 train_time:94904ms step_avg:88.36ms +step:1075/1670 train_time:94993ms step_avg:88.37ms +step:1076/1670 train_time:95083ms step_avg:88.37ms +step:1077/1670 train_time:95171ms step_avg:88.37ms +step:1078/1670 train_time:95261ms step_avg:88.37ms +step:1079/1670 train_time:95350ms step_avg:88.37ms +step:1080/1670 train_time:95439ms step_avg:88.37ms +step:1081/1670 train_time:95527ms step_avg:88.37ms +step:1082/1670 train_time:95616ms step_avg:88.37ms +step:1083/1670 train_time:95705ms step_avg:88.37ms +step:1084/1670 train_time:95794ms step_avg:88.37ms +step:1085/1670 train_time:95883ms step_avg:88.37ms +step:1086/1670 train_time:95972ms step_avg:88.37ms +step:1087/1670 train_time:96061ms step_avg:88.37ms +step:1088/1670 train_time:96150ms step_avg:88.37ms +step:1089/1670 train_time:96240ms step_avg:88.37ms +step:1090/1670 train_time:96330ms step_avg:88.38ms +step:1091/1670 train_time:96421ms step_avg:88.38ms +step:1092/1670 train_time:96511ms step_avg:88.38ms +step:1093/1670 train_time:96601ms step_avg:88.38ms +step:1094/1670 train_time:96691ms step_avg:88.38ms +step:1095/1670 train_time:96781ms step_avg:88.38ms +step:1096/1670 train_time:96871ms step_avg:88.39ms +step:1097/1670 train_time:96960ms step_avg:88.39ms +step:1098/1670 train_time:97049ms step_avg:88.39ms +step:1099/1670 train_time:97139ms step_avg:88.39ms +step:1100/1670 train_time:97228ms step_avg:88.39ms +step:1101/1670 train_time:97318ms step_avg:88.39ms +step:1102/1670 train_time:97407ms step_avg:88.39ms +step:1103/1670 train_time:97498ms step_avg:88.39ms +step:1104/1670 train_time:97588ms step_avg:88.39ms +step:1105/1670 train_time:97677ms step_avg:88.40ms +step:1106/1670 train_time:97767ms step_avg:88.40ms +step:1107/1670 train_time:97856ms step_avg:88.40ms +step:1108/1670 train_time:97946ms step_avg:88.40ms +step:1109/1670 train_time:98036ms step_avg:88.40ms +step:1110/1670 train_time:98125ms step_avg:88.40ms +step:1111/1670 train_time:98215ms step_avg:88.40ms +step:1112/1670 train_time:98304ms step_avg:88.40ms +step:1113/1670 train_time:98394ms step_avg:88.40ms +step:1114/1670 train_time:98484ms step_avg:88.41ms +step:1115/1670 train_time:98574ms step_avg:88.41ms +step:1116/1670 train_time:98663ms step_avg:88.41ms +step:1117/1670 train_time:98753ms step_avg:88.41ms +step:1118/1670 train_time:98842ms step_avg:88.41ms +step:1119/1670 train_time:98931ms step_avg:88.41ms +step:1120/1670 train_time:99021ms step_avg:88.41ms +step:1121/1670 train_time:99111ms step_avg:88.41ms +step:1122/1670 train_time:99201ms step_avg:88.41ms +step:1123/1670 train_time:99291ms step_avg:88.42ms +step:1124/1670 train_time:99381ms step_avg:88.42ms +step:1125/1670 train_time:99471ms step_avg:88.42ms +step:1125/1670 val_loss:3.4156 train_time:99562ms step_avg:88.50ms +step:1126/1670 train_time:99582ms step_avg:88.44ms +step:1127/1670 train_time:99656ms step_avg:88.43ms +step:1128/1670 train_time:99746ms step_avg:88.43ms +step:1129/1670 train_time:99838ms step_avg:88.43ms +step:1130/1670 train_time:99927ms step_avg:88.43ms +step:1131/1670 train_time:100015ms step_avg:88.43ms +step:1132/1670 train_time:100104ms step_avg:88.43ms +step:1133/1670 train_time:100191ms step_avg:88.43ms +step:1134/1670 train_time:100279ms step_avg:88.43ms +step:1135/1670 train_time:100368ms step_avg:88.43ms +step:1136/1670 train_time:100463ms step_avg:88.44ms +step:1137/1670 train_time:100556ms step_avg:88.44ms +step:1138/1670 train_time:100648ms step_avg:88.44ms +step:1139/1670 train_time:100738ms step_avg:88.44ms +step:1140/1670 train_time:100828ms step_avg:88.45ms +step:1141/1670 train_time:100917ms step_avg:88.45ms +step:1142/1670 train_time:101005ms step_avg:88.45ms +step:1143/1670 train_time:101094ms step_avg:88.45ms +step:1144/1670 train_time:101182ms step_avg:88.45ms +step:1145/1670 train_time:101270ms step_avg:88.45ms +step:1146/1670 train_time:101360ms step_avg:88.45ms +step:1147/1670 train_time:101452ms step_avg:88.45ms +step:1148/1670 train_time:101544ms step_avg:88.45ms +step:1149/1670 train_time:101634ms step_avg:88.45ms +step:1150/1670 train_time:101725ms step_avg:88.46ms +step:1151/1670 train_time:101815ms step_avg:88.46ms +step:1152/1670 train_time:101905ms step_avg:88.46ms +step:1153/1670 train_time:101993ms step_avg:88.46ms +step:1154/1670 train_time:102082ms step_avg:88.46ms +step:1155/1670 train_time:102170ms step_avg:88.46ms +step:1156/1670 train_time:102258ms step_avg:88.46ms +step:1157/1670 train_time:102348ms step_avg:88.46ms +step:1158/1670 train_time:102440ms step_avg:88.46ms +step:1159/1670 train_time:102529ms step_avg:88.46ms +step:1160/1670 train_time:102620ms step_avg:88.47ms +step:1161/1670 train_time:102711ms step_avg:88.47ms +step:1162/1670 train_time:102802ms step_avg:88.47ms +step:1163/1670 train_time:102890ms step_avg:88.47ms +step:1164/1670 train_time:102979ms step_avg:88.47ms +step:1165/1670 train_time:103069ms step_avg:88.47ms +step:1166/1670 train_time:103158ms step_avg:88.47ms +step:1167/1670 train_time:103246ms step_avg:88.47ms +step:1168/1670 train_time:103335ms step_avg:88.47ms +step:1169/1670 train_time:103426ms step_avg:88.47ms +step:1170/1670 train_time:103515ms step_avg:88.47ms +step:1171/1670 train_time:103607ms step_avg:88.48ms +step:1172/1670 train_time:103698ms step_avg:88.48ms +step:1173/1670 train_time:103788ms step_avg:88.48ms +step:1174/1670 train_time:103878ms step_avg:88.48ms +step:1175/1670 train_time:103968ms step_avg:88.48ms +step:1176/1670 train_time:104059ms step_avg:88.49ms +step:1177/1670 train_time:104147ms step_avg:88.49ms +step:1178/1670 train_time:104236ms step_avg:88.49ms +step:1179/1670 train_time:104325ms step_avg:88.49ms +step:1180/1670 train_time:104415ms step_avg:88.49ms +step:1181/1670 train_time:104506ms step_avg:88.49ms +step:1182/1670 train_time:104596ms step_avg:88.49ms +step:1183/1670 train_time:104687ms step_avg:88.49ms +step:1184/1670 train_time:104777ms step_avg:88.49ms +step:1185/1670 train_time:104867ms step_avg:88.50ms +step:1186/1670 train_time:104956ms step_avg:88.50ms +step:1187/1670 train_time:105046ms step_avg:88.50ms +step:1188/1670 train_time:105135ms step_avg:88.50ms +step:1189/1670 train_time:105225ms step_avg:88.50ms +step:1190/1670 train_time:105314ms step_avg:88.50ms +step:1191/1670 train_time:105404ms step_avg:88.50ms +step:1192/1670 train_time:105494ms step_avg:88.50ms +step:1193/1670 train_time:105585ms step_avg:88.50ms +step:1194/1670 train_time:105676ms step_avg:88.51ms +step:1195/1670 train_time:105766ms step_avg:88.51ms +step:1196/1670 train_time:105855ms step_avg:88.51ms +step:1197/1670 train_time:105945ms step_avg:88.51ms +step:1198/1670 train_time:106035ms step_avg:88.51ms +step:1199/1670 train_time:106124ms step_avg:88.51ms +step:1200/1670 train_time:106213ms step_avg:88.51ms +step:1201/1670 train_time:106303ms step_avg:88.51ms +step:1202/1670 train_time:106392ms step_avg:88.51ms +step:1203/1670 train_time:106481ms step_avg:88.51ms +step:1204/1670 train_time:106570ms step_avg:88.51ms +step:1205/1670 train_time:106661ms step_avg:88.51ms +step:1206/1670 train_time:106751ms step_avg:88.52ms +step:1207/1670 train_time:106840ms step_avg:88.52ms +step:1208/1670 train_time:106930ms step_avg:88.52ms +step:1209/1670 train_time:107020ms step_avg:88.52ms +step:1210/1670 train_time:107110ms step_avg:88.52ms +step:1211/1670 train_time:107199ms step_avg:88.52ms +step:1212/1670 train_time:107289ms step_avg:88.52ms +step:1213/1670 train_time:107379ms step_avg:88.52ms +step:1214/1670 train_time:107469ms step_avg:88.52ms +step:1215/1670 train_time:107559ms step_avg:88.53ms +step:1216/1670 train_time:107649ms step_avg:88.53ms +step:1217/1670 train_time:107739ms step_avg:88.53ms +step:1218/1670 train_time:107829ms step_avg:88.53ms +step:1219/1670 train_time:107918ms step_avg:88.53ms +step:1220/1670 train_time:108008ms step_avg:88.53ms +step:1221/1670 train_time:108098ms step_avg:88.53ms +step:1222/1670 train_time:108188ms step_avg:88.53ms +step:1223/1670 train_time:108278ms step_avg:88.53ms +step:1224/1670 train_time:108367ms step_avg:88.54ms +step:1225/1670 train_time:108457ms step_avg:88.54ms +step:1226/1670 train_time:108547ms step_avg:88.54ms +step:1227/1670 train_time:108637ms step_avg:88.54ms +step:1228/1670 train_time:108727ms step_avg:88.54ms +step:1229/1670 train_time:108816ms step_avg:88.54ms +step:1230/1670 train_time:108906ms step_avg:88.54ms +step:1231/1670 train_time:108995ms step_avg:88.54ms +step:1232/1670 train_time:109085ms step_avg:88.54ms +step:1233/1670 train_time:109174ms step_avg:88.54ms +step:1234/1670 train_time:109264ms step_avg:88.54ms +step:1235/1670 train_time:109353ms step_avg:88.54ms +step:1236/1670 train_time:109443ms step_avg:88.55ms +step:1237/1670 train_time:109533ms step_avg:88.55ms +step:1238/1670 train_time:109622ms step_avg:88.55ms +step:1239/1670 train_time:109711ms step_avg:88.55ms +step:1240/1670 train_time:109802ms step_avg:88.55ms +step:1241/1670 train_time:109891ms step_avg:88.55ms +step:1242/1670 train_time:109981ms step_avg:88.55ms +step:1243/1670 train_time:110071ms step_avg:88.55ms +step:1244/1670 train_time:110161ms step_avg:88.55ms +step:1245/1670 train_time:110251ms step_avg:88.55ms +step:1246/1670 train_time:110340ms step_avg:88.56ms +step:1247/1670 train_time:110429ms step_avg:88.56ms +step:1248/1670 train_time:110520ms step_avg:88.56ms +step:1249/1670 train_time:110609ms step_avg:88.56ms +step:1250/1670 train_time:110700ms step_avg:88.56ms +step:1250/1670 val_loss:3.3769 train_time:110790ms step_avg:88.63ms +step:1251/1670 train_time:110810ms step_avg:88.58ms +step:1252/1670 train_time:110883ms step_avg:88.56ms +step:1253/1670 train_time:110975ms step_avg:88.57ms +step:1254/1670 train_time:111064ms step_avg:88.57ms +step:1255/1670 train_time:111152ms step_avg:88.57ms +step:1256/1670 train_time:111241ms step_avg:88.57ms +step:1257/1670 train_time:111329ms step_avg:88.57ms +step:1258/1670 train_time:111417ms step_avg:88.57ms +step:1259/1670 train_time:111508ms step_avg:88.57ms +step:1260/1670 train_time:111597ms step_avg:88.57ms +step:1261/1670 train_time:111686ms step_avg:88.57ms +step:1262/1670 train_time:111780ms step_avg:88.57ms +step:1263/1670 train_time:111872ms step_avg:88.58ms +step:1264/1670 train_time:111963ms step_avg:88.58ms +step:1265/1670 train_time:112053ms step_avg:88.58ms +step:1266/1670 train_time:112142ms step_avg:88.58ms +step:1267/1670 train_time:112232ms step_avg:88.58ms +step:1268/1670 train_time:112320ms step_avg:88.58ms +step:1269/1670 train_time:112409ms step_avg:88.58ms +step:1270/1670 train_time:112497ms step_avg:88.58ms +step:1271/1670 train_time:112586ms step_avg:88.58ms +step:1272/1670 train_time:112676ms step_avg:88.58ms +step:1273/1670 train_time:112767ms step_avg:88.58ms +step:1274/1670 train_time:112858ms step_avg:88.59ms +step:1275/1670 train_time:112949ms step_avg:88.59ms +step:1276/1670 train_time:113039ms step_avg:88.59ms +step:1277/1670 train_time:113128ms step_avg:88.59ms +step:1278/1670 train_time:113217ms step_avg:88.59ms +step:1279/1670 train_time:113306ms step_avg:88.59ms +step:1280/1670 train_time:113396ms step_avg:88.59ms +step:1281/1670 train_time:113486ms step_avg:88.59ms +step:1282/1670 train_time:113575ms step_avg:88.59ms +step:1283/1670 train_time:113665ms step_avg:88.59ms +step:1284/1670 train_time:113755ms step_avg:88.59ms +step:1285/1670 train_time:113846ms step_avg:88.60ms +step:1286/1670 train_time:113937ms step_avg:88.60ms +step:1287/1670 train_time:114028ms step_avg:88.60ms +step:1288/1670 train_time:114117ms step_avg:88.60ms +step:1289/1670 train_time:114207ms step_avg:88.60ms +step:1290/1670 train_time:114297ms step_avg:88.60ms +step:1291/1670 train_time:114386ms step_avg:88.60ms +step:1292/1670 train_time:114476ms step_avg:88.60ms +step:1293/1670 train_time:114565ms step_avg:88.60ms +step:1294/1670 train_time:114655ms step_avg:88.60ms +step:1295/1670 train_time:114745ms step_avg:88.61ms +step:1296/1670 train_time:114836ms step_avg:88.61ms +step:1297/1670 train_time:114928ms step_avg:88.61ms +step:1298/1670 train_time:115017ms step_avg:88.61ms +step:1299/1670 train_time:115107ms step_avg:88.61ms +step:1300/1670 train_time:115197ms step_avg:88.61ms +step:1301/1670 train_time:115288ms step_avg:88.61ms +step:1302/1670 train_time:115376ms step_avg:88.61ms +step:1303/1670 train_time:115466ms step_avg:88.62ms +step:1304/1670 train_time:115555ms step_avg:88.62ms +step:1305/1670 train_time:115644ms step_avg:88.62ms +step:1306/1670 train_time:115734ms step_avg:88.62ms +step:1307/1670 train_time:115824ms step_avg:88.62ms +step:1308/1670 train_time:115915ms step_avg:88.62ms +step:1309/1670 train_time:116006ms step_avg:88.62ms +step:1310/1670 train_time:116096ms step_avg:88.62ms +step:1311/1670 train_time:116186ms step_avg:88.62ms +step:1312/1670 train_time:116275ms step_avg:88.62ms +step:1313/1670 train_time:116365ms step_avg:88.63ms +step:1314/1670 train_time:116454ms step_avg:88.63ms +step:1315/1670 train_time:116544ms step_avg:88.63ms +step:1316/1670 train_time:116634ms step_avg:88.63ms +step:1317/1670 train_time:116723ms step_avg:88.63ms +step:1318/1670 train_time:116813ms step_avg:88.63ms +step:1319/1670 train_time:116904ms step_avg:88.63ms +step:1320/1670 train_time:116996ms step_avg:88.63ms +step:1321/1670 train_time:117087ms step_avg:88.63ms +step:1322/1670 train_time:117176ms step_avg:88.64ms +step:1323/1670 train_time:117266ms step_avg:88.64ms +step:1324/1670 train_time:117355ms step_avg:88.64ms +step:1325/1670 train_time:117444ms step_avg:88.64ms +step:1326/1670 train_time:117533ms step_avg:88.64ms +step:1327/1670 train_time:117622ms step_avg:88.64ms +step:1328/1670 train_time:117712ms step_avg:88.64ms +step:1329/1670 train_time:117802ms step_avg:88.64ms +step:1330/1670 train_time:117893ms step_avg:88.64ms +step:1331/1670 train_time:117984ms step_avg:88.64ms +step:1332/1670 train_time:118074ms step_avg:88.64ms +step:1333/1670 train_time:118164ms step_avg:88.65ms +step:1334/1670 train_time:118253ms step_avg:88.65ms +step:1335/1670 train_time:118343ms step_avg:88.65ms +step:1336/1670 train_time:118432ms step_avg:88.65ms +step:1337/1670 train_time:118522ms step_avg:88.65ms +step:1338/1670 train_time:118612ms step_avg:88.65ms +step:1339/1670 train_time:118702ms step_avg:88.65ms +step:1340/1670 train_time:118792ms step_avg:88.65ms +step:1341/1670 train_time:118881ms step_avg:88.65ms +step:1342/1670 train_time:118971ms step_avg:88.65ms +step:1343/1670 train_time:119062ms step_avg:88.65ms +step:1344/1670 train_time:119152ms step_avg:88.65ms +step:1345/1670 train_time:119242ms step_avg:88.66ms +step:1346/1670 train_time:119332ms step_avg:88.66ms +step:1347/1670 train_time:119421ms step_avg:88.66ms +step:1348/1670 train_time:119511ms step_avg:88.66ms +step:1349/1670 train_time:119600ms step_avg:88.66ms +step:1350/1670 train_time:119690ms step_avg:88.66ms +step:1351/1670 train_time:119779ms step_avg:88.66ms +step:1352/1670 train_time:119870ms step_avg:88.66ms +step:1353/1670 train_time:119961ms step_avg:88.66ms +step:1354/1670 train_time:120051ms step_avg:88.66ms +step:1355/1670 train_time:120140ms step_avg:88.66ms +step:1356/1670 train_time:120230ms step_avg:88.67ms +step:1357/1670 train_time:120319ms step_avg:88.67ms +step:1358/1670 train_time:120408ms step_avg:88.67ms +step:1359/1670 train_time:120497ms step_avg:88.67ms +step:1360/1670 train_time:120586ms step_avg:88.67ms +step:1361/1670 train_time:120675ms step_avg:88.67ms +step:1362/1670 train_time:120765ms step_avg:88.67ms +step:1363/1670 train_time:120855ms step_avg:88.67ms +step:1364/1670 train_time:120945ms step_avg:88.67ms +step:1365/1670 train_time:121035ms step_avg:88.67ms +step:1366/1670 train_time:121125ms step_avg:88.67ms +step:1367/1670 train_time:121216ms step_avg:88.67ms +step:1368/1670 train_time:121306ms step_avg:88.67ms +step:1369/1670 train_time:121396ms step_avg:88.67ms +step:1370/1670 train_time:121485ms step_avg:88.68ms +step:1371/1670 train_time:121574ms step_avg:88.68ms +step:1372/1670 train_time:121663ms step_avg:88.68ms +step:1373/1670 train_time:121753ms step_avg:88.68ms +step:1374/1670 train_time:121843ms step_avg:88.68ms +step:1375/1670 train_time:121935ms step_avg:88.68ms +step:1375/1670 val_loss:3.3421 train_time:122026ms step_avg:88.75ms +step:1376/1670 train_time:122046ms step_avg:88.70ms +step:1377/1670 train_time:122118ms step_avg:88.68ms +step:1378/1670 train_time:122209ms step_avg:88.69ms +step:1379/1670 train_time:122298ms step_avg:88.69ms +step:1380/1670 train_time:122386ms step_avg:88.69ms +step:1381/1670 train_time:122475ms step_avg:88.69ms +step:1382/1670 train_time:122563ms step_avg:88.69ms +step:1383/1670 train_time:122654ms step_avg:88.69ms +step:1384/1670 train_time:122743ms step_avg:88.69ms +step:1385/1670 train_time:122833ms step_avg:88.69ms +step:1386/1670 train_time:122923ms step_avg:88.69ms +step:1387/1670 train_time:123016ms step_avg:88.69ms +step:1388/1670 train_time:123108ms step_avg:88.69ms +step:1389/1670 train_time:123200ms step_avg:88.70ms +step:1390/1670 train_time:123290ms step_avg:88.70ms +step:1391/1670 train_time:123379ms step_avg:88.70ms +step:1392/1670 train_time:123468ms step_avg:88.70ms +step:1393/1670 train_time:123557ms step_avg:88.70ms +step:1394/1670 train_time:123645ms step_avg:88.70ms +step:1395/1670 train_time:123735ms step_avg:88.70ms +step:1396/1670 train_time:123824ms step_avg:88.70ms +step:1397/1670 train_time:123914ms step_avg:88.70ms +step:1398/1670 train_time:124005ms step_avg:88.70ms +step:1399/1670 train_time:124096ms step_avg:88.70ms +step:1400/1670 train_time:124187ms step_avg:88.71ms +step:1401/1670 train_time:124277ms step_avg:88.71ms +step:1402/1670 train_time:124366ms step_avg:88.71ms +step:1403/1670 train_time:124456ms step_avg:88.71ms +step:1404/1670 train_time:124546ms step_avg:88.71ms +step:1405/1670 train_time:124635ms step_avg:88.71ms +step:1406/1670 train_time:124723ms step_avg:88.71ms +step:1407/1670 train_time:124814ms step_avg:88.71ms +step:1408/1670 train_time:124903ms step_avg:88.71ms +step:1409/1670 train_time:124995ms step_avg:88.71ms +step:1410/1670 train_time:125085ms step_avg:88.71ms +step:1411/1670 train_time:125175ms step_avg:88.71ms +step:1412/1670 train_time:125265ms step_avg:88.71ms +step:1413/1670 train_time:125356ms step_avg:88.72ms +step:1414/1670 train_time:125445ms step_avg:88.72ms +step:1415/1670 train_time:125535ms step_avg:88.72ms +step:1416/1670 train_time:125624ms step_avg:88.72ms +step:1417/1670 train_time:125714ms step_avg:88.72ms +step:1418/1670 train_time:125803ms step_avg:88.72ms +step:1419/1670 train_time:125894ms step_avg:88.72ms +step:1420/1670 train_time:125983ms step_avg:88.72ms +step:1421/1670 train_time:126073ms step_avg:88.72ms +step:1422/1670 train_time:126164ms step_avg:88.72ms +step:1423/1670 train_time:126255ms step_avg:88.72ms +step:1424/1670 train_time:126344ms step_avg:88.72ms +step:1425/1670 train_time:126434ms step_avg:88.73ms +step:1426/1670 train_time:126524ms step_avg:88.73ms +step:1427/1670 train_time:126614ms step_avg:88.73ms +step:1428/1670 train_time:126703ms step_avg:88.73ms +step:1429/1670 train_time:126793ms step_avg:88.73ms +step:1430/1670 train_time:126882ms step_avg:88.73ms +step:1431/1670 train_time:126972ms step_avg:88.73ms +step:1432/1670 train_time:127062ms step_avg:88.73ms +step:1433/1670 train_time:127154ms step_avg:88.73ms +step:1434/1670 train_time:127244ms step_avg:88.73ms +step:1435/1670 train_time:127335ms step_avg:88.74ms +step:1436/1670 train_time:127425ms step_avg:88.74ms +step:1437/1670 train_time:127515ms step_avg:88.74ms +step:1438/1670 train_time:127603ms step_avg:88.74ms +step:1439/1670 train_time:127692ms step_avg:88.74ms +step:1440/1670 train_time:127782ms step_avg:88.74ms +step:1441/1670 train_time:127871ms step_avg:88.74ms +step:1442/1670 train_time:127962ms step_avg:88.74ms +step:1443/1670 train_time:128051ms step_avg:88.74ms +step:1444/1670 train_time:128141ms step_avg:88.74ms +step:1445/1670 train_time:128232ms step_avg:88.74ms +step:1446/1670 train_time:128323ms step_avg:88.74ms +step:1447/1670 train_time:128414ms step_avg:88.74ms +step:1448/1670 train_time:128503ms step_avg:88.75ms +step:1449/1670 train_time:128593ms step_avg:88.75ms +step:1450/1670 train_time:128682ms step_avg:88.75ms +step:1451/1670 train_time:128772ms step_avg:88.75ms +step:1452/1670 train_time:128862ms step_avg:88.75ms +step:1453/1670 train_time:128952ms step_avg:88.75ms +step:1454/1670 train_time:129042ms step_avg:88.75ms +step:1455/1670 train_time:129133ms step_avg:88.75ms +step:1456/1670 train_time:129225ms step_avg:88.75ms +step:1457/1670 train_time:129316ms step_avg:88.75ms +step:1458/1670 train_time:129405ms step_avg:88.76ms +step:1459/1670 train_time:129495ms step_avg:88.76ms +step:1460/1670 train_time:129585ms step_avg:88.76ms +step:1461/1670 train_time:129674ms step_avg:88.76ms +step:1462/1670 train_time:129764ms step_avg:88.76ms +step:1463/1670 train_time:129854ms step_avg:88.76ms +step:1464/1670 train_time:129944ms step_avg:88.76ms +step:1465/1670 train_time:130034ms step_avg:88.76ms +step:1466/1670 train_time:130124ms step_avg:88.76ms +step:1467/1670 train_time:130215ms step_avg:88.76ms +step:1468/1670 train_time:130305ms step_avg:88.76ms +step:1469/1670 train_time:130395ms step_avg:88.76ms +step:1470/1670 train_time:130484ms step_avg:88.76ms +step:1471/1670 train_time:130573ms step_avg:88.77ms +step:1472/1670 train_time:130664ms step_avg:88.77ms +step:1473/1670 train_time:130754ms step_avg:88.77ms +step:1474/1670 train_time:130844ms step_avg:88.77ms +step:1475/1670 train_time:130934ms step_avg:88.77ms +step:1476/1670 train_time:131024ms step_avg:88.77ms +step:1477/1670 train_time:131114ms step_avg:88.77ms +step:1478/1670 train_time:131204ms step_avg:88.77ms +step:1479/1670 train_time:131295ms step_avg:88.77ms +step:1480/1670 train_time:131384ms step_avg:88.77ms +step:1481/1670 train_time:131473ms step_avg:88.77ms +step:1482/1670 train_time:131564ms step_avg:88.77ms +step:1483/1670 train_time:131653ms step_avg:88.78ms +step:1484/1670 train_time:131743ms step_avg:88.78ms +step:1485/1670 train_time:131833ms step_avg:88.78ms +step:1486/1670 train_time:131923ms step_avg:88.78ms +step:1487/1670 train_time:132013ms step_avg:88.78ms +step:1488/1670 train_time:132103ms step_avg:88.78ms +step:1489/1670 train_time:132194ms step_avg:88.78ms +step:1490/1670 train_time:132285ms step_avg:88.78ms +step:1491/1670 train_time:132374ms step_avg:88.78ms +step:1492/1670 train_time:132463ms step_avg:88.78ms +step:1493/1670 train_time:132554ms step_avg:88.78ms +step:1494/1670 train_time:132643ms step_avg:88.78ms +step:1495/1670 train_time:132733ms step_avg:88.78ms +step:1496/1670 train_time:132823ms step_avg:88.79ms +step:1497/1670 train_time:132913ms step_avg:88.79ms +step:1498/1670 train_time:133002ms step_avg:88.79ms +step:1499/1670 train_time:133092ms step_avg:88.79ms +step:1500/1670 train_time:133182ms step_avg:88.79ms +step:1500/1670 val_loss:3.3124 train_time:133273ms step_avg:88.85ms +step:1501/1670 train_time:133292ms step_avg:88.80ms +step:1502/1670 train_time:133367ms step_avg:88.79ms +step:1503/1670 train_time:133459ms step_avg:88.80ms +step:1504/1670 train_time:133550ms step_avg:88.80ms +step:1505/1670 train_time:133639ms step_avg:88.80ms +step:1506/1670 train_time:133728ms step_avg:88.80ms +step:1507/1670 train_time:133816ms step_avg:88.80ms +step:1508/1670 train_time:133904ms step_avg:88.80ms +step:1509/1670 train_time:133993ms step_avg:88.80ms +step:1510/1670 train_time:134082ms step_avg:88.80ms +step:1511/1670 train_time:134171ms step_avg:88.80ms +step:1512/1670 train_time:134263ms step_avg:88.80ms +step:1513/1670 train_time:134356ms step_avg:88.80ms +step:1514/1670 train_time:134450ms step_avg:88.80ms +step:1515/1670 train_time:134539ms step_avg:88.80ms +step:1516/1670 train_time:134629ms step_avg:88.81ms +step:1517/1670 train_time:134719ms step_avg:88.81ms +step:1518/1670 train_time:134808ms step_avg:88.81ms +step:1519/1670 train_time:134896ms step_avg:88.81ms +step:1520/1670 train_time:134985ms step_avg:88.81ms +step:1521/1670 train_time:135074ms step_avg:88.81ms +step:1522/1670 train_time:135164ms step_avg:88.81ms +step:1523/1670 train_time:135256ms step_avg:88.81ms +step:1524/1670 train_time:135348ms step_avg:88.81ms +step:1525/1670 train_time:135439ms step_avg:88.81ms +step:1526/1670 train_time:135530ms step_avg:88.81ms +step:1527/1670 train_time:135620ms step_avg:88.81ms +step:1528/1670 train_time:135710ms step_avg:88.82ms +step:1529/1670 train_time:135799ms step_avg:88.82ms +step:1530/1670 train_time:135888ms step_avg:88.82ms +step:1531/1670 train_time:135977ms step_avg:88.82ms +step:1532/1670 train_time:136066ms step_avg:88.82ms +step:1533/1670 train_time:136156ms step_avg:88.82ms +step:1534/1670 train_time:136246ms step_avg:88.82ms +step:1535/1670 train_time:136336ms step_avg:88.82ms +step:1536/1670 train_time:136427ms step_avg:88.82ms +step:1537/1670 train_time:136517ms step_avg:88.82ms +step:1538/1670 train_time:136607ms step_avg:88.82ms +step:1539/1670 train_time:136696ms step_avg:88.82ms +step:1540/1670 train_time:136786ms step_avg:88.82ms +step:1541/1670 train_time:136875ms step_avg:88.82ms +step:1542/1670 train_time:136965ms step_avg:88.82ms +step:1543/1670 train_time:137053ms step_avg:88.82ms +step:1544/1670 train_time:137143ms step_avg:88.82ms +step:1545/1670 train_time:137233ms step_avg:88.82ms +step:1546/1670 train_time:137323ms step_avg:88.82ms +step:1547/1670 train_time:137414ms step_avg:88.83ms +step:1548/1670 train_time:137505ms step_avg:88.83ms +step:1549/1670 train_time:137594ms step_avg:88.83ms +step:1550/1670 train_time:137685ms step_avg:88.83ms +step:1551/1670 train_time:137774ms step_avg:88.83ms +step:1552/1670 train_time:137864ms step_avg:88.83ms +step:1553/1670 train_time:137953ms step_avg:88.83ms +step:1554/1670 train_time:138043ms step_avg:88.83ms +step:1555/1670 train_time:138132ms step_avg:88.83ms +step:1556/1670 train_time:138222ms step_avg:88.83ms +step:1557/1670 train_time:138312ms step_avg:88.83ms +step:1558/1670 train_time:138403ms step_avg:88.83ms +step:1559/1670 train_time:138493ms step_avg:88.83ms +step:1560/1670 train_time:138584ms step_avg:88.84ms +step:1561/1670 train_time:138673ms step_avg:88.84ms +step:1562/1670 train_time:138763ms step_avg:88.84ms +step:1563/1670 train_time:138853ms step_avg:88.84ms +step:1564/1670 train_time:138942ms step_avg:88.84ms +step:1565/1670 train_time:139032ms step_avg:88.84ms +step:1566/1670 train_time:139121ms step_avg:88.84ms +step:1567/1670 train_time:139212ms step_avg:88.84ms +step:1568/1670 train_time:139302ms step_avg:88.84ms +step:1569/1670 train_time:139392ms step_avg:88.84ms +step:1570/1670 train_time:139484ms step_avg:88.84ms +step:1571/1670 train_time:139573ms step_avg:88.84ms +step:1572/1670 train_time:139663ms step_avg:88.84ms +step:1573/1670 train_time:139752ms step_avg:88.84ms +step:1574/1670 train_time:139843ms step_avg:88.85ms +step:1575/1670 train_time:139932ms step_avg:88.85ms +step:1576/1670 train_time:140022ms step_avg:88.85ms +step:1577/1670 train_time:140112ms step_avg:88.85ms +step:1578/1670 train_time:140201ms step_avg:88.85ms +step:1579/1670 train_time:140292ms step_avg:88.85ms +step:1580/1670 train_time:140382ms step_avg:88.85ms +step:1581/1670 train_time:140473ms step_avg:88.85ms +step:1582/1670 train_time:140564ms step_avg:88.85ms +step:1583/1670 train_time:140654ms step_avg:88.85ms +step:1584/1670 train_time:140744ms step_avg:88.85ms +step:1585/1670 train_time:140833ms step_avg:88.85ms +step:1586/1670 train_time:140923ms step_avg:88.85ms +step:1587/1670 train_time:141011ms step_avg:88.85ms +step:1588/1670 train_time:141101ms step_avg:88.85ms +step:1589/1670 train_time:141191ms step_avg:88.86ms +step:1590/1670 train_time:141280ms step_avg:88.86ms +step:1591/1670 train_time:141371ms step_avg:88.86ms +step:1592/1670 train_time:141462ms step_avg:88.86ms +step:1593/1670 train_time:141552ms step_avg:88.86ms +step:1594/1670 train_time:141642ms step_avg:88.86ms +step:1595/1670 train_time:141732ms step_avg:88.86ms +step:1596/1670 train_time:141822ms step_avg:88.86ms +step:1597/1670 train_time:141912ms step_avg:88.86ms +step:1598/1670 train_time:142001ms step_avg:88.86ms +step:1599/1670 train_time:142091ms step_avg:88.86ms +step:1600/1670 train_time:142180ms step_avg:88.86ms +step:1601/1670 train_time:142271ms step_avg:88.86ms +step:1602/1670 train_time:142361ms step_avg:88.86ms +step:1603/1670 train_time:142451ms step_avg:88.87ms +step:1604/1670 train_time:142541ms step_avg:88.87ms +step:1605/1670 train_time:142631ms step_avg:88.87ms +step:1606/1670 train_time:142721ms step_avg:88.87ms +step:1607/1670 train_time:142811ms step_avg:88.87ms +step:1608/1670 train_time:142901ms step_avg:88.87ms +step:1609/1670 train_time:142991ms step_avg:88.87ms +step:1610/1670 train_time:143081ms step_avg:88.87ms +step:1611/1670 train_time:143171ms step_avg:88.87ms +step:1612/1670 train_time:143261ms step_avg:88.87ms +step:1613/1670 train_time:143351ms step_avg:88.87ms +step:1614/1670 train_time:143442ms step_avg:88.87ms +step:1615/1670 train_time:143531ms step_avg:88.87ms +step:1616/1670 train_time:143622ms step_avg:88.87ms +step:1617/1670 train_time:143712ms step_avg:88.88ms +step:1618/1670 train_time:143802ms step_avg:88.88ms +step:1619/1670 train_time:143892ms step_avg:88.88ms +step:1620/1670 train_time:143981ms step_avg:88.88ms +step:1621/1670 train_time:144072ms step_avg:88.88ms +step:1622/1670 train_time:144161ms step_avg:88.88ms +step:1623/1670 train_time:144251ms step_avg:88.88ms +step:1624/1670 train_time:144342ms step_avg:88.88ms +step:1625/1670 train_time:144431ms step_avg:88.88ms +step:1625/1670 val_loss:3.2891 train_time:144522ms step_avg:88.94ms +step:1626/1670 train_time:144542ms step_avg:88.89ms +step:1627/1670 train_time:144617ms step_avg:88.89ms +step:1628/1670 train_time:144711ms step_avg:88.89ms +step:1629/1670 train_time:144801ms step_avg:88.89ms +step:1630/1670 train_time:144890ms step_avg:88.89ms +step:1631/1670 train_time:144979ms step_avg:88.89ms +step:1632/1670 train_time:145067ms step_avg:88.89ms +step:1633/1670 train_time:145156ms step_avg:88.89ms +step:1634/1670 train_time:145245ms step_avg:88.89ms +step:1635/1670 train_time:145335ms step_avg:88.89ms +step:1636/1670 train_time:145423ms step_avg:88.89ms +step:1637/1670 train_time:145515ms step_avg:88.89ms +step:1638/1670 train_time:145607ms step_avg:88.89ms +step:1639/1670 train_time:145699ms step_avg:88.89ms +step:1640/1670 train_time:145790ms step_avg:88.90ms +step:1641/1670 train_time:145880ms step_avg:88.90ms +step:1642/1670 train_time:145969ms step_avg:88.90ms +step:1643/1670 train_time:146058ms step_avg:88.90ms +step:1644/1670 train_time:146147ms step_avg:88.90ms +step:1645/1670 train_time:146236ms step_avg:88.90ms +step:1646/1670 train_time:146325ms step_avg:88.90ms +step:1647/1670 train_time:146415ms step_avg:88.90ms +step:1648/1670 train_time:146505ms step_avg:88.90ms +step:1649/1670 train_time:146597ms step_avg:88.90ms +step:1650/1670 train_time:146688ms step_avg:88.90ms +step:1651/1670 train_time:146780ms step_avg:88.90ms +step:1652/1670 train_time:146871ms step_avg:88.90ms +step:1653/1670 train_time:146960ms step_avg:88.90ms +step:1654/1670 train_time:147049ms step_avg:88.91ms +step:1655/1670 train_time:147138ms step_avg:88.91ms +step:1656/1670 train_time:147227ms step_avg:88.91ms +step:1657/1670 train_time:147316ms step_avg:88.91ms +step:1658/1670 train_time:147405ms step_avg:88.91ms +step:1659/1670 train_time:147496ms step_avg:88.91ms +step:1660/1670 train_time:147586ms step_avg:88.91ms +step:1661/1670 train_time:147679ms step_avg:88.91ms +step:1662/1670 train_time:147770ms step_avg:88.91ms +step:1663/1670 train_time:147860ms step_avg:88.91ms +step:1664/1670 train_time:147950ms step_avg:88.91ms +step:1665/1670 train_time:148039ms step_avg:88.91ms +step:1666/1670 train_time:148128ms step_avg:88.91ms +step:1667/1670 train_time:148217ms step_avg:88.91ms +step:1668/1670 train_time:148306ms step_avg:88.91ms +step:1669/1670 train_time:148395ms step_avg:88.91ms +step:1670/1670 train_time:148485ms step_avg:88.91ms +step:1670/1670 val_loss:3.2797 train_time:148578ms step_avg:88.97ms +peak memory allocated: 30760 MiB reserved: 45694 MiB diff --git a/records/092925_PolarExpress/README.md b/records/092925_PolarExpress/README.md new file mode 100644 index 000000000..7860f3108 --- /dev/null +++ b/records/092925_PolarExpress/README.md @@ -0,0 +1,85 @@ +# New record 09/29/25 + +This submission reflects all recent WR changes up to [PR#133](https://github.com/KellerJordan/modded-nanogpt/pull/133). + +The main improvement in this PR is using the [Polar Express](https://arxiv.org/pdf/2505.16932) +sign method in Muon instead of Newton-Schulz. This paper was designed with reference to ModdedNanoGPT so it was very easy to implement, +and I direct the reader to this paper directly for details. Using Polar Express, I've reduced the train steps by 10. + +The next change in this PR is packaging Flash Attention 3 via [Huggingface's Kernels](https://huggingface.co/docs/kernels/en/index). +This does not impact timing but should increase ease of development for anyone working on this project. + +## Timing and Validation + +This PR improves the final training by 10 steps, with no change in the time per step. + +``` +import scipy.stats +import torch + +losses = [3.2789, 3.2792, 3.2796, 3.2776, 3.2797, 3.2787, 3.2792] +times = [148.617, 148.580, 148.569, 148.653, 148.578, 148.542, 148.587] + +print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue) +# p=0.0045 + +print("losses:", torch.std_mean(torch.tensor(losses))) +# losses: (std=0.0007057, mean=3.2789857) + +print("time:", torch.std_mean(torch.tensor(times))) +# time: (std=0.0358076, mean=148.5894318) +``` + +You may notice that this PR shows a 0.2 second mean *increase* in timing over the result in PR#133. +However, that PR was timed on very fast machine. To demonstrate that this PR accurately represents +a decrease in train time, I timed PR#133 on the same machine as above: + +``` +import scipy.stats +import torch + +times = [149.714, 149.676, 149.659, 149.716, 149.649, 149.569, 149.521] + +print("time:", torch.std_mean(torch.tensor(times))) +# time: (std=0.0732, mean=149.6434) +``` + +Therefore, I believe that this PR represents at least a 1 second improvement. + +Thank you to Prime Intellect for compute credits, which made this PR possible. + +## Polar Express + +All credit to the original authors (Noah Amsel, David Persson, Christopher Musco, Robert M. Gower) +for discovery and implementation of this method. I adapted their code from https://github.com/NoahAmsel/PolarExpress/tree/main. + +I found optimal parameters with +- `num_iters=5`: each iteration adds about a second to train time +- `muon_lr=0.06`: I found bumping the Muon LR seems to perform slightly better +- `safety_factor=1.02`: hyperparameter for Polar Express coefficients + +Despite the paper explicitly referencing and showing improvements on Modded NanoGPT, +I was unable to replicate the level of success shown in this paper. However, it may +be possible to further tune parameters to achieve a better result. +Additionally, like [Cesista 2025](https://leloykun.github.io/ponder/muon-opt-coeffs/) I believe it may be more promising on the GPT Medium track. + +## Flash Attention 3 Huggingface Kernel + +A couple weeks ago, Flash Attention merged [ABI-stability](https://github.com/Dao-AILab/flash-attention/pull/1791) +into the main FA3 repo. This allows builds of Flash Attention on PyTorch nightlies after 08/30 to be compatible with each other. +Since [PR#118](https://github.com/KellerJordan/modded-nanogpt/pull/118), we have been using +[a variant](https://github.com/Dao-AILab/flash-attention/pull/1769) of FA3 by @Guilhermeleobas that is compatible with `torch.compile`. +I have written a [fork](https://github.com/varunneal/flash-attention/tree/stable) that combines +these changes and uploaded its build to Huggingface at https://huggingface.co/varunneal/flash-attention-3. + +I have modified training script to fetch these builds via Hugginface's `get_kernel`. +Therefore, it will no longer be needed for developers to manually build Flash Attention. + +I have packaged this kernel for both CUDA 12.6 and 12.8 for the following PyTorch versions: +- `2.8.0` +- `2.9` nightlies after 8/30 +- `2.10` nightlies + +Note that the actual build `.so` is identical for all Torch Nightly versions. + +This most recent record uses the same `2.10` nightly as PR#133. \ No newline at end of file diff --git a/records/092925_PolarExpress/df2b2147-08b3-48d4-8ac7-75d509768560.txt b/records/092925_PolarExpress/df2b2147-08b3-48d4-8ac7-75d509768560.txt new file mode 100644 index 000000000..6d872956f --- /dev/null +++ b/records/092925_PolarExpress/df2b2147-08b3-48d4-8ac7-75d509768560.txt @@ -0,0 +1,3252 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from kernels import get_kernel +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[ + Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size() == 8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1, 10, 16, 16] + assert len(params) == sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx + size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx < len(params) else 0 + if getattr(params[module_idx], 'module', 'none') == 'attn': + for p in params[module_idx:module_idx + chunk_size]: + assert getattr(params[module_idx], 'module', 'none') == 'attn' + batch = 4 * original_shape[0] + d1 = original_shape[1] + d2 = original_shape[2] // 4 + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = polar_express(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = polar_express(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim // 4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim // 4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim * 4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module = 'attn' + with torch.no_grad(): + self.qkvo_w.view(4, self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4, self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, + 3 * self.num_heads, + self.head_dim).chunk( + 3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module = 'mlp' + self.c_proj.module = 'mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, + max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim ** 0.5) / 448, w_s=2 ** -9, + grad_s=1 / 448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), # extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, + long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1, len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i < 11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens = tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter == 5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter += 1 + return starts, ends + + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, + align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1630 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if + p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999, step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations + args.iteration_extension: + return args.ws_validate // 2, args.ws_validate + x = min(step / (1 + args.num_iterations), 0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[ + step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long < ws_long: + model.yarn.reset() + ws_long = new_ws_long + model(inputs, targets, cum_seqlens, ws_long // 2, ws_long).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.yarn.reset() +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations + args.iteration_extension +ws_short, ws_long = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws_short, new_ws_long = get_ws(step) + if new_ws_long != ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, + grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.5.0 +Mon Sep 29 06:39:26 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 40C P0 130W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:149ms step_avg:149.45ms +step:2/1670 train_time:171ms step_avg:85.25ms +step:3/1670 train_time:234ms step_avg:78.14ms +step:4/1670 train_time:321ms step_avg:80.28ms +step:5/1670 train_time:405ms step_avg:81.09ms +step:6/1670 train_time:492ms step_avg:81.97ms +step:7/1670 train_time:578ms step_avg:82.61ms +step:8/1670 train_time:665ms step_avg:83.15ms +step:9/1670 train_time:752ms step_avg:83.59ms +step:10/1670 train_time:839ms step_avg:83.89ms +step:11/1670 train_time:926ms step_avg:84.15ms +step:12/1670 train_time:1016ms step_avg:84.70ms +step:13/1670 train_time:1111ms step_avg:85.46ms +step:14/1670 train_time:1200ms step_avg:85.69ms +step:15/1670 train_time:1288ms step_avg:85.84ms +step:16/1670 train_time:1375ms step_avg:85.94ms +step:17/1670 train_time:1462ms step_avg:86.02ms +step:18/1670 train_time:1550ms step_avg:86.13ms +step:19/1670 train_time:1638ms step_avg:86.20ms +step:20/1670 train_time:1724ms step_avg:86.22ms +step:21/1670 train_time:1811ms step_avg:86.25ms +step:22/1670 train_time:1898ms step_avg:86.29ms +step:23/1670 train_time:1988ms step_avg:86.42ms +step:24/1670 train_time:2077ms step_avg:86.53ms +step:25/1670 train_time:2165ms step_avg:86.62ms +step:26/1670 train_time:2255ms step_avg:86.72ms +step:27/1670 train_time:2344ms step_avg:86.81ms +step:28/1670 train_time:2433ms step_avg:86.89ms +step:29/1670 train_time:2521ms step_avg:86.92ms +step:30/1670 train_time:2608ms step_avg:86.92ms +step:31/1670 train_time:2694ms step_avg:86.91ms +step:32/1670 train_time:2781ms step_avg:86.91ms +step:33/1670 train_time:2868ms step_avg:86.92ms +step:34/1670 train_time:2956ms step_avg:86.94ms +step:35/1670 train_time:3044ms step_avg:86.98ms +step:36/1670 train_time:3133ms step_avg:87.03ms +step:37/1670 train_time:3222ms step_avg:87.07ms +step:38/1670 train_time:3312ms step_avg:87.15ms +step:39/1670 train_time:3400ms step_avg:87.18ms +step:40/1670 train_time:3488ms step_avg:87.20ms +step:41/1670 train_time:3575ms step_avg:87.20ms +step:42/1670 train_time:3663ms step_avg:87.21ms +step:43/1670 train_time:3750ms step_avg:87.22ms +step:44/1670 train_time:3838ms step_avg:87.22ms +step:45/1670 train_time:3925ms step_avg:87.22ms +step:46/1670 train_time:4013ms step_avg:87.25ms +step:47/1670 train_time:4101ms step_avg:87.26ms +step:48/1670 train_time:4190ms step_avg:87.30ms +step:49/1670 train_time:4277ms step_avg:87.28ms +step:50/1670 train_time:4366ms step_avg:87.31ms +step:51/1670 train_time:4454ms step_avg:87.34ms +step:52/1670 train_time:4542ms step_avg:87.34ms +step:53/1670 train_time:4629ms step_avg:87.34ms +step:54/1670 train_time:4716ms step_avg:87.33ms +step:55/1670 train_time:4803ms step_avg:87.33ms +step:56/1670 train_time:4891ms step_avg:87.33ms +step:57/1670 train_time:4978ms step_avg:87.34ms +step:58/1670 train_time:5066ms step_avg:87.35ms +step:59/1670 train_time:5155ms step_avg:87.37ms +step:60/1670 train_time:5244ms step_avg:87.40ms +step:61/1670 train_time:5333ms step_avg:87.43ms +step:62/1670 train_time:5421ms step_avg:87.43ms +step:63/1670 train_time:5508ms step_avg:87.44ms +step:64/1670 train_time:5596ms step_avg:87.43ms +step:65/1670 train_time:5683ms step_avg:87.43ms +step:66/1670 train_time:5771ms step_avg:87.44ms +step:67/1670 train_time:5859ms step_avg:87.44ms +step:68/1670 train_time:5947ms step_avg:87.45ms +step:69/1670 train_time:6035ms step_avg:87.46ms +step:70/1670 train_time:6122ms step_avg:87.46ms +step:71/1670 train_time:6210ms step_avg:87.46ms +step:72/1670 train_time:6298ms step_avg:87.47ms +step:73/1670 train_time:6387ms step_avg:87.49ms +step:74/1670 train_time:6475ms step_avg:87.50ms +step:75/1670 train_time:6563ms step_avg:87.51ms +step:76/1670 train_time:6651ms step_avg:87.51ms +step:77/1670 train_time:6739ms step_avg:87.52ms +step:78/1670 train_time:6827ms step_avg:87.52ms +step:79/1670 train_time:6914ms step_avg:87.52ms +step:80/1670 train_time:7001ms step_avg:87.51ms +step:81/1670 train_time:7089ms step_avg:87.52ms +step:82/1670 train_time:7176ms step_avg:87.52ms +step:83/1670 train_time:7264ms step_avg:87.52ms +step:84/1670 train_time:7353ms step_avg:87.53ms +step:85/1670 train_time:7440ms step_avg:87.53ms +step:86/1670 train_time:7528ms step_avg:87.53ms +step:87/1670 train_time:7615ms step_avg:87.53ms +step:88/1670 train_time:7703ms step_avg:87.53ms +step:89/1670 train_time:7791ms step_avg:87.54ms +step:90/1670 train_time:7878ms step_avg:87.54ms +step:91/1670 train_time:7966ms step_avg:87.54ms +step:92/1670 train_time:8054ms step_avg:87.54ms +step:93/1670 train_time:8142ms step_avg:87.55ms +step:94/1670 train_time:8230ms step_avg:87.55ms +step:95/1670 train_time:8318ms step_avg:87.55ms +step:96/1670 train_time:8405ms step_avg:87.55ms +step:97/1670 train_time:8494ms step_avg:87.56ms +step:98/1670 train_time:8581ms step_avg:87.56ms +step:99/1670 train_time:8669ms step_avg:87.57ms +step:100/1670 train_time:8756ms step_avg:87.56ms +step:101/1670 train_time:8844ms step_avg:87.56ms +step:102/1670 train_time:8932ms step_avg:87.57ms +step:103/1670 train_time:9019ms step_avg:87.56ms +step:104/1670 train_time:9106ms step_avg:87.56ms +step:105/1670 train_time:9194ms step_avg:87.56ms +step:106/1670 train_time:9282ms step_avg:87.56ms +step:107/1670 train_time:9370ms step_avg:87.57ms +step:108/1670 train_time:9458ms step_avg:87.57ms +step:109/1670 train_time:9545ms step_avg:87.57ms +step:110/1670 train_time:9633ms step_avg:87.58ms +step:111/1670 train_time:9721ms step_avg:87.58ms +step:112/1670 train_time:9809ms step_avg:87.58ms +step:113/1670 train_time:9896ms step_avg:87.58ms +step:114/1670 train_time:9983ms step_avg:87.57ms +step:115/1670 train_time:10071ms step_avg:87.58ms +step:116/1670 train_time:10159ms step_avg:87.58ms +step:117/1670 train_time:10246ms step_avg:87.58ms +step:118/1670 train_time:10334ms step_avg:87.58ms +step:119/1670 train_time:10422ms step_avg:87.58ms +step:120/1670 train_time:10510ms step_avg:87.58ms +step:121/1670 train_time:10597ms step_avg:87.58ms +step:122/1670 train_time:10685ms step_avg:87.58ms +step:123/1670 train_time:10773ms step_avg:87.58ms +step:124/1670 train_time:10860ms step_avg:87.58ms +step:125/1670 train_time:10947ms step_avg:87.58ms +step:125/1670 val_loss:4.3268 train_time:11036ms step_avg:88.29ms +step:126/1670 train_time:11056ms step_avg:87.74ms +step:127/1670 train_time:11123ms step_avg:87.59ms +step:128/1670 train_time:11210ms step_avg:87.58ms +step:129/1670 train_time:11300ms step_avg:87.60ms +step:130/1670 train_time:11390ms step_avg:87.62ms +step:131/1670 train_time:11478ms step_avg:87.62ms +step:132/1670 train_time:11565ms step_avg:87.61ms +step:133/1670 train_time:11651ms step_avg:87.60ms +step:134/1670 train_time:11737ms step_avg:87.59ms +step:135/1670 train_time:11823ms step_avg:87.58ms +step:136/1670 train_time:11910ms step_avg:87.57ms +step:137/1670 train_time:12002ms step_avg:87.61ms +step:138/1670 train_time:12093ms step_avg:87.63ms +step:139/1670 train_time:12181ms step_avg:87.63ms +step:140/1670 train_time:12268ms step_avg:87.63ms +step:141/1670 train_time:12356ms step_avg:87.63ms +step:142/1670 train_time:12443ms step_avg:87.63ms +step:143/1670 train_time:12530ms step_avg:87.62ms +step:144/1670 train_time:12617ms step_avg:87.62ms +step:145/1670 train_time:12704ms step_avg:87.62ms +step:146/1670 train_time:12791ms step_avg:87.61ms +step:147/1670 train_time:12878ms step_avg:87.61ms +step:148/1670 train_time:12967ms step_avg:87.61ms +step:149/1670 train_time:13055ms step_avg:87.62ms +step:150/1670 train_time:13144ms step_avg:87.63ms +step:151/1670 train_time:13232ms step_avg:87.63ms +step:152/1670 train_time:13319ms step_avg:87.62ms +step:153/1670 train_time:13406ms step_avg:87.62ms +step:154/1670 train_time:13493ms step_avg:87.62ms +step:155/1670 train_time:13580ms step_avg:87.61ms +step:156/1670 train_time:13668ms step_avg:87.61ms +step:157/1670 train_time:13755ms step_avg:87.61ms +step:158/1670 train_time:13842ms step_avg:87.61ms +step:159/1670 train_time:13930ms step_avg:87.61ms +step:160/1670 train_time:14018ms step_avg:87.61ms +step:161/1670 train_time:14106ms step_avg:87.61ms +step:162/1670 train_time:14194ms step_avg:87.62ms +step:163/1670 train_time:14281ms step_avg:87.62ms +step:164/1670 train_time:14369ms step_avg:87.62ms +step:165/1670 train_time:14456ms step_avg:87.61ms +step:166/1670 train_time:14544ms step_avg:87.61ms +step:167/1670 train_time:14631ms step_avg:87.61ms +step:168/1670 train_time:14719ms step_avg:87.61ms +step:169/1670 train_time:14805ms step_avg:87.60ms +step:170/1670 train_time:14893ms step_avg:87.61ms +step:171/1670 train_time:14980ms step_avg:87.60ms +step:172/1670 train_time:15069ms step_avg:87.61ms +step:173/1670 train_time:15157ms step_avg:87.61ms +step:174/1670 train_time:15245ms step_avg:87.61ms +step:175/1670 train_time:15332ms step_avg:87.61ms +step:176/1670 train_time:15419ms step_avg:87.61ms +step:177/1670 train_time:15506ms step_avg:87.61ms +step:178/1670 train_time:15594ms step_avg:87.61ms +step:179/1670 train_time:15680ms step_avg:87.60ms +step:180/1670 train_time:15768ms step_avg:87.60ms +step:181/1670 train_time:15855ms step_avg:87.60ms +step:182/1670 train_time:15943ms step_avg:87.60ms +step:183/1670 train_time:16031ms step_avg:87.60ms +step:184/1670 train_time:16120ms step_avg:87.61ms +step:185/1670 train_time:16208ms step_avg:87.61ms +step:186/1670 train_time:16296ms step_avg:87.61ms +step:187/1670 train_time:16383ms step_avg:87.61ms +step:188/1670 train_time:16470ms step_avg:87.61ms +step:189/1670 train_time:16558ms step_avg:87.61ms +step:190/1670 train_time:16645ms step_avg:87.61ms +step:191/1670 train_time:16733ms step_avg:87.61ms +step:192/1670 train_time:16820ms step_avg:87.60ms +step:193/1670 train_time:16907ms step_avg:87.60ms +step:194/1670 train_time:16995ms step_avg:87.61ms +step:195/1670 train_time:17083ms step_avg:87.61ms +step:196/1670 train_time:17171ms step_avg:87.61ms +step:197/1670 train_time:17259ms step_avg:87.61ms +step:198/1670 train_time:17346ms step_avg:87.61ms +step:199/1670 train_time:17433ms step_avg:87.60ms +step:200/1670 train_time:17520ms step_avg:87.60ms +step:201/1670 train_time:17608ms step_avg:87.60ms +step:202/1670 train_time:17695ms step_avg:87.60ms +step:203/1670 train_time:17782ms step_avg:87.60ms +step:204/1670 train_time:17870ms step_avg:87.60ms +step:205/1670 train_time:17957ms step_avg:87.59ms +step:206/1670 train_time:18045ms step_avg:87.59ms +step:207/1670 train_time:18133ms step_avg:87.60ms +step:208/1670 train_time:18220ms step_avg:87.60ms +step:209/1670 train_time:18308ms step_avg:87.60ms +step:210/1670 train_time:18396ms step_avg:87.60ms +step:211/1670 train_time:18483ms step_avg:87.60ms +step:212/1670 train_time:18571ms step_avg:87.60ms +step:213/1670 train_time:18659ms step_avg:87.60ms +step:214/1670 train_time:18746ms step_avg:87.60ms +step:215/1670 train_time:18833ms step_avg:87.60ms +step:216/1670 train_time:18920ms step_avg:87.59ms +step:217/1670 train_time:19007ms step_avg:87.59ms +step:218/1670 train_time:19095ms step_avg:87.59ms +step:219/1670 train_time:19182ms step_avg:87.59ms +step:220/1670 train_time:19270ms step_avg:87.59ms +step:221/1670 train_time:19358ms step_avg:87.59ms +step:222/1670 train_time:19445ms step_avg:87.59ms +step:223/1670 train_time:19533ms step_avg:87.59ms +step:224/1670 train_time:19620ms step_avg:87.59ms +step:225/1670 train_time:19707ms step_avg:87.59ms +step:226/1670 train_time:19795ms step_avg:87.59ms +step:227/1670 train_time:19883ms step_avg:87.59ms +step:228/1670 train_time:19971ms step_avg:87.59ms +step:229/1670 train_time:20058ms step_avg:87.59ms +step:230/1670 train_time:20146ms step_avg:87.59ms +step:231/1670 train_time:20235ms step_avg:87.60ms +step:232/1670 train_time:20321ms step_avg:87.59ms +step:233/1670 train_time:20410ms step_avg:87.60ms +step:234/1670 train_time:20497ms step_avg:87.59ms +step:235/1670 train_time:20585ms step_avg:87.59ms +step:236/1670 train_time:20672ms step_avg:87.59ms +step:237/1670 train_time:20759ms step_avg:87.59ms +step:238/1670 train_time:20847ms step_avg:87.59ms +step:239/1670 train_time:20934ms step_avg:87.59ms +step:240/1670 train_time:21022ms step_avg:87.59ms +step:241/1670 train_time:21111ms step_avg:87.60ms +step:242/1670 train_time:21198ms step_avg:87.60ms +step:243/1670 train_time:21286ms step_avg:87.60ms +step:244/1670 train_time:21373ms step_avg:87.60ms +step:245/1670 train_time:21461ms step_avg:87.60ms +step:246/1670 train_time:21549ms step_avg:87.60ms +step:247/1670 train_time:21636ms step_avg:87.59ms +step:248/1670 train_time:21723ms step_avg:87.59ms +step:249/1670 train_time:21811ms step_avg:87.59ms +step:250/1670 train_time:21898ms step_avg:87.59ms +step:250/1670 val_loss:3.9745 train_time:21987ms step_avg:87.95ms +step:251/1670 train_time:22009ms step_avg:87.68ms +step:252/1670 train_time:22079ms step_avg:87.61ms +step:253/1670 train_time:22172ms step_avg:87.64ms +step:254/1670 train_time:22259ms step_avg:87.64ms +step:255/1670 train_time:22346ms step_avg:87.63ms +step:256/1670 train_time:22433ms step_avg:87.63ms +step:257/1670 train_time:22519ms step_avg:87.62ms +step:258/1670 train_time:22606ms step_avg:87.62ms +step:259/1670 train_time:22692ms step_avg:87.62ms +step:260/1670 train_time:22779ms step_avg:87.61ms +step:261/1670 train_time:22866ms step_avg:87.61ms +step:262/1670 train_time:22954ms step_avg:87.61ms +step:263/1670 train_time:23044ms step_avg:87.62ms +step:264/1670 train_time:23134ms step_avg:87.63ms +step:265/1670 train_time:23222ms step_avg:87.63ms +step:266/1670 train_time:23310ms step_avg:87.63ms +step:267/1670 train_time:23398ms step_avg:87.63ms +step:268/1670 train_time:23485ms step_avg:87.63ms +step:269/1670 train_time:23572ms step_avg:87.63ms +step:270/1670 train_time:23658ms step_avg:87.62ms +step:271/1670 train_time:23745ms step_avg:87.62ms +step:272/1670 train_time:23832ms step_avg:87.62ms +step:273/1670 train_time:23919ms step_avg:87.62ms +step:274/1670 train_time:24007ms step_avg:87.62ms +step:275/1670 train_time:24096ms step_avg:87.62ms +step:276/1670 train_time:24184ms step_avg:87.62ms +step:277/1670 train_time:24272ms step_avg:87.63ms +step:278/1670 train_time:24360ms step_avg:87.63ms +step:279/1670 train_time:24448ms step_avg:87.63ms +step:280/1670 train_time:24535ms step_avg:87.62ms +step:281/1670 train_time:24621ms step_avg:87.62ms +step:282/1670 train_time:24708ms step_avg:87.62ms +step:283/1670 train_time:24795ms step_avg:87.62ms +step:284/1670 train_time:24883ms step_avg:87.62ms +step:285/1670 train_time:24971ms step_avg:87.62ms +step:286/1670 train_time:25059ms step_avg:87.62ms +step:287/1670 train_time:25146ms step_avg:87.62ms +step:288/1670 train_time:25234ms step_avg:87.62ms +step:289/1670 train_time:25322ms step_avg:87.62ms +step:290/1670 train_time:25410ms step_avg:87.62ms +step:291/1670 train_time:25497ms step_avg:87.62ms +step:292/1670 train_time:25585ms step_avg:87.62ms +step:293/1670 train_time:25672ms step_avg:87.62ms +step:294/1670 train_time:25760ms step_avg:87.62ms +step:295/1670 train_time:25847ms step_avg:87.62ms +step:296/1670 train_time:25934ms step_avg:87.62ms +step:297/1670 train_time:26022ms step_avg:87.61ms +step:298/1670 train_time:26110ms step_avg:87.62ms +step:299/1670 train_time:26197ms step_avg:87.62ms +step:300/1670 train_time:26285ms step_avg:87.62ms +step:301/1670 train_time:26373ms step_avg:87.62ms +step:302/1670 train_time:26461ms step_avg:87.62ms +step:303/1670 train_time:26549ms step_avg:87.62ms +step:304/1670 train_time:26636ms step_avg:87.62ms +step:305/1670 train_time:26723ms step_avg:87.62ms +step:306/1670 train_time:26811ms step_avg:87.62ms +step:307/1670 train_time:26898ms step_avg:87.61ms +step:308/1670 train_time:26985ms step_avg:87.61ms +step:309/1670 train_time:27073ms step_avg:87.61ms +step:310/1670 train_time:27160ms step_avg:87.61ms +step:311/1670 train_time:27248ms step_avg:87.61ms +step:312/1670 train_time:27335ms step_avg:87.61ms +step:313/1670 train_time:27423ms step_avg:87.61ms +step:314/1670 train_time:27513ms step_avg:87.62ms +step:315/1670 train_time:27600ms step_avg:87.62ms +step:316/1670 train_time:27687ms step_avg:87.62ms +step:317/1670 train_time:27774ms step_avg:87.61ms +step:318/1670 train_time:27861ms step_avg:87.61ms +step:319/1670 train_time:27949ms step_avg:87.61ms +step:320/1670 train_time:28036ms step_avg:87.61ms +step:321/1670 train_time:28124ms step_avg:87.61ms +step:322/1670 train_time:28212ms step_avg:87.62ms +step:323/1670 train_time:28299ms step_avg:87.61ms +step:324/1670 train_time:28389ms step_avg:87.62ms +step:325/1670 train_time:28477ms step_avg:87.62ms +step:326/1670 train_time:28565ms step_avg:87.62ms +step:327/1670 train_time:28653ms step_avg:87.62ms +step:328/1670 train_time:28740ms step_avg:87.62ms +step:329/1670 train_time:28826ms step_avg:87.62ms +step:330/1670 train_time:28914ms step_avg:87.62ms +step:331/1670 train_time:29001ms step_avg:87.62ms +step:332/1670 train_time:29089ms step_avg:87.62ms +step:333/1670 train_time:29177ms step_avg:87.62ms +step:334/1670 train_time:29264ms step_avg:87.62ms +step:335/1670 train_time:29353ms step_avg:87.62ms +step:336/1670 train_time:29442ms step_avg:87.63ms +step:337/1670 train_time:29531ms step_avg:87.63ms +step:338/1670 train_time:29618ms step_avg:87.63ms +step:339/1670 train_time:29704ms step_avg:87.62ms +step:340/1670 train_time:29793ms step_avg:87.63ms +step:341/1670 train_time:29880ms step_avg:87.63ms +step:342/1670 train_time:29968ms step_avg:87.62ms +step:343/1670 train_time:30055ms step_avg:87.62ms +step:344/1670 train_time:30142ms step_avg:87.62ms +step:345/1670 train_time:30229ms step_avg:87.62ms +step:346/1670 train_time:30317ms step_avg:87.62ms +step:347/1670 train_time:30405ms step_avg:87.62ms +step:348/1670 train_time:30493ms step_avg:87.62ms +step:349/1670 train_time:30580ms step_avg:87.62ms +step:350/1670 train_time:30668ms step_avg:87.62ms +step:351/1670 train_time:30755ms step_avg:87.62ms +step:352/1670 train_time:30842ms step_avg:87.62ms +step:353/1670 train_time:30930ms step_avg:87.62ms +step:354/1670 train_time:31018ms step_avg:87.62ms +step:355/1670 train_time:31105ms step_avg:87.62ms +step:356/1670 train_time:31193ms step_avg:87.62ms +step:357/1670 train_time:31280ms step_avg:87.62ms +step:358/1670 train_time:31368ms step_avg:87.62ms +step:359/1670 train_time:31456ms step_avg:87.62ms +step:360/1670 train_time:31543ms step_avg:87.62ms +step:361/1670 train_time:31631ms step_avg:87.62ms +step:362/1670 train_time:31719ms step_avg:87.62ms +step:363/1670 train_time:31807ms step_avg:87.62ms +step:364/1670 train_time:31894ms step_avg:87.62ms +step:365/1670 train_time:31982ms step_avg:87.62ms +step:366/1670 train_time:32070ms step_avg:87.62ms +step:367/1670 train_time:32157ms step_avg:87.62ms +step:368/1670 train_time:32244ms step_avg:87.62ms +step:369/1670 train_time:32333ms step_avg:87.62ms +step:370/1670 train_time:32420ms step_avg:87.62ms +step:371/1670 train_time:32508ms step_avg:87.62ms +step:372/1670 train_time:32596ms step_avg:87.62ms +step:373/1670 train_time:32683ms step_avg:87.62ms +step:374/1670 train_time:32771ms step_avg:87.62ms +step:375/1670 train_time:32859ms step_avg:87.62ms +step:375/1670 val_loss:3.8208 train_time:32948ms step_avg:87.86ms +step:376/1670 train_time:32967ms step_avg:87.68ms +step:377/1670 train_time:33038ms step_avg:87.63ms +step:378/1670 train_time:33132ms step_avg:87.65ms +step:379/1670 train_time:33220ms step_avg:87.65ms +step:380/1670 train_time:33309ms step_avg:87.65ms +step:381/1670 train_time:33395ms step_avg:87.65ms +step:382/1670 train_time:33481ms step_avg:87.65ms +step:383/1670 train_time:33568ms step_avg:87.65ms +step:384/1670 train_time:33655ms step_avg:87.64ms +step:385/1670 train_time:33741ms step_avg:87.64ms +step:386/1670 train_time:33828ms step_avg:87.64ms +step:387/1670 train_time:33915ms step_avg:87.64ms +step:388/1670 train_time:34005ms step_avg:87.64ms +step:389/1670 train_time:34097ms step_avg:87.65ms +step:390/1670 train_time:34186ms step_avg:87.66ms +step:391/1670 train_time:34275ms step_avg:87.66ms +step:392/1670 train_time:34362ms step_avg:87.66ms +step:393/1670 train_time:34448ms step_avg:87.65ms +step:394/1670 train_time:34535ms step_avg:87.65ms +step:395/1670 train_time:34621ms step_avg:87.65ms +step:396/1670 train_time:34709ms step_avg:87.65ms +step:397/1670 train_time:34796ms step_avg:87.65ms +step:398/1670 train_time:34883ms step_avg:87.65ms +step:399/1670 train_time:34971ms step_avg:87.65ms +step:400/1670 train_time:35060ms step_avg:87.65ms +step:401/1670 train_time:35149ms step_avg:87.65ms +step:402/1670 train_time:35238ms step_avg:87.66ms +step:403/1670 train_time:35326ms step_avg:87.66ms +step:404/1670 train_time:35415ms step_avg:87.66ms +step:405/1670 train_time:35501ms step_avg:87.66ms +step:406/1670 train_time:35588ms step_avg:87.65ms +step:407/1670 train_time:35675ms step_avg:87.65ms +step:408/1670 train_time:35762ms step_avg:87.65ms +step:409/1670 train_time:35849ms step_avg:87.65ms +step:410/1670 train_time:35936ms step_avg:87.65ms +step:411/1670 train_time:36024ms step_avg:87.65ms +step:412/1670 train_time:36114ms step_avg:87.65ms +step:413/1670 train_time:36202ms step_avg:87.66ms +step:414/1670 train_time:36291ms step_avg:87.66ms +step:415/1670 train_time:36378ms step_avg:87.66ms +step:416/1670 train_time:36466ms step_avg:87.66ms +step:417/1670 train_time:36553ms step_avg:87.66ms +step:418/1670 train_time:36640ms step_avg:87.66ms +step:419/1670 train_time:36727ms step_avg:87.65ms +step:420/1670 train_time:36816ms step_avg:87.66ms +step:421/1670 train_time:36903ms step_avg:87.66ms +step:422/1670 train_time:36991ms step_avg:87.66ms +step:423/1670 train_time:37079ms step_avg:87.66ms +step:424/1670 train_time:37167ms step_avg:87.66ms +step:425/1670 train_time:37256ms step_avg:87.66ms +step:426/1670 train_time:37344ms step_avg:87.66ms +step:427/1670 train_time:37432ms step_avg:87.66ms +step:428/1670 train_time:37519ms step_avg:87.66ms +step:429/1670 train_time:37605ms step_avg:87.66ms +step:430/1670 train_time:37693ms step_avg:87.66ms +step:431/1670 train_time:37780ms step_avg:87.66ms +step:432/1670 train_time:37868ms step_avg:87.66ms +step:433/1670 train_time:37956ms step_avg:87.66ms +step:434/1670 train_time:38044ms step_avg:87.66ms +step:435/1670 train_time:38132ms step_avg:87.66ms +step:436/1670 train_time:38220ms step_avg:87.66ms +step:437/1670 train_time:38308ms step_avg:87.66ms +step:438/1670 train_time:38396ms step_avg:87.66ms +step:439/1670 train_time:38484ms step_avg:87.66ms +step:440/1670 train_time:38571ms step_avg:87.66ms +step:441/1670 train_time:38658ms step_avg:87.66ms +step:442/1670 train_time:38746ms step_avg:87.66ms +step:443/1670 train_time:38834ms step_avg:87.66ms +step:444/1670 train_time:38921ms step_avg:87.66ms +step:445/1670 train_time:39009ms step_avg:87.66ms +step:446/1670 train_time:39097ms step_avg:87.66ms +step:447/1670 train_time:39185ms step_avg:87.66ms +step:448/1670 train_time:39273ms step_avg:87.66ms +step:449/1670 train_time:39361ms step_avg:87.66ms +step:450/1670 train_time:39448ms step_avg:87.66ms +step:451/1670 train_time:39536ms step_avg:87.66ms +step:452/1670 train_time:39623ms step_avg:87.66ms +step:453/1670 train_time:39711ms step_avg:87.66ms +step:454/1670 train_time:39798ms step_avg:87.66ms +step:455/1670 train_time:39884ms step_avg:87.66ms +step:456/1670 train_time:39972ms step_avg:87.66ms +step:457/1670 train_time:40060ms step_avg:87.66ms +step:458/1670 train_time:40147ms step_avg:87.66ms +step:459/1670 train_time:40235ms step_avg:87.66ms +step:460/1670 train_time:40323ms step_avg:87.66ms +step:461/1670 train_time:40410ms step_avg:87.66ms +step:462/1670 train_time:40498ms step_avg:87.66ms +step:463/1670 train_time:40585ms step_avg:87.66ms +step:464/1670 train_time:40673ms step_avg:87.66ms +step:465/1670 train_time:40760ms step_avg:87.66ms +step:466/1670 train_time:40848ms step_avg:87.66ms +step:467/1670 train_time:40936ms step_avg:87.66ms +step:468/1670 train_time:41022ms step_avg:87.65ms +step:469/1670 train_time:41110ms step_avg:87.65ms +step:470/1670 train_time:41198ms step_avg:87.65ms +step:471/1670 train_time:41286ms step_avg:87.66ms +step:472/1670 train_time:41374ms step_avg:87.66ms +step:473/1670 train_time:41461ms step_avg:87.66ms +step:474/1670 train_time:41549ms step_avg:87.66ms +step:475/1670 train_time:41637ms step_avg:87.66ms +step:476/1670 train_time:41724ms step_avg:87.66ms +step:477/1670 train_time:41812ms step_avg:87.66ms +step:478/1670 train_time:41899ms step_avg:87.66ms +step:479/1670 train_time:41987ms step_avg:87.66ms +step:480/1670 train_time:42076ms step_avg:87.66ms +step:481/1670 train_time:42164ms step_avg:87.66ms +step:482/1670 train_time:42251ms step_avg:87.66ms +step:483/1670 train_time:42339ms step_avg:87.66ms +step:484/1670 train_time:42426ms step_avg:87.66ms +step:485/1670 train_time:42514ms step_avg:87.66ms +step:486/1670 train_time:42601ms step_avg:87.66ms +step:487/1670 train_time:42688ms step_avg:87.66ms +step:488/1670 train_time:42776ms step_avg:87.66ms +step:489/1670 train_time:42863ms step_avg:87.65ms +step:490/1670 train_time:42951ms step_avg:87.66ms +step:491/1670 train_time:43039ms step_avg:87.66ms +step:492/1670 train_time:43127ms step_avg:87.66ms +step:493/1670 train_time:43215ms step_avg:87.66ms +step:494/1670 train_time:43302ms step_avg:87.66ms +step:495/1670 train_time:43390ms step_avg:87.66ms +step:496/1670 train_time:43478ms step_avg:87.66ms +step:497/1670 train_time:43565ms step_avg:87.66ms +step:498/1670 train_time:43653ms step_avg:87.66ms +step:499/1670 train_time:43740ms step_avg:87.66ms +step:500/1670 train_time:43828ms step_avg:87.66ms +step:500/1670 val_loss:3.7187 train_time:43917ms step_avg:87.83ms +step:501/1670 train_time:43937ms step_avg:87.70ms +step:502/1670 train_time:44010ms step_avg:87.67ms +step:503/1670 train_time:44105ms step_avg:87.68ms +step:504/1670 train_time:44195ms step_avg:87.69ms +step:505/1670 train_time:44283ms step_avg:87.69ms +step:506/1670 train_time:44370ms step_avg:87.69ms +step:507/1670 train_time:44456ms step_avg:87.68ms +step:508/1670 train_time:44543ms step_avg:87.68ms +step:509/1670 train_time:44628ms step_avg:87.68ms +step:510/1670 train_time:44715ms step_avg:87.68ms +step:511/1670 train_time:44802ms step_avg:87.68ms +step:512/1670 train_time:44890ms step_avg:87.67ms +step:513/1670 train_time:44980ms step_avg:87.68ms +step:514/1670 train_time:45069ms step_avg:87.68ms +step:515/1670 train_time:45159ms step_avg:87.69ms +step:516/1670 train_time:45247ms step_avg:87.69ms +step:517/1670 train_time:45335ms step_avg:87.69ms +step:518/1670 train_time:45423ms step_avg:87.69ms +step:519/1670 train_time:45509ms step_avg:87.69ms +step:520/1670 train_time:45596ms step_avg:87.68ms +step:521/1670 train_time:45682ms step_avg:87.68ms +step:522/1670 train_time:45768ms step_avg:87.68ms +step:523/1670 train_time:45856ms step_avg:87.68ms +step:524/1670 train_time:45946ms step_avg:87.68ms +step:525/1670 train_time:46035ms step_avg:87.69ms +step:526/1670 train_time:46125ms step_avg:87.69ms +step:527/1670 train_time:46214ms step_avg:87.69ms +step:528/1670 train_time:46301ms step_avg:87.69ms +step:529/1670 train_time:46388ms step_avg:87.69ms +step:530/1670 train_time:46476ms step_avg:87.69ms +step:531/1670 train_time:46563ms step_avg:87.69ms +step:532/1670 train_time:46649ms step_avg:87.69ms +step:533/1670 train_time:46736ms step_avg:87.69ms +step:534/1670 train_time:46824ms step_avg:87.69ms +step:535/1670 train_time:46911ms step_avg:87.68ms +step:536/1670 train_time:47000ms step_avg:87.69ms +step:537/1670 train_time:47088ms step_avg:87.69ms +step:538/1670 train_time:47177ms step_avg:87.69ms +step:539/1670 train_time:47265ms step_avg:87.69ms +step:540/1670 train_time:47353ms step_avg:87.69ms +step:541/1670 train_time:47441ms step_avg:87.69ms +step:542/1670 train_time:47528ms step_avg:87.69ms +step:543/1670 train_time:47615ms step_avg:87.69ms +step:544/1670 train_time:47702ms step_avg:87.69ms +step:545/1670 train_time:47789ms step_avg:87.69ms +step:546/1670 train_time:47878ms step_avg:87.69ms +step:547/1670 train_time:47967ms step_avg:87.69ms +step:548/1670 train_time:48057ms step_avg:87.69ms +step:549/1670 train_time:48147ms step_avg:87.70ms +step:550/1670 train_time:48238ms step_avg:87.70ms +step:551/1670 train_time:48327ms step_avg:87.71ms +step:552/1670 train_time:48416ms step_avg:87.71ms +step:553/1670 train_time:48505ms step_avg:87.71ms +step:554/1670 train_time:48593ms step_avg:87.71ms +step:555/1670 train_time:48682ms step_avg:87.72ms +step:556/1670 train_time:48770ms step_avg:87.72ms +step:557/1670 train_time:48859ms step_avg:87.72ms +step:558/1670 train_time:48948ms step_avg:87.72ms +step:559/1670 train_time:49038ms step_avg:87.72ms +step:560/1670 train_time:49127ms step_avg:87.73ms +step:561/1670 train_time:49216ms step_avg:87.73ms +step:562/1670 train_time:49305ms step_avg:87.73ms +step:563/1670 train_time:49394ms step_avg:87.73ms +step:564/1670 train_time:49485ms step_avg:87.74ms +step:565/1670 train_time:49574ms step_avg:87.74ms +step:566/1670 train_time:49663ms step_avg:87.74ms +step:567/1670 train_time:49751ms step_avg:87.74ms +step:568/1670 train_time:49840ms step_avg:87.75ms +step:569/1670 train_time:49928ms step_avg:87.75ms +step:570/1670 train_time:50017ms step_avg:87.75ms +step:571/1670 train_time:50106ms step_avg:87.75ms +step:572/1670 train_time:50194ms step_avg:87.75ms +step:573/1670 train_time:50284ms step_avg:87.76ms +step:574/1670 train_time:50373ms step_avg:87.76ms +step:575/1670 train_time:50463ms step_avg:87.76ms +step:576/1670 train_time:50552ms step_avg:87.76ms +step:577/1670 train_time:50642ms step_avg:87.77ms +step:578/1670 train_time:50730ms step_avg:87.77ms +step:579/1670 train_time:50819ms step_avg:87.77ms +step:580/1670 train_time:50907ms step_avg:87.77ms +step:581/1670 train_time:50996ms step_avg:87.77ms +step:582/1670 train_time:51085ms step_avg:87.78ms +step:583/1670 train_time:51174ms step_avg:87.78ms +step:584/1670 train_time:51263ms step_avg:87.78ms +step:585/1670 train_time:51352ms step_avg:87.78ms +step:586/1670 train_time:51441ms step_avg:87.78ms +step:587/1670 train_time:51529ms step_avg:87.78ms +step:588/1670 train_time:51618ms step_avg:87.79ms +step:589/1670 train_time:51707ms step_avg:87.79ms +step:590/1670 train_time:51795ms step_avg:87.79ms +step:591/1670 train_time:51884ms step_avg:87.79ms +step:592/1670 train_time:51973ms step_avg:87.79ms +step:593/1670 train_time:52062ms step_avg:87.79ms +step:594/1670 train_time:52151ms step_avg:87.80ms +step:595/1670 train_time:52240ms step_avg:87.80ms +step:596/1670 train_time:52328ms step_avg:87.80ms +step:597/1670 train_time:52417ms step_avg:87.80ms +step:598/1670 train_time:52506ms step_avg:87.80ms +step:599/1670 train_time:52596ms step_avg:87.81ms +step:600/1670 train_time:52685ms step_avg:87.81ms +step:601/1670 train_time:52774ms step_avg:87.81ms +step:602/1670 train_time:52863ms step_avg:87.81ms +step:603/1670 train_time:52951ms step_avg:87.81ms +step:604/1670 train_time:53041ms step_avg:87.82ms +step:605/1670 train_time:53129ms step_avg:87.82ms +step:606/1670 train_time:53218ms step_avg:87.82ms +step:607/1670 train_time:53306ms step_avg:87.82ms +step:608/1670 train_time:53395ms step_avg:87.82ms +step:609/1670 train_time:53486ms step_avg:87.83ms +step:610/1670 train_time:53575ms step_avg:87.83ms +step:611/1670 train_time:53665ms step_avg:87.83ms +step:612/1670 train_time:53754ms step_avg:87.83ms +step:613/1670 train_time:53844ms step_avg:87.84ms +step:614/1670 train_time:53933ms step_avg:87.84ms +step:615/1670 train_time:54021ms step_avg:87.84ms +step:616/1670 train_time:54109ms step_avg:87.84ms +step:617/1670 train_time:54197ms step_avg:87.84ms +step:618/1670 train_time:54286ms step_avg:87.84ms +step:619/1670 train_time:54375ms step_avg:87.84ms +step:620/1670 train_time:54464ms step_avg:87.84ms +step:621/1670 train_time:54553ms step_avg:87.85ms +step:622/1670 train_time:54643ms step_avg:87.85ms +step:623/1670 train_time:54731ms step_avg:87.85ms +step:624/1670 train_time:54819ms step_avg:87.85ms +step:625/1670 train_time:54908ms step_avg:87.85ms +step:625/1670 val_loss:3.6148 train_time:54998ms step_avg:88.00ms +step:626/1670 train_time:55017ms step_avg:87.89ms +step:627/1670 train_time:55088ms step_avg:87.86ms +step:628/1670 train_time:55176ms step_avg:87.86ms +step:629/1670 train_time:55265ms step_avg:87.86ms +step:630/1670 train_time:55354ms step_avg:87.86ms +step:631/1670 train_time:55442ms step_avg:87.86ms +step:632/1670 train_time:55529ms step_avg:87.86ms +step:633/1670 train_time:55617ms step_avg:87.86ms +step:634/1670 train_time:55705ms step_avg:87.86ms +step:635/1670 train_time:55794ms step_avg:87.86ms +step:636/1670 train_time:55884ms step_avg:87.87ms +step:637/1670 train_time:55975ms step_avg:87.87ms +step:638/1670 train_time:56065ms step_avg:87.88ms +step:639/1670 train_time:56154ms step_avg:87.88ms +step:640/1670 train_time:56243ms step_avg:87.88ms +step:641/1670 train_time:56331ms step_avg:87.88ms +step:642/1670 train_time:56419ms step_avg:87.88ms +step:643/1670 train_time:56507ms step_avg:87.88ms +step:644/1670 train_time:56595ms step_avg:87.88ms +step:645/1670 train_time:56683ms step_avg:87.88ms +step:646/1670 train_time:56772ms step_avg:87.88ms +step:647/1670 train_time:56862ms step_avg:87.89ms +step:648/1670 train_time:56952ms step_avg:87.89ms +step:649/1670 train_time:57042ms step_avg:87.89ms +step:650/1670 train_time:57131ms step_avg:87.89ms +step:651/1670 train_time:57220ms step_avg:87.90ms +step:652/1670 train_time:57308ms step_avg:87.90ms +step:653/1670 train_time:57398ms step_avg:87.90ms +step:654/1670 train_time:57486ms step_avg:87.90ms +step:655/1670 train_time:57574ms step_avg:87.90ms +step:656/1670 train_time:57663ms step_avg:87.90ms +step:657/1670 train_time:57751ms step_avg:87.90ms +step:658/1670 train_time:57841ms step_avg:87.90ms +step:659/1670 train_time:57930ms step_avg:87.91ms +step:660/1670 train_time:58020ms step_avg:87.91ms +step:661/1670 train_time:58110ms step_avg:87.91ms +step:662/1670 train_time:58199ms step_avg:87.91ms +step:663/1670 train_time:58288ms step_avg:87.92ms +step:664/1670 train_time:58376ms step_avg:87.92ms +step:665/1670 train_time:58465ms step_avg:87.92ms +step:666/1670 train_time:58553ms step_avg:87.92ms +step:667/1670 train_time:58641ms step_avg:87.92ms +step:668/1670 train_time:58730ms step_avg:87.92ms +step:669/1670 train_time:58819ms step_avg:87.92ms +step:670/1670 train_time:58907ms step_avg:87.92ms +step:671/1670 train_time:58996ms step_avg:87.92ms +step:672/1670 train_time:59087ms step_avg:87.93ms +step:673/1670 train_time:59176ms step_avg:87.93ms +step:674/1670 train_time:59266ms step_avg:87.93ms +step:675/1670 train_time:59355ms step_avg:87.93ms +step:676/1670 train_time:59444ms step_avg:87.93ms +step:677/1670 train_time:59533ms step_avg:87.94ms +step:678/1670 train_time:59622ms step_avg:87.94ms +step:679/1670 train_time:59710ms step_avg:87.94ms +step:680/1670 train_time:59799ms step_avg:87.94ms +step:681/1670 train_time:59888ms step_avg:87.94ms +step:682/1670 train_time:59978ms step_avg:87.94ms +step:683/1670 train_time:60067ms step_avg:87.95ms +step:684/1670 train_time:60157ms step_avg:87.95ms +step:685/1670 train_time:60246ms step_avg:87.95ms +step:686/1670 train_time:60335ms step_avg:87.95ms +step:687/1670 train_time:60424ms step_avg:87.95ms +step:688/1670 train_time:60512ms step_avg:87.95ms +step:689/1670 train_time:60601ms step_avg:87.95ms +step:690/1670 train_time:60689ms step_avg:87.95ms +step:691/1670 train_time:60778ms step_avg:87.96ms +step:692/1670 train_time:60867ms step_avg:87.96ms +step:693/1670 train_time:60957ms step_avg:87.96ms +step:694/1670 train_time:61046ms step_avg:87.96ms +step:695/1670 train_time:61135ms step_avg:87.96ms +step:696/1670 train_time:61225ms step_avg:87.97ms +step:697/1670 train_time:61314ms step_avg:87.97ms +step:698/1670 train_time:61403ms step_avg:87.97ms +step:699/1670 train_time:61491ms step_avg:87.97ms +step:700/1670 train_time:61580ms step_avg:87.97ms +step:701/1670 train_time:61668ms step_avg:87.97ms +step:702/1670 train_time:61757ms step_avg:87.97ms +step:703/1670 train_time:61847ms step_avg:87.98ms +step:704/1670 train_time:61936ms step_avg:87.98ms +step:705/1670 train_time:62025ms step_avg:87.98ms +step:706/1670 train_time:62114ms step_avg:87.98ms +step:707/1670 train_time:62202ms step_avg:87.98ms +step:708/1670 train_time:62291ms step_avg:87.98ms +step:709/1670 train_time:62380ms step_avg:87.98ms +step:710/1670 train_time:62468ms step_avg:87.98ms +step:711/1670 train_time:62556ms step_avg:87.98ms +step:712/1670 train_time:62645ms step_avg:87.98ms +step:713/1670 train_time:62734ms step_avg:87.99ms +step:714/1670 train_time:62823ms step_avg:87.99ms +step:715/1670 train_time:62912ms step_avg:87.99ms +step:716/1670 train_time:63001ms step_avg:87.99ms +step:717/1670 train_time:63090ms step_avg:87.99ms +step:718/1670 train_time:63179ms step_avg:87.99ms +step:719/1670 train_time:63268ms step_avg:87.99ms +step:720/1670 train_time:63357ms step_avg:88.00ms +step:721/1670 train_time:63446ms step_avg:88.00ms +step:722/1670 train_time:63535ms step_avg:88.00ms +step:723/1670 train_time:63624ms step_avg:88.00ms +step:724/1670 train_time:63712ms step_avg:88.00ms +step:725/1670 train_time:63802ms step_avg:88.00ms +step:726/1670 train_time:63891ms step_avg:88.00ms +step:727/1670 train_time:63980ms step_avg:88.01ms +step:728/1670 train_time:64069ms step_avg:88.01ms +step:729/1670 train_time:64158ms step_avg:88.01ms +step:730/1670 train_time:64248ms step_avg:88.01ms +step:731/1670 train_time:64338ms step_avg:88.01ms +step:732/1670 train_time:64426ms step_avg:88.01ms +step:733/1670 train_time:64515ms step_avg:88.01ms +step:734/1670 train_time:64604ms step_avg:88.02ms +step:735/1670 train_time:64692ms step_avg:88.02ms +step:736/1670 train_time:64781ms step_avg:88.02ms +step:737/1670 train_time:64870ms step_avg:88.02ms +step:738/1670 train_time:64959ms step_avg:88.02ms +step:739/1670 train_time:65048ms step_avg:88.02ms +step:740/1670 train_time:65137ms step_avg:88.02ms +step:741/1670 train_time:65226ms step_avg:88.02ms +step:742/1670 train_time:65315ms step_avg:88.03ms +step:743/1670 train_time:65405ms step_avg:88.03ms +step:744/1670 train_time:65494ms step_avg:88.03ms +step:745/1670 train_time:65583ms step_avg:88.03ms +step:746/1670 train_time:65671ms step_avg:88.03ms +step:747/1670 train_time:65760ms step_avg:88.03ms +step:748/1670 train_time:65849ms step_avg:88.03ms +step:749/1670 train_time:65939ms step_avg:88.04ms +step:750/1670 train_time:66028ms step_avg:88.04ms +step:750/1670 val_loss:3.5672 train_time:66119ms step_avg:88.16ms +step:751/1670 train_time:66138ms step_avg:88.07ms +step:752/1670 train_time:66213ms step_avg:88.05ms +step:753/1670 train_time:66305ms step_avg:88.05ms +step:754/1670 train_time:66394ms step_avg:88.06ms +step:755/1670 train_time:66482ms step_avg:88.06ms +step:756/1670 train_time:66570ms step_avg:88.06ms +step:757/1670 train_time:66657ms step_avg:88.05ms +step:758/1670 train_time:66745ms step_avg:88.05ms +step:759/1670 train_time:66833ms step_avg:88.05ms +step:760/1670 train_time:66921ms step_avg:88.05ms +step:761/1670 train_time:67010ms step_avg:88.05ms +step:762/1670 train_time:67101ms step_avg:88.06ms +step:763/1670 train_time:67193ms step_avg:88.06ms +step:764/1670 train_time:67284ms step_avg:88.07ms +step:765/1670 train_time:67373ms step_avg:88.07ms +step:766/1670 train_time:67463ms step_avg:88.07ms +step:767/1670 train_time:67552ms step_avg:88.07ms +step:768/1670 train_time:67639ms step_avg:88.07ms +step:769/1670 train_time:67728ms step_avg:88.07ms +step:770/1670 train_time:67816ms step_avg:88.07ms +step:771/1670 train_time:67903ms step_avg:88.07ms +step:772/1670 train_time:67992ms step_avg:88.07ms +step:773/1670 train_time:68081ms step_avg:88.07ms +step:774/1670 train_time:68171ms step_avg:88.08ms +step:775/1670 train_time:68262ms step_avg:88.08ms +step:776/1670 train_time:68353ms step_avg:88.08ms +step:777/1670 train_time:68442ms step_avg:88.08ms +step:778/1670 train_time:68530ms step_avg:88.08ms +step:779/1670 train_time:68618ms step_avg:88.09ms +step:780/1670 train_time:68707ms step_avg:88.09ms +step:781/1670 train_time:68796ms step_avg:88.09ms +step:782/1670 train_time:68883ms step_avg:88.09ms +step:783/1670 train_time:68973ms step_avg:88.09ms +step:784/1670 train_time:69061ms step_avg:88.09ms +step:785/1670 train_time:69150ms step_avg:88.09ms +step:786/1670 train_time:69241ms step_avg:88.09ms +step:787/1670 train_time:69331ms step_avg:88.10ms +step:788/1670 train_time:69420ms step_avg:88.10ms +step:789/1670 train_time:69509ms step_avg:88.10ms +step:790/1670 train_time:69598ms step_avg:88.10ms +step:791/1670 train_time:69686ms step_avg:88.10ms +step:792/1670 train_time:69774ms step_avg:88.10ms +step:793/1670 train_time:69862ms step_avg:88.10ms +step:794/1670 train_time:69951ms step_avg:88.10ms +step:795/1670 train_time:70039ms step_avg:88.10ms +step:796/1670 train_time:70128ms step_avg:88.10ms +step:797/1670 train_time:70218ms step_avg:88.10ms +step:798/1670 train_time:70307ms step_avg:88.10ms +step:799/1670 train_time:70397ms step_avg:88.11ms +step:800/1670 train_time:70486ms step_avg:88.11ms +step:801/1670 train_time:70575ms step_avg:88.11ms +step:802/1670 train_time:70663ms step_avg:88.11ms +step:803/1670 train_time:70751ms step_avg:88.11ms +step:804/1670 train_time:70839ms step_avg:88.11ms +step:805/1670 train_time:70928ms step_avg:88.11ms +step:806/1670 train_time:71017ms step_avg:88.11ms +step:807/1670 train_time:71107ms step_avg:88.11ms +step:808/1670 train_time:71197ms step_avg:88.11ms +step:809/1670 train_time:71286ms step_avg:88.12ms +step:810/1670 train_time:71375ms step_avg:88.12ms +step:811/1670 train_time:71464ms step_avg:88.12ms +step:812/1670 train_time:71553ms step_avg:88.12ms +step:813/1670 train_time:71641ms step_avg:88.12ms +step:814/1670 train_time:71730ms step_avg:88.12ms +step:815/1670 train_time:71819ms step_avg:88.12ms +step:816/1670 train_time:71908ms step_avg:88.12ms +step:817/1670 train_time:71997ms step_avg:88.12ms +step:818/1670 train_time:72086ms step_avg:88.13ms +step:819/1670 train_time:72176ms step_avg:88.13ms +step:820/1670 train_time:72264ms step_avg:88.13ms +step:821/1670 train_time:72353ms step_avg:88.13ms +step:822/1670 train_time:72443ms step_avg:88.13ms +step:823/1670 train_time:72532ms step_avg:88.13ms +step:824/1670 train_time:72620ms step_avg:88.13ms +step:825/1670 train_time:72710ms step_avg:88.13ms +step:826/1670 train_time:72799ms step_avg:88.13ms +step:827/1670 train_time:72887ms step_avg:88.13ms +step:828/1670 train_time:72976ms step_avg:88.13ms +step:829/1670 train_time:73064ms step_avg:88.14ms +step:830/1670 train_time:73154ms step_avg:88.14ms +step:831/1670 train_time:73242ms step_avg:88.14ms +step:832/1670 train_time:73331ms step_avg:88.14ms +step:833/1670 train_time:73420ms step_avg:88.14ms +step:834/1670 train_time:73509ms step_avg:88.14ms +step:835/1670 train_time:73599ms step_avg:88.14ms +step:836/1670 train_time:73688ms step_avg:88.14ms +step:837/1670 train_time:73776ms step_avg:88.14ms +step:838/1670 train_time:73865ms step_avg:88.14ms +step:839/1670 train_time:73955ms step_avg:88.15ms +step:840/1670 train_time:74043ms step_avg:88.15ms +step:841/1670 train_time:74132ms step_avg:88.15ms +step:842/1670 train_time:74220ms step_avg:88.15ms +step:843/1670 train_time:74309ms step_avg:88.15ms +step:844/1670 train_time:74398ms step_avg:88.15ms +step:845/1670 train_time:74488ms step_avg:88.15ms +step:846/1670 train_time:74577ms step_avg:88.15ms +step:847/1670 train_time:74668ms step_avg:88.16ms +step:848/1670 train_time:74757ms step_avg:88.16ms +step:849/1670 train_time:74845ms step_avg:88.16ms +step:850/1670 train_time:74935ms step_avg:88.16ms +step:851/1670 train_time:75023ms step_avg:88.16ms +step:852/1670 train_time:75113ms step_avg:88.16ms +step:853/1670 train_time:75202ms step_avg:88.16ms +step:854/1670 train_time:75291ms step_avg:88.16ms +step:855/1670 train_time:75381ms step_avg:88.16ms +step:856/1670 train_time:75469ms step_avg:88.16ms +step:857/1670 train_time:75558ms step_avg:88.17ms +step:858/1670 train_time:75647ms step_avg:88.17ms +step:859/1670 train_time:75736ms step_avg:88.17ms +step:860/1670 train_time:75825ms step_avg:88.17ms +step:861/1670 train_time:75915ms step_avg:88.17ms +step:862/1670 train_time:76004ms step_avg:88.17ms +step:863/1670 train_time:76094ms step_avg:88.17ms +step:864/1670 train_time:76182ms step_avg:88.17ms +step:865/1670 train_time:76271ms step_avg:88.17ms +step:866/1670 train_time:76360ms step_avg:88.18ms +step:867/1670 train_time:76449ms step_avg:88.18ms +step:868/1670 train_time:76538ms step_avg:88.18ms +step:869/1670 train_time:76627ms step_avg:88.18ms +step:870/1670 train_time:76716ms step_avg:88.18ms +step:871/1670 train_time:76805ms step_avg:88.18ms +step:872/1670 train_time:76894ms step_avg:88.18ms +step:873/1670 train_time:76982ms step_avg:88.18ms +step:874/1670 train_time:77072ms step_avg:88.18ms +step:875/1670 train_time:77160ms step_avg:88.18ms +step:875/1670 val_loss:3.5183 train_time:77250ms step_avg:88.29ms +step:876/1670 train_time:77269ms step_avg:88.21ms +step:877/1670 train_time:77342ms step_avg:88.19ms +step:878/1670 train_time:77437ms step_avg:88.20ms +step:879/1670 train_time:77527ms step_avg:88.20ms +step:880/1670 train_time:77614ms step_avg:88.20ms +step:881/1670 train_time:77703ms step_avg:88.20ms +step:882/1670 train_time:77791ms step_avg:88.20ms +step:883/1670 train_time:77878ms step_avg:88.20ms +step:884/1670 train_time:77966ms step_avg:88.20ms +step:885/1670 train_time:78054ms step_avg:88.20ms +step:886/1670 train_time:78141ms step_avg:88.20ms +step:887/1670 train_time:78231ms step_avg:88.20ms +step:888/1670 train_time:78323ms step_avg:88.20ms +step:889/1670 train_time:78414ms step_avg:88.20ms +step:890/1670 train_time:78505ms step_avg:88.21ms +step:891/1670 train_time:78593ms step_avg:88.21ms +step:892/1670 train_time:78682ms step_avg:88.21ms +step:893/1670 train_time:78771ms step_avg:88.21ms +step:894/1670 train_time:78858ms step_avg:88.21ms +step:895/1670 train_time:78946ms step_avg:88.21ms +step:896/1670 train_time:79034ms step_avg:88.21ms +step:897/1670 train_time:79122ms step_avg:88.21ms +step:898/1670 train_time:79211ms step_avg:88.21ms +step:899/1670 train_time:79301ms step_avg:88.21ms +step:900/1670 train_time:79393ms step_avg:88.21ms +step:901/1670 train_time:79484ms step_avg:88.22ms +step:902/1670 train_time:79574ms step_avg:88.22ms +step:903/1670 train_time:79663ms step_avg:88.22ms +step:904/1670 train_time:79752ms step_avg:88.22ms +step:905/1670 train_time:79840ms step_avg:88.22ms +step:906/1670 train_time:79928ms step_avg:88.22ms +step:907/1670 train_time:80016ms step_avg:88.22ms +step:908/1670 train_time:80104ms step_avg:88.22ms +step:909/1670 train_time:80193ms step_avg:88.22ms +step:910/1670 train_time:80282ms step_avg:88.22ms +step:911/1670 train_time:80374ms step_avg:88.23ms +step:912/1670 train_time:80464ms step_avg:88.23ms +step:913/1670 train_time:80554ms step_avg:88.23ms +step:914/1670 train_time:80643ms step_avg:88.23ms +step:915/1670 train_time:80732ms step_avg:88.23ms +step:916/1670 train_time:80821ms step_avg:88.23ms +step:917/1670 train_time:80909ms step_avg:88.23ms +step:918/1670 train_time:80997ms step_avg:88.23ms +step:919/1670 train_time:81085ms step_avg:88.23ms +step:920/1670 train_time:81174ms step_avg:88.23ms +step:921/1670 train_time:81263ms step_avg:88.23ms +step:922/1670 train_time:81353ms step_avg:88.24ms +step:923/1670 train_time:81443ms step_avg:88.24ms +step:924/1670 train_time:81533ms step_avg:88.24ms +step:925/1670 train_time:81622ms step_avg:88.24ms +step:926/1670 train_time:81711ms step_avg:88.24ms +step:927/1670 train_time:81799ms step_avg:88.24ms +step:928/1670 train_time:81888ms step_avg:88.24ms +step:929/1670 train_time:81976ms step_avg:88.24ms +step:930/1670 train_time:82065ms step_avg:88.24ms +step:931/1670 train_time:82154ms step_avg:88.24ms +step:932/1670 train_time:82243ms step_avg:88.24ms +step:933/1670 train_time:82333ms step_avg:88.24ms +step:934/1670 train_time:82423ms step_avg:88.25ms +step:935/1670 train_time:82513ms step_avg:88.25ms +step:936/1670 train_time:82602ms step_avg:88.25ms +step:937/1670 train_time:82692ms step_avg:88.25ms +step:938/1670 train_time:82780ms step_avg:88.25ms +step:939/1670 train_time:82870ms step_avg:88.25ms +step:940/1670 train_time:82958ms step_avg:88.25ms +step:941/1670 train_time:83047ms step_avg:88.25ms +step:942/1670 train_time:83136ms step_avg:88.25ms +step:943/1670 train_time:83225ms step_avg:88.26ms +step:944/1670 train_time:83315ms step_avg:88.26ms +step:945/1670 train_time:83404ms step_avg:88.26ms +step:946/1670 train_time:83494ms step_avg:88.26ms +step:947/1670 train_time:83583ms step_avg:88.26ms +step:948/1670 train_time:83673ms step_avg:88.26ms +step:949/1670 train_time:83763ms step_avg:88.26ms +step:950/1670 train_time:83851ms step_avg:88.26ms +step:951/1670 train_time:83941ms step_avg:88.27ms +step:952/1670 train_time:84030ms step_avg:88.27ms +step:953/1670 train_time:84118ms step_avg:88.27ms +step:954/1670 train_time:84207ms step_avg:88.27ms +step:955/1670 train_time:84296ms step_avg:88.27ms +step:956/1670 train_time:84384ms step_avg:88.27ms +step:957/1670 train_time:84474ms step_avg:88.27ms +step:958/1670 train_time:84563ms step_avg:88.27ms +step:959/1670 train_time:84652ms step_avg:88.27ms +step:960/1670 train_time:84741ms step_avg:88.27ms +step:961/1670 train_time:84831ms step_avg:88.27ms +step:962/1670 train_time:84919ms step_avg:88.27ms +step:963/1670 train_time:85008ms step_avg:88.27ms +step:964/1670 train_time:85097ms step_avg:88.28ms +step:965/1670 train_time:85187ms step_avg:88.28ms +step:966/1670 train_time:85275ms step_avg:88.28ms +step:967/1670 train_time:85365ms step_avg:88.28ms +step:968/1670 train_time:85453ms step_avg:88.28ms +step:969/1670 train_time:85543ms step_avg:88.28ms +step:970/1670 train_time:85632ms step_avg:88.28ms +step:971/1670 train_time:85722ms step_avg:88.28ms +step:972/1670 train_time:85810ms step_avg:88.28ms +step:973/1670 train_time:85899ms step_avg:88.28ms +step:974/1670 train_time:85988ms step_avg:88.28ms +step:975/1670 train_time:86076ms step_avg:88.28ms +step:976/1670 train_time:86166ms step_avg:88.28ms +step:977/1670 train_time:86255ms step_avg:88.29ms +step:978/1670 train_time:86343ms step_avg:88.29ms +step:979/1670 train_time:86433ms step_avg:88.29ms +step:980/1670 train_time:86522ms step_avg:88.29ms +step:981/1670 train_time:86611ms step_avg:88.29ms +step:982/1670 train_time:86700ms step_avg:88.29ms +step:983/1670 train_time:86789ms step_avg:88.29ms +step:984/1670 train_time:86878ms step_avg:88.29ms +step:985/1670 train_time:86966ms step_avg:88.29ms +step:986/1670 train_time:87055ms step_avg:88.29ms +step:987/1670 train_time:87145ms step_avg:88.29ms +step:988/1670 train_time:87234ms step_avg:88.29ms +step:989/1670 train_time:87323ms step_avg:88.29ms +step:990/1670 train_time:87412ms step_avg:88.29ms +step:991/1670 train_time:87500ms step_avg:88.29ms +step:992/1670 train_time:87590ms step_avg:88.30ms +step:993/1670 train_time:87679ms step_avg:88.30ms +step:994/1670 train_time:87768ms step_avg:88.30ms +step:995/1670 train_time:87856ms step_avg:88.30ms +step:996/1670 train_time:87945ms step_avg:88.30ms +step:997/1670 train_time:88034ms step_avg:88.30ms +step:998/1670 train_time:88123ms step_avg:88.30ms +step:999/1670 train_time:88212ms step_avg:88.30ms +step:1000/1670 train_time:88301ms step_avg:88.30ms +step:1000/1670 val_loss:3.4677 train_time:88391ms step_avg:88.39ms +step:1001/1670 train_time:88412ms step_avg:88.32ms +step:1002/1670 train_time:88485ms step_avg:88.31ms +step:1003/1670 train_time:88578ms step_avg:88.31ms +step:1004/1670 train_time:88667ms step_avg:88.31ms +step:1005/1670 train_time:88755ms step_avg:88.31ms +step:1006/1670 train_time:88843ms step_avg:88.31ms +step:1007/1670 train_time:88931ms step_avg:88.31ms +step:1008/1670 train_time:89019ms step_avg:88.31ms +step:1009/1670 train_time:89106ms step_avg:88.31ms +step:1010/1670 train_time:89194ms step_avg:88.31ms +step:1011/1670 train_time:89282ms step_avg:88.31ms +step:1012/1670 train_time:89372ms step_avg:88.31ms +step:1013/1670 train_time:89465ms step_avg:88.32ms +step:1014/1670 train_time:89557ms step_avg:88.32ms +step:1015/1670 train_time:89645ms step_avg:88.32ms +step:1016/1670 train_time:89734ms step_avg:88.32ms +step:1017/1670 train_time:89823ms step_avg:88.32ms +step:1018/1670 train_time:89911ms step_avg:88.32ms +step:1019/1670 train_time:90000ms step_avg:88.32ms +step:1020/1670 train_time:90087ms step_avg:88.32ms +step:1021/1670 train_time:90175ms step_avg:88.32ms +step:1022/1670 train_time:90264ms step_avg:88.32ms +step:1023/1670 train_time:90353ms step_avg:88.32ms +step:1024/1670 train_time:90444ms step_avg:88.32ms +step:1025/1670 train_time:90534ms step_avg:88.33ms +step:1026/1670 train_time:90624ms step_avg:88.33ms +step:1027/1670 train_time:90714ms step_avg:88.33ms +step:1028/1670 train_time:90804ms step_avg:88.33ms +step:1029/1670 train_time:90892ms step_avg:88.33ms +step:1030/1670 train_time:90980ms step_avg:88.33ms +step:1031/1670 train_time:91069ms step_avg:88.33ms +step:1032/1670 train_time:91157ms step_avg:88.33ms +step:1033/1670 train_time:91244ms step_avg:88.33ms +step:1034/1670 train_time:91334ms step_avg:88.33ms +step:1035/1670 train_time:91424ms step_avg:88.33ms +step:1036/1670 train_time:91514ms step_avg:88.33ms +step:1037/1670 train_time:91603ms step_avg:88.33ms +step:1038/1670 train_time:91693ms step_avg:88.34ms +step:1039/1670 train_time:91783ms step_avg:88.34ms +step:1040/1670 train_time:91872ms step_avg:88.34ms +step:1041/1670 train_time:91961ms step_avg:88.34ms +step:1042/1670 train_time:92050ms step_avg:88.34ms +step:1043/1670 train_time:92139ms step_avg:88.34ms +step:1044/1670 train_time:92227ms step_avg:88.34ms +step:1045/1670 train_time:92316ms step_avg:88.34ms +step:1046/1670 train_time:92405ms step_avg:88.34ms +step:1047/1670 train_time:92495ms step_avg:88.34ms +step:1048/1670 train_time:92584ms step_avg:88.34ms +step:1049/1670 train_time:92673ms step_avg:88.34ms +step:1050/1670 train_time:92763ms step_avg:88.35ms +step:1051/1670 train_time:92852ms step_avg:88.35ms +step:1052/1670 train_time:92941ms step_avg:88.35ms +step:1053/1670 train_time:93029ms step_avg:88.35ms +step:1054/1670 train_time:93117ms step_avg:88.35ms +step:1055/1670 train_time:93205ms step_avg:88.35ms +step:1056/1670 train_time:93294ms step_avg:88.35ms +step:1057/1670 train_time:93383ms step_avg:88.35ms +step:1058/1670 train_time:93473ms step_avg:88.35ms +step:1059/1670 train_time:93562ms step_avg:88.35ms +step:1060/1670 train_time:93652ms step_avg:88.35ms +step:1061/1670 train_time:93741ms step_avg:88.35ms +step:1062/1670 train_time:93830ms step_avg:88.35ms +step:1063/1670 train_time:93919ms step_avg:88.35ms +step:1064/1670 train_time:94007ms step_avg:88.35ms +step:1065/1670 train_time:94096ms step_avg:88.35ms +step:1066/1670 train_time:94184ms step_avg:88.35ms +step:1067/1670 train_time:94273ms step_avg:88.35ms +step:1068/1670 train_time:94363ms step_avg:88.35ms +step:1069/1670 train_time:94453ms step_avg:88.36ms +step:1070/1670 train_time:94542ms step_avg:88.36ms +step:1071/1670 train_time:94630ms step_avg:88.36ms +step:1072/1670 train_time:94719ms step_avg:88.36ms +step:1073/1670 train_time:94808ms step_avg:88.36ms +step:1074/1670 train_time:94897ms step_avg:88.36ms +step:1075/1670 train_time:94986ms step_avg:88.36ms +step:1076/1670 train_time:95075ms step_avg:88.36ms +step:1077/1670 train_time:95163ms step_avg:88.36ms +step:1078/1670 train_time:95252ms step_avg:88.36ms +step:1079/1670 train_time:95340ms step_avg:88.36ms +step:1080/1670 train_time:95429ms step_avg:88.36ms +step:1081/1670 train_time:95519ms step_avg:88.36ms +step:1082/1670 train_time:95608ms step_avg:88.36ms +step:1083/1670 train_time:95697ms step_avg:88.36ms +step:1084/1670 train_time:95785ms step_avg:88.36ms +step:1085/1670 train_time:95874ms step_avg:88.36ms +step:1086/1670 train_time:95963ms step_avg:88.36ms +step:1087/1670 train_time:96052ms step_avg:88.36ms +step:1088/1670 train_time:96141ms step_avg:88.36ms +step:1089/1670 train_time:96229ms step_avg:88.36ms +step:1090/1670 train_time:96319ms step_avg:88.37ms +step:1091/1670 train_time:96408ms step_avg:88.37ms +step:1092/1670 train_time:96497ms step_avg:88.37ms +step:1093/1670 train_time:96586ms step_avg:88.37ms +step:1094/1670 train_time:96676ms step_avg:88.37ms +step:1095/1670 train_time:96765ms step_avg:88.37ms +step:1096/1670 train_time:96855ms step_avg:88.37ms +step:1097/1670 train_time:96944ms step_avg:88.37ms +step:1098/1670 train_time:97035ms step_avg:88.37ms +step:1099/1670 train_time:97125ms step_avg:88.38ms +step:1100/1670 train_time:97214ms step_avg:88.38ms +step:1101/1670 train_time:97304ms step_avg:88.38ms +step:1102/1670 train_time:97394ms step_avg:88.38ms +step:1103/1670 train_time:97483ms step_avg:88.38ms +step:1104/1670 train_time:97573ms step_avg:88.38ms +step:1105/1670 train_time:97664ms step_avg:88.38ms +step:1106/1670 train_time:97754ms step_avg:88.39ms +step:1107/1670 train_time:97844ms step_avg:88.39ms +step:1108/1670 train_time:97934ms step_avg:88.39ms +step:1109/1670 train_time:98023ms step_avg:88.39ms +step:1110/1670 train_time:98112ms step_avg:88.39ms +step:1111/1670 train_time:98202ms step_avg:88.39ms +step:1112/1670 train_time:98293ms step_avg:88.39ms +step:1113/1670 train_time:98383ms step_avg:88.39ms +step:1114/1670 train_time:98473ms step_avg:88.40ms +step:1115/1670 train_time:98562ms step_avg:88.40ms +step:1116/1670 train_time:98653ms step_avg:88.40ms +step:1117/1670 train_time:98743ms step_avg:88.40ms +step:1118/1670 train_time:98833ms step_avg:88.40ms +step:1119/1670 train_time:98923ms step_avg:88.40ms +step:1120/1670 train_time:99012ms step_avg:88.40ms +step:1121/1670 train_time:99101ms step_avg:88.40ms +step:1122/1670 train_time:99191ms step_avg:88.41ms +step:1123/1670 train_time:99281ms step_avg:88.41ms +step:1124/1670 train_time:99370ms step_avg:88.41ms +step:1125/1670 train_time:99460ms step_avg:88.41ms +step:1125/1670 val_loss:3.4135 train_time:99551ms step_avg:88.49ms +step:1126/1670 train_time:99571ms step_avg:88.43ms +step:1127/1670 train_time:99642ms step_avg:88.41ms +step:1128/1670 train_time:99731ms step_avg:88.41ms +step:1129/1670 train_time:99822ms step_avg:88.42ms +step:1130/1670 train_time:99910ms step_avg:88.42ms +step:1131/1670 train_time:99999ms step_avg:88.42ms +step:1132/1670 train_time:100087ms step_avg:88.42ms +step:1133/1670 train_time:100176ms step_avg:88.42ms +step:1134/1670 train_time:100266ms step_avg:88.42ms +step:1135/1670 train_time:100354ms step_avg:88.42ms +step:1136/1670 train_time:100445ms step_avg:88.42ms +step:1137/1670 train_time:100536ms step_avg:88.42ms +step:1138/1670 train_time:100628ms step_avg:88.43ms +step:1139/1670 train_time:100721ms step_avg:88.43ms +step:1140/1670 train_time:100811ms step_avg:88.43ms +step:1141/1670 train_time:100900ms step_avg:88.43ms +step:1142/1670 train_time:100990ms step_avg:88.43ms +step:1143/1670 train_time:101079ms step_avg:88.43ms +step:1144/1670 train_time:101167ms step_avg:88.43ms +step:1145/1670 train_time:101256ms step_avg:88.43ms +step:1146/1670 train_time:101345ms step_avg:88.43ms +step:1147/1670 train_time:101434ms step_avg:88.43ms +step:1148/1670 train_time:101525ms step_avg:88.44ms +step:1149/1670 train_time:101616ms step_avg:88.44ms +step:1150/1670 train_time:101707ms step_avg:88.44ms +step:1151/1670 train_time:101797ms step_avg:88.44ms +step:1152/1670 train_time:101887ms step_avg:88.44ms +step:1153/1670 train_time:101977ms step_avg:88.44ms +step:1154/1670 train_time:102066ms step_avg:88.45ms +step:1155/1670 train_time:102155ms step_avg:88.45ms +step:1156/1670 train_time:102244ms step_avg:88.45ms +step:1157/1670 train_time:102332ms step_avg:88.45ms +step:1158/1670 train_time:102423ms step_avg:88.45ms +step:1159/1670 train_time:102512ms step_avg:88.45ms +step:1160/1670 train_time:102604ms step_avg:88.45ms +step:1161/1670 train_time:102695ms step_avg:88.45ms +step:1162/1670 train_time:102785ms step_avg:88.46ms +step:1163/1670 train_time:102875ms step_avg:88.46ms +step:1164/1670 train_time:102965ms step_avg:88.46ms +step:1165/1670 train_time:103053ms step_avg:88.46ms +step:1166/1670 train_time:103143ms step_avg:88.46ms +step:1167/1670 train_time:103231ms step_avg:88.46ms +step:1168/1670 train_time:103321ms step_avg:88.46ms +step:1169/1670 train_time:103410ms step_avg:88.46ms +step:1170/1670 train_time:103499ms step_avg:88.46ms +step:1171/1670 train_time:103591ms step_avg:88.46ms +step:1172/1670 train_time:103682ms step_avg:88.47ms +step:1173/1670 train_time:103771ms step_avg:88.47ms +step:1174/1670 train_time:103862ms step_avg:88.47ms +step:1175/1670 train_time:103951ms step_avg:88.47ms +step:1176/1670 train_time:104041ms step_avg:88.47ms +step:1177/1670 train_time:104130ms step_avg:88.47ms +step:1178/1670 train_time:104219ms step_avg:88.47ms +step:1179/1670 train_time:104309ms step_avg:88.47ms +step:1180/1670 train_time:104398ms step_avg:88.47ms +step:1181/1670 train_time:104488ms step_avg:88.47ms +step:1182/1670 train_time:104578ms step_avg:88.48ms +step:1183/1670 train_time:104668ms step_avg:88.48ms +step:1184/1670 train_time:104758ms step_avg:88.48ms +step:1185/1670 train_time:104848ms step_avg:88.48ms +step:1186/1670 train_time:104937ms step_avg:88.48ms +step:1187/1670 train_time:105028ms step_avg:88.48ms +step:1188/1670 train_time:105117ms step_avg:88.48ms +step:1189/1670 train_time:105206ms step_avg:88.48ms +step:1190/1670 train_time:105296ms step_avg:88.48ms +step:1191/1670 train_time:105386ms step_avg:88.49ms +step:1192/1670 train_time:105475ms step_avg:88.49ms +step:1193/1670 train_time:105566ms step_avg:88.49ms +step:1194/1670 train_time:105656ms step_avg:88.49ms +step:1195/1670 train_time:105747ms step_avg:88.49ms +step:1196/1670 train_time:105837ms step_avg:88.49ms +step:1197/1670 train_time:105927ms step_avg:88.49ms +step:1198/1670 train_time:106016ms step_avg:88.49ms +step:1199/1670 train_time:106106ms step_avg:88.50ms +step:1200/1670 train_time:106196ms step_avg:88.50ms +step:1201/1670 train_time:106286ms step_avg:88.50ms +step:1202/1670 train_time:106375ms step_avg:88.50ms +step:1203/1670 train_time:106464ms step_avg:88.50ms +step:1204/1670 train_time:106554ms step_avg:88.50ms +step:1205/1670 train_time:106644ms step_avg:88.50ms +step:1206/1670 train_time:106734ms step_avg:88.50ms +step:1207/1670 train_time:106825ms step_avg:88.50ms +step:1208/1670 train_time:106914ms step_avg:88.50ms +step:1209/1670 train_time:107003ms step_avg:88.51ms +step:1210/1670 train_time:107093ms step_avg:88.51ms +step:1211/1670 train_time:107182ms step_avg:88.51ms +step:1212/1670 train_time:107272ms step_avg:88.51ms +step:1213/1670 train_time:107361ms step_avg:88.51ms +step:1214/1670 train_time:107450ms step_avg:88.51ms +step:1215/1670 train_time:107540ms step_avg:88.51ms +step:1216/1670 train_time:107630ms step_avg:88.51ms +step:1217/1670 train_time:107721ms step_avg:88.51ms +step:1218/1670 train_time:107810ms step_avg:88.51ms +step:1219/1670 train_time:107900ms step_avg:88.52ms +step:1220/1670 train_time:107990ms step_avg:88.52ms +step:1221/1670 train_time:108080ms step_avg:88.52ms +step:1222/1670 train_time:108170ms step_avg:88.52ms +step:1223/1670 train_time:108259ms step_avg:88.52ms +step:1224/1670 train_time:108349ms step_avg:88.52ms +step:1225/1670 train_time:108438ms step_avg:88.52ms +step:1226/1670 train_time:108527ms step_avg:88.52ms +step:1227/1670 train_time:108617ms step_avg:88.52ms +step:1228/1670 train_time:108706ms step_avg:88.52ms +step:1229/1670 train_time:108796ms step_avg:88.52ms +step:1230/1670 train_time:108886ms step_avg:88.53ms +step:1231/1670 train_time:108975ms step_avg:88.53ms +step:1232/1670 train_time:109065ms step_avg:88.53ms +step:1233/1670 train_time:109155ms step_avg:88.53ms +step:1234/1670 train_time:109246ms step_avg:88.53ms +step:1235/1670 train_time:109335ms step_avg:88.53ms +step:1236/1670 train_time:109424ms step_avg:88.53ms +step:1237/1670 train_time:109513ms step_avg:88.53ms +step:1238/1670 train_time:109604ms step_avg:88.53ms +step:1239/1670 train_time:109693ms step_avg:88.53ms +step:1240/1670 train_time:109783ms step_avg:88.53ms +step:1241/1670 train_time:109873ms step_avg:88.54ms +step:1242/1670 train_time:109963ms step_avg:88.54ms +step:1243/1670 train_time:110052ms step_avg:88.54ms +step:1244/1670 train_time:110141ms step_avg:88.54ms +step:1245/1670 train_time:110230ms step_avg:88.54ms +step:1246/1670 train_time:110321ms step_avg:88.54ms +step:1247/1670 train_time:110409ms step_avg:88.54ms +step:1248/1670 train_time:110499ms step_avg:88.54ms +step:1249/1670 train_time:110589ms step_avg:88.54ms +step:1250/1670 train_time:110679ms step_avg:88.54ms +step:1250/1670 val_loss:3.3755 train_time:110770ms step_avg:88.62ms +step:1251/1670 train_time:110790ms step_avg:88.56ms +step:1252/1670 train_time:110867ms step_avg:88.55ms +step:1253/1670 train_time:110960ms step_avg:88.56ms +step:1254/1670 train_time:111051ms step_avg:88.56ms +step:1255/1670 train_time:111140ms step_avg:88.56ms +step:1256/1670 train_time:111228ms step_avg:88.56ms +step:1257/1670 train_time:111316ms step_avg:88.56ms +step:1258/1670 train_time:111405ms step_avg:88.56ms +step:1259/1670 train_time:111493ms step_avg:88.56ms +step:1260/1670 train_time:111582ms step_avg:88.56ms +step:1261/1670 train_time:111671ms step_avg:88.56ms +step:1262/1670 train_time:111762ms step_avg:88.56ms +step:1263/1670 train_time:111856ms step_avg:88.56ms +step:1264/1670 train_time:111949ms step_avg:88.57ms +step:1265/1670 train_time:112039ms step_avg:88.57ms +step:1266/1670 train_time:112129ms step_avg:88.57ms +step:1267/1670 train_time:112218ms step_avg:88.57ms +step:1268/1670 train_time:112307ms step_avg:88.57ms +step:1269/1670 train_time:112395ms step_avg:88.57ms +step:1270/1670 train_time:112484ms step_avg:88.57ms +step:1271/1670 train_time:112573ms step_avg:88.57ms +step:1272/1670 train_time:112661ms step_avg:88.57ms +step:1273/1670 train_time:112753ms step_avg:88.57ms +step:1274/1670 train_time:112844ms step_avg:88.57ms +step:1275/1670 train_time:112936ms step_avg:88.58ms +step:1276/1670 train_time:113027ms step_avg:88.58ms +step:1277/1670 train_time:113117ms step_avg:88.58ms +step:1278/1670 train_time:113206ms step_avg:88.58ms +step:1279/1670 train_time:113294ms step_avg:88.58ms +step:1280/1670 train_time:113383ms step_avg:88.58ms +step:1281/1670 train_time:113472ms step_avg:88.58ms +step:1282/1670 train_time:113561ms step_avg:88.58ms +step:1283/1670 train_time:113651ms step_avg:88.58ms +step:1284/1670 train_time:113740ms step_avg:88.58ms +step:1285/1670 train_time:113832ms step_avg:88.59ms +step:1286/1670 train_time:113923ms step_avg:88.59ms +step:1287/1670 train_time:114015ms step_avg:88.59ms +step:1288/1670 train_time:114107ms step_avg:88.59ms +step:1289/1670 train_time:114196ms step_avg:88.59ms +step:1290/1670 train_time:114286ms step_avg:88.59ms +step:1291/1670 train_time:114375ms step_avg:88.59ms +step:1292/1670 train_time:114464ms step_avg:88.59ms +step:1293/1670 train_time:114553ms step_avg:88.59ms +step:1294/1670 train_time:114642ms step_avg:88.59ms +step:1295/1670 train_time:114732ms step_avg:88.60ms +step:1296/1670 train_time:114823ms step_avg:88.60ms +step:1297/1670 train_time:114914ms step_avg:88.60ms +step:1298/1670 train_time:115005ms step_avg:88.60ms +step:1299/1670 train_time:115095ms step_avg:88.60ms +step:1300/1670 train_time:115185ms step_avg:88.60ms +step:1301/1670 train_time:115274ms step_avg:88.60ms +step:1302/1670 train_time:115363ms step_avg:88.60ms +step:1303/1670 train_time:115453ms step_avg:88.61ms +step:1304/1670 train_time:115542ms step_avg:88.61ms +step:1305/1670 train_time:115631ms step_avg:88.61ms +step:1306/1670 train_time:115720ms step_avg:88.61ms +step:1307/1670 train_time:115812ms step_avg:88.61ms +step:1308/1670 train_time:115903ms step_avg:88.61ms +step:1309/1670 train_time:115993ms step_avg:88.61ms +step:1310/1670 train_time:116084ms step_avg:88.61ms +step:1311/1670 train_time:116173ms step_avg:88.61ms +step:1312/1670 train_time:116263ms step_avg:88.61ms +step:1313/1670 train_time:116352ms step_avg:88.62ms +step:1314/1670 train_time:116442ms step_avg:88.62ms +step:1315/1670 train_time:116532ms step_avg:88.62ms +step:1316/1670 train_time:116622ms step_avg:88.62ms +step:1317/1670 train_time:116712ms step_avg:88.62ms +step:1318/1670 train_time:116801ms step_avg:88.62ms +step:1319/1670 train_time:116892ms step_avg:88.62ms +step:1320/1670 train_time:116984ms step_avg:88.62ms +step:1321/1670 train_time:117074ms step_avg:88.62ms +step:1322/1670 train_time:117163ms step_avg:88.63ms +step:1323/1670 train_time:117254ms step_avg:88.63ms +step:1324/1670 train_time:117343ms step_avg:88.63ms +step:1325/1670 train_time:117432ms step_avg:88.63ms +step:1326/1670 train_time:117523ms step_avg:88.63ms +step:1327/1670 train_time:117612ms step_avg:88.63ms +step:1328/1670 train_time:117702ms step_avg:88.63ms +step:1329/1670 train_time:117792ms step_avg:88.63ms +step:1330/1670 train_time:117881ms step_avg:88.63ms +step:1331/1670 train_time:117972ms step_avg:88.63ms +step:1332/1670 train_time:118062ms step_avg:88.63ms +step:1333/1670 train_time:118152ms step_avg:88.64ms +step:1334/1670 train_time:118242ms step_avg:88.64ms +step:1335/1670 train_time:118331ms step_avg:88.64ms +step:1336/1670 train_time:118420ms step_avg:88.64ms +step:1337/1670 train_time:118510ms step_avg:88.64ms +step:1338/1670 train_time:118599ms step_avg:88.64ms +step:1339/1670 train_time:118690ms step_avg:88.64ms +step:1340/1670 train_time:118779ms step_avg:88.64ms +step:1341/1670 train_time:118869ms step_avg:88.64ms +step:1342/1670 train_time:118958ms step_avg:88.64ms +step:1343/1670 train_time:119049ms step_avg:88.64ms +step:1344/1670 train_time:119138ms step_avg:88.64ms +step:1345/1670 train_time:119228ms step_avg:88.65ms +step:1346/1670 train_time:119317ms step_avg:88.65ms +step:1347/1670 train_time:119407ms step_avg:88.65ms +step:1348/1670 train_time:119496ms step_avg:88.65ms +step:1349/1670 train_time:119586ms step_avg:88.65ms +step:1350/1670 train_time:119676ms step_avg:88.65ms +step:1351/1670 train_time:119766ms step_avg:88.65ms +step:1352/1670 train_time:119855ms step_avg:88.65ms +step:1353/1670 train_time:119946ms step_avg:88.65ms +step:1354/1670 train_time:120037ms step_avg:88.65ms +step:1355/1670 train_time:120128ms step_avg:88.66ms +step:1356/1670 train_time:120217ms step_avg:88.66ms +step:1357/1670 train_time:120307ms step_avg:88.66ms +step:1358/1670 train_time:120396ms step_avg:88.66ms +step:1359/1670 train_time:120486ms step_avg:88.66ms +step:1360/1670 train_time:120575ms step_avg:88.66ms +step:1361/1670 train_time:120664ms step_avg:88.66ms +step:1362/1670 train_time:120755ms step_avg:88.66ms +step:1363/1670 train_time:120845ms step_avg:88.66ms +step:1364/1670 train_time:120935ms step_avg:88.66ms +step:1365/1670 train_time:121025ms step_avg:88.66ms +step:1366/1670 train_time:121114ms step_avg:88.66ms +step:1367/1670 train_time:121204ms step_avg:88.66ms +step:1368/1670 train_time:121294ms step_avg:88.66ms +step:1369/1670 train_time:121383ms step_avg:88.67ms +step:1370/1670 train_time:121472ms step_avg:88.67ms +step:1371/1670 train_time:121562ms step_avg:88.67ms +step:1372/1670 train_time:121653ms step_avg:88.67ms +step:1373/1670 train_time:121744ms step_avg:88.67ms +step:1374/1670 train_time:121833ms step_avg:88.67ms +step:1375/1670 train_time:121923ms step_avg:88.67ms +step:1375/1670 val_loss:3.3409 train_time:122014ms step_avg:88.74ms +step:1376/1670 train_time:122033ms step_avg:88.69ms +step:1377/1670 train_time:122108ms step_avg:88.68ms +step:1378/1670 train_time:122199ms step_avg:88.68ms +step:1379/1670 train_time:122289ms step_avg:88.68ms +step:1380/1670 train_time:122376ms step_avg:88.68ms +step:1381/1670 train_time:122465ms step_avg:88.68ms +step:1382/1670 train_time:122553ms step_avg:88.68ms +step:1383/1670 train_time:122641ms step_avg:88.68ms +step:1384/1670 train_time:122731ms step_avg:88.68ms +step:1385/1670 train_time:122821ms step_avg:88.68ms +step:1386/1670 train_time:122910ms step_avg:88.68ms +step:1387/1670 train_time:123002ms step_avg:88.68ms +step:1388/1670 train_time:123094ms step_avg:88.68ms +step:1389/1670 train_time:123186ms step_avg:88.69ms +step:1390/1670 train_time:123276ms step_avg:88.69ms +step:1391/1670 train_time:123366ms step_avg:88.69ms +step:1392/1670 train_time:123454ms step_avg:88.69ms +step:1393/1670 train_time:123544ms step_avg:88.69ms +step:1394/1670 train_time:123632ms step_avg:88.69ms +step:1395/1670 train_time:123722ms step_avg:88.69ms +step:1396/1670 train_time:123812ms step_avg:88.69ms +step:1397/1670 train_time:123901ms step_avg:88.69ms +step:1398/1670 train_time:123992ms step_avg:88.69ms +step:1399/1670 train_time:124084ms step_avg:88.69ms +step:1400/1670 train_time:124174ms step_avg:88.70ms +step:1401/1670 train_time:124265ms step_avg:88.70ms +step:1402/1670 train_time:124355ms step_avg:88.70ms +step:1403/1670 train_time:124445ms step_avg:88.70ms +step:1404/1670 train_time:124534ms step_avg:88.70ms +step:1405/1670 train_time:124624ms step_avg:88.70ms +step:1406/1670 train_time:124712ms step_avg:88.70ms +step:1407/1670 train_time:124801ms step_avg:88.70ms +step:1408/1670 train_time:124891ms step_avg:88.70ms +step:1409/1670 train_time:124981ms step_avg:88.70ms +step:1410/1670 train_time:125071ms step_avg:88.70ms +step:1411/1670 train_time:125161ms step_avg:88.70ms +step:1412/1670 train_time:125251ms step_avg:88.70ms +step:1413/1670 train_time:125342ms step_avg:88.71ms +step:1414/1670 train_time:125431ms step_avg:88.71ms +step:1415/1670 train_time:125521ms step_avg:88.71ms +step:1416/1670 train_time:125610ms step_avg:88.71ms +step:1417/1670 train_time:125699ms step_avg:88.71ms +step:1418/1670 train_time:125788ms step_avg:88.71ms +step:1419/1670 train_time:125878ms step_avg:88.71ms +step:1420/1670 train_time:125968ms step_avg:88.71ms +step:1421/1670 train_time:126058ms step_avg:88.71ms +step:1422/1670 train_time:126149ms step_avg:88.71ms +step:1423/1670 train_time:126239ms step_avg:88.71ms +step:1424/1670 train_time:126330ms step_avg:88.71ms +step:1425/1670 train_time:126419ms step_avg:88.71ms +step:1426/1670 train_time:126509ms step_avg:88.72ms +step:1427/1670 train_time:126598ms step_avg:88.72ms +step:1428/1670 train_time:126688ms step_avg:88.72ms +step:1429/1670 train_time:126777ms step_avg:88.72ms +step:1430/1670 train_time:126868ms step_avg:88.72ms +step:1431/1670 train_time:126956ms step_avg:88.72ms +step:1432/1670 train_time:127046ms step_avg:88.72ms +step:1433/1670 train_time:127136ms step_avg:88.72ms +step:1434/1670 train_time:127227ms step_avg:88.72ms +step:1435/1670 train_time:127317ms step_avg:88.72ms +step:1436/1670 train_time:127407ms step_avg:88.72ms +step:1437/1670 train_time:127496ms step_avg:88.72ms +step:1438/1670 train_time:127586ms step_avg:88.72ms +step:1439/1670 train_time:127675ms step_avg:88.73ms +step:1440/1670 train_time:127766ms step_avg:88.73ms +step:1441/1670 train_time:127855ms step_avg:88.73ms +step:1442/1670 train_time:127945ms step_avg:88.73ms +step:1443/1670 train_time:128034ms step_avg:88.73ms +step:1444/1670 train_time:128124ms step_avg:88.73ms +step:1445/1670 train_time:128215ms step_avg:88.73ms +step:1446/1670 train_time:128305ms step_avg:88.73ms +step:1447/1670 train_time:128394ms step_avg:88.73ms +step:1448/1670 train_time:128484ms step_avg:88.73ms +step:1449/1670 train_time:128573ms step_avg:88.73ms +step:1450/1670 train_time:128663ms step_avg:88.73ms +step:1451/1670 train_time:128752ms step_avg:88.73ms +step:1452/1670 train_time:128842ms step_avg:88.73ms +step:1453/1670 train_time:128932ms step_avg:88.73ms +step:1454/1670 train_time:129022ms step_avg:88.74ms +step:1455/1670 train_time:129112ms step_avg:88.74ms +step:1456/1670 train_time:129202ms step_avg:88.74ms +step:1457/1670 train_time:129292ms step_avg:88.74ms +step:1458/1670 train_time:129382ms step_avg:88.74ms +step:1459/1670 train_time:129471ms step_avg:88.74ms +step:1460/1670 train_time:129562ms step_avg:88.74ms +step:1461/1670 train_time:129651ms step_avg:88.74ms +step:1462/1670 train_time:129742ms step_avg:88.74ms +step:1463/1670 train_time:129833ms step_avg:88.74ms +step:1464/1670 train_time:129923ms step_avg:88.75ms +step:1465/1670 train_time:130013ms step_avg:88.75ms +step:1466/1670 train_time:130102ms step_avg:88.75ms +step:1467/1670 train_time:130191ms step_avg:88.75ms +step:1468/1670 train_time:130282ms step_avg:88.75ms +step:1469/1670 train_time:130372ms step_avg:88.75ms +step:1470/1670 train_time:130462ms step_avg:88.75ms +step:1471/1670 train_time:130552ms step_avg:88.75ms +step:1472/1670 train_time:130641ms step_avg:88.75ms +step:1473/1670 train_time:130731ms step_avg:88.75ms +step:1474/1670 train_time:130820ms step_avg:88.75ms +step:1475/1670 train_time:130910ms step_avg:88.75ms +step:1476/1670 train_time:130999ms step_avg:88.75ms +step:1477/1670 train_time:131089ms step_avg:88.75ms +step:1478/1670 train_time:131178ms step_avg:88.75ms +step:1479/1670 train_time:131269ms step_avg:88.75ms +step:1480/1670 train_time:131358ms step_avg:88.76ms +step:1481/1670 train_time:131448ms step_avg:88.76ms +step:1482/1670 train_time:131537ms step_avg:88.76ms +step:1483/1670 train_time:131627ms step_avg:88.76ms +step:1484/1670 train_time:131717ms step_avg:88.76ms +step:1485/1670 train_time:131806ms step_avg:88.76ms +step:1486/1670 train_time:131895ms step_avg:88.76ms +step:1487/1670 train_time:131984ms step_avg:88.76ms +step:1488/1670 train_time:132074ms step_avg:88.76ms +step:1489/1670 train_time:132163ms step_avg:88.76ms +step:1490/1670 train_time:132253ms step_avg:88.76ms +step:1491/1670 train_time:132343ms step_avg:88.76ms +step:1492/1670 train_time:132433ms step_avg:88.76ms +step:1493/1670 train_time:132523ms step_avg:88.76ms +step:1494/1670 train_time:132613ms step_avg:88.76ms +step:1495/1670 train_time:132704ms step_avg:88.77ms +step:1496/1670 train_time:132793ms step_avg:88.77ms +step:1497/1670 train_time:132882ms step_avg:88.77ms +step:1498/1670 train_time:132972ms step_avg:88.77ms +step:1499/1670 train_time:133062ms step_avg:88.77ms +step:1500/1670 train_time:133152ms step_avg:88.77ms +step:1500/1670 val_loss:3.3110 train_time:133243ms step_avg:88.83ms +step:1501/1670 train_time:133264ms step_avg:88.78ms +step:1502/1670 train_time:133340ms step_avg:88.77ms +step:1503/1670 train_time:133433ms step_avg:88.78ms +step:1504/1670 train_time:133523ms step_avg:88.78ms +step:1505/1670 train_time:133612ms step_avg:88.78ms +step:1506/1670 train_time:133699ms step_avg:88.78ms +step:1507/1670 train_time:133788ms step_avg:88.78ms +step:1508/1670 train_time:133876ms step_avg:88.78ms +step:1509/1670 train_time:133965ms step_avg:88.78ms +step:1510/1670 train_time:134055ms step_avg:88.78ms +step:1511/1670 train_time:134144ms step_avg:88.78ms +step:1512/1670 train_time:134235ms step_avg:88.78ms +step:1513/1670 train_time:134328ms step_avg:88.78ms +step:1514/1670 train_time:134420ms step_avg:88.78ms +step:1515/1670 train_time:134511ms step_avg:88.79ms +step:1516/1670 train_time:134600ms step_avg:88.79ms +step:1517/1670 train_time:134690ms step_avg:88.79ms +step:1518/1670 train_time:134779ms step_avg:88.79ms +step:1519/1670 train_time:134868ms step_avg:88.79ms +step:1520/1670 train_time:134957ms step_avg:88.79ms +step:1521/1670 train_time:135046ms step_avg:88.79ms +step:1522/1670 train_time:135135ms step_avg:88.79ms +step:1523/1670 train_time:135226ms step_avg:88.79ms +step:1524/1670 train_time:135318ms step_avg:88.79ms +step:1525/1670 train_time:135409ms step_avg:88.79ms +step:1526/1670 train_time:135498ms step_avg:88.79ms +step:1527/1670 train_time:135588ms step_avg:88.79ms +step:1528/1670 train_time:135678ms step_avg:88.79ms +step:1529/1670 train_time:135767ms step_avg:88.79ms +step:1530/1670 train_time:135856ms step_avg:88.80ms +step:1531/1670 train_time:135945ms step_avg:88.80ms +step:1532/1670 train_time:136034ms step_avg:88.80ms +step:1533/1670 train_time:136123ms step_avg:88.80ms +step:1534/1670 train_time:136213ms step_avg:88.80ms +step:1535/1670 train_time:136303ms step_avg:88.80ms +step:1536/1670 train_time:136394ms step_avg:88.80ms +step:1537/1670 train_time:136485ms step_avg:88.80ms +step:1538/1670 train_time:136575ms step_avg:88.80ms +step:1539/1670 train_time:136665ms step_avg:88.80ms +step:1540/1670 train_time:136754ms step_avg:88.80ms +step:1541/1670 train_time:136843ms step_avg:88.80ms +step:1542/1670 train_time:136932ms step_avg:88.80ms +step:1543/1670 train_time:137020ms step_avg:88.80ms +step:1544/1670 train_time:137109ms step_avg:88.80ms +step:1545/1670 train_time:137199ms step_avg:88.80ms +step:1546/1670 train_time:137289ms step_avg:88.80ms +step:1547/1670 train_time:137379ms step_avg:88.80ms +step:1548/1670 train_time:137469ms step_avg:88.80ms +step:1549/1670 train_time:137559ms step_avg:88.81ms +step:1550/1670 train_time:137649ms step_avg:88.81ms +step:1551/1670 train_time:137739ms step_avg:88.81ms +step:1552/1670 train_time:137828ms step_avg:88.81ms +step:1553/1670 train_time:137917ms step_avg:88.81ms +step:1554/1670 train_time:138006ms step_avg:88.81ms +step:1555/1670 train_time:138096ms step_avg:88.81ms +step:1556/1670 train_time:138187ms step_avg:88.81ms +step:1557/1670 train_time:138277ms step_avg:88.81ms +step:1558/1670 train_time:138366ms step_avg:88.81ms +step:1559/1670 train_time:138456ms step_avg:88.81ms +step:1560/1670 train_time:138547ms step_avg:88.81ms +step:1561/1670 train_time:138637ms step_avg:88.81ms +step:1562/1670 train_time:138727ms step_avg:88.81ms +step:1563/1670 train_time:138816ms step_avg:88.81ms +step:1564/1670 train_time:138905ms step_avg:88.81ms +step:1565/1670 train_time:138995ms step_avg:88.81ms +step:1566/1670 train_time:139086ms step_avg:88.82ms +step:1567/1670 train_time:139176ms step_avg:88.82ms +step:1568/1670 train_time:139267ms step_avg:88.82ms +step:1569/1670 train_time:139359ms step_avg:88.82ms +step:1570/1670 train_time:139450ms step_avg:88.82ms +step:1571/1670 train_time:139540ms step_avg:88.82ms +step:1572/1670 train_time:139630ms step_avg:88.82ms +step:1573/1670 train_time:139719ms step_avg:88.82ms +step:1574/1670 train_time:139809ms step_avg:88.82ms +step:1575/1670 train_time:139899ms step_avg:88.82ms +step:1576/1670 train_time:139988ms step_avg:88.83ms +step:1577/1670 train_time:140078ms step_avg:88.83ms +step:1578/1670 train_time:140168ms step_avg:88.83ms +step:1579/1670 train_time:140258ms step_avg:88.83ms +step:1580/1670 train_time:140349ms step_avg:88.83ms +step:1581/1670 train_time:140438ms step_avg:88.83ms +step:1582/1670 train_time:140528ms step_avg:88.83ms +step:1583/1670 train_time:140618ms step_avg:88.83ms +step:1584/1670 train_time:140707ms step_avg:88.83ms +step:1585/1670 train_time:140797ms step_avg:88.83ms +step:1586/1670 train_time:140887ms step_avg:88.83ms +step:1587/1670 train_time:140978ms step_avg:88.83ms +step:1588/1670 train_time:141069ms step_avg:88.83ms +step:1589/1670 train_time:141158ms step_avg:88.83ms +step:1590/1670 train_time:141248ms step_avg:88.84ms +step:1591/1670 train_time:141337ms step_avg:88.84ms +step:1592/1670 train_time:141427ms step_avg:88.84ms +step:1593/1670 train_time:141516ms step_avg:88.84ms +step:1594/1670 train_time:141607ms step_avg:88.84ms +step:1595/1670 train_time:141697ms step_avg:88.84ms +step:1596/1670 train_time:141786ms step_avg:88.84ms +step:1597/1670 train_time:141875ms step_avg:88.84ms +step:1598/1670 train_time:141965ms step_avg:88.84ms +step:1599/1670 train_time:142055ms step_avg:88.84ms +step:1600/1670 train_time:142145ms step_avg:88.84ms +step:1601/1670 train_time:142236ms step_avg:88.84ms +step:1602/1670 train_time:142326ms step_avg:88.84ms +step:1603/1670 train_time:142415ms step_avg:88.84ms +step:1604/1670 train_time:142505ms step_avg:88.84ms +step:1605/1670 train_time:142596ms step_avg:88.84ms +step:1606/1670 train_time:142686ms step_avg:88.85ms +step:1607/1670 train_time:142775ms step_avg:88.85ms +step:1608/1670 train_time:142866ms step_avg:88.85ms +step:1609/1670 train_time:142956ms step_avg:88.85ms +step:1610/1670 train_time:143047ms step_avg:88.85ms +step:1611/1670 train_time:143136ms step_avg:88.85ms +step:1612/1670 train_time:143226ms step_avg:88.85ms +step:1613/1670 train_time:143316ms step_avg:88.85ms +step:1614/1670 train_time:143406ms step_avg:88.85ms +step:1615/1670 train_time:143496ms step_avg:88.85ms +step:1616/1670 train_time:143586ms step_avg:88.85ms +step:1617/1670 train_time:143677ms step_avg:88.85ms +step:1618/1670 train_time:143767ms step_avg:88.85ms +step:1619/1670 train_time:143858ms step_avg:88.86ms +step:1620/1670 train_time:143948ms step_avg:88.86ms +step:1621/1670 train_time:144038ms step_avg:88.86ms +step:1622/1670 train_time:144127ms step_avg:88.86ms +step:1623/1670 train_time:144217ms step_avg:88.86ms +step:1624/1670 train_time:144306ms step_avg:88.86ms +step:1625/1670 train_time:144397ms step_avg:88.86ms +step:1625/1670 val_loss:3.2879 train_time:144489ms step_avg:88.92ms +step:1626/1670 train_time:144508ms step_avg:88.87ms +step:1627/1670 train_time:144584ms step_avg:88.87ms +step:1628/1670 train_time:144679ms step_avg:88.87ms +step:1629/1670 train_time:144771ms step_avg:88.87ms +step:1630/1670 train_time:144860ms step_avg:88.87ms +step:1631/1670 train_time:144949ms step_avg:88.87ms +step:1632/1670 train_time:145037ms step_avg:88.87ms +step:1633/1670 train_time:145126ms step_avg:88.87ms +step:1634/1670 train_time:145214ms step_avg:88.87ms +step:1635/1670 train_time:145302ms step_avg:88.87ms +step:1636/1670 train_time:145391ms step_avg:88.87ms +step:1637/1670 train_time:145482ms step_avg:88.87ms +step:1638/1670 train_time:145574ms step_avg:88.87ms +step:1639/1670 train_time:145665ms step_avg:88.87ms +step:1640/1670 train_time:145757ms step_avg:88.88ms +step:1641/1670 train_time:145847ms step_avg:88.88ms +step:1642/1670 train_time:145937ms step_avg:88.88ms +step:1643/1670 train_time:146026ms step_avg:88.88ms +step:1644/1670 train_time:146116ms step_avg:88.88ms +step:1645/1670 train_time:146204ms step_avg:88.88ms +step:1646/1670 train_time:146293ms step_avg:88.88ms +step:1647/1670 train_time:146381ms step_avg:88.88ms +step:1648/1670 train_time:146472ms step_avg:88.88ms +step:1649/1670 train_time:146562ms step_avg:88.88ms +step:1650/1670 train_time:146654ms step_avg:88.88ms +step:1651/1670 train_time:146744ms step_avg:88.88ms +step:1652/1670 train_time:146835ms step_avg:88.88ms +step:1653/1670 train_time:146924ms step_avg:88.88ms +step:1654/1670 train_time:147014ms step_avg:88.88ms +step:1655/1670 train_time:147103ms step_avg:88.88ms +step:1656/1670 train_time:147193ms step_avg:88.88ms +step:1657/1670 train_time:147281ms step_avg:88.88ms +step:1658/1670 train_time:147370ms step_avg:88.88ms +step:1659/1670 train_time:147461ms step_avg:88.89ms +step:1660/1670 train_time:147550ms step_avg:88.89ms +step:1661/1670 train_time:147641ms step_avg:88.89ms +step:1662/1670 train_time:147731ms step_avg:88.89ms +step:1663/1670 train_time:147821ms step_avg:88.89ms +step:1664/1670 train_time:147911ms step_avg:88.89ms +step:1665/1670 train_time:148001ms step_avg:88.89ms +step:1666/1670 train_time:148091ms step_avg:88.89ms +step:1667/1670 train_time:148180ms step_avg:88.89ms +step:1668/1670 train_time:148270ms step_avg:88.89ms +step:1669/1670 train_time:148360ms step_avg:88.89ms +step:1670/1670 train_time:148450ms step_avg:88.89ms +step:1670/1670 val_loss:3.2787 train_time:148542ms step_avg:88.95ms +peak memory allocated: 30760 MiB reserved: 45514 MiB diff --git a/records/092925_PolarExpress/f62629d6-8c01-4154-96d1-85945920514a.txt b/records/092925_PolarExpress/f62629d6-8c01-4154-96d1-85945920514a.txt new file mode 100644 index 000000000..83713771b --- /dev/null +++ b/records/092925_PolarExpress/f62629d6-8c01-4154-96d1-85945920514a.txt @@ -0,0 +1,3252 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from kernels import get_kernel +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[ + Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size() == 8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1, 10, 16, 16] + assert len(params) == sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx + size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx < len(params) else 0 + if getattr(params[module_idx], 'module', 'none') == 'attn': + for p in params[module_idx:module_idx + chunk_size]: + assert getattr(params[module_idx], 'module', 'none') == 'attn' + batch = 4 * original_shape[0] + d1 = original_shape[1] + d2 = original_shape[2] // 4 + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = polar_express(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = polar_express(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim // 4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim // 4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim * 4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module = 'attn' + with torch.no_grad(): + self.qkvo_w.view(4, self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4, self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, + 3 * self.num_heads, + self.head_dim).chunk( + 3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module = 'mlp' + self.c_proj.module = 'mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, + max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim ** 0.5) / 448, w_s=2 ** -9, + grad_s=1 / 448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), # extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, + long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1, len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i < 11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens = tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter == 5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter += 1 + return starts, ends + + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, + align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1630 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if + p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999, step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations + args.iteration_extension: + return args.ws_validate // 2, args.ws_validate + x = min(step / (1 + args.num_iterations), 0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[ + step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long < ws_long: + model.yarn.reset() + ws_long = new_ws_long + model(inputs, targets, cum_seqlens, ws_long // 2, ws_long).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.yarn.reset() +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, + grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations + args.iteration_extension +ws_short, ws_long = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws_short, new_ws_long = get_ws(step) + if new_ws_long != ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, + grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.5.0 +Mon Sep 29 06:18:24 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 36C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 35C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:141ms step_avg:141.22ms +step:2/1670 train_time:162ms step_avg:80.79ms +step:3/1670 train_time:225ms step_avg:74.92ms +step:4/1670 train_time:311ms step_avg:77.67ms +step:5/1670 train_time:397ms step_avg:79.47ms +step:6/1670 train_time:484ms step_avg:80.63ms +step:7/1670 train_time:570ms step_avg:81.45ms +step:8/1670 train_time:658ms step_avg:82.20ms +step:9/1670 train_time:744ms step_avg:82.71ms +step:10/1670 train_time:831ms step_avg:83.13ms +step:11/1670 train_time:919ms step_avg:83.51ms +step:12/1670 train_time:1007ms step_avg:83.89ms +step:13/1670 train_time:1099ms step_avg:84.54ms +step:14/1670 train_time:1189ms step_avg:84.94ms +step:15/1670 train_time:1277ms step_avg:85.14ms +step:16/1670 train_time:1364ms step_avg:85.27ms +step:17/1670 train_time:1452ms step_avg:85.42ms +step:18/1670 train_time:1539ms step_avg:85.50ms +step:19/1670 train_time:1626ms step_avg:85.60ms +step:20/1670 train_time:1714ms step_avg:85.70ms +step:21/1670 train_time:1801ms step_avg:85.78ms +step:22/1670 train_time:1888ms step_avg:85.82ms +step:23/1670 train_time:1976ms step_avg:85.93ms +step:24/1670 train_time:2066ms step_avg:86.09ms +step:25/1670 train_time:2155ms step_avg:86.19ms +step:26/1670 train_time:2243ms step_avg:86.27ms +step:27/1670 train_time:2332ms step_avg:86.38ms +step:28/1670 train_time:2421ms step_avg:86.45ms +step:29/1670 train_time:2508ms step_avg:86.47ms +step:30/1670 train_time:2595ms step_avg:86.51ms +step:31/1670 train_time:2682ms step_avg:86.53ms +step:32/1670 train_time:2770ms step_avg:86.55ms +step:33/1670 train_time:2857ms step_avg:86.57ms +step:34/1670 train_time:2944ms step_avg:86.60ms +step:35/1670 train_time:3032ms step_avg:86.64ms +step:36/1670 train_time:3121ms step_avg:86.69ms +step:37/1670 train_time:3208ms step_avg:86.72ms +step:38/1670 train_time:3297ms step_avg:86.77ms +step:39/1670 train_time:3386ms step_avg:86.81ms +step:40/1670 train_time:3473ms step_avg:86.82ms +step:41/1670 train_time:3561ms step_avg:86.86ms +step:42/1670 train_time:3648ms step_avg:86.86ms +step:43/1670 train_time:3736ms step_avg:86.88ms +step:44/1670 train_time:3823ms step_avg:86.89ms +step:45/1670 train_time:3911ms step_avg:86.91ms +step:46/1670 train_time:3999ms step_avg:86.94ms +step:47/1670 train_time:4087ms step_avg:86.95ms +step:48/1670 train_time:4175ms step_avg:86.98ms +step:49/1670 train_time:4264ms step_avg:87.02ms +step:50/1670 train_time:4352ms step_avg:87.04ms +step:51/1670 train_time:4440ms step_avg:87.05ms +step:52/1670 train_time:4528ms step_avg:87.07ms +step:53/1670 train_time:4615ms step_avg:87.07ms +step:54/1670 train_time:4702ms step_avg:87.08ms +step:55/1670 train_time:4789ms step_avg:87.08ms +step:56/1670 train_time:4877ms step_avg:87.09ms +step:57/1670 train_time:4964ms step_avg:87.09ms +step:58/1670 train_time:5052ms step_avg:87.10ms +step:59/1670 train_time:5140ms step_avg:87.12ms +step:60/1670 train_time:5228ms step_avg:87.14ms +step:61/1670 train_time:5317ms step_avg:87.16ms +step:62/1670 train_time:5404ms step_avg:87.17ms +step:63/1670 train_time:5492ms step_avg:87.18ms +step:64/1670 train_time:5580ms step_avg:87.19ms +step:65/1670 train_time:5667ms step_avg:87.19ms +step:66/1670 train_time:5755ms step_avg:87.19ms +step:67/1670 train_time:5842ms step_avg:87.20ms +step:68/1670 train_time:5930ms step_avg:87.20ms +step:69/1670 train_time:6018ms step_avg:87.21ms +step:70/1670 train_time:6105ms step_avg:87.22ms +step:71/1670 train_time:6192ms step_avg:87.22ms +step:72/1670 train_time:6281ms step_avg:87.24ms +step:73/1670 train_time:6369ms step_avg:87.24ms +step:74/1670 train_time:6457ms step_avg:87.26ms +step:75/1670 train_time:6544ms step_avg:87.25ms +step:76/1670 train_time:6632ms step_avg:87.27ms +step:77/1670 train_time:6720ms step_avg:87.27ms +step:78/1670 train_time:6807ms step_avg:87.27ms +step:79/1670 train_time:6895ms step_avg:87.28ms +step:80/1670 train_time:6982ms step_avg:87.28ms +step:81/1670 train_time:7070ms step_avg:87.28ms +step:82/1670 train_time:7157ms step_avg:87.29ms +step:83/1670 train_time:7245ms step_avg:87.29ms +step:84/1670 train_time:7333ms step_avg:87.30ms +step:85/1670 train_time:7422ms step_avg:87.32ms +step:86/1670 train_time:7510ms step_avg:87.32ms +step:87/1670 train_time:7598ms step_avg:87.33ms +step:88/1670 train_time:7685ms step_avg:87.33ms +step:89/1670 train_time:7772ms step_avg:87.33ms +step:90/1670 train_time:7859ms step_avg:87.33ms +step:91/1670 train_time:7946ms step_avg:87.32ms +step:92/1670 train_time:8034ms step_avg:87.33ms +step:93/1670 train_time:8123ms step_avg:87.34ms +step:94/1670 train_time:8210ms step_avg:87.34ms +step:95/1670 train_time:8297ms step_avg:87.34ms +step:96/1670 train_time:8385ms step_avg:87.34ms +step:97/1670 train_time:8473ms step_avg:87.35ms +step:98/1670 train_time:8561ms step_avg:87.36ms +step:99/1670 train_time:8649ms step_avg:87.37ms +step:100/1670 train_time:8737ms step_avg:87.37ms +step:101/1670 train_time:8825ms step_avg:87.37ms +step:102/1670 train_time:8912ms step_avg:87.37ms +step:103/1670 train_time:9000ms step_avg:87.38ms +step:104/1670 train_time:9087ms step_avg:87.38ms +step:105/1670 train_time:9175ms step_avg:87.38ms +step:106/1670 train_time:9263ms step_avg:87.39ms +step:107/1670 train_time:9351ms step_avg:87.39ms +step:108/1670 train_time:9439ms step_avg:87.40ms +step:109/1670 train_time:9527ms step_avg:87.40ms +step:110/1670 train_time:9615ms step_avg:87.41ms +step:111/1670 train_time:9703ms step_avg:87.41ms +step:112/1670 train_time:9790ms step_avg:87.41ms +step:113/1670 train_time:9877ms step_avg:87.41ms +step:114/1670 train_time:9965ms step_avg:87.41ms +step:115/1670 train_time:10053ms step_avg:87.42ms +step:116/1670 train_time:10141ms step_avg:87.42ms +step:117/1670 train_time:10228ms step_avg:87.42ms +step:118/1670 train_time:10316ms step_avg:87.42ms +step:119/1670 train_time:10403ms step_avg:87.42ms +step:120/1670 train_time:10491ms step_avg:87.43ms +step:121/1670 train_time:10579ms step_avg:87.43ms +step:122/1670 train_time:10666ms step_avg:87.43ms +step:123/1670 train_time:10754ms step_avg:87.43ms +step:124/1670 train_time:10842ms step_avg:87.44ms +step:125/1670 train_time:10930ms step_avg:87.44ms +step:125/1670 val_loss:4.3285 train_time:11020ms step_avg:88.16ms +step:126/1670 train_time:11039ms step_avg:87.61ms +step:127/1670 train_time:11109ms step_avg:87.47ms +step:128/1670 train_time:11203ms step_avg:87.52ms +step:129/1670 train_time:11294ms step_avg:87.55ms +step:130/1670 train_time:11382ms step_avg:87.55ms +step:131/1670 train_time:11469ms step_avg:87.55ms +step:132/1670 train_time:11555ms step_avg:87.54ms +step:133/1670 train_time:11641ms step_avg:87.53ms +step:134/1670 train_time:11728ms step_avg:87.52ms +step:135/1670 train_time:11814ms step_avg:87.51ms +step:136/1670 train_time:11900ms step_avg:87.50ms +step:137/1670 train_time:11987ms step_avg:87.49ms +step:138/1670 train_time:12075ms step_avg:87.50ms +step:139/1670 train_time:12165ms step_avg:87.52ms +step:140/1670 train_time:12255ms step_avg:87.53ms +step:141/1670 train_time:12343ms step_avg:87.54ms +step:142/1670 train_time:12431ms step_avg:87.54ms +step:143/1670 train_time:12518ms step_avg:87.54ms +step:144/1670 train_time:12605ms step_avg:87.53ms +step:145/1670 train_time:12692ms step_avg:87.53ms +step:146/1670 train_time:12778ms step_avg:87.52ms +step:147/1670 train_time:12865ms step_avg:87.52ms +step:148/1670 train_time:12952ms step_avg:87.51ms +step:149/1670 train_time:13039ms step_avg:87.51ms +step:150/1670 train_time:13128ms step_avg:87.52ms +step:151/1670 train_time:13217ms step_avg:87.53ms +step:152/1670 train_time:13306ms step_avg:87.54ms +step:153/1670 train_time:13393ms step_avg:87.54ms +step:154/1670 train_time:13482ms step_avg:87.55ms +step:155/1670 train_time:13569ms step_avg:87.54ms +step:156/1670 train_time:13656ms step_avg:87.54ms +step:157/1670 train_time:13742ms step_avg:87.53ms +step:158/1670 train_time:13828ms step_avg:87.52ms +step:159/1670 train_time:13915ms step_avg:87.52ms +step:160/1670 train_time:14003ms step_avg:87.52ms +step:161/1670 train_time:14091ms step_avg:87.52ms +step:162/1670 train_time:14180ms step_avg:87.53ms +step:163/1670 train_time:14269ms step_avg:87.54ms +step:164/1670 train_time:14357ms step_avg:87.54ms +step:165/1670 train_time:14445ms step_avg:87.54ms +step:166/1670 train_time:14532ms step_avg:87.55ms +step:167/1670 train_time:14619ms step_avg:87.54ms +step:168/1670 train_time:14707ms step_avg:87.54ms +step:169/1670 train_time:14793ms step_avg:87.53ms +step:170/1670 train_time:14880ms step_avg:87.53ms +step:171/1670 train_time:14967ms step_avg:87.53ms +step:172/1670 train_time:15054ms step_avg:87.52ms +step:173/1670 train_time:15142ms step_avg:87.53ms +step:174/1670 train_time:15230ms step_avg:87.53ms +step:175/1670 train_time:15318ms step_avg:87.53ms +step:176/1670 train_time:15407ms step_avg:87.54ms +step:177/1670 train_time:15494ms step_avg:87.53ms +step:178/1670 train_time:15582ms step_avg:87.54ms +step:179/1670 train_time:15669ms step_avg:87.54ms +step:180/1670 train_time:15755ms step_avg:87.53ms +step:181/1670 train_time:15842ms step_avg:87.53ms +step:182/1670 train_time:15930ms step_avg:87.53ms +step:183/1670 train_time:16017ms step_avg:87.52ms +step:184/1670 train_time:16105ms step_avg:87.53ms +step:185/1670 train_time:16192ms step_avg:87.53ms +step:186/1670 train_time:16280ms step_avg:87.53ms +step:187/1670 train_time:16368ms step_avg:87.53ms +step:188/1670 train_time:16455ms step_avg:87.53ms +step:189/1670 train_time:16543ms step_avg:87.53ms +step:190/1670 train_time:16630ms step_avg:87.53ms +step:191/1670 train_time:16717ms step_avg:87.52ms +step:192/1670 train_time:16804ms step_avg:87.52ms +step:193/1670 train_time:16891ms step_avg:87.52ms +step:194/1670 train_time:16978ms step_avg:87.52ms +step:195/1670 train_time:17066ms step_avg:87.52ms +step:196/1670 train_time:17153ms step_avg:87.52ms +step:197/1670 train_time:17242ms step_avg:87.52ms +step:198/1670 train_time:17329ms step_avg:87.52ms +step:199/1670 train_time:17417ms step_avg:87.52ms +step:200/1670 train_time:17504ms step_avg:87.52ms +step:201/1670 train_time:17591ms step_avg:87.52ms +step:202/1670 train_time:17679ms step_avg:87.52ms +step:203/1670 train_time:17767ms step_avg:87.52ms +step:204/1670 train_time:17854ms step_avg:87.52ms +step:205/1670 train_time:17941ms step_avg:87.52ms +step:206/1670 train_time:18028ms step_avg:87.51ms +step:207/1670 train_time:18115ms step_avg:87.51ms +step:208/1670 train_time:18202ms step_avg:87.51ms +step:209/1670 train_time:18290ms step_avg:87.51ms +step:210/1670 train_time:18377ms step_avg:87.51ms +step:211/1670 train_time:18465ms step_avg:87.51ms +step:212/1670 train_time:18552ms step_avg:87.51ms +step:213/1670 train_time:18640ms step_avg:87.51ms +step:214/1670 train_time:18727ms step_avg:87.51ms +step:215/1670 train_time:18814ms step_avg:87.51ms +step:216/1670 train_time:18902ms step_avg:87.51ms +step:217/1670 train_time:18989ms step_avg:87.51ms +step:218/1670 train_time:19076ms step_avg:87.51ms +step:219/1670 train_time:19164ms step_avg:87.51ms +step:220/1670 train_time:19251ms step_avg:87.51ms +step:221/1670 train_time:19339ms step_avg:87.51ms +step:222/1670 train_time:19427ms step_avg:87.51ms +step:223/1670 train_time:19514ms step_avg:87.51ms +step:224/1670 train_time:19601ms step_avg:87.51ms +step:225/1670 train_time:19690ms step_avg:87.51ms +step:226/1670 train_time:19777ms step_avg:87.51ms +step:227/1670 train_time:19864ms step_avg:87.51ms +step:228/1670 train_time:19951ms step_avg:87.50ms +step:229/1670 train_time:20038ms step_avg:87.50ms +step:230/1670 train_time:20125ms step_avg:87.50ms +step:231/1670 train_time:20213ms step_avg:87.50ms +step:232/1670 train_time:20300ms step_avg:87.50ms +step:233/1670 train_time:20389ms step_avg:87.50ms +step:234/1670 train_time:20476ms step_avg:87.50ms +step:235/1670 train_time:20564ms step_avg:87.50ms +step:236/1670 train_time:20651ms step_avg:87.51ms +step:237/1670 train_time:20739ms step_avg:87.51ms +step:238/1670 train_time:20827ms step_avg:87.51ms +step:239/1670 train_time:20914ms step_avg:87.51ms +step:240/1670 train_time:21002ms step_avg:87.51ms +step:241/1670 train_time:21089ms step_avg:87.51ms +step:242/1670 train_time:21177ms step_avg:87.51ms +step:243/1670 train_time:21265ms step_avg:87.51ms +step:244/1670 train_time:21352ms step_avg:87.51ms +step:245/1670 train_time:21439ms step_avg:87.51ms +step:246/1670 train_time:21526ms step_avg:87.50ms +step:247/1670 train_time:21614ms step_avg:87.51ms +step:248/1670 train_time:21701ms step_avg:87.51ms +step:249/1670 train_time:21789ms step_avg:87.51ms +step:250/1670 train_time:21877ms step_avg:87.51ms +step:250/1670 val_loss:3.9931 train_time:21966ms step_avg:87.86ms +step:251/1670 train_time:21986ms step_avg:87.59ms +step:252/1670 train_time:22056ms step_avg:87.52ms +step:253/1670 train_time:22148ms step_avg:87.54ms +step:254/1670 train_time:22238ms step_avg:87.55ms +step:255/1670 train_time:22325ms step_avg:87.55ms +step:256/1670 train_time:22411ms step_avg:87.54ms +step:257/1670 train_time:22497ms step_avg:87.54ms +step:258/1670 train_time:22585ms step_avg:87.54ms +step:259/1670 train_time:22672ms step_avg:87.54ms +step:260/1670 train_time:22759ms step_avg:87.53ms +step:261/1670 train_time:22845ms step_avg:87.53ms +step:262/1670 train_time:22932ms step_avg:87.53ms +step:263/1670 train_time:23021ms step_avg:87.53ms +step:264/1670 train_time:23110ms step_avg:87.54ms +step:265/1670 train_time:23200ms step_avg:87.55ms +step:266/1670 train_time:23288ms step_avg:87.55ms +step:267/1670 train_time:23375ms step_avg:87.55ms +step:268/1670 train_time:23462ms step_avg:87.54ms +step:269/1670 train_time:23548ms step_avg:87.54ms +step:270/1670 train_time:23635ms step_avg:87.54ms +step:271/1670 train_time:23723ms step_avg:87.54ms +step:272/1670 train_time:23809ms step_avg:87.53ms +step:273/1670 train_time:23895ms step_avg:87.53ms +step:274/1670 train_time:23983ms step_avg:87.53ms +step:275/1670 train_time:24071ms step_avg:87.53ms +step:276/1670 train_time:24160ms step_avg:87.54ms +step:277/1670 train_time:24248ms step_avg:87.54ms +step:278/1670 train_time:24335ms step_avg:87.54ms +step:279/1670 train_time:24423ms step_avg:87.54ms +step:280/1670 train_time:24510ms step_avg:87.54ms +step:281/1670 train_time:24598ms step_avg:87.54ms +step:282/1670 train_time:24684ms step_avg:87.53ms +step:283/1670 train_time:24771ms step_avg:87.53ms +step:284/1670 train_time:24859ms step_avg:87.53ms +step:285/1670 train_time:24946ms step_avg:87.53ms +step:286/1670 train_time:25035ms step_avg:87.53ms +step:287/1670 train_time:25124ms step_avg:87.54ms +step:288/1670 train_time:25211ms step_avg:87.54ms +step:289/1670 train_time:25298ms step_avg:87.54ms +step:290/1670 train_time:25386ms step_avg:87.54ms +step:291/1670 train_time:25473ms step_avg:87.54ms +step:292/1670 train_time:25560ms step_avg:87.53ms +step:293/1670 train_time:25647ms step_avg:87.53ms +step:294/1670 train_time:25734ms step_avg:87.53ms +step:295/1670 train_time:25821ms step_avg:87.53ms +step:296/1670 train_time:25908ms step_avg:87.53ms +step:297/1670 train_time:25996ms step_avg:87.53ms +step:298/1670 train_time:26084ms step_avg:87.53ms +step:299/1670 train_time:26171ms step_avg:87.53ms +step:300/1670 train_time:26260ms step_avg:87.53ms +step:301/1670 train_time:26347ms step_avg:87.53ms +step:302/1670 train_time:26435ms step_avg:87.53ms +step:303/1670 train_time:26523ms step_avg:87.53ms +step:304/1670 train_time:26610ms step_avg:87.53ms +step:305/1670 train_time:26698ms step_avg:87.53ms +step:306/1670 train_time:26785ms step_avg:87.53ms +step:307/1670 train_time:26872ms step_avg:87.53ms +step:308/1670 train_time:26959ms step_avg:87.53ms +step:309/1670 train_time:27047ms step_avg:87.53ms +step:310/1670 train_time:27136ms step_avg:87.53ms +step:311/1670 train_time:27224ms step_avg:87.54ms +step:312/1670 train_time:27312ms step_avg:87.54ms +step:313/1670 train_time:27399ms step_avg:87.54ms +step:314/1670 train_time:27487ms step_avg:87.54ms +step:315/1670 train_time:27574ms step_avg:87.54ms +step:316/1670 train_time:27662ms step_avg:87.54ms +step:317/1670 train_time:27748ms step_avg:87.53ms +step:318/1670 train_time:27836ms step_avg:87.53ms +step:319/1670 train_time:27924ms step_avg:87.54ms +step:320/1670 train_time:28011ms step_avg:87.53ms +step:321/1670 train_time:28099ms step_avg:87.53ms +step:322/1670 train_time:28186ms step_avg:87.53ms +step:323/1670 train_time:28274ms step_avg:87.54ms +step:324/1670 train_time:28362ms step_avg:87.54ms +step:325/1670 train_time:28449ms step_avg:87.54ms +step:326/1670 train_time:28537ms step_avg:87.54ms +step:327/1670 train_time:28625ms step_avg:87.54ms +step:328/1670 train_time:28712ms step_avg:87.54ms +step:329/1670 train_time:28800ms step_avg:87.54ms +step:330/1670 train_time:28887ms step_avg:87.54ms +step:331/1670 train_time:28975ms step_avg:87.54ms +step:332/1670 train_time:29063ms step_avg:87.54ms +step:333/1670 train_time:29150ms step_avg:87.54ms +step:334/1670 train_time:29238ms step_avg:87.54ms +step:335/1670 train_time:29326ms step_avg:87.54ms +step:336/1670 train_time:29413ms step_avg:87.54ms +step:337/1670 train_time:29501ms step_avg:87.54ms +step:338/1670 train_time:29589ms step_avg:87.54ms +step:339/1670 train_time:29675ms step_avg:87.54ms +step:340/1670 train_time:29763ms step_avg:87.54ms +step:341/1670 train_time:29850ms step_avg:87.54ms +step:342/1670 train_time:29938ms step_avg:87.54ms +step:343/1670 train_time:30026ms step_avg:87.54ms +step:344/1670 train_time:30114ms step_avg:87.54ms +step:345/1670 train_time:30202ms step_avg:87.54ms +step:346/1670 train_time:30289ms step_avg:87.54ms +step:347/1670 train_time:30376ms step_avg:87.54ms +step:348/1670 train_time:30464ms step_avg:87.54ms +step:349/1670 train_time:30550ms step_avg:87.54ms +step:350/1670 train_time:30638ms step_avg:87.54ms +step:351/1670 train_time:30727ms step_avg:87.54ms +step:352/1670 train_time:30814ms step_avg:87.54ms +step:353/1670 train_time:30902ms step_avg:87.54ms +step:354/1670 train_time:30990ms step_avg:87.54ms +step:355/1670 train_time:31077ms step_avg:87.54ms +step:356/1670 train_time:31165ms step_avg:87.54ms +step:357/1670 train_time:31252ms step_avg:87.54ms +step:358/1670 train_time:31339ms step_avg:87.54ms +step:359/1670 train_time:31427ms step_avg:87.54ms +step:360/1670 train_time:31514ms step_avg:87.54ms +step:361/1670 train_time:31602ms step_avg:87.54ms +step:362/1670 train_time:31690ms step_avg:87.54ms +step:363/1670 train_time:31777ms step_avg:87.54ms +step:364/1670 train_time:31865ms step_avg:87.54ms +step:365/1670 train_time:31953ms step_avg:87.54ms +step:366/1670 train_time:32042ms step_avg:87.55ms +step:367/1670 train_time:32129ms step_avg:87.55ms +step:368/1670 train_time:32216ms step_avg:87.54ms +step:369/1670 train_time:32304ms step_avg:87.55ms +step:370/1670 train_time:32392ms step_avg:87.55ms +step:371/1670 train_time:32479ms step_avg:87.54ms +step:372/1670 train_time:32567ms step_avg:87.55ms +step:373/1670 train_time:32655ms step_avg:87.55ms +step:374/1670 train_time:32743ms step_avg:87.55ms +step:375/1670 train_time:32830ms step_avg:87.55ms +step:375/1670 val_loss:3.8258 train_time:32919ms step_avg:87.78ms +step:376/1670 train_time:32938ms step_avg:87.60ms +step:377/1670 train_time:33009ms step_avg:87.56ms +step:378/1670 train_time:33099ms step_avg:87.56ms +step:379/1670 train_time:33190ms step_avg:87.57ms +step:380/1670 train_time:33278ms step_avg:87.57ms +step:381/1670 train_time:33365ms step_avg:87.57ms +step:382/1670 train_time:33451ms step_avg:87.57ms +step:383/1670 train_time:33538ms step_avg:87.57ms +step:384/1670 train_time:33624ms step_avg:87.56ms +step:385/1670 train_time:33711ms step_avg:87.56ms +step:386/1670 train_time:33798ms step_avg:87.56ms +step:387/1670 train_time:33887ms step_avg:87.56ms +step:388/1670 train_time:33976ms step_avg:87.57ms +step:389/1670 train_time:34066ms step_avg:87.57ms +step:390/1670 train_time:34155ms step_avg:87.58ms +step:391/1670 train_time:34244ms step_avg:87.58ms +step:392/1670 train_time:34331ms step_avg:87.58ms +step:393/1670 train_time:34418ms step_avg:87.58ms +step:394/1670 train_time:34506ms step_avg:87.58ms +step:395/1670 train_time:34592ms step_avg:87.58ms +step:396/1670 train_time:34680ms step_avg:87.57ms +step:397/1670 train_time:34768ms step_avg:87.58ms +step:398/1670 train_time:34854ms step_avg:87.57ms +step:399/1670 train_time:34943ms step_avg:87.58ms +step:400/1670 train_time:35031ms step_avg:87.58ms +step:401/1670 train_time:35119ms step_avg:87.58ms +step:402/1670 train_time:35208ms step_avg:87.58ms +step:403/1670 train_time:35296ms step_avg:87.58ms +step:404/1670 train_time:35384ms step_avg:87.58ms +step:405/1670 train_time:35471ms step_avg:87.58ms +step:406/1670 train_time:35558ms step_avg:87.58ms +step:407/1670 train_time:35645ms step_avg:87.58ms +step:408/1670 train_time:35732ms step_avg:87.58ms +step:409/1670 train_time:35820ms step_avg:87.58ms +step:410/1670 train_time:35907ms step_avg:87.58ms +step:411/1670 train_time:35994ms step_avg:87.58ms +step:412/1670 train_time:36083ms step_avg:87.58ms +step:413/1670 train_time:36171ms step_avg:87.58ms +step:414/1670 train_time:36259ms step_avg:87.58ms +step:415/1670 train_time:36347ms step_avg:87.58ms +step:416/1670 train_time:36434ms step_avg:87.58ms +step:417/1670 train_time:36521ms step_avg:87.58ms +step:418/1670 train_time:36609ms step_avg:87.58ms +step:419/1670 train_time:36696ms step_avg:87.58ms +step:420/1670 train_time:36784ms step_avg:87.58ms +step:421/1670 train_time:36872ms step_avg:87.58ms +step:422/1670 train_time:36960ms step_avg:87.58ms +step:423/1670 train_time:37048ms step_avg:87.58ms +step:424/1670 train_time:37136ms step_avg:87.59ms +step:425/1670 train_time:37225ms step_avg:87.59ms +step:426/1670 train_time:37312ms step_avg:87.59ms +step:427/1670 train_time:37401ms step_avg:87.59ms +step:428/1670 train_time:37487ms step_avg:87.59ms +step:429/1670 train_time:37574ms step_avg:87.59ms +step:430/1670 train_time:37662ms step_avg:87.59ms +step:431/1670 train_time:37749ms step_avg:87.58ms +step:432/1670 train_time:37836ms step_avg:87.58ms +step:433/1670 train_time:37923ms step_avg:87.58ms +step:434/1670 train_time:38012ms step_avg:87.58ms +step:435/1670 train_time:38100ms step_avg:87.59ms +step:436/1670 train_time:38188ms step_avg:87.59ms +step:437/1670 train_time:38276ms step_avg:87.59ms +step:438/1670 train_time:38364ms step_avg:87.59ms +step:439/1670 train_time:38451ms step_avg:87.59ms +step:440/1670 train_time:38539ms step_avg:87.59ms +step:441/1670 train_time:38627ms step_avg:87.59ms +step:442/1670 train_time:38714ms step_avg:87.59ms +step:443/1670 train_time:38800ms step_avg:87.59ms +step:444/1670 train_time:38889ms step_avg:87.59ms +step:445/1670 train_time:38978ms step_avg:87.59ms +step:446/1670 train_time:39067ms step_avg:87.59ms +step:447/1670 train_time:39153ms step_avg:87.59ms +step:448/1670 train_time:39241ms step_avg:87.59ms +step:449/1670 train_time:39329ms step_avg:87.59ms +step:450/1670 train_time:39417ms step_avg:87.59ms +step:451/1670 train_time:39506ms step_avg:87.60ms +step:452/1670 train_time:39593ms step_avg:87.59ms +step:453/1670 train_time:39681ms step_avg:87.60ms +step:454/1670 train_time:39768ms step_avg:87.59ms +step:455/1670 train_time:39855ms step_avg:87.59ms +step:456/1670 train_time:39944ms step_avg:87.60ms +step:457/1670 train_time:40031ms step_avg:87.60ms +step:458/1670 train_time:40119ms step_avg:87.60ms +step:459/1670 train_time:40207ms step_avg:87.60ms +step:460/1670 train_time:40294ms step_avg:87.60ms +step:461/1670 train_time:40382ms step_avg:87.60ms +step:462/1670 train_time:40471ms step_avg:87.60ms +step:463/1670 train_time:40558ms step_avg:87.60ms +step:464/1670 train_time:40646ms step_avg:87.60ms +step:465/1670 train_time:40733ms step_avg:87.60ms +step:466/1670 train_time:40821ms step_avg:87.60ms +step:467/1670 train_time:40909ms step_avg:87.60ms +step:468/1670 train_time:40996ms step_avg:87.60ms +step:469/1670 train_time:41084ms step_avg:87.60ms +step:470/1670 train_time:41172ms step_avg:87.60ms +step:471/1670 train_time:41259ms step_avg:87.60ms +step:472/1670 train_time:41347ms step_avg:87.60ms +step:473/1670 train_time:41434ms step_avg:87.60ms +step:474/1670 train_time:41522ms step_avg:87.60ms +step:475/1670 train_time:41611ms step_avg:87.60ms +step:476/1670 train_time:41699ms step_avg:87.60ms +step:477/1670 train_time:41787ms step_avg:87.60ms +step:478/1670 train_time:41874ms step_avg:87.60ms +step:479/1670 train_time:41962ms step_avg:87.60ms +step:480/1670 train_time:42050ms step_avg:87.60ms +step:481/1670 train_time:42138ms step_avg:87.60ms +step:482/1670 train_time:42226ms step_avg:87.61ms +step:483/1670 train_time:42313ms step_avg:87.60ms +step:484/1670 train_time:42400ms step_avg:87.60ms +step:485/1670 train_time:42488ms step_avg:87.60ms +step:486/1670 train_time:42575ms step_avg:87.60ms +step:487/1670 train_time:42664ms step_avg:87.60ms +step:488/1670 train_time:42751ms step_avg:87.60ms +step:489/1670 train_time:42839ms step_avg:87.61ms +step:490/1670 train_time:42928ms step_avg:87.61ms +step:491/1670 train_time:43015ms step_avg:87.61ms +step:492/1670 train_time:43103ms step_avg:87.61ms +step:493/1670 train_time:43191ms step_avg:87.61ms +step:494/1670 train_time:43278ms step_avg:87.61ms +step:495/1670 train_time:43367ms step_avg:87.61ms +step:496/1670 train_time:43454ms step_avg:87.61ms +step:497/1670 train_time:43542ms step_avg:87.61ms +step:498/1670 train_time:43629ms step_avg:87.61ms +step:499/1670 train_time:43717ms step_avg:87.61ms +step:500/1670 train_time:43805ms step_avg:87.61ms +step:500/1670 val_loss:3.7195 train_time:43894ms step_avg:87.79ms +step:501/1670 train_time:43913ms step_avg:87.65ms +step:502/1670 train_time:43985ms step_avg:87.62ms +step:503/1670 train_time:44079ms step_avg:87.63ms +step:504/1670 train_time:44167ms step_avg:87.63ms +step:505/1670 train_time:44255ms step_avg:87.63ms +step:506/1670 train_time:44341ms step_avg:87.63ms +step:507/1670 train_time:44428ms step_avg:87.63ms +step:508/1670 train_time:44515ms step_avg:87.63ms +step:509/1670 train_time:44601ms step_avg:87.62ms +step:510/1670 train_time:44687ms step_avg:87.62ms +step:511/1670 train_time:44774ms step_avg:87.62ms +step:512/1670 train_time:44862ms step_avg:87.62ms +step:513/1670 train_time:44951ms step_avg:87.62ms +step:514/1670 train_time:45041ms step_avg:87.63ms +step:515/1670 train_time:45131ms step_avg:87.63ms +step:516/1670 train_time:45220ms step_avg:87.63ms +step:517/1670 train_time:45307ms step_avg:87.63ms +step:518/1670 train_time:45394ms step_avg:87.63ms +step:519/1670 train_time:45481ms step_avg:87.63ms +step:520/1670 train_time:45568ms step_avg:87.63ms +step:521/1670 train_time:45656ms step_avg:87.63ms +step:522/1670 train_time:45742ms step_avg:87.63ms +step:523/1670 train_time:45830ms step_avg:87.63ms +step:524/1670 train_time:45920ms step_avg:87.63ms +step:525/1670 train_time:46009ms step_avg:87.64ms +step:526/1670 train_time:46098ms step_avg:87.64ms +step:527/1670 train_time:46187ms step_avg:87.64ms +step:528/1670 train_time:46275ms step_avg:87.64ms +step:529/1670 train_time:46362ms step_avg:87.64ms +step:530/1670 train_time:46449ms step_avg:87.64ms +step:531/1670 train_time:46536ms step_avg:87.64ms +step:532/1670 train_time:46623ms step_avg:87.64ms +step:533/1670 train_time:46710ms step_avg:87.64ms +step:534/1670 train_time:46798ms step_avg:87.64ms +step:535/1670 train_time:46884ms step_avg:87.63ms +step:536/1670 train_time:46973ms step_avg:87.64ms +step:537/1670 train_time:47062ms step_avg:87.64ms +step:538/1670 train_time:47151ms step_avg:87.64ms +step:539/1670 train_time:47239ms step_avg:87.64ms +step:540/1670 train_time:47328ms step_avg:87.64ms +step:541/1670 train_time:47416ms step_avg:87.65ms +step:542/1670 train_time:47503ms step_avg:87.64ms +step:543/1670 train_time:47590ms step_avg:87.64ms +step:544/1670 train_time:47677ms step_avg:87.64ms +step:545/1670 train_time:47764ms step_avg:87.64ms +step:546/1670 train_time:47854ms step_avg:87.64ms +step:547/1670 train_time:47943ms step_avg:87.65ms +step:548/1670 train_time:48032ms step_avg:87.65ms +step:549/1670 train_time:48122ms step_avg:87.65ms +step:550/1670 train_time:48211ms step_avg:87.66ms +step:551/1670 train_time:48300ms step_avg:87.66ms +step:552/1670 train_time:48389ms step_avg:87.66ms +step:553/1670 train_time:48478ms step_avg:87.66ms +step:554/1670 train_time:48566ms step_avg:87.66ms +step:555/1670 train_time:48655ms step_avg:87.67ms +step:556/1670 train_time:48743ms step_avg:87.67ms +step:557/1670 train_time:48833ms step_avg:87.67ms +step:558/1670 train_time:48922ms step_avg:87.67ms +step:559/1670 train_time:49012ms step_avg:87.68ms +step:560/1670 train_time:49100ms step_avg:87.68ms +step:561/1670 train_time:49190ms step_avg:87.68ms +step:562/1670 train_time:49279ms step_avg:87.69ms +step:563/1670 train_time:49368ms step_avg:87.69ms +step:564/1670 train_time:49458ms step_avg:87.69ms +step:565/1670 train_time:49546ms step_avg:87.69ms +step:566/1670 train_time:49634ms step_avg:87.69ms +step:567/1670 train_time:49723ms step_avg:87.69ms +step:568/1670 train_time:49812ms step_avg:87.70ms +step:569/1670 train_time:49901ms step_avg:87.70ms +step:570/1670 train_time:49991ms step_avg:87.70ms +step:571/1670 train_time:50079ms step_avg:87.70ms +step:572/1670 train_time:50168ms step_avg:87.71ms +step:573/1670 train_time:50257ms step_avg:87.71ms +step:574/1670 train_time:50346ms step_avg:87.71ms +step:575/1670 train_time:50435ms step_avg:87.71ms +step:576/1670 train_time:50524ms step_avg:87.71ms +step:577/1670 train_time:50613ms step_avg:87.72ms +step:578/1670 train_time:50701ms step_avg:87.72ms +step:579/1670 train_time:50790ms step_avg:87.72ms +step:580/1670 train_time:50879ms step_avg:87.72ms +step:581/1670 train_time:50968ms step_avg:87.73ms +step:582/1670 train_time:51058ms step_avg:87.73ms +step:583/1670 train_time:51147ms step_avg:87.73ms +step:584/1670 train_time:51237ms step_avg:87.73ms +step:585/1670 train_time:51326ms step_avg:87.74ms +step:586/1670 train_time:51415ms step_avg:87.74ms +step:587/1670 train_time:51504ms step_avg:87.74ms +step:588/1670 train_time:51593ms step_avg:87.74ms +step:589/1670 train_time:51681ms step_avg:87.74ms +step:590/1670 train_time:51770ms step_avg:87.75ms +step:591/1670 train_time:51859ms step_avg:87.75ms +step:592/1670 train_time:51949ms step_avg:87.75ms +step:593/1670 train_time:52039ms step_avg:87.76ms +step:594/1670 train_time:52128ms step_avg:87.76ms +step:595/1670 train_time:52217ms step_avg:87.76ms +step:596/1670 train_time:52306ms step_avg:87.76ms +step:597/1670 train_time:52395ms step_avg:87.76ms +step:598/1670 train_time:52483ms step_avg:87.76ms +step:599/1670 train_time:52573ms step_avg:87.77ms +step:600/1670 train_time:52661ms step_avg:87.77ms +step:601/1670 train_time:52750ms step_avg:87.77ms +step:602/1670 train_time:52838ms step_avg:87.77ms +step:603/1670 train_time:52927ms step_avg:87.77ms +step:604/1670 train_time:53017ms step_avg:87.78ms +step:605/1670 train_time:53105ms step_avg:87.78ms +step:606/1670 train_time:53194ms step_avg:87.78ms +step:607/1670 train_time:53283ms step_avg:87.78ms +step:608/1670 train_time:53372ms step_avg:87.78ms +step:609/1670 train_time:53461ms step_avg:87.79ms +step:610/1670 train_time:53551ms step_avg:87.79ms +step:611/1670 train_time:53639ms step_avg:87.79ms +step:612/1670 train_time:53727ms step_avg:87.79ms +step:613/1670 train_time:53818ms step_avg:87.80ms +step:614/1670 train_time:53907ms step_avg:87.80ms +step:615/1670 train_time:53997ms step_avg:87.80ms +step:616/1670 train_time:54085ms step_avg:87.80ms +step:617/1670 train_time:54174ms step_avg:87.80ms +step:618/1670 train_time:54263ms step_avg:87.80ms +step:619/1670 train_time:54352ms step_avg:87.81ms +step:620/1670 train_time:54441ms step_avg:87.81ms +step:621/1670 train_time:54531ms step_avg:87.81ms +step:622/1670 train_time:54619ms step_avg:87.81ms +step:623/1670 train_time:54707ms step_avg:87.81ms +step:624/1670 train_time:54797ms step_avg:87.82ms +step:625/1670 train_time:54886ms step_avg:87.82ms +step:625/1670 val_loss:3.6182 train_time:54977ms step_avg:87.96ms +step:626/1670 train_time:54997ms step_avg:87.85ms +step:627/1670 train_time:55067ms step_avg:87.83ms +step:628/1670 train_time:55156ms step_avg:87.83ms +step:629/1670 train_time:55248ms step_avg:87.83ms +step:630/1670 train_time:55335ms step_avg:87.83ms +step:631/1670 train_time:55423ms step_avg:87.83ms +step:632/1670 train_time:55510ms step_avg:87.83ms +step:633/1670 train_time:55597ms step_avg:87.83ms +step:634/1670 train_time:55686ms step_avg:87.83ms +step:635/1670 train_time:55776ms step_avg:87.84ms +step:636/1670 train_time:55866ms step_avg:87.84ms +step:637/1670 train_time:55956ms step_avg:87.84ms +step:638/1670 train_time:56048ms step_avg:87.85ms +step:639/1670 train_time:56136ms step_avg:87.85ms +step:640/1670 train_time:56225ms step_avg:87.85ms +step:641/1670 train_time:56314ms step_avg:87.85ms +step:642/1670 train_time:56402ms step_avg:87.85ms +step:643/1670 train_time:56490ms step_avg:87.85ms +step:644/1670 train_time:56578ms step_avg:87.85ms +step:645/1670 train_time:56667ms step_avg:87.86ms +step:646/1670 train_time:56755ms step_avg:87.86ms +step:647/1670 train_time:56845ms step_avg:87.86ms +step:648/1670 train_time:56935ms step_avg:87.86ms +step:649/1670 train_time:57025ms step_avg:87.87ms +step:650/1670 train_time:57114ms step_avg:87.87ms +step:651/1670 train_time:57203ms step_avg:87.87ms +step:652/1670 train_time:57292ms step_avg:87.87ms +step:653/1670 train_time:57381ms step_avg:87.87ms +step:654/1670 train_time:57469ms step_avg:87.87ms +step:655/1670 train_time:57557ms step_avg:87.87ms +step:656/1670 train_time:57646ms step_avg:87.88ms +step:657/1670 train_time:57735ms step_avg:87.88ms +step:658/1670 train_time:57824ms step_avg:87.88ms +step:659/1670 train_time:57913ms step_avg:87.88ms +step:660/1670 train_time:58003ms step_avg:87.88ms +step:661/1670 train_time:58092ms step_avg:87.89ms +step:662/1670 train_time:58182ms step_avg:87.89ms +step:663/1670 train_time:58271ms step_avg:87.89ms +step:664/1670 train_time:58359ms step_avg:87.89ms +step:665/1670 train_time:58448ms step_avg:87.89ms +step:666/1670 train_time:58536ms step_avg:87.89ms +step:667/1670 train_time:58624ms step_avg:87.89ms +step:668/1670 train_time:58713ms step_avg:87.89ms +step:669/1670 train_time:58801ms step_avg:87.89ms +step:670/1670 train_time:58891ms step_avg:87.90ms +step:671/1670 train_time:58980ms step_avg:87.90ms +step:672/1670 train_time:59070ms step_avg:87.90ms +step:673/1670 train_time:59159ms step_avg:87.90ms +step:674/1670 train_time:59248ms step_avg:87.90ms +step:675/1670 train_time:59336ms step_avg:87.90ms +step:676/1670 train_time:59424ms step_avg:87.91ms +step:677/1670 train_time:59513ms step_avg:87.91ms +step:678/1670 train_time:59601ms step_avg:87.91ms +step:679/1670 train_time:59690ms step_avg:87.91ms +step:680/1670 train_time:59780ms step_avg:87.91ms +step:681/1670 train_time:59869ms step_avg:87.91ms +step:682/1670 train_time:59959ms step_avg:87.92ms +step:683/1670 train_time:60050ms step_avg:87.92ms +step:684/1670 train_time:60139ms step_avg:87.92ms +step:685/1670 train_time:60228ms step_avg:87.92ms +step:686/1670 train_time:60317ms step_avg:87.93ms +step:687/1670 train_time:60406ms step_avg:87.93ms +step:688/1670 train_time:60494ms step_avg:87.93ms +step:689/1670 train_time:60582ms step_avg:87.93ms +step:690/1670 train_time:60672ms step_avg:87.93ms +step:691/1670 train_time:60761ms step_avg:87.93ms +step:692/1670 train_time:60851ms step_avg:87.93ms +step:693/1670 train_time:60940ms step_avg:87.94ms +step:694/1670 train_time:61029ms step_avg:87.94ms +step:695/1670 train_time:61118ms step_avg:87.94ms +step:696/1670 train_time:61208ms step_avg:87.94ms +step:697/1670 train_time:61296ms step_avg:87.94ms +step:698/1670 train_time:61386ms step_avg:87.95ms +step:699/1670 train_time:61474ms step_avg:87.95ms +step:700/1670 train_time:61563ms step_avg:87.95ms +step:701/1670 train_time:61652ms step_avg:87.95ms +step:702/1670 train_time:61740ms step_avg:87.95ms +step:703/1670 train_time:61829ms step_avg:87.95ms +step:704/1670 train_time:61918ms step_avg:87.95ms +step:705/1670 train_time:62007ms step_avg:87.95ms +step:706/1670 train_time:62095ms step_avg:87.95ms +step:707/1670 train_time:62184ms step_avg:87.95ms +step:708/1670 train_time:62273ms step_avg:87.96ms +step:709/1670 train_time:62362ms step_avg:87.96ms +step:710/1670 train_time:62452ms step_avg:87.96ms +step:711/1670 train_time:62541ms step_avg:87.96ms +step:712/1670 train_time:62629ms step_avg:87.96ms +step:713/1670 train_time:62718ms step_avg:87.96ms +step:714/1670 train_time:62807ms step_avg:87.97ms +step:715/1670 train_time:62896ms step_avg:87.97ms +step:716/1670 train_time:62985ms step_avg:87.97ms +step:717/1670 train_time:63073ms step_avg:87.97ms +step:718/1670 train_time:63162ms step_avg:87.97ms +step:719/1670 train_time:63251ms step_avg:87.97ms +step:720/1670 train_time:63340ms step_avg:87.97ms +step:721/1670 train_time:63430ms step_avg:87.97ms +step:722/1670 train_time:63519ms step_avg:87.98ms +step:723/1670 train_time:63608ms step_avg:87.98ms +step:724/1670 train_time:63696ms step_avg:87.98ms +step:725/1670 train_time:63786ms step_avg:87.98ms +step:726/1670 train_time:63875ms step_avg:87.98ms +step:727/1670 train_time:63964ms step_avg:87.98ms +step:728/1670 train_time:64053ms step_avg:87.99ms +step:729/1670 train_time:64142ms step_avg:87.99ms +step:730/1670 train_time:64231ms step_avg:87.99ms +step:731/1670 train_time:64320ms step_avg:87.99ms +step:732/1670 train_time:64410ms step_avg:87.99ms +step:733/1670 train_time:64498ms step_avg:87.99ms +step:734/1670 train_time:64587ms step_avg:87.99ms +step:735/1670 train_time:64675ms step_avg:87.99ms +step:736/1670 train_time:64765ms step_avg:88.00ms +step:737/1670 train_time:64855ms step_avg:88.00ms +step:738/1670 train_time:64945ms step_avg:88.00ms +step:739/1670 train_time:65034ms step_avg:88.00ms +step:740/1670 train_time:65124ms step_avg:88.01ms +step:741/1670 train_time:65213ms step_avg:88.01ms +step:742/1670 train_time:65301ms step_avg:88.01ms +step:743/1670 train_time:65390ms step_avg:88.01ms +step:744/1670 train_time:65479ms step_avg:88.01ms +step:745/1670 train_time:65568ms step_avg:88.01ms +step:746/1670 train_time:65656ms step_avg:88.01ms +step:747/1670 train_time:65746ms step_avg:88.01ms +step:748/1670 train_time:65834ms step_avg:88.01ms +step:749/1670 train_time:65923ms step_avg:88.02ms +step:750/1670 train_time:66012ms step_avg:88.02ms +step:750/1670 val_loss:3.5684 train_time:66103ms step_avg:88.14ms +step:751/1670 train_time:66123ms step_avg:88.05ms +step:752/1670 train_time:66197ms step_avg:88.03ms +step:753/1670 train_time:66290ms step_avg:88.03ms +step:754/1670 train_time:66379ms step_avg:88.04ms +step:755/1670 train_time:66468ms step_avg:88.04ms +step:756/1670 train_time:66556ms step_avg:88.04ms +step:757/1670 train_time:66644ms step_avg:88.04ms +step:758/1670 train_time:66733ms step_avg:88.04ms +step:759/1670 train_time:66820ms step_avg:88.04ms +step:760/1670 train_time:66908ms step_avg:88.04ms +step:761/1670 train_time:66996ms step_avg:88.04ms +step:762/1670 train_time:67087ms step_avg:88.04ms +step:763/1670 train_time:67177ms step_avg:88.04ms +step:764/1670 train_time:67268ms step_avg:88.05ms +step:765/1670 train_time:67357ms step_avg:88.05ms +step:766/1670 train_time:67447ms step_avg:88.05ms +step:767/1670 train_time:67535ms step_avg:88.05ms +step:768/1670 train_time:67623ms step_avg:88.05ms +step:769/1670 train_time:67711ms step_avg:88.05ms +step:770/1670 train_time:67799ms step_avg:88.05ms +step:771/1670 train_time:67888ms step_avg:88.05ms +step:772/1670 train_time:67976ms step_avg:88.05ms +step:773/1670 train_time:68066ms step_avg:88.05ms +step:774/1670 train_time:68156ms step_avg:88.06ms +step:775/1670 train_time:68247ms step_avg:88.06ms +step:776/1670 train_time:68337ms step_avg:88.06ms +step:777/1670 train_time:68427ms step_avg:88.07ms +step:778/1670 train_time:68515ms step_avg:88.07ms +step:779/1670 train_time:68603ms step_avg:88.07ms +step:780/1670 train_time:68691ms step_avg:88.07ms +step:781/1670 train_time:68780ms step_avg:88.07ms +step:782/1670 train_time:68869ms step_avg:88.07ms +step:783/1670 train_time:68957ms step_avg:88.07ms +step:784/1670 train_time:69046ms step_avg:88.07ms +step:785/1670 train_time:69135ms step_avg:88.07ms +step:786/1670 train_time:69225ms step_avg:88.07ms +step:787/1670 train_time:69314ms step_avg:88.07ms +step:788/1670 train_time:69404ms step_avg:88.08ms +step:789/1670 train_time:69493ms step_avg:88.08ms +step:790/1670 train_time:69582ms step_avg:88.08ms +step:791/1670 train_time:69670ms step_avg:88.08ms +step:792/1670 train_time:69758ms step_avg:88.08ms +step:793/1670 train_time:69847ms step_avg:88.08ms +step:794/1670 train_time:69936ms step_avg:88.08ms +step:795/1670 train_time:70024ms step_avg:88.08ms +step:796/1670 train_time:70114ms step_avg:88.08ms +step:797/1670 train_time:70203ms step_avg:88.08ms +step:798/1670 train_time:70293ms step_avg:88.09ms +step:799/1670 train_time:70382ms step_avg:88.09ms +step:800/1670 train_time:70472ms step_avg:88.09ms +step:801/1670 train_time:70562ms step_avg:88.09ms +step:802/1670 train_time:70650ms step_avg:88.09ms +step:803/1670 train_time:70739ms step_avg:88.09ms +step:804/1670 train_time:70828ms step_avg:88.09ms +step:805/1670 train_time:70916ms step_avg:88.09ms +step:806/1670 train_time:71005ms step_avg:88.10ms +step:807/1670 train_time:71094ms step_avg:88.10ms +step:808/1670 train_time:71183ms step_avg:88.10ms +step:809/1670 train_time:71272ms step_avg:88.10ms +step:810/1670 train_time:71361ms step_avg:88.10ms +step:811/1670 train_time:71451ms step_avg:88.10ms +step:812/1670 train_time:71540ms step_avg:88.10ms +step:813/1670 train_time:71629ms step_avg:88.10ms +step:814/1670 train_time:71718ms step_avg:88.11ms +step:815/1670 train_time:71808ms step_avg:88.11ms +step:816/1670 train_time:71897ms step_avg:88.11ms +step:817/1670 train_time:71986ms step_avg:88.11ms +step:818/1670 train_time:72074ms step_avg:88.11ms +step:819/1670 train_time:72163ms step_avg:88.11ms +step:820/1670 train_time:72252ms step_avg:88.11ms +step:821/1670 train_time:72342ms step_avg:88.11ms +step:822/1670 train_time:72432ms step_avg:88.12ms +step:823/1670 train_time:72521ms step_avg:88.12ms +step:824/1670 train_time:72610ms step_avg:88.12ms +step:825/1670 train_time:72699ms step_avg:88.12ms +step:826/1670 train_time:72790ms step_avg:88.12ms +step:827/1670 train_time:72878ms step_avg:88.12ms +step:828/1670 train_time:72968ms step_avg:88.13ms +step:829/1670 train_time:73056ms step_avg:88.13ms +step:830/1670 train_time:73144ms step_avg:88.13ms +step:831/1670 train_time:73234ms step_avg:88.13ms +step:832/1670 train_time:73323ms step_avg:88.13ms +step:833/1670 train_time:73412ms step_avg:88.13ms +step:834/1670 train_time:73501ms step_avg:88.13ms +step:835/1670 train_time:73590ms step_avg:88.13ms +step:836/1670 train_time:73679ms step_avg:88.13ms +step:837/1670 train_time:73768ms step_avg:88.13ms +step:838/1670 train_time:73857ms step_avg:88.13ms +step:839/1670 train_time:73946ms step_avg:88.14ms +step:840/1670 train_time:74034ms step_avg:88.14ms +step:841/1670 train_time:74123ms step_avg:88.14ms +step:842/1670 train_time:74212ms step_avg:88.14ms +step:843/1670 train_time:74302ms step_avg:88.14ms +step:844/1670 train_time:74391ms step_avg:88.14ms +step:845/1670 train_time:74480ms step_avg:88.14ms +step:846/1670 train_time:74570ms step_avg:88.14ms +step:847/1670 train_time:74658ms step_avg:88.14ms +step:848/1670 train_time:74747ms step_avg:88.14ms +step:849/1670 train_time:74836ms step_avg:88.15ms +step:850/1670 train_time:74925ms step_avg:88.15ms +step:851/1670 train_time:75013ms step_avg:88.15ms +step:852/1670 train_time:75102ms step_avg:88.15ms +step:853/1670 train_time:75192ms step_avg:88.15ms +step:854/1670 train_time:75281ms step_avg:88.15ms +step:855/1670 train_time:75370ms step_avg:88.15ms +step:856/1670 train_time:75459ms step_avg:88.15ms +step:857/1670 train_time:75548ms step_avg:88.15ms +step:858/1670 train_time:75636ms step_avg:88.15ms +step:859/1670 train_time:75726ms step_avg:88.16ms +step:860/1670 train_time:75814ms step_avg:88.16ms +step:861/1670 train_time:75904ms step_avg:88.16ms +step:862/1670 train_time:75992ms step_avg:88.16ms +step:863/1670 train_time:76081ms step_avg:88.16ms +step:864/1670 train_time:76170ms step_avg:88.16ms +step:865/1670 train_time:76259ms step_avg:88.16ms +step:866/1670 train_time:76349ms step_avg:88.16ms +step:867/1670 train_time:76438ms step_avg:88.16ms +step:868/1670 train_time:76526ms step_avg:88.16ms +step:869/1670 train_time:76615ms step_avg:88.16ms +step:870/1670 train_time:76704ms step_avg:88.17ms +step:871/1670 train_time:76793ms step_avg:88.17ms +step:872/1670 train_time:76882ms step_avg:88.17ms +step:873/1670 train_time:76971ms step_avg:88.17ms +step:874/1670 train_time:77059ms step_avg:88.17ms +step:875/1670 train_time:77148ms step_avg:88.17ms +step:875/1670 val_loss:3.5195 train_time:77238ms step_avg:88.27ms +step:876/1670 train_time:77258ms step_avg:88.19ms +step:877/1670 train_time:77332ms step_avg:88.18ms +step:878/1670 train_time:77426ms step_avg:88.18ms +step:879/1670 train_time:77515ms step_avg:88.18ms +step:880/1670 train_time:77603ms step_avg:88.19ms +step:881/1670 train_time:77690ms step_avg:88.18ms +step:882/1670 train_time:77778ms step_avg:88.18ms +step:883/1670 train_time:77867ms step_avg:88.18ms +step:884/1670 train_time:77954ms step_avg:88.18ms +step:885/1670 train_time:78042ms step_avg:88.18ms +step:886/1670 train_time:78130ms step_avg:88.18ms +step:887/1670 train_time:78220ms step_avg:88.18ms +step:888/1670 train_time:78312ms step_avg:88.19ms +step:889/1670 train_time:78404ms step_avg:88.19ms +step:890/1670 train_time:78494ms step_avg:88.20ms +step:891/1670 train_time:78583ms step_avg:88.20ms +step:892/1670 train_time:78671ms step_avg:88.20ms +step:893/1670 train_time:78760ms step_avg:88.20ms +step:894/1670 train_time:78848ms step_avg:88.20ms +step:895/1670 train_time:78935ms step_avg:88.20ms +step:896/1670 train_time:79024ms step_avg:88.20ms +step:897/1670 train_time:79112ms step_avg:88.20ms +step:898/1670 train_time:79201ms step_avg:88.20ms +step:899/1670 train_time:79292ms step_avg:88.20ms +step:900/1670 train_time:79383ms step_avg:88.20ms +step:901/1670 train_time:79473ms step_avg:88.20ms +step:902/1670 train_time:79562ms step_avg:88.21ms +step:903/1670 train_time:79651ms step_avg:88.21ms +step:904/1670 train_time:79739ms step_avg:88.21ms +step:905/1670 train_time:79829ms step_avg:88.21ms +step:906/1670 train_time:79917ms step_avg:88.21ms +step:907/1670 train_time:80006ms step_avg:88.21ms +step:908/1670 train_time:80094ms step_avg:88.21ms +step:909/1670 train_time:80183ms step_avg:88.21ms +step:910/1670 train_time:80273ms step_avg:88.21ms +step:911/1670 train_time:80363ms step_avg:88.21ms +step:912/1670 train_time:80452ms step_avg:88.21ms +step:913/1670 train_time:80541ms step_avg:88.22ms +step:914/1670 train_time:80630ms step_avg:88.22ms +step:915/1670 train_time:80720ms step_avg:88.22ms +step:916/1670 train_time:80808ms step_avg:88.22ms +step:917/1670 train_time:80896ms step_avg:88.22ms +step:918/1670 train_time:80985ms step_avg:88.22ms +step:919/1670 train_time:81073ms step_avg:88.22ms +step:920/1670 train_time:81162ms step_avg:88.22ms +step:921/1670 train_time:81252ms step_avg:88.22ms +step:922/1670 train_time:81342ms step_avg:88.22ms +step:923/1670 train_time:81431ms step_avg:88.22ms +step:924/1670 train_time:81522ms step_avg:88.23ms +step:925/1670 train_time:81611ms step_avg:88.23ms +step:926/1670 train_time:81700ms step_avg:88.23ms +step:927/1670 train_time:81788ms step_avg:88.23ms +step:928/1670 train_time:81877ms step_avg:88.23ms +step:929/1670 train_time:81965ms step_avg:88.23ms +step:930/1670 train_time:82053ms step_avg:88.23ms +step:931/1670 train_time:82142ms step_avg:88.23ms +step:932/1670 train_time:82232ms step_avg:88.23ms +step:933/1670 train_time:82322ms step_avg:88.23ms +step:934/1670 train_time:82411ms step_avg:88.23ms +step:935/1670 train_time:82500ms step_avg:88.24ms +step:936/1670 train_time:82590ms step_avg:88.24ms +step:937/1670 train_time:82678ms step_avg:88.24ms +step:938/1670 train_time:82768ms step_avg:88.24ms +step:939/1670 train_time:82857ms step_avg:88.24ms +step:940/1670 train_time:82946ms step_avg:88.24ms +step:941/1670 train_time:83034ms step_avg:88.24ms +step:942/1670 train_time:83123ms step_avg:88.24ms +step:943/1670 train_time:83212ms step_avg:88.24ms +step:944/1670 train_time:83301ms step_avg:88.24ms +step:945/1670 train_time:83391ms step_avg:88.24ms +step:946/1670 train_time:83479ms step_avg:88.24ms +step:947/1670 train_time:83570ms step_avg:88.25ms +step:948/1670 train_time:83659ms step_avg:88.25ms +step:949/1670 train_time:83749ms step_avg:88.25ms +step:950/1670 train_time:83838ms step_avg:88.25ms +step:951/1670 train_time:83927ms step_avg:88.25ms +step:952/1670 train_time:84016ms step_avg:88.25ms +step:953/1670 train_time:84105ms step_avg:88.25ms +step:954/1670 train_time:84194ms step_avg:88.25ms +step:955/1670 train_time:84282ms step_avg:88.25ms +step:956/1670 train_time:84371ms step_avg:88.25ms +step:957/1670 train_time:84460ms step_avg:88.26ms +step:958/1670 train_time:84550ms step_avg:88.26ms +step:959/1670 train_time:84639ms step_avg:88.26ms +step:960/1670 train_time:84730ms step_avg:88.26ms +step:961/1670 train_time:84819ms step_avg:88.26ms +step:962/1670 train_time:84908ms step_avg:88.26ms +step:963/1670 train_time:84997ms step_avg:88.26ms +step:964/1670 train_time:85087ms step_avg:88.26ms +step:965/1670 train_time:85175ms step_avg:88.26ms +step:966/1670 train_time:85265ms step_avg:88.27ms +step:967/1670 train_time:85353ms step_avg:88.27ms +step:968/1670 train_time:85443ms step_avg:88.27ms +step:969/1670 train_time:85532ms step_avg:88.27ms +step:970/1670 train_time:85621ms step_avg:88.27ms +step:971/1670 train_time:85710ms step_avg:88.27ms +step:972/1670 train_time:85799ms step_avg:88.27ms +step:973/1670 train_time:85888ms step_avg:88.27ms +step:974/1670 train_time:85976ms step_avg:88.27ms +step:975/1670 train_time:86065ms step_avg:88.27ms +step:976/1670 train_time:86153ms step_avg:88.27ms +step:977/1670 train_time:86242ms step_avg:88.27ms +step:978/1670 train_time:86331ms step_avg:88.27ms +step:979/1670 train_time:86420ms step_avg:88.27ms +step:980/1670 train_time:86508ms step_avg:88.27ms +step:981/1670 train_time:86598ms step_avg:88.27ms +step:982/1670 train_time:86688ms step_avg:88.28ms +step:983/1670 train_time:86777ms step_avg:88.28ms +step:984/1670 train_time:86866ms step_avg:88.28ms +step:985/1670 train_time:86954ms step_avg:88.28ms +step:986/1670 train_time:87044ms step_avg:88.28ms +step:987/1670 train_time:87132ms step_avg:88.28ms +step:988/1670 train_time:87222ms step_avg:88.28ms +step:989/1670 train_time:87310ms step_avg:88.28ms +step:990/1670 train_time:87399ms step_avg:88.28ms +step:991/1670 train_time:87488ms step_avg:88.28ms +step:992/1670 train_time:87576ms step_avg:88.28ms +step:993/1670 train_time:87665ms step_avg:88.28ms +step:994/1670 train_time:87753ms step_avg:88.28ms +step:995/1670 train_time:87842ms step_avg:88.28ms +step:996/1670 train_time:87932ms step_avg:88.28ms +step:997/1670 train_time:88021ms step_avg:88.29ms +step:998/1670 train_time:88111ms step_avg:88.29ms +step:999/1670 train_time:88200ms step_avg:88.29ms +step:1000/1670 train_time:88288ms step_avg:88.29ms +step:1000/1670 val_loss:3.4680 train_time:88378ms step_avg:88.38ms +step:1001/1670 train_time:88399ms step_avg:88.31ms +step:1002/1670 train_time:88471ms step_avg:88.29ms +step:1003/1670 train_time:88569ms step_avg:88.30ms +step:1004/1670 train_time:88659ms step_avg:88.31ms +step:1005/1670 train_time:88747ms step_avg:88.31ms +step:1006/1670 train_time:88835ms step_avg:88.30ms +step:1007/1670 train_time:88923ms step_avg:88.30ms +step:1008/1670 train_time:89011ms step_avg:88.31ms +step:1009/1670 train_time:89099ms step_avg:88.30ms +step:1010/1670 train_time:89187ms step_avg:88.30ms +step:1011/1670 train_time:89275ms step_avg:88.30ms +step:1012/1670 train_time:89364ms step_avg:88.30ms +step:1013/1670 train_time:89455ms step_avg:88.31ms +step:1014/1670 train_time:89549ms step_avg:88.31ms +step:1015/1670 train_time:89638ms step_avg:88.31ms +step:1016/1670 train_time:89728ms step_avg:88.32ms +step:1017/1670 train_time:89816ms step_avg:88.32ms +step:1018/1670 train_time:89906ms step_avg:88.32ms +step:1019/1670 train_time:89993ms step_avg:88.32ms +step:1020/1670 train_time:90082ms step_avg:88.32ms +step:1021/1670 train_time:90170ms step_avg:88.32ms +step:1022/1670 train_time:90258ms step_avg:88.32ms +step:1023/1670 train_time:90348ms step_avg:88.32ms +step:1024/1670 train_time:90438ms step_avg:88.32ms +step:1025/1670 train_time:90529ms step_avg:88.32ms +step:1026/1670 train_time:90619ms step_avg:88.32ms +step:1027/1670 train_time:90711ms step_avg:88.33ms +step:1028/1670 train_time:90800ms step_avg:88.33ms +step:1029/1670 train_time:90889ms step_avg:88.33ms +step:1030/1670 train_time:90978ms step_avg:88.33ms +step:1031/1670 train_time:91067ms step_avg:88.33ms +step:1032/1670 train_time:91154ms step_avg:88.33ms +step:1033/1670 train_time:91243ms step_avg:88.33ms +step:1034/1670 train_time:91331ms step_avg:88.33ms +step:1035/1670 train_time:91421ms step_avg:88.33ms +step:1036/1670 train_time:91512ms step_avg:88.33ms +step:1037/1670 train_time:91602ms step_avg:88.33ms +step:1038/1670 train_time:91692ms step_avg:88.33ms +step:1039/1670 train_time:91781ms step_avg:88.34ms +step:1040/1670 train_time:91871ms step_avg:88.34ms +step:1041/1670 train_time:91960ms step_avg:88.34ms +step:1042/1670 train_time:92048ms step_avg:88.34ms +step:1043/1670 train_time:92136ms step_avg:88.34ms +step:1044/1670 train_time:92224ms step_avg:88.34ms +step:1045/1670 train_time:92313ms step_avg:88.34ms +step:1046/1670 train_time:92402ms step_avg:88.34ms +step:1047/1670 train_time:92492ms step_avg:88.34ms +step:1048/1670 train_time:92582ms step_avg:88.34ms +step:1049/1670 train_time:92672ms step_avg:88.34ms +step:1050/1670 train_time:92763ms step_avg:88.35ms +step:1051/1670 train_time:92851ms step_avg:88.35ms +step:1052/1670 train_time:92940ms step_avg:88.35ms +step:1053/1670 train_time:93029ms step_avg:88.35ms +step:1054/1670 train_time:93118ms step_avg:88.35ms +step:1055/1670 train_time:93207ms step_avg:88.35ms +step:1056/1670 train_time:93295ms step_avg:88.35ms +step:1057/1670 train_time:93385ms step_avg:88.35ms +step:1058/1670 train_time:93473ms step_avg:88.35ms +step:1059/1670 train_time:93564ms step_avg:88.35ms +step:1060/1670 train_time:93653ms step_avg:88.35ms +step:1061/1670 train_time:93743ms step_avg:88.35ms +step:1062/1670 train_time:93832ms step_avg:88.35ms +step:1063/1670 train_time:93920ms step_avg:88.35ms +step:1064/1670 train_time:94009ms step_avg:88.35ms +step:1065/1670 train_time:94098ms step_avg:88.35ms +step:1066/1670 train_time:94187ms step_avg:88.36ms +step:1067/1670 train_time:94275ms step_avg:88.36ms +step:1068/1670 train_time:94365ms step_avg:88.36ms +step:1069/1670 train_time:94454ms step_avg:88.36ms +step:1070/1670 train_time:94543ms step_avg:88.36ms +step:1071/1670 train_time:94632ms step_avg:88.36ms +step:1072/1670 train_time:94722ms step_avg:88.36ms +step:1073/1670 train_time:94811ms step_avg:88.36ms +step:1074/1670 train_time:94900ms step_avg:88.36ms +step:1075/1670 train_time:94988ms step_avg:88.36ms +step:1076/1670 train_time:95077ms step_avg:88.36ms +step:1077/1670 train_time:95167ms step_avg:88.36ms +step:1078/1670 train_time:95255ms step_avg:88.36ms +step:1079/1670 train_time:95344ms step_avg:88.36ms +step:1080/1670 train_time:95433ms step_avg:88.36ms +step:1081/1670 train_time:95522ms step_avg:88.36ms +step:1082/1670 train_time:95611ms step_avg:88.36ms +step:1083/1670 train_time:95701ms step_avg:88.37ms +step:1084/1670 train_time:95790ms step_avg:88.37ms +step:1085/1670 train_time:95879ms step_avg:88.37ms +step:1086/1670 train_time:95968ms step_avg:88.37ms +step:1087/1670 train_time:96056ms step_avg:88.37ms +step:1088/1670 train_time:96145ms step_avg:88.37ms +step:1089/1670 train_time:96234ms step_avg:88.37ms +step:1090/1670 train_time:96324ms step_avg:88.37ms +step:1091/1670 train_time:96413ms step_avg:88.37ms +step:1092/1670 train_time:96503ms step_avg:88.37ms +step:1093/1670 train_time:96592ms step_avg:88.37ms +step:1094/1670 train_time:96682ms step_avg:88.37ms +step:1095/1670 train_time:96773ms step_avg:88.38ms +step:1096/1670 train_time:96863ms step_avg:88.38ms +step:1097/1670 train_time:96953ms step_avg:88.38ms +step:1098/1670 train_time:97043ms step_avg:88.38ms +step:1099/1670 train_time:97132ms step_avg:88.38ms +step:1100/1670 train_time:97222ms step_avg:88.38ms +step:1101/1670 train_time:97312ms step_avg:88.39ms +step:1102/1670 train_time:97403ms step_avg:88.39ms +step:1103/1670 train_time:97492ms step_avg:88.39ms +step:1104/1670 train_time:97582ms step_avg:88.39ms +step:1105/1670 train_time:97672ms step_avg:88.39ms +step:1106/1670 train_time:97763ms step_avg:88.39ms +step:1107/1670 train_time:97852ms step_avg:88.39ms +step:1108/1670 train_time:97942ms step_avg:88.40ms +step:1109/1670 train_time:98031ms step_avg:88.40ms +step:1110/1670 train_time:98121ms step_avg:88.40ms +step:1111/1670 train_time:98211ms step_avg:88.40ms +step:1112/1670 train_time:98300ms step_avg:88.40ms +step:1113/1670 train_time:98391ms step_avg:88.40ms +step:1114/1670 train_time:98480ms step_avg:88.40ms +step:1115/1670 train_time:98571ms step_avg:88.40ms +step:1116/1670 train_time:98660ms step_avg:88.40ms +step:1117/1670 train_time:98750ms step_avg:88.41ms +step:1118/1670 train_time:98839ms step_avg:88.41ms +step:1119/1670 train_time:98929ms step_avg:88.41ms +step:1120/1670 train_time:99019ms step_avg:88.41ms +step:1121/1670 train_time:99108ms step_avg:88.41ms +step:1122/1670 train_time:99198ms step_avg:88.41ms +step:1123/1670 train_time:99288ms step_avg:88.41ms +step:1124/1670 train_time:99378ms step_avg:88.41ms +step:1125/1670 train_time:99470ms step_avg:88.42ms +step:1125/1670 val_loss:3.4143 train_time:99561ms step_avg:88.50ms +step:1126/1670 train_time:99581ms step_avg:88.44ms +step:1127/1670 train_time:99654ms step_avg:88.42ms +step:1128/1670 train_time:99746ms step_avg:88.43ms +step:1129/1670 train_time:99835ms step_avg:88.43ms +step:1130/1670 train_time:99924ms step_avg:88.43ms +step:1131/1670 train_time:100013ms step_avg:88.43ms +step:1132/1670 train_time:100102ms step_avg:88.43ms +step:1133/1670 train_time:100191ms step_avg:88.43ms +step:1134/1670 train_time:100279ms step_avg:88.43ms +step:1135/1670 train_time:100370ms step_avg:88.43ms +step:1136/1670 train_time:100460ms step_avg:88.43ms +step:1137/1670 train_time:100551ms step_avg:88.44ms +step:1138/1670 train_time:100643ms step_avg:88.44ms +step:1139/1670 train_time:100735ms step_avg:88.44ms +step:1140/1670 train_time:100826ms step_avg:88.44ms +step:1141/1670 train_time:100916ms step_avg:88.45ms +step:1142/1670 train_time:101005ms step_avg:88.45ms +step:1143/1670 train_time:101093ms step_avg:88.45ms +step:1144/1670 train_time:101181ms step_avg:88.45ms +step:1145/1670 train_time:101270ms step_avg:88.45ms +step:1146/1670 train_time:101359ms step_avg:88.45ms +step:1147/1670 train_time:101449ms step_avg:88.45ms +step:1148/1670 train_time:101539ms step_avg:88.45ms +step:1149/1670 train_time:101630ms step_avg:88.45ms +step:1150/1670 train_time:101721ms step_avg:88.45ms +step:1151/1670 train_time:101812ms step_avg:88.45ms +step:1152/1670 train_time:101902ms step_avg:88.46ms +step:1153/1670 train_time:101992ms step_avg:88.46ms +step:1154/1670 train_time:102081ms step_avg:88.46ms +step:1155/1670 train_time:102170ms step_avg:88.46ms +step:1156/1670 train_time:102259ms step_avg:88.46ms +step:1157/1670 train_time:102349ms step_avg:88.46ms +step:1158/1670 train_time:102438ms step_avg:88.46ms +step:1159/1670 train_time:102528ms step_avg:88.46ms +step:1160/1670 train_time:102617ms step_avg:88.46ms +step:1161/1670 train_time:102708ms step_avg:88.46ms +step:1162/1670 train_time:102798ms step_avg:88.47ms +step:1163/1670 train_time:102888ms step_avg:88.47ms +step:1164/1670 train_time:102978ms step_avg:88.47ms +step:1165/1670 train_time:103068ms step_avg:88.47ms +step:1166/1670 train_time:103156ms step_avg:88.47ms +step:1167/1670 train_time:103245ms step_avg:88.47ms +step:1168/1670 train_time:103334ms step_avg:88.47ms +step:1169/1670 train_time:103424ms step_avg:88.47ms +step:1170/1670 train_time:103513ms step_avg:88.47ms +step:1171/1670 train_time:103603ms step_avg:88.47ms +step:1172/1670 train_time:103692ms step_avg:88.47ms +step:1173/1670 train_time:103782ms step_avg:88.48ms +step:1174/1670 train_time:103873ms step_avg:88.48ms +step:1175/1670 train_time:103962ms step_avg:88.48ms +step:1176/1670 train_time:104053ms step_avg:88.48ms +step:1177/1670 train_time:104143ms step_avg:88.48ms +step:1178/1670 train_time:104231ms step_avg:88.48ms +step:1179/1670 train_time:104320ms step_avg:88.48ms +step:1180/1670 train_time:104410ms step_avg:88.48ms +step:1181/1670 train_time:104500ms step_avg:88.48ms +step:1182/1670 train_time:104589ms step_avg:88.49ms +step:1183/1670 train_time:104678ms step_avg:88.49ms +step:1184/1670 train_time:104769ms step_avg:88.49ms +step:1185/1670 train_time:104859ms step_avg:88.49ms +step:1186/1670 train_time:104949ms step_avg:88.49ms +step:1187/1670 train_time:105039ms step_avg:88.49ms +step:1188/1670 train_time:105130ms step_avg:88.49ms +step:1189/1670 train_time:105220ms step_avg:88.49ms +step:1190/1670 train_time:105310ms step_avg:88.50ms +step:1191/1670 train_time:105399ms step_avg:88.50ms +step:1192/1670 train_time:105489ms step_avg:88.50ms +step:1193/1670 train_time:105579ms step_avg:88.50ms +step:1194/1670 train_time:105669ms step_avg:88.50ms +step:1195/1670 train_time:105758ms step_avg:88.50ms +step:1196/1670 train_time:105849ms step_avg:88.50ms +step:1197/1670 train_time:105938ms step_avg:88.50ms +step:1198/1670 train_time:106028ms step_avg:88.50ms +step:1199/1670 train_time:106118ms step_avg:88.51ms +step:1200/1670 train_time:106208ms step_avg:88.51ms +step:1201/1670 train_time:106297ms step_avg:88.51ms +step:1202/1670 train_time:106387ms step_avg:88.51ms +step:1203/1670 train_time:106476ms step_avg:88.51ms +step:1204/1670 train_time:106565ms step_avg:88.51ms +step:1205/1670 train_time:106655ms step_avg:88.51ms +step:1206/1670 train_time:106744ms step_avg:88.51ms +step:1207/1670 train_time:106834ms step_avg:88.51ms +step:1208/1670 train_time:106924ms step_avg:88.51ms +step:1209/1670 train_time:107013ms step_avg:88.51ms +step:1210/1670 train_time:107103ms step_avg:88.52ms +step:1211/1670 train_time:107193ms step_avg:88.52ms +step:1212/1670 train_time:107283ms step_avg:88.52ms +step:1213/1670 train_time:107372ms step_avg:88.52ms +step:1214/1670 train_time:107462ms step_avg:88.52ms +step:1215/1670 train_time:107553ms step_avg:88.52ms +step:1216/1670 train_time:107644ms step_avg:88.52ms +step:1217/1670 train_time:107733ms step_avg:88.52ms +step:1218/1670 train_time:107822ms step_avg:88.52ms +step:1219/1670 train_time:107912ms step_avg:88.52ms +step:1220/1670 train_time:108002ms step_avg:88.53ms +step:1221/1670 train_time:108091ms step_avg:88.53ms +step:1222/1670 train_time:108181ms step_avg:88.53ms +step:1223/1670 train_time:108271ms step_avg:88.53ms +step:1224/1670 train_time:108360ms step_avg:88.53ms +step:1225/1670 train_time:108451ms step_avg:88.53ms +step:1226/1670 train_time:108541ms step_avg:88.53ms +step:1227/1670 train_time:108631ms step_avg:88.53ms +step:1228/1670 train_time:108721ms step_avg:88.53ms +step:1229/1670 train_time:108811ms step_avg:88.54ms +step:1230/1670 train_time:108901ms step_avg:88.54ms +step:1231/1670 train_time:108990ms step_avg:88.54ms +step:1232/1670 train_time:109080ms step_avg:88.54ms +step:1233/1670 train_time:109170ms step_avg:88.54ms +step:1234/1670 train_time:109259ms step_avg:88.54ms +step:1235/1670 train_time:109350ms step_avg:88.54ms +step:1236/1670 train_time:109440ms step_avg:88.54ms +step:1237/1670 train_time:109530ms step_avg:88.55ms +step:1238/1670 train_time:109620ms step_avg:88.55ms +step:1239/1670 train_time:109710ms step_avg:88.55ms +step:1240/1670 train_time:109799ms step_avg:88.55ms +step:1241/1670 train_time:109889ms step_avg:88.55ms +step:1242/1670 train_time:109979ms step_avg:88.55ms +step:1243/1670 train_time:110069ms step_avg:88.55ms +step:1244/1670 train_time:110159ms step_avg:88.55ms +step:1245/1670 train_time:110248ms step_avg:88.55ms +step:1246/1670 train_time:110338ms step_avg:88.55ms +step:1247/1670 train_time:110428ms step_avg:88.56ms +step:1248/1670 train_time:110518ms step_avg:88.56ms +step:1249/1670 train_time:110609ms step_avg:88.56ms +step:1250/1670 train_time:110700ms step_avg:88.56ms +step:1250/1670 val_loss:3.3758 train_time:110791ms step_avg:88.63ms +step:1251/1670 train_time:110810ms step_avg:88.58ms +step:1252/1670 train_time:110884ms step_avg:88.57ms +step:1253/1670 train_time:110980ms step_avg:88.57ms +step:1254/1670 train_time:111070ms step_avg:88.57ms +step:1255/1670 train_time:111159ms step_avg:88.57ms +step:1256/1670 train_time:111247ms step_avg:88.57ms +step:1257/1670 train_time:111335ms step_avg:88.57ms +step:1258/1670 train_time:111423ms step_avg:88.57ms +step:1259/1670 train_time:111512ms step_avg:88.57ms +step:1260/1670 train_time:111602ms step_avg:88.57ms +step:1261/1670 train_time:111691ms step_avg:88.57ms +step:1262/1670 train_time:111785ms step_avg:88.58ms +step:1263/1670 train_time:111879ms step_avg:88.58ms +step:1264/1670 train_time:111970ms step_avg:88.58ms +step:1265/1670 train_time:112062ms step_avg:88.59ms +step:1266/1670 train_time:112151ms step_avg:88.59ms +step:1267/1670 train_time:112241ms step_avg:88.59ms +step:1268/1670 train_time:112330ms step_avg:88.59ms +step:1269/1670 train_time:112418ms step_avg:88.59ms +step:1270/1670 train_time:112507ms step_avg:88.59ms +step:1271/1670 train_time:112595ms step_avg:88.59ms +step:1272/1670 train_time:112684ms step_avg:88.59ms +step:1273/1670 train_time:112776ms step_avg:88.59ms +step:1274/1670 train_time:112866ms step_avg:88.59ms +step:1275/1670 train_time:112957ms step_avg:88.59ms +step:1276/1670 train_time:113047ms step_avg:88.59ms +step:1277/1670 train_time:113137ms step_avg:88.60ms +step:1278/1670 train_time:113225ms step_avg:88.60ms +step:1279/1670 train_time:113315ms step_avg:88.60ms +step:1280/1670 train_time:113404ms step_avg:88.60ms +step:1281/1670 train_time:113492ms step_avg:88.60ms +step:1282/1670 train_time:113583ms step_avg:88.60ms +step:1283/1670 train_time:113673ms step_avg:88.60ms +step:1284/1670 train_time:113763ms step_avg:88.60ms +step:1285/1670 train_time:113854ms step_avg:88.60ms +step:1286/1670 train_time:113945ms step_avg:88.60ms +step:1287/1670 train_time:114036ms step_avg:88.61ms +step:1288/1670 train_time:114125ms step_avg:88.61ms +step:1289/1670 train_time:114215ms step_avg:88.61ms +step:1290/1670 train_time:114304ms step_avg:88.61ms +step:1291/1670 train_time:114393ms step_avg:88.61ms +step:1292/1670 train_time:114482ms step_avg:88.61ms +step:1293/1670 train_time:114572ms step_avg:88.61ms +step:1294/1670 train_time:114662ms step_avg:88.61ms +step:1295/1670 train_time:114753ms step_avg:88.61ms +step:1296/1670 train_time:114843ms step_avg:88.61ms +step:1297/1670 train_time:114934ms step_avg:88.62ms +step:1298/1670 train_time:115024ms step_avg:88.62ms +step:1299/1670 train_time:115114ms step_avg:88.62ms +step:1300/1670 train_time:115203ms step_avg:88.62ms +step:1301/1670 train_time:115292ms step_avg:88.62ms +step:1302/1670 train_time:115382ms step_avg:88.62ms +step:1303/1670 train_time:115472ms step_avg:88.62ms +step:1304/1670 train_time:115562ms step_avg:88.62ms +step:1305/1670 train_time:115652ms step_avg:88.62ms +step:1306/1670 train_time:115742ms step_avg:88.62ms +step:1307/1670 train_time:115832ms step_avg:88.62ms +step:1308/1670 train_time:115922ms step_avg:88.63ms +step:1309/1670 train_time:116012ms step_avg:88.63ms +step:1310/1670 train_time:116103ms step_avg:88.63ms +step:1311/1670 train_time:116193ms step_avg:88.63ms +step:1312/1670 train_time:116283ms step_avg:88.63ms +step:1313/1670 train_time:116373ms step_avg:88.63ms +step:1314/1670 train_time:116462ms step_avg:88.63ms +step:1315/1670 train_time:116552ms step_avg:88.63ms +step:1316/1670 train_time:116642ms step_avg:88.63ms +step:1317/1670 train_time:116733ms step_avg:88.64ms +step:1318/1670 train_time:116823ms step_avg:88.64ms +step:1319/1670 train_time:116913ms step_avg:88.64ms +step:1320/1670 train_time:117002ms step_avg:88.64ms +step:1321/1670 train_time:117091ms step_avg:88.64ms +step:1322/1670 train_time:117181ms step_avg:88.64ms +step:1323/1670 train_time:117272ms step_avg:88.64ms +step:1324/1670 train_time:117361ms step_avg:88.64ms +step:1325/1670 train_time:117451ms step_avg:88.64ms +step:1326/1670 train_time:117541ms step_avg:88.64ms +step:1327/1670 train_time:117631ms step_avg:88.64ms +step:1328/1670 train_time:117721ms step_avg:88.65ms +step:1329/1670 train_time:117811ms step_avg:88.65ms +step:1330/1670 train_time:117901ms step_avg:88.65ms +step:1331/1670 train_time:117991ms step_avg:88.65ms +step:1332/1670 train_time:118080ms step_avg:88.65ms +step:1333/1670 train_time:118170ms step_avg:88.65ms +step:1334/1670 train_time:118260ms step_avg:88.65ms +step:1335/1670 train_time:118349ms step_avg:88.65ms +step:1336/1670 train_time:118440ms step_avg:88.65ms +step:1337/1670 train_time:118531ms step_avg:88.65ms +step:1338/1670 train_time:118621ms step_avg:88.66ms +step:1339/1670 train_time:118712ms step_avg:88.66ms +step:1340/1670 train_time:118801ms step_avg:88.66ms +step:1341/1670 train_time:118891ms step_avg:88.66ms +step:1342/1670 train_time:118980ms step_avg:88.66ms +step:1343/1670 train_time:119070ms step_avg:88.66ms +step:1344/1670 train_time:119159ms step_avg:88.66ms +step:1345/1670 train_time:119249ms step_avg:88.66ms +step:1346/1670 train_time:119339ms step_avg:88.66ms +step:1347/1670 train_time:119429ms step_avg:88.66ms +step:1348/1670 train_time:119519ms step_avg:88.66ms +step:1349/1670 train_time:119609ms step_avg:88.66ms +step:1350/1670 train_time:119699ms step_avg:88.67ms +step:1351/1670 train_time:119788ms step_avg:88.67ms +step:1352/1670 train_time:119878ms step_avg:88.67ms +step:1353/1670 train_time:119967ms step_avg:88.67ms +step:1354/1670 train_time:120057ms step_avg:88.67ms +step:1355/1670 train_time:120147ms step_avg:88.67ms +step:1356/1670 train_time:120236ms step_avg:88.67ms +step:1357/1670 train_time:120325ms step_avg:88.67ms +step:1358/1670 train_time:120415ms step_avg:88.67ms +step:1359/1670 train_time:120505ms step_avg:88.67ms +step:1360/1670 train_time:120595ms step_avg:88.67ms +step:1361/1670 train_time:120683ms step_avg:88.67ms +step:1362/1670 train_time:120774ms step_avg:88.67ms +step:1363/1670 train_time:120864ms step_avg:88.67ms +step:1364/1670 train_time:120954ms step_avg:88.68ms +step:1365/1670 train_time:121043ms step_avg:88.68ms +step:1366/1670 train_time:121133ms step_avg:88.68ms +step:1367/1670 train_time:121223ms step_avg:88.68ms +step:1368/1670 train_time:121312ms step_avg:88.68ms +step:1369/1670 train_time:121402ms step_avg:88.68ms +step:1370/1670 train_time:121492ms step_avg:88.68ms +step:1371/1670 train_time:121582ms step_avg:88.68ms +step:1372/1670 train_time:121672ms step_avg:88.68ms +step:1373/1670 train_time:121762ms step_avg:88.68ms +step:1374/1670 train_time:121853ms step_avg:88.68ms +step:1375/1670 train_time:121943ms step_avg:88.69ms +step:1375/1670 val_loss:3.3410 train_time:122035ms step_avg:88.75ms +step:1376/1670 train_time:122054ms step_avg:88.70ms +step:1377/1670 train_time:122129ms step_avg:88.69ms +step:1378/1670 train_time:122223ms step_avg:88.70ms +step:1379/1670 train_time:122312ms step_avg:88.70ms +step:1380/1670 train_time:122401ms step_avg:88.70ms +step:1381/1670 train_time:122489ms step_avg:88.70ms +step:1382/1670 train_time:122578ms step_avg:88.70ms +step:1383/1670 train_time:122667ms step_avg:88.70ms +step:1384/1670 train_time:122755ms step_avg:88.70ms +step:1385/1670 train_time:122845ms step_avg:88.70ms +step:1386/1670 train_time:122934ms step_avg:88.70ms +step:1387/1670 train_time:123026ms step_avg:88.70ms +step:1388/1670 train_time:123118ms step_avg:88.70ms +step:1389/1670 train_time:123212ms step_avg:88.71ms +step:1390/1670 train_time:123304ms step_avg:88.71ms +step:1391/1670 train_time:123392ms step_avg:88.71ms +step:1392/1670 train_time:123482ms step_avg:88.71ms +step:1393/1670 train_time:123571ms step_avg:88.71ms +step:1394/1670 train_time:123661ms step_avg:88.71ms +step:1395/1670 train_time:123750ms step_avg:88.71ms +step:1396/1670 train_time:123839ms step_avg:88.71ms +step:1397/1670 train_time:123927ms step_avg:88.71ms +step:1398/1670 train_time:124016ms step_avg:88.71ms +step:1399/1670 train_time:124108ms step_avg:88.71ms +step:1400/1670 train_time:124200ms step_avg:88.71ms +step:1401/1670 train_time:124290ms step_avg:88.72ms +step:1402/1670 train_time:124379ms step_avg:88.72ms +step:1403/1670 train_time:124469ms step_avg:88.72ms +step:1404/1670 train_time:124559ms step_avg:88.72ms +step:1405/1670 train_time:124647ms step_avg:88.72ms +step:1406/1670 train_time:124736ms step_avg:88.72ms +step:1407/1670 train_time:124826ms step_avg:88.72ms +step:1408/1670 train_time:124915ms step_avg:88.72ms +step:1409/1670 train_time:125006ms step_avg:88.72ms +step:1410/1670 train_time:125096ms step_avg:88.72ms +step:1411/1670 train_time:125187ms step_avg:88.72ms +step:1412/1670 train_time:125278ms step_avg:88.72ms +step:1413/1670 train_time:125369ms step_avg:88.73ms +step:1414/1670 train_time:125459ms step_avg:88.73ms +step:1415/1670 train_time:125548ms step_avg:88.73ms +step:1416/1670 train_time:125638ms step_avg:88.73ms +step:1417/1670 train_time:125727ms step_avg:88.73ms +step:1418/1670 train_time:125817ms step_avg:88.73ms +step:1419/1670 train_time:125908ms step_avg:88.73ms +step:1420/1670 train_time:125997ms step_avg:88.73ms +step:1421/1670 train_time:126089ms step_avg:88.73ms +step:1422/1670 train_time:126179ms step_avg:88.73ms +step:1423/1670 train_time:126271ms step_avg:88.74ms +step:1424/1670 train_time:126361ms step_avg:88.74ms +step:1425/1670 train_time:126450ms step_avg:88.74ms +step:1426/1670 train_time:126540ms step_avg:88.74ms +step:1427/1670 train_time:126629ms step_avg:88.74ms +step:1428/1670 train_time:126719ms step_avg:88.74ms +step:1429/1670 train_time:126810ms step_avg:88.74ms +step:1430/1670 train_time:126900ms step_avg:88.74ms +step:1431/1670 train_time:126990ms step_avg:88.74ms +step:1432/1670 train_time:127080ms step_avg:88.74ms +step:1433/1670 train_time:127171ms step_avg:88.74ms +step:1434/1670 train_time:127262ms step_avg:88.75ms +step:1435/1670 train_time:127352ms step_avg:88.75ms +step:1436/1670 train_time:127442ms step_avg:88.75ms +step:1437/1670 train_time:127530ms step_avg:88.75ms +step:1438/1670 train_time:127619ms step_avg:88.75ms +step:1439/1670 train_time:127708ms step_avg:88.75ms +step:1440/1670 train_time:127798ms step_avg:88.75ms +step:1441/1670 train_time:127888ms step_avg:88.75ms +step:1442/1670 train_time:127979ms step_avg:88.75ms +step:1443/1670 train_time:128069ms step_avg:88.75ms +step:1444/1670 train_time:128159ms step_avg:88.75ms +step:1445/1670 train_time:128250ms step_avg:88.75ms +step:1446/1670 train_time:128341ms step_avg:88.76ms +step:1447/1670 train_time:128430ms step_avg:88.76ms +step:1448/1670 train_time:128520ms step_avg:88.76ms +step:1449/1670 train_time:128610ms step_avg:88.76ms +step:1450/1670 train_time:128700ms step_avg:88.76ms +step:1451/1670 train_time:128789ms step_avg:88.76ms +step:1452/1670 train_time:128879ms step_avg:88.76ms +step:1453/1670 train_time:128969ms step_avg:88.76ms +step:1454/1670 train_time:129059ms step_avg:88.76ms +step:1455/1670 train_time:129149ms step_avg:88.76ms +step:1456/1670 train_time:129240ms step_avg:88.76ms +step:1457/1670 train_time:129330ms step_avg:88.76ms +step:1458/1670 train_time:129419ms step_avg:88.76ms +step:1459/1670 train_time:129509ms step_avg:88.77ms +step:1460/1670 train_time:129600ms step_avg:88.77ms +step:1461/1670 train_time:129689ms step_avg:88.77ms +step:1462/1670 train_time:129780ms step_avg:88.77ms +step:1463/1670 train_time:129870ms step_avg:88.77ms +step:1464/1670 train_time:129961ms step_avg:88.77ms +step:1465/1670 train_time:130050ms step_avg:88.77ms +step:1466/1670 train_time:130139ms step_avg:88.77ms +step:1467/1670 train_time:130229ms step_avg:88.77ms +step:1468/1670 train_time:130319ms step_avg:88.77ms +step:1469/1670 train_time:130409ms step_avg:88.77ms +step:1470/1670 train_time:130499ms step_avg:88.77ms +step:1471/1670 train_time:130589ms step_avg:88.78ms +step:1472/1670 train_time:130678ms step_avg:88.78ms +step:1473/1670 train_time:130769ms step_avg:88.78ms +step:1474/1670 train_time:130860ms step_avg:88.78ms +step:1475/1670 train_time:130950ms step_avg:88.78ms +step:1476/1670 train_time:131039ms step_avg:88.78ms +step:1477/1670 train_time:131129ms step_avg:88.78ms +step:1478/1670 train_time:131218ms step_avg:88.78ms +step:1479/1670 train_time:131308ms step_avg:88.78ms +step:1480/1670 train_time:131399ms step_avg:88.78ms +step:1481/1670 train_time:131488ms step_avg:88.78ms +step:1482/1670 train_time:131578ms step_avg:88.78ms +step:1483/1670 train_time:131669ms step_avg:88.79ms +step:1484/1670 train_time:131759ms step_avg:88.79ms +step:1485/1670 train_time:131849ms step_avg:88.79ms +step:1486/1670 train_time:131939ms step_avg:88.79ms +step:1487/1670 train_time:132029ms step_avg:88.79ms +step:1488/1670 train_time:132119ms step_avg:88.79ms +step:1489/1670 train_time:132210ms step_avg:88.79ms +step:1490/1670 train_time:132300ms step_avg:88.79ms +step:1491/1670 train_time:132390ms step_avg:88.79ms +step:1492/1670 train_time:132479ms step_avg:88.79ms +step:1493/1670 train_time:132569ms step_avg:88.79ms +step:1494/1670 train_time:132659ms step_avg:88.79ms +step:1495/1670 train_time:132749ms step_avg:88.80ms +step:1496/1670 train_time:132839ms step_avg:88.80ms +step:1497/1670 train_time:132928ms step_avg:88.80ms +step:1498/1670 train_time:133018ms step_avg:88.80ms +step:1499/1670 train_time:133107ms step_avg:88.80ms +step:1500/1670 train_time:133197ms step_avg:88.80ms +step:1500/1670 val_loss:3.3114 train_time:133288ms step_avg:88.86ms +step:1501/1670 train_time:133307ms step_avg:88.81ms +step:1502/1670 train_time:133381ms step_avg:88.80ms +step:1503/1670 train_time:133473ms step_avg:88.80ms +step:1504/1670 train_time:133563ms step_avg:88.80ms +step:1505/1670 train_time:133651ms step_avg:88.80ms +step:1506/1670 train_time:133739ms step_avg:88.80ms +step:1507/1670 train_time:133827ms step_avg:88.80ms +step:1508/1670 train_time:133916ms step_avg:88.80ms +step:1509/1670 train_time:134005ms step_avg:88.80ms +step:1510/1670 train_time:134096ms step_avg:88.81ms +step:1511/1670 train_time:134186ms step_avg:88.81ms +step:1512/1670 train_time:134279ms step_avg:88.81ms +step:1513/1670 train_time:134372ms step_avg:88.81ms +step:1514/1670 train_time:134463ms step_avg:88.81ms +step:1515/1670 train_time:134554ms step_avg:88.81ms +step:1516/1670 train_time:134643ms step_avg:88.81ms +step:1517/1670 train_time:134732ms step_avg:88.82ms +step:1518/1670 train_time:134821ms step_avg:88.81ms +step:1519/1670 train_time:134909ms step_avg:88.81ms +step:1520/1670 train_time:134998ms step_avg:88.81ms +step:1521/1670 train_time:135087ms step_avg:88.81ms +step:1522/1670 train_time:135177ms step_avg:88.82ms +step:1523/1670 train_time:135268ms step_avg:88.82ms +step:1524/1670 train_time:135358ms step_avg:88.82ms +step:1525/1670 train_time:135449ms step_avg:88.82ms +step:1526/1670 train_time:135539ms step_avg:88.82ms +step:1527/1670 train_time:135629ms step_avg:88.82ms +step:1528/1670 train_time:135718ms step_avg:88.82ms +step:1529/1670 train_time:135807ms step_avg:88.82ms +step:1530/1670 train_time:135896ms step_avg:88.82ms +step:1531/1670 train_time:135986ms step_avg:88.82ms +step:1532/1670 train_time:136076ms step_avg:88.82ms +step:1533/1670 train_time:136166ms step_avg:88.82ms +step:1534/1670 train_time:136256ms step_avg:88.82ms +step:1535/1670 train_time:136346ms step_avg:88.82ms +step:1536/1670 train_time:136438ms step_avg:88.83ms +step:1537/1670 train_time:136528ms step_avg:88.83ms +step:1538/1670 train_time:136618ms step_avg:88.83ms +step:1539/1670 train_time:136708ms step_avg:88.83ms +step:1540/1670 train_time:136797ms step_avg:88.83ms +step:1541/1670 train_time:136886ms step_avg:88.83ms +step:1542/1670 train_time:136976ms step_avg:88.83ms +step:1543/1670 train_time:137065ms step_avg:88.83ms +step:1544/1670 train_time:137155ms step_avg:88.83ms +step:1545/1670 train_time:137245ms step_avg:88.83ms +step:1546/1670 train_time:137337ms step_avg:88.83ms +step:1547/1670 train_time:137428ms step_avg:88.83ms +step:1548/1670 train_time:137518ms step_avg:88.84ms +step:1549/1670 train_time:137608ms step_avg:88.84ms +step:1550/1670 train_time:137698ms step_avg:88.84ms +step:1551/1670 train_time:137787ms step_avg:88.84ms +step:1552/1670 train_time:137876ms step_avg:88.84ms +step:1553/1670 train_time:137965ms step_avg:88.84ms +step:1554/1670 train_time:138055ms step_avg:88.84ms +step:1555/1670 train_time:138144ms step_avg:88.84ms +step:1556/1670 train_time:138234ms step_avg:88.84ms +step:1557/1670 train_time:138325ms step_avg:88.84ms +step:1558/1670 train_time:138416ms step_avg:88.84ms +step:1559/1670 train_time:138506ms step_avg:88.84ms +step:1560/1670 train_time:138597ms step_avg:88.84ms +step:1561/1670 train_time:138687ms step_avg:88.84ms +step:1562/1670 train_time:138777ms step_avg:88.85ms +step:1563/1670 train_time:138866ms step_avg:88.85ms +step:1564/1670 train_time:138954ms step_avg:88.85ms +step:1565/1670 train_time:139044ms step_avg:88.85ms +step:1566/1670 train_time:139134ms step_avg:88.85ms +step:1567/1670 train_time:139223ms step_avg:88.85ms +step:1568/1670 train_time:139314ms step_avg:88.85ms +step:1569/1670 train_time:139405ms step_avg:88.85ms +step:1570/1670 train_time:139497ms step_avg:88.85ms +step:1571/1670 train_time:139588ms step_avg:88.85ms +step:1572/1670 train_time:139678ms step_avg:88.85ms +step:1573/1670 train_time:139769ms step_avg:88.85ms +step:1574/1670 train_time:139858ms step_avg:88.86ms +step:1575/1670 train_time:139947ms step_avg:88.86ms +step:1576/1670 train_time:140036ms step_avg:88.86ms +step:1577/1670 train_time:140126ms step_avg:88.86ms +step:1578/1670 train_time:140215ms step_avg:88.86ms +step:1579/1670 train_time:140305ms step_avg:88.86ms +step:1580/1670 train_time:140396ms step_avg:88.86ms +step:1581/1670 train_time:140487ms step_avg:88.86ms +step:1582/1670 train_time:140578ms step_avg:88.86ms +step:1583/1670 train_time:140668ms step_avg:88.86ms +step:1584/1670 train_time:140757ms step_avg:88.86ms +step:1585/1670 train_time:140847ms step_avg:88.86ms +step:1586/1670 train_time:140937ms step_avg:88.86ms +step:1587/1670 train_time:141027ms step_avg:88.86ms +step:1588/1670 train_time:141117ms step_avg:88.86ms +step:1589/1670 train_time:141207ms step_avg:88.87ms +step:1590/1670 train_time:141298ms step_avg:88.87ms +step:1591/1670 train_time:141388ms step_avg:88.87ms +step:1592/1670 train_time:141477ms step_avg:88.87ms +step:1593/1670 train_time:141567ms step_avg:88.87ms +step:1594/1670 train_time:141657ms step_avg:88.87ms +step:1595/1670 train_time:141746ms step_avg:88.87ms +step:1596/1670 train_time:141835ms step_avg:88.87ms +step:1597/1670 train_time:141926ms step_avg:88.87ms +step:1598/1670 train_time:142015ms step_avg:88.87ms +step:1599/1670 train_time:142105ms step_avg:88.87ms +step:1600/1670 train_time:142195ms step_avg:88.87ms +step:1601/1670 train_time:142285ms step_avg:88.87ms +step:1602/1670 train_time:142375ms step_avg:88.87ms +step:1603/1670 train_time:142465ms step_avg:88.87ms +step:1604/1670 train_time:142555ms step_avg:88.87ms +step:1605/1670 train_time:142646ms step_avg:88.88ms +step:1606/1670 train_time:142735ms step_avg:88.88ms +step:1607/1670 train_time:142826ms step_avg:88.88ms +step:1608/1670 train_time:142916ms step_avg:88.88ms +step:1609/1670 train_time:143007ms step_avg:88.88ms +step:1610/1670 train_time:143096ms step_avg:88.88ms +step:1611/1670 train_time:143186ms step_avg:88.88ms +step:1612/1670 train_time:143276ms step_avg:88.88ms +step:1613/1670 train_time:143366ms step_avg:88.88ms +step:1614/1670 train_time:143456ms step_avg:88.88ms +step:1615/1670 train_time:143546ms step_avg:88.88ms +step:1616/1670 train_time:143637ms step_avg:88.88ms +step:1617/1670 train_time:143727ms step_avg:88.88ms +step:1618/1670 train_time:143816ms step_avg:88.89ms +step:1619/1670 train_time:143907ms step_avg:88.89ms +step:1620/1670 train_time:143997ms step_avg:88.89ms +step:1621/1670 train_time:144087ms step_avg:88.89ms +step:1622/1670 train_time:144177ms step_avg:88.89ms +step:1623/1670 train_time:144267ms step_avg:88.89ms +step:1624/1670 train_time:144357ms step_avg:88.89ms +step:1625/1670 train_time:144448ms step_avg:88.89ms +step:1625/1670 val_loss:3.2884 train_time:144538ms step_avg:88.95ms +step:1626/1670 train_time:144559ms step_avg:88.90ms +step:1627/1670 train_time:144631ms step_avg:88.89ms +step:1628/1670 train_time:144725ms step_avg:88.90ms +step:1629/1670 train_time:144816ms step_avg:88.90ms +step:1630/1670 train_time:144904ms step_avg:88.90ms +step:1631/1670 train_time:144993ms step_avg:88.90ms +step:1632/1670 train_time:145082ms step_avg:88.90ms +step:1633/1670 train_time:145170ms step_avg:88.90ms +step:1634/1670 train_time:145259ms step_avg:88.90ms +step:1635/1670 train_time:145347ms step_avg:88.90ms +step:1636/1670 train_time:145436ms step_avg:88.90ms +step:1637/1670 train_time:145528ms step_avg:88.90ms +step:1638/1670 train_time:145621ms step_avg:88.90ms +step:1639/1670 train_time:145712ms step_avg:88.90ms +step:1640/1670 train_time:145803ms step_avg:88.90ms +step:1641/1670 train_time:145893ms step_avg:88.90ms +step:1642/1670 train_time:145982ms step_avg:88.90ms +step:1643/1670 train_time:146071ms step_avg:88.90ms +step:1644/1670 train_time:146160ms step_avg:88.91ms +step:1645/1670 train_time:146249ms step_avg:88.91ms +step:1646/1670 train_time:146339ms step_avg:88.91ms +step:1647/1670 train_time:146428ms step_avg:88.91ms +step:1648/1670 train_time:146520ms step_avg:88.91ms +step:1649/1670 train_time:146611ms step_avg:88.91ms +step:1650/1670 train_time:146701ms step_avg:88.91ms +step:1651/1670 train_time:146791ms step_avg:88.91ms +step:1652/1670 train_time:146881ms step_avg:88.91ms +step:1653/1670 train_time:146970ms step_avg:88.91ms +step:1654/1670 train_time:147060ms step_avg:88.91ms +step:1655/1670 train_time:147149ms step_avg:88.91ms +step:1656/1670 train_time:147239ms step_avg:88.91ms +step:1657/1670 train_time:147327ms step_avg:88.91ms +step:1658/1670 train_time:147418ms step_avg:88.91ms +step:1659/1670 train_time:147507ms step_avg:88.91ms +step:1660/1670 train_time:147597ms step_avg:88.91ms +step:1661/1670 train_time:147687ms step_avg:88.91ms +step:1662/1670 train_time:147778ms step_avg:88.92ms +step:1663/1670 train_time:147868ms step_avg:88.92ms +step:1664/1670 train_time:147958ms step_avg:88.92ms +step:1665/1670 train_time:148047ms step_avg:88.92ms +step:1666/1670 train_time:148137ms step_avg:88.92ms +step:1667/1670 train_time:148226ms step_avg:88.92ms +step:1668/1670 train_time:148316ms step_avg:88.92ms +step:1669/1670 train_time:148405ms step_avg:88.92ms +step:1670/1670 train_time:148496ms step_avg:88.92ms +step:1670/1670 val_loss:3.2792 train_time:148587ms step_avg:88.97ms +peak memory allocated: 30511 MiB reserved: 45934 MiB diff --git a/train_gpt.py b/train_gpt.py index 95d474470..ae853d017 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -26,7 +26,7 @@ # torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min import triton import triton.language as tl -from flash_attn_interface import flash_attn_varlen_func +from kernels import get_kernel from torch import Tensor, nn dynamo.config.recompile_limit = 64 @@ -166,7 +166,7 @@ def _pid_to_block( key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], ) @triton.jit -def ns_line_1_kernel( +def XXT_kernel( A_ptr, C_ptr, M, K, a_stride_b, a_stride_r, a_stride_c, @@ -224,7 +224,7 @@ def ns_line_1_kernel( c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) tl.store(c_ptrs_t, output.T, mask=c_mask_t) -def ns_line_1(A: torch.Tensor, out: torch.Tensor): +def XXT(A: torch.Tensor, out: torch.Tensor): """ Launch Triton kernel to compute C = A @ A.T """ @@ -240,7 +240,7 @@ def ns_line_1(A: torch.Tensor, out: torch.Tensor): grid = lambda meta: ( batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), ) - ns_line_1_kernel[grid]( + XXT_kernel[grid]( A_ptr=A, C_ptr=out, M=M, @@ -259,7 +259,7 @@ def ns_line_1(A: torch.Tensor, out: torch.Tensor): key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], ) @triton.jit -def ns_line_2_kernel( +def ba_plus_cAA_kernel( A_ptr, C_ptr, M, a_stride_b, a_stride_r, a_stride_c, @@ -271,8 +271,8 @@ def ns_line_2_kernel( GROUP_SIZE_M: tl.constexpr, LOWER_UPPER: tl.constexpr, ): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels pid = tl.program_id(axis=0) batch_idx, m_idx, n_idx = _pid_to_block( pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M @@ -331,7 +331,7 @@ def ns_line_2_kernel( c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) tl.store(c_ptrs_t, output.T, mask=c_mask_t) -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): """ Launch Triton kernel to compute C = alpha * A @ A.T + beta * A """ @@ -348,7 +348,7 @@ def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): grid = lambda meta: ( batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), ) - ns_line_2_kernel[grid]( + ba_plus_cAA_kernel[grid]( A_ptr=A, C_ptr=out, M=M, @@ -363,15 +363,28 @@ def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): ) return out +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +coeffs_list = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + @torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) +def polar_express(G: torch.Tensor): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + Code adapted from https://github.com/NoahAmsel/PolarExpress/tree/main by @varunneal. + """ X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) # Allocate buffers X = X.contiguous() @@ -379,13 +392,13 @@ def newton_schulz_triton(G: torch.Tensor): B = torch.empty_like(A) C = torch.empty_like(X) - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + # Perform the iterations + for a, b, c in coeffs_list: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X X, C = C, X # Swap references to avoid unnecessary copies if G.size(-2) > G.size(-1): @@ -617,10 +630,10 @@ def step(self): d1 = original_shape[1] d2 = original_shape[2] // 4 batched = batched_update_grads.view(batch, d1, d2) - v_chunk = newton_schulz_triton(batched) + v_chunk = polar_express(batched) v_chunk = v_chunk.view(original_shape) else: - v_chunk = newton_schulz_triton(batched_update_grads) + v_chunk = polar_express(batched_update_grads) # Add the computed zeropower update to the parameters in the buffer. # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. @@ -826,6 +839,8 @@ class AttnArgs: sin: torch.Tensor attn_scale: float +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + class CausalSelfAttention(nn.Module): def __init__(self, dim: int, head_dim: int, num_heads: int): super().__init__() @@ -873,7 +888,7 @@ def forward(self, x: Tensor, attn_args: AttnArgs): max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng - y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) y = y.view(B, T, self.num_heads, self.head_dim) y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) @@ -1022,7 +1037,7 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_sho x = norm(x) logits = self.lm_head(x) # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits = 30 * torch.sigmoid(logits / 7.5) logits_for_loss = logits.float() if not self.training else logits loss = F.cross_entropy( logits_for_loss.view(-1, logits_for_loss.size(-1)), @@ -1220,7 +1235,7 @@ class Hyperparameters: train_max_seq_len: int = 128 * 16 val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1640 # number of iterations to run + num_iterations: int = 1630 # number of iterations to run iteration_extension = 40 # number of iterations to continue training at final cooldown and window size cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate # evaluation and logging @@ -1310,7 +1325,7 @@ def nvidia_smi(): eps=1e-8, weight_decay=0.0, ) -optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0) optimizers = [optimizer1, optimizer2] for opt in optimizers: for group in opt.param_groups: