diff --git a/records/092725_BF16CE/08c0770f-17fc-44cd-971d-734a7a28a3e3.txt b/records/092725_BF16CE/08c0770f-17fc-44cd-971d-734a7a28a3e3.txt new file mode 100644 index 000000000..493a5bb0a --- /dev/null +++ b/records/092725_BF16CE/08c0770f-17fc-44cd-971d-734a7a28a3e3.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:07:16 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 168906 C /usr/bin/python 0MiB | +| 0 N/A N/A 168907 C /usr/bin/python 0MiB | +| 0 N/A N/A 168908 C /usr/bin/python 0MiB | +| 0 N/A N/A 168909 C /usr/bin/python 0MiB | +| 0 N/A N/A 168910 C /usr/bin/python 0MiB | +| 0 N/A N/A 168911 C /usr/bin/python 0MiB | +| 0 N/A N/A 168912 C /usr/bin/python 0MiB | +| 0 N/A N/A 168913 C /usr/bin/python 0MiB | +| 1 N/A N/A 168907 C /usr/bin/python 0MiB | +| 2 N/A N/A 168908 C /usr/bin/python 0MiB | +| 3 N/A N/A 168909 C /usr/bin/python 0MiB | +| 4 N/A N/A 168910 C /usr/bin/python 0MiB | +| 5 N/A N/A 168911 C /usr/bin/python 0MiB | +| 6 N/A N/A 168912 C /usr/bin/python 0MiB | +| 7 N/A N/A 168913 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:145ms step_avg:145.22ms +step:2/1680 train_time:166ms step_avg:82.87ms +step:3/1680 train_time:229ms step_avg:76.30ms +step:4/1680 train_time:314ms step_avg:78.41ms +step:5/1680 train_time:400ms step_avg:79.90ms +step:6/1680 train_time:486ms step_avg:80.96ms +step:7/1680 train_time:572ms step_avg:81.69ms +step:8/1680 train_time:659ms step_avg:82.43ms +step:9/1680 train_time:746ms step_avg:82.88ms +step:10/1680 train_time:833ms step_avg:83.25ms +step:11/1680 train_time:919ms step_avg:83.53ms +step:12/1680 train_time:1006ms step_avg:83.85ms +step:13/1680 train_time:1096ms step_avg:84.33ms +step:14/1680 train_time:1187ms step_avg:84.82ms +step:15/1680 train_time:1275ms step_avg:85.02ms +step:16/1680 train_time:1362ms step_avg:85.14ms +step:17/1680 train_time:1450ms step_avg:85.27ms +step:18/1680 train_time:1536ms step_avg:85.33ms +step:19/1680 train_time:1622ms step_avg:85.35ms +step:20/1680 train_time:1709ms step_avg:85.43ms +step:21/1680 train_time:1795ms step_avg:85.48ms +step:22/1680 train_time:1882ms step_avg:85.54ms +step:23/1680 train_time:1969ms step_avg:85.62ms +step:24/1680 train_time:2057ms step_avg:85.72ms +step:25/1680 train_time:2145ms step_avg:85.81ms +step:26/1680 train_time:2233ms step_avg:85.89ms +step:27/1680 train_time:2321ms step_avg:85.95ms +step:28/1680 train_time:2408ms step_avg:85.99ms +step:29/1680 train_time:2494ms step_avg:86.01ms +step:30/1680 train_time:2583ms step_avg:86.09ms +step:31/1680 train_time:2670ms step_avg:86.12ms +step:32/1680 train_time:2756ms step_avg:86.13ms +step:33/1680 train_time:2843ms step_avg:86.15ms +step:34/1680 train_time:2931ms step_avg:86.20ms +step:35/1680 train_time:3018ms step_avg:86.24ms +step:36/1680 train_time:3107ms step_avg:86.30ms +step:37/1680 train_time:3194ms step_avg:86.33ms +step:38/1680 train_time:3282ms step_avg:86.36ms +step:39/1680 train_time:3370ms step_avg:86.41ms +step:40/1680 train_time:3457ms step_avg:86.42ms +step:41/1680 train_time:3544ms step_avg:86.44ms +step:42/1680 train_time:3631ms step_avg:86.46ms +step:43/1680 train_time:3718ms step_avg:86.46ms +step:44/1680 train_time:3805ms step_avg:86.47ms +step:45/1680 train_time:3892ms step_avg:86.50ms +step:46/1680 train_time:3979ms step_avg:86.51ms +step:47/1680 train_time:4067ms step_avg:86.53ms +step:48/1680 train_time:4155ms step_avg:86.56ms +step:49/1680 train_time:4242ms step_avg:86.58ms +step:50/1680 train_time:4331ms step_avg:86.61ms +step:51/1680 train_time:4417ms step_avg:86.61ms +step:52/1680 train_time:4505ms step_avg:86.63ms +step:53/1680 train_time:4592ms step_avg:86.64ms +step:54/1680 train_time:4679ms step_avg:86.65ms +step:55/1680 train_time:4765ms step_avg:86.64ms +step:56/1680 train_time:4853ms step_avg:86.66ms +step:57/1680 train_time:4940ms step_avg:86.66ms +step:58/1680 train_time:5027ms step_avg:86.67ms +step:59/1680 train_time:5114ms step_avg:86.68ms +step:60/1680 train_time:5202ms step_avg:86.70ms +step:61/1680 train_time:5290ms step_avg:86.72ms +step:62/1680 train_time:5377ms step_avg:86.72ms +step:63/1680 train_time:5464ms step_avg:86.72ms +step:64/1680 train_time:5550ms step_avg:86.72ms +step:65/1680 train_time:5637ms step_avg:86.72ms +step:66/1680 train_time:5724ms step_avg:86.73ms +step:67/1680 train_time:5812ms step_avg:86.74ms +step:68/1680 train_time:5899ms step_avg:86.75ms +step:69/1680 train_time:5986ms step_avg:86.75ms +step:70/1680 train_time:6074ms step_avg:86.77ms +step:71/1680 train_time:6161ms step_avg:86.77ms +step:72/1680 train_time:6248ms step_avg:86.77ms +step:73/1680 train_time:6335ms step_avg:86.78ms +step:74/1680 train_time:6422ms step_avg:86.79ms +step:75/1680 train_time:6509ms step_avg:86.79ms +step:76/1680 train_time:6596ms step_avg:86.79ms +step:77/1680 train_time:6683ms step_avg:86.80ms +step:78/1680 train_time:6771ms step_avg:86.81ms +step:79/1680 train_time:6858ms step_avg:86.81ms +step:80/1680 train_time:6945ms step_avg:86.82ms +step:81/1680 train_time:7033ms step_avg:86.82ms +step:82/1680 train_time:7120ms step_avg:86.83ms +step:83/1680 train_time:7207ms step_avg:86.83ms +step:84/1680 train_time:7295ms step_avg:86.84ms +step:85/1680 train_time:7382ms step_avg:86.85ms +step:86/1680 train_time:7469ms step_avg:86.85ms +step:87/1680 train_time:7556ms step_avg:86.85ms +step:88/1680 train_time:7643ms step_avg:86.86ms +step:89/1680 train_time:7732ms step_avg:86.87ms +step:90/1680 train_time:7819ms step_avg:86.88ms +step:91/1680 train_time:7906ms step_avg:86.88ms +step:92/1680 train_time:7993ms step_avg:86.88ms +step:93/1680 train_time:8080ms step_avg:86.89ms +step:94/1680 train_time:8170ms step_avg:86.91ms +step:95/1680 train_time:8255ms step_avg:86.90ms +step:96/1680 train_time:8342ms step_avg:86.90ms +step:97/1680 train_time:8430ms step_avg:86.91ms +step:98/1680 train_time:8517ms step_avg:86.91ms +step:99/1680 train_time:8604ms step_avg:86.91ms +step:100/1680 train_time:8692ms step_avg:86.92ms +step:101/1680 train_time:8779ms step_avg:86.92ms +step:102/1680 train_time:8866ms step_avg:86.92ms +step:103/1680 train_time:8953ms step_avg:86.92ms +step:104/1680 train_time:9040ms step_avg:86.93ms +step:105/1680 train_time:9128ms step_avg:86.94ms +step:106/1680 train_time:9215ms step_avg:86.94ms +step:107/1680 train_time:9302ms step_avg:86.93ms +step:108/1680 train_time:9390ms step_avg:86.94ms +step:109/1680 train_time:9477ms step_avg:86.94ms +step:110/1680 train_time:9564ms step_avg:86.95ms +step:111/1680 train_time:9651ms step_avg:86.95ms +step:112/1680 train_time:9738ms step_avg:86.95ms +step:113/1680 train_time:9825ms step_avg:86.95ms +step:114/1680 train_time:9912ms step_avg:86.95ms +step:115/1680 train_time:9999ms step_avg:86.95ms +step:116/1680 train_time:10086ms step_avg:86.95ms +step:117/1680 train_time:10173ms step_avg:86.95ms +step:118/1680 train_time:10260ms step_avg:86.95ms +step:119/1680 train_time:10346ms step_avg:86.95ms +step:120/1680 train_time:10434ms step_avg:86.95ms +step:121/1680 train_time:10521ms step_avg:86.95ms +step:122/1680 train_time:10609ms step_avg:86.96ms +step:123/1680 train_time:10696ms step_avg:86.96ms +step:124/1680 train_time:10783ms step_avg:86.96ms +step:125/1680 train_time:10871ms step_avg:86.97ms +step:125/1680 val_loss:4.3173 train_time:10959ms step_avg:87.67ms +step:126/1680 train_time:10982ms step_avg:87.16ms +step:127/1680 train_time:11049ms step_avg:87.00ms +step:128/1680 train_time:11143ms step_avg:87.06ms +step:129/1680 train_time:11233ms step_avg:87.07ms +step:130/1680 train_time:11320ms step_avg:87.08ms +step:131/1680 train_time:11407ms step_avg:87.08ms +step:132/1680 train_time:11493ms step_avg:87.07ms +step:133/1680 train_time:11579ms step_avg:87.06ms +step:134/1680 train_time:11665ms step_avg:87.05ms +step:135/1680 train_time:11751ms step_avg:87.04ms +step:136/1680 train_time:11836ms step_avg:87.03ms +step:137/1680 train_time:11922ms step_avg:87.02ms +step:138/1680 train_time:12009ms step_avg:87.02ms +step:139/1680 train_time:12099ms step_avg:87.04ms +step:140/1680 train_time:12188ms step_avg:87.06ms +step:141/1680 train_time:12276ms step_avg:87.06ms +step:142/1680 train_time:12363ms step_avg:87.07ms +step:143/1680 train_time:12450ms step_avg:87.07ms +step:144/1680 train_time:12537ms step_avg:87.06ms +step:145/1680 train_time:12623ms step_avg:87.06ms +step:146/1680 train_time:12710ms step_avg:87.05ms +step:147/1680 train_time:12796ms step_avg:87.05ms +step:148/1680 train_time:12882ms step_avg:87.04ms +step:149/1680 train_time:12968ms step_avg:87.04ms +step:150/1680 train_time:13056ms step_avg:87.04ms +step:151/1680 train_time:13144ms step_avg:87.05ms +step:152/1680 train_time:13232ms step_avg:87.05ms +step:153/1680 train_time:13320ms step_avg:87.06ms +step:154/1680 train_time:13408ms step_avg:87.06ms +step:155/1680 train_time:13495ms step_avg:87.06ms +step:156/1680 train_time:13582ms step_avg:87.07ms +step:157/1680 train_time:13669ms step_avg:87.06ms +step:158/1680 train_time:13755ms step_avg:87.06ms +step:159/1680 train_time:13842ms step_avg:87.06ms +step:160/1680 train_time:13928ms step_avg:87.05ms +step:161/1680 train_time:14014ms step_avg:87.04ms +step:162/1680 train_time:14102ms step_avg:87.05ms +step:163/1680 train_time:14189ms step_avg:87.05ms +step:164/1680 train_time:14277ms step_avg:87.05ms +step:165/1680 train_time:14364ms step_avg:87.05ms +step:166/1680 train_time:14451ms step_avg:87.06ms +step:167/1680 train_time:14538ms step_avg:87.05ms +step:168/1680 train_time:14625ms step_avg:87.05ms +step:169/1680 train_time:14712ms step_avg:87.05ms +step:170/1680 train_time:14799ms step_avg:87.05ms +step:171/1680 train_time:14886ms step_avg:87.05ms +step:172/1680 train_time:14972ms step_avg:87.05ms +step:173/1680 train_time:15060ms step_avg:87.05ms +step:174/1680 train_time:15147ms step_avg:87.05ms +step:175/1680 train_time:15234ms step_avg:87.05ms +step:176/1680 train_time:15322ms step_avg:87.06ms +step:177/1680 train_time:15409ms step_avg:87.06ms +step:178/1680 train_time:15497ms step_avg:87.06ms +step:179/1680 train_time:15584ms step_avg:87.06ms +step:180/1680 train_time:15670ms step_avg:87.06ms +step:181/1680 train_time:15757ms step_avg:87.05ms +step:182/1680 train_time:15844ms step_avg:87.05ms +step:183/1680 train_time:15931ms step_avg:87.06ms +step:184/1680 train_time:16020ms step_avg:87.07ms +step:185/1680 train_time:16106ms step_avg:87.06ms +step:186/1680 train_time:16193ms step_avg:87.06ms +step:187/1680 train_time:16280ms step_avg:87.06ms +step:188/1680 train_time:16367ms step_avg:87.06ms +step:189/1680 train_time:16454ms step_avg:87.06ms +step:190/1680 train_time:16541ms step_avg:87.06ms +step:191/1680 train_time:16628ms step_avg:87.06ms +step:192/1680 train_time:16715ms step_avg:87.06ms +step:193/1680 train_time:16802ms step_avg:87.06ms +step:194/1680 train_time:16888ms step_avg:87.05ms +step:195/1680 train_time:16976ms step_avg:87.05ms +step:196/1680 train_time:17063ms step_avg:87.05ms +step:197/1680 train_time:17150ms step_avg:87.06ms +step:198/1680 train_time:17237ms step_avg:87.06ms +step:199/1680 train_time:17324ms step_avg:87.06ms +step:200/1680 train_time:17411ms step_avg:87.06ms +step:201/1680 train_time:17498ms step_avg:87.06ms +step:202/1680 train_time:17586ms step_avg:87.06ms +step:203/1680 train_time:17672ms step_avg:87.05ms +step:204/1680 train_time:17759ms step_avg:87.05ms +step:205/1680 train_time:17846ms step_avg:87.05ms +step:206/1680 train_time:17933ms step_avg:87.05ms +step:207/1680 train_time:18020ms step_avg:87.05ms +step:208/1680 train_time:18108ms step_avg:87.06ms +step:209/1680 train_time:18195ms step_avg:87.06ms +step:210/1680 train_time:18282ms step_avg:87.06ms +step:211/1680 train_time:18369ms step_avg:87.06ms +step:212/1680 train_time:18457ms step_avg:87.06ms +step:213/1680 train_time:18544ms step_avg:87.06ms +step:214/1680 train_time:18631ms step_avg:87.06ms +step:215/1680 train_time:18718ms step_avg:87.06ms +step:216/1680 train_time:18805ms step_avg:87.06ms +step:217/1680 train_time:18893ms step_avg:87.06ms +step:218/1680 train_time:18981ms step_avg:87.07ms +step:219/1680 train_time:19068ms step_avg:87.07ms +step:220/1680 train_time:19155ms step_avg:87.07ms +step:221/1680 train_time:19241ms step_avg:87.06ms +step:222/1680 train_time:19328ms step_avg:87.06ms +step:223/1680 train_time:19416ms step_avg:87.07ms +step:224/1680 train_time:19503ms step_avg:87.07ms +step:225/1680 train_time:19590ms step_avg:87.07ms +step:226/1680 train_time:19677ms step_avg:87.07ms +step:227/1680 train_time:19764ms step_avg:87.07ms +step:228/1680 train_time:19851ms step_avg:87.07ms +step:229/1680 train_time:19938ms step_avg:87.07ms +step:230/1680 train_time:20025ms step_avg:87.07ms +step:231/1680 train_time:20112ms step_avg:87.07ms +step:232/1680 train_time:20199ms step_avg:87.07ms +step:233/1680 train_time:20286ms step_avg:87.06ms +step:234/1680 train_time:20373ms step_avg:87.06ms +step:235/1680 train_time:20460ms step_avg:87.06ms +step:236/1680 train_time:20547ms step_avg:87.06ms +step:237/1680 train_time:20635ms step_avg:87.07ms +step:238/1680 train_time:20722ms step_avg:87.07ms +step:239/1680 train_time:20808ms step_avg:87.06ms +step:240/1680 train_time:20897ms step_avg:87.07ms +step:241/1680 train_time:20984ms step_avg:87.07ms +step:242/1680 train_time:21071ms step_avg:87.07ms +step:243/1680 train_time:21157ms step_avg:87.07ms +step:244/1680 train_time:21244ms step_avg:87.07ms +step:245/1680 train_time:21331ms step_avg:87.06ms +step:246/1680 train_time:21418ms step_avg:87.06ms +step:247/1680 train_time:21505ms step_avg:87.06ms +step:248/1680 train_time:21592ms step_avg:87.07ms +step:249/1680 train_time:21680ms step_avg:87.07ms +step:250/1680 train_time:21768ms step_avg:87.07ms +step:250/1680 val_loss:3.9760 train_time:21856ms step_avg:87.42ms +step:251/1680 train_time:21878ms step_avg:87.16ms +step:252/1680 train_time:21944ms step_avg:87.08ms +step:253/1680 train_time:22034ms step_avg:87.09ms +step:254/1680 train_time:22123ms step_avg:87.10ms +step:255/1680 train_time:22209ms step_avg:87.09ms +step:256/1680 train_time:22295ms step_avg:87.09ms +step:257/1680 train_time:22381ms step_avg:87.09ms +step:258/1680 train_time:22468ms step_avg:87.08ms +step:259/1680 train_time:22554ms step_avg:87.08ms +step:260/1680 train_time:22639ms step_avg:87.07ms +step:261/1680 train_time:22726ms step_avg:87.07ms +step:262/1680 train_time:22815ms step_avg:87.08ms +step:263/1680 train_time:22903ms step_avg:87.08ms +step:264/1680 train_time:22992ms step_avg:87.09ms +step:265/1680 train_time:23080ms step_avg:87.09ms +step:266/1680 train_time:23167ms step_avg:87.09ms +step:267/1680 train_time:23253ms step_avg:87.09ms +step:268/1680 train_time:23341ms step_avg:87.09ms +step:269/1680 train_time:23427ms step_avg:87.09ms +step:270/1680 train_time:23513ms step_avg:87.08ms +step:271/1680 train_time:23599ms step_avg:87.08ms +step:272/1680 train_time:23686ms step_avg:87.08ms +step:273/1680 train_time:23773ms step_avg:87.08ms +step:274/1680 train_time:23860ms step_avg:87.08ms +step:275/1680 train_time:23948ms step_avg:87.09ms +step:276/1680 train_time:24036ms step_avg:87.09ms +step:277/1680 train_time:24123ms step_avg:87.09ms +step:278/1680 train_time:24210ms step_avg:87.09ms +step:279/1680 train_time:24297ms step_avg:87.09ms +step:280/1680 train_time:24384ms step_avg:87.09ms +step:281/1680 train_time:24470ms step_avg:87.08ms +step:282/1680 train_time:24557ms step_avg:87.08ms +step:283/1680 train_time:24644ms step_avg:87.08ms +step:284/1680 train_time:24730ms step_avg:87.08ms +step:285/1680 train_time:24817ms step_avg:87.08ms +step:286/1680 train_time:24905ms step_avg:87.08ms +step:287/1680 train_time:24993ms step_avg:87.08ms +step:288/1680 train_time:25079ms step_avg:87.08ms +step:289/1680 train_time:25167ms step_avg:87.08ms +step:290/1680 train_time:25254ms step_avg:87.08ms +step:291/1680 train_time:25341ms step_avg:87.08ms +step:292/1680 train_time:25429ms step_avg:87.08ms +step:293/1680 train_time:25516ms step_avg:87.09ms +step:294/1680 train_time:25602ms step_avg:87.08ms +step:295/1680 train_time:25689ms step_avg:87.08ms +step:296/1680 train_time:25777ms step_avg:87.08ms +step:297/1680 train_time:25864ms step_avg:87.08ms +step:298/1680 train_time:25951ms step_avg:87.09ms +step:299/1680 train_time:26038ms step_avg:87.09ms +step:300/1680 train_time:26126ms step_avg:87.09ms +step:301/1680 train_time:26212ms step_avg:87.08ms +step:302/1680 train_time:26300ms step_avg:87.09ms +step:303/1680 train_time:26387ms step_avg:87.09ms +step:304/1680 train_time:26474ms step_avg:87.09ms +step:305/1680 train_time:26561ms step_avg:87.08ms +step:306/1680 train_time:26648ms step_avg:87.08ms +step:307/1680 train_time:26734ms step_avg:87.08ms +step:308/1680 train_time:26822ms step_avg:87.08ms +step:309/1680 train_time:26908ms step_avg:87.08ms +step:310/1680 train_time:26995ms step_avg:87.08ms +step:311/1680 train_time:27082ms step_avg:87.08ms +step:312/1680 train_time:27169ms step_avg:87.08ms +step:313/1680 train_time:27256ms step_avg:87.08ms +step:314/1680 train_time:27344ms step_avg:87.08ms +step:315/1680 train_time:27431ms step_avg:87.08ms +step:316/1680 train_time:27518ms step_avg:87.08ms +step:317/1680 train_time:27604ms step_avg:87.08ms +step:318/1680 train_time:27692ms step_avg:87.08ms +step:319/1680 train_time:27778ms step_avg:87.08ms +step:320/1680 train_time:27866ms step_avg:87.08ms +step:321/1680 train_time:27952ms step_avg:87.08ms +step:322/1680 train_time:28040ms step_avg:87.08ms +step:323/1680 train_time:28127ms step_avg:87.08ms +step:324/1680 train_time:28214ms step_avg:87.08ms +step:325/1680 train_time:28301ms step_avg:87.08ms +step:326/1680 train_time:28388ms step_avg:87.08ms +step:327/1680 train_time:28476ms step_avg:87.08ms +step:328/1680 train_time:28562ms step_avg:87.08ms +step:329/1680 train_time:28649ms step_avg:87.08ms +step:330/1680 train_time:28735ms step_avg:87.08ms +step:331/1680 train_time:28822ms step_avg:87.08ms +step:332/1680 train_time:28910ms step_avg:87.08ms +step:333/1680 train_time:28997ms step_avg:87.08ms +step:334/1680 train_time:29084ms step_avg:87.08ms +step:335/1680 train_time:29171ms step_avg:87.08ms +step:336/1680 train_time:29258ms step_avg:87.08ms +step:337/1680 train_time:29346ms step_avg:87.08ms +step:338/1680 train_time:29433ms step_avg:87.08ms +step:339/1680 train_time:29520ms step_avg:87.08ms +step:340/1680 train_time:29607ms step_avg:87.08ms +step:341/1680 train_time:29694ms step_avg:87.08ms +step:342/1680 train_time:29780ms step_avg:87.08ms +step:343/1680 train_time:29868ms step_avg:87.08ms +step:344/1680 train_time:29955ms step_avg:87.08ms +step:345/1680 train_time:30042ms step_avg:87.08ms +step:346/1680 train_time:30128ms step_avg:87.08ms +step:347/1680 train_time:30216ms step_avg:87.08ms +step:348/1680 train_time:30302ms step_avg:87.08ms +step:349/1680 train_time:30389ms step_avg:87.08ms +step:350/1680 train_time:30476ms step_avg:87.07ms +step:351/1680 train_time:30563ms step_avg:87.07ms +step:352/1680 train_time:30649ms step_avg:87.07ms +step:353/1680 train_time:30736ms step_avg:87.07ms +step:354/1680 train_time:30824ms step_avg:87.07ms +step:355/1680 train_time:30910ms step_avg:87.07ms +step:356/1680 train_time:30997ms step_avg:87.07ms +step:357/1680 train_time:31084ms step_avg:87.07ms +step:358/1680 train_time:31172ms step_avg:87.07ms +step:359/1680 train_time:31259ms step_avg:87.07ms +step:360/1680 train_time:31346ms step_avg:87.07ms +step:361/1680 train_time:31432ms step_avg:87.07ms +step:362/1680 train_time:31519ms step_avg:87.07ms +step:363/1680 train_time:31606ms step_avg:87.07ms +step:364/1680 train_time:31693ms step_avg:87.07ms +step:365/1680 train_time:31780ms step_avg:87.07ms +step:366/1680 train_time:31867ms step_avg:87.07ms +step:367/1680 train_time:31953ms step_avg:87.07ms +step:368/1680 train_time:32041ms step_avg:87.07ms +step:369/1680 train_time:32128ms step_avg:87.07ms +step:370/1680 train_time:32215ms step_avg:87.07ms +step:371/1680 train_time:32302ms step_avg:87.07ms +step:372/1680 train_time:32389ms step_avg:87.07ms +step:373/1680 train_time:32476ms step_avg:87.07ms +step:374/1680 train_time:32563ms step_avg:87.07ms +step:375/1680 train_time:32650ms step_avg:87.07ms +step:375/1680 val_loss:3.8203 train_time:32739ms step_avg:87.30ms +step:376/1680 train_time:32757ms step_avg:87.12ms +step:377/1680 train_time:32831ms step_avg:87.09ms +step:378/1680 train_time:32921ms step_avg:87.09ms +step:379/1680 train_time:33010ms step_avg:87.10ms +step:380/1680 train_time:33097ms step_avg:87.10ms +step:381/1680 train_time:33183ms step_avg:87.10ms +step:382/1680 train_time:33269ms step_avg:87.09ms +step:383/1680 train_time:33355ms step_avg:87.09ms +step:384/1680 train_time:33441ms step_avg:87.09ms +step:385/1680 train_time:33527ms step_avg:87.08ms +step:386/1680 train_time:33613ms step_avg:87.08ms +step:387/1680 train_time:33701ms step_avg:87.08ms +step:388/1680 train_time:33789ms step_avg:87.09ms +step:389/1680 train_time:33879ms step_avg:87.09ms +step:390/1680 train_time:33967ms step_avg:87.10ms +step:391/1680 train_time:34055ms step_avg:87.10ms +step:392/1680 train_time:34142ms step_avg:87.10ms +step:393/1680 train_time:34230ms step_avg:87.10ms +step:394/1680 train_time:34315ms step_avg:87.09ms +step:395/1680 train_time:34401ms step_avg:87.09ms +step:396/1680 train_time:34488ms step_avg:87.09ms +step:397/1680 train_time:34574ms step_avg:87.09ms +step:398/1680 train_time:34660ms step_avg:87.09ms +step:399/1680 train_time:34749ms step_avg:87.09ms +step:400/1680 train_time:34837ms step_avg:87.09ms +step:401/1680 train_time:34926ms step_avg:87.10ms +step:402/1680 train_time:35014ms step_avg:87.10ms +step:403/1680 train_time:35101ms step_avg:87.10ms +step:404/1680 train_time:35187ms step_avg:87.10ms +step:405/1680 train_time:35274ms step_avg:87.10ms +step:406/1680 train_time:35360ms step_avg:87.09ms +step:407/1680 train_time:35446ms step_avg:87.09ms +step:408/1680 train_time:35533ms step_avg:87.09ms +step:409/1680 train_time:35619ms step_avg:87.09ms +step:410/1680 train_time:35706ms step_avg:87.09ms +step:411/1680 train_time:35794ms step_avg:87.09ms +step:412/1680 train_time:35882ms step_avg:87.09ms +step:413/1680 train_time:35969ms step_avg:87.09ms +step:414/1680 train_time:36057ms step_avg:87.09ms +step:415/1680 train_time:36144ms step_avg:87.10ms +step:416/1680 train_time:36231ms step_avg:87.09ms +step:417/1680 train_time:36318ms step_avg:87.09ms +step:418/1680 train_time:36405ms step_avg:87.09ms +step:419/1680 train_time:36492ms step_avg:87.09ms +step:420/1680 train_time:36578ms step_avg:87.09ms +step:421/1680 train_time:36665ms step_avg:87.09ms +step:422/1680 train_time:36752ms step_avg:87.09ms +step:423/1680 train_time:36840ms step_avg:87.09ms +step:424/1680 train_time:36928ms step_avg:87.09ms +step:425/1680 train_time:37016ms step_avg:87.10ms +step:426/1680 train_time:37103ms step_avg:87.10ms +step:427/1680 train_time:37190ms step_avg:87.10ms +step:428/1680 train_time:37277ms step_avg:87.10ms +step:429/1680 train_time:37364ms step_avg:87.10ms +step:430/1680 train_time:37450ms step_avg:87.09ms +step:431/1680 train_time:37537ms step_avg:87.09ms +step:432/1680 train_time:37623ms step_avg:87.09ms +step:433/1680 train_time:37710ms step_avg:87.09ms +step:434/1680 train_time:37798ms step_avg:87.09ms +step:435/1680 train_time:37885ms step_avg:87.09ms +step:436/1680 train_time:37972ms step_avg:87.09ms +step:437/1680 train_time:38060ms step_avg:87.09ms +step:438/1680 train_time:38148ms step_avg:87.10ms +step:439/1680 train_time:38235ms step_avg:87.10ms +step:440/1680 train_time:38322ms step_avg:87.10ms +step:441/1680 train_time:38409ms step_avg:87.10ms +step:442/1680 train_time:38496ms step_avg:87.10ms +step:443/1680 train_time:38584ms step_avg:87.10ms +step:444/1680 train_time:38670ms step_avg:87.09ms +step:445/1680 train_time:38757ms step_avg:87.09ms +step:446/1680 train_time:38844ms step_avg:87.09ms +step:447/1680 train_time:38931ms step_avg:87.09ms +step:448/1680 train_time:39019ms step_avg:87.10ms +step:449/1680 train_time:39106ms step_avg:87.10ms +step:450/1680 train_time:39193ms step_avg:87.10ms +step:451/1680 train_time:39280ms step_avg:87.09ms +step:452/1680 train_time:39367ms step_avg:87.09ms +step:453/1680 train_time:39453ms step_avg:87.09ms +step:454/1680 train_time:39541ms step_avg:87.09ms +step:455/1680 train_time:39628ms step_avg:87.09ms +step:456/1680 train_time:39714ms step_avg:87.09ms +step:457/1680 train_time:39802ms step_avg:87.09ms +step:458/1680 train_time:39889ms step_avg:87.09ms +step:459/1680 train_time:39976ms step_avg:87.09ms +step:460/1680 train_time:40063ms step_avg:87.09ms +step:461/1680 train_time:40150ms step_avg:87.09ms +step:462/1680 train_time:40237ms step_avg:87.09ms +step:463/1680 train_time:40324ms step_avg:87.09ms +step:464/1680 train_time:40411ms step_avg:87.09ms +step:465/1680 train_time:40499ms step_avg:87.09ms +step:466/1680 train_time:40585ms step_avg:87.09ms +step:467/1680 train_time:40672ms step_avg:87.09ms +step:468/1680 train_time:40759ms step_avg:87.09ms +step:469/1680 train_time:40847ms step_avg:87.09ms +step:470/1680 train_time:40934ms step_avg:87.09ms +step:471/1680 train_time:41021ms step_avg:87.09ms +step:472/1680 train_time:41108ms step_avg:87.09ms +step:473/1680 train_time:41195ms step_avg:87.09ms +step:474/1680 train_time:41283ms step_avg:87.09ms +step:475/1680 train_time:41371ms step_avg:87.10ms +step:476/1680 train_time:41458ms step_avg:87.10ms +step:477/1680 train_time:41545ms step_avg:87.10ms +step:478/1680 train_time:41632ms step_avg:87.10ms +step:479/1680 train_time:41719ms step_avg:87.10ms +step:480/1680 train_time:41807ms step_avg:87.10ms +step:481/1680 train_time:41894ms step_avg:87.10ms +step:482/1680 train_time:41981ms step_avg:87.10ms +step:483/1680 train_time:42069ms step_avg:87.10ms +step:484/1680 train_time:42155ms step_avg:87.10ms +step:485/1680 train_time:42242ms step_avg:87.10ms +step:486/1680 train_time:42329ms step_avg:87.10ms +step:487/1680 train_time:42416ms step_avg:87.10ms +step:488/1680 train_time:42504ms step_avg:87.10ms +step:489/1680 train_time:42590ms step_avg:87.10ms +step:490/1680 train_time:42677ms step_avg:87.10ms +step:491/1680 train_time:42764ms step_avg:87.10ms +step:492/1680 train_time:42851ms step_avg:87.10ms +step:493/1680 train_time:42939ms step_avg:87.10ms +step:494/1680 train_time:43025ms step_avg:87.10ms +step:495/1680 train_time:43112ms step_avg:87.10ms +step:496/1680 train_time:43199ms step_avg:87.09ms +step:497/1680 train_time:43286ms step_avg:87.09ms +step:498/1680 train_time:43373ms step_avg:87.09ms +step:499/1680 train_time:43460ms step_avg:87.09ms +step:500/1680 train_time:43547ms step_avg:87.09ms +step:500/1680 val_loss:3.7188 train_time:43635ms step_avg:87.27ms +step:501/1680 train_time:43654ms step_avg:87.13ms +step:502/1680 train_time:43723ms step_avg:87.10ms +step:503/1680 train_time:43817ms step_avg:87.11ms +step:504/1680 train_time:43906ms step_avg:87.12ms +step:505/1680 train_time:43993ms step_avg:87.12ms +step:506/1680 train_time:44080ms step_avg:87.11ms +step:507/1680 train_time:44166ms step_avg:87.11ms +step:508/1680 train_time:44252ms step_avg:87.11ms +step:509/1680 train_time:44338ms step_avg:87.11ms +step:510/1680 train_time:44424ms step_avg:87.11ms +step:511/1680 train_time:44511ms step_avg:87.11ms +step:512/1680 train_time:44599ms step_avg:87.11ms +step:513/1680 train_time:44686ms step_avg:87.11ms +step:514/1680 train_time:44775ms step_avg:87.11ms +step:515/1680 train_time:44864ms step_avg:87.12ms +step:516/1680 train_time:44952ms step_avg:87.12ms +step:517/1680 train_time:45039ms step_avg:87.12ms +step:518/1680 train_time:45126ms step_avg:87.12ms +step:519/1680 train_time:45212ms step_avg:87.11ms +step:520/1680 train_time:45298ms step_avg:87.11ms +step:521/1680 train_time:45385ms step_avg:87.11ms +step:522/1680 train_time:45472ms step_avg:87.11ms +step:523/1680 train_time:45559ms step_avg:87.11ms +step:524/1680 train_time:45646ms step_avg:87.11ms +step:525/1680 train_time:45734ms step_avg:87.11ms +step:526/1680 train_time:45822ms step_avg:87.11ms +step:527/1680 train_time:45911ms step_avg:87.12ms +step:528/1680 train_time:45998ms step_avg:87.12ms +step:529/1680 train_time:46085ms step_avg:87.12ms +step:530/1680 train_time:46172ms step_avg:87.12ms +step:531/1680 train_time:46258ms step_avg:87.12ms +step:532/1680 train_time:46345ms step_avg:87.12ms +step:533/1680 train_time:46432ms step_avg:87.11ms +step:534/1680 train_time:46518ms step_avg:87.11ms +step:535/1680 train_time:46606ms step_avg:87.11ms +step:536/1680 train_time:46694ms step_avg:87.12ms +step:537/1680 train_time:46782ms step_avg:87.12ms +step:538/1680 train_time:46871ms step_avg:87.12ms +step:539/1680 train_time:46958ms step_avg:87.12ms +step:540/1680 train_time:47045ms step_avg:87.12ms +step:541/1680 train_time:47132ms step_avg:87.12ms +step:542/1680 train_time:47218ms step_avg:87.12ms +step:543/1680 train_time:47305ms step_avg:87.12ms +step:544/1680 train_time:47392ms step_avg:87.12ms +step:545/1680 train_time:47479ms step_avg:87.12ms +step:546/1680 train_time:47565ms step_avg:87.12ms +step:547/1680 train_time:47652ms step_avg:87.12ms +step:548/1680 train_time:47739ms step_avg:87.12ms +step:549/1680 train_time:47828ms step_avg:87.12ms +step:550/1680 train_time:47917ms step_avg:87.12ms +step:551/1680 train_time:48005ms step_avg:87.12ms +step:552/1680 train_time:48094ms step_avg:87.13ms +step:553/1680 train_time:48183ms step_avg:87.13ms +step:554/1680 train_time:48272ms step_avg:87.13ms +step:555/1680 train_time:48359ms step_avg:87.13ms +step:556/1680 train_time:48446ms step_avg:87.13ms +step:557/1680 train_time:48534ms step_avg:87.13ms +step:558/1680 train_time:48623ms step_avg:87.14ms +step:559/1680 train_time:48711ms step_avg:87.14ms +step:560/1680 train_time:48799ms step_avg:87.14ms +step:561/1680 train_time:48888ms step_avg:87.14ms +step:562/1680 train_time:48976ms step_avg:87.15ms +step:563/1680 train_time:49064ms step_avg:87.15ms +step:564/1680 train_time:49152ms step_avg:87.15ms +step:565/1680 train_time:49240ms step_avg:87.15ms +step:566/1680 train_time:49328ms step_avg:87.15ms +step:567/1680 train_time:49416ms step_avg:87.15ms +step:568/1680 train_time:49504ms step_avg:87.16ms +step:569/1680 train_time:49592ms step_avg:87.16ms +step:570/1680 train_time:49680ms step_avg:87.16ms +step:571/1680 train_time:49768ms step_avg:87.16ms +step:572/1680 train_time:49856ms step_avg:87.16ms +step:573/1680 train_time:49944ms step_avg:87.16ms +step:574/1680 train_time:50032ms step_avg:87.16ms +step:575/1680 train_time:50120ms step_avg:87.17ms +step:576/1680 train_time:50209ms step_avg:87.17ms +step:577/1680 train_time:50297ms step_avg:87.17ms +step:578/1680 train_time:50385ms step_avg:87.17ms +step:579/1680 train_time:50474ms step_avg:87.17ms +step:580/1680 train_time:50563ms step_avg:87.18ms +step:581/1680 train_time:50651ms step_avg:87.18ms +step:582/1680 train_time:50739ms step_avg:87.18ms +step:583/1680 train_time:50828ms step_avg:87.18ms +step:584/1680 train_time:50916ms step_avg:87.18ms +step:585/1680 train_time:51004ms step_avg:87.19ms +step:586/1680 train_time:51093ms step_avg:87.19ms +step:587/1680 train_time:51181ms step_avg:87.19ms +step:588/1680 train_time:51269ms step_avg:87.19ms +step:589/1680 train_time:51357ms step_avg:87.19ms +step:590/1680 train_time:51446ms step_avg:87.20ms +step:591/1680 train_time:51534ms step_avg:87.20ms +step:592/1680 train_time:51624ms step_avg:87.20ms +step:593/1680 train_time:51711ms step_avg:87.20ms +step:594/1680 train_time:51799ms step_avg:87.20ms +step:595/1680 train_time:51888ms step_avg:87.21ms +step:596/1680 train_time:51976ms step_avg:87.21ms +step:597/1680 train_time:52064ms step_avg:87.21ms +step:598/1680 train_time:52153ms step_avg:87.21ms +step:599/1680 train_time:52241ms step_avg:87.21ms +step:600/1680 train_time:52329ms step_avg:87.22ms +step:601/1680 train_time:52417ms step_avg:87.22ms +step:602/1680 train_time:52505ms step_avg:87.22ms +step:603/1680 train_time:52594ms step_avg:87.22ms +step:604/1680 train_time:52682ms step_avg:87.22ms +step:605/1680 train_time:52771ms step_avg:87.22ms +step:606/1680 train_time:52859ms step_avg:87.23ms +step:607/1680 train_time:52947ms step_avg:87.23ms +step:608/1680 train_time:53035ms step_avg:87.23ms +step:609/1680 train_time:53123ms step_avg:87.23ms +step:610/1680 train_time:53211ms step_avg:87.23ms +step:611/1680 train_time:53299ms step_avg:87.23ms +step:612/1680 train_time:53387ms step_avg:87.23ms +step:613/1680 train_time:53475ms step_avg:87.24ms +step:614/1680 train_time:53564ms step_avg:87.24ms +step:615/1680 train_time:53652ms step_avg:87.24ms +step:616/1680 train_time:53741ms step_avg:87.24ms +step:617/1680 train_time:53829ms step_avg:87.24ms +step:618/1680 train_time:53917ms step_avg:87.24ms +step:619/1680 train_time:54005ms step_avg:87.25ms +step:620/1680 train_time:54093ms step_avg:87.25ms +step:621/1680 train_time:54182ms step_avg:87.25ms +step:622/1680 train_time:54270ms step_avg:87.25ms +step:623/1680 train_time:54358ms step_avg:87.25ms +step:624/1680 train_time:54446ms step_avg:87.25ms +step:625/1680 train_time:54534ms step_avg:87.25ms +step:625/1680 val_loss:3.6184 train_time:54624ms step_avg:87.40ms +step:626/1680 train_time:54643ms step_avg:87.29ms +step:627/1680 train_time:54712ms step_avg:87.26ms +step:628/1680 train_time:54801ms step_avg:87.26ms +step:629/1680 train_time:54892ms step_avg:87.27ms +step:630/1680 train_time:54983ms step_avg:87.28ms +step:631/1680 train_time:55070ms step_avg:87.27ms +step:632/1680 train_time:55157ms step_avg:87.27ms +step:633/1680 train_time:55244ms step_avg:87.27ms +step:634/1680 train_time:55331ms step_avg:87.27ms +step:635/1680 train_time:55418ms step_avg:87.27ms +step:636/1680 train_time:55506ms step_avg:87.27ms +step:637/1680 train_time:55597ms step_avg:87.28ms +step:638/1680 train_time:55687ms step_avg:87.28ms +step:639/1680 train_time:55776ms step_avg:87.29ms +step:640/1680 train_time:55865ms step_avg:87.29ms +step:641/1680 train_time:55953ms step_avg:87.29ms +step:642/1680 train_time:56042ms step_avg:87.29ms +step:643/1680 train_time:56131ms step_avg:87.29ms +step:644/1680 train_time:56218ms step_avg:87.29ms +step:645/1680 train_time:56305ms step_avg:87.30ms +step:646/1680 train_time:56393ms step_avg:87.30ms +step:647/1680 train_time:56480ms step_avg:87.30ms +step:648/1680 train_time:56569ms step_avg:87.30ms +step:649/1680 train_time:56658ms step_avg:87.30ms +step:650/1680 train_time:56747ms step_avg:87.30ms +step:651/1680 train_time:56836ms step_avg:87.30ms +step:652/1680 train_time:56924ms step_avg:87.31ms +step:653/1680 train_time:57013ms step_avg:87.31ms +step:654/1680 train_time:57100ms step_avg:87.31ms +step:655/1680 train_time:57188ms step_avg:87.31ms +step:656/1680 train_time:57275ms step_avg:87.31ms +step:657/1680 train_time:57363ms step_avg:87.31ms +step:658/1680 train_time:57450ms step_avg:87.31ms +step:659/1680 train_time:57538ms step_avg:87.31ms +step:660/1680 train_time:57626ms step_avg:87.31ms +step:661/1680 train_time:57715ms step_avg:87.31ms +step:662/1680 train_time:57803ms step_avg:87.32ms +step:663/1680 train_time:57892ms step_avg:87.32ms +step:664/1680 train_time:57981ms step_avg:87.32ms +step:665/1680 train_time:58069ms step_avg:87.32ms +step:666/1680 train_time:58157ms step_avg:87.32ms +step:667/1680 train_time:58244ms step_avg:87.32ms +step:668/1680 train_time:58331ms step_avg:87.32ms +step:669/1680 train_time:58418ms step_avg:87.32ms +step:670/1680 train_time:58507ms step_avg:87.32ms +step:671/1680 train_time:58595ms step_avg:87.32ms +step:672/1680 train_time:58683ms step_avg:87.33ms +step:673/1680 train_time:58771ms step_avg:87.33ms +step:674/1680 train_time:58860ms step_avg:87.33ms +step:675/1680 train_time:58948ms step_avg:87.33ms +step:676/1680 train_time:59035ms step_avg:87.33ms +step:677/1680 train_time:59124ms step_avg:87.33ms +step:678/1680 train_time:59212ms step_avg:87.33ms +step:679/1680 train_time:59300ms step_avg:87.33ms +step:680/1680 train_time:59388ms step_avg:87.33ms +step:681/1680 train_time:59475ms step_avg:87.34ms +step:682/1680 train_time:59563ms step_avg:87.34ms +step:683/1680 train_time:59652ms step_avg:87.34ms +step:684/1680 train_time:59740ms step_avg:87.34ms +step:685/1680 train_time:59829ms step_avg:87.34ms +step:686/1680 train_time:59917ms step_avg:87.34ms +step:687/1680 train_time:60005ms step_avg:87.34ms +step:688/1680 train_time:60093ms step_avg:87.34ms +step:689/1680 train_time:60181ms step_avg:87.35ms +step:690/1680 train_time:60270ms step_avg:87.35ms +step:691/1680 train_time:60358ms step_avg:87.35ms +step:692/1680 train_time:60445ms step_avg:87.35ms +step:693/1680 train_time:60533ms step_avg:87.35ms +step:694/1680 train_time:60621ms step_avg:87.35ms +step:695/1680 train_time:60710ms step_avg:87.35ms +step:696/1680 train_time:60798ms step_avg:87.35ms +step:697/1680 train_time:60885ms step_avg:87.35ms +step:698/1680 train_time:60974ms step_avg:87.35ms +step:699/1680 train_time:61061ms step_avg:87.36ms +step:700/1680 train_time:61150ms step_avg:87.36ms +step:701/1680 train_time:61238ms step_avg:87.36ms +step:702/1680 train_time:61327ms step_avg:87.36ms +step:703/1680 train_time:61415ms step_avg:87.36ms +step:704/1680 train_time:61502ms step_avg:87.36ms +step:705/1680 train_time:61591ms step_avg:87.36ms +step:706/1680 train_time:61679ms step_avg:87.36ms +step:707/1680 train_time:61768ms step_avg:87.37ms +step:708/1680 train_time:61855ms step_avg:87.37ms +step:709/1680 train_time:61944ms step_avg:87.37ms +step:710/1680 train_time:62032ms step_avg:87.37ms +step:711/1680 train_time:62120ms step_avg:87.37ms +step:712/1680 train_time:62207ms step_avg:87.37ms +step:713/1680 train_time:62295ms step_avg:87.37ms +step:714/1680 train_time:62384ms step_avg:87.37ms +step:715/1680 train_time:62472ms step_avg:87.37ms +step:716/1680 train_time:62561ms step_avg:87.38ms +step:717/1680 train_time:62649ms step_avg:87.38ms +step:718/1680 train_time:62737ms step_avg:87.38ms +step:719/1680 train_time:62825ms step_avg:87.38ms +step:720/1680 train_time:62913ms step_avg:87.38ms +step:721/1680 train_time:63001ms step_avg:87.38ms +step:722/1680 train_time:63089ms step_avg:87.38ms +step:723/1680 train_time:63177ms step_avg:87.38ms +step:724/1680 train_time:63265ms step_avg:87.38ms +step:725/1680 train_time:63353ms step_avg:87.38ms +step:726/1680 train_time:63442ms step_avg:87.39ms +step:727/1680 train_time:63530ms step_avg:87.39ms +step:728/1680 train_time:63618ms step_avg:87.39ms +step:729/1680 train_time:63706ms step_avg:87.39ms +step:730/1680 train_time:63794ms step_avg:87.39ms +step:731/1680 train_time:63882ms step_avg:87.39ms +step:732/1680 train_time:63970ms step_avg:87.39ms +step:733/1680 train_time:64059ms step_avg:87.39ms +step:734/1680 train_time:64147ms step_avg:87.39ms +step:735/1680 train_time:64235ms step_avg:87.39ms +step:736/1680 train_time:64323ms step_avg:87.40ms +step:737/1680 train_time:64411ms step_avg:87.40ms +step:738/1680 train_time:64499ms step_avg:87.40ms +step:739/1680 train_time:64587ms step_avg:87.40ms +step:740/1680 train_time:64675ms step_avg:87.40ms +step:741/1680 train_time:64763ms step_avg:87.40ms +step:742/1680 train_time:64852ms step_avg:87.40ms +step:743/1680 train_time:64941ms step_avg:87.40ms +step:744/1680 train_time:65029ms step_avg:87.40ms +step:745/1680 train_time:65117ms step_avg:87.41ms +step:746/1680 train_time:65205ms step_avg:87.41ms +step:747/1680 train_time:65293ms step_avg:87.41ms +step:748/1680 train_time:65381ms step_avg:87.41ms +step:749/1680 train_time:65469ms step_avg:87.41ms +step:750/1680 train_time:65558ms step_avg:87.41ms +step:750/1680 val_loss:3.5686 train_time:65647ms step_avg:87.53ms +step:751/1680 train_time:65665ms step_avg:87.44ms +step:752/1680 train_time:65737ms step_avg:87.42ms +step:753/1680 train_time:65829ms step_avg:87.42ms +step:754/1680 train_time:65918ms step_avg:87.42ms +step:755/1680 train_time:66006ms step_avg:87.42ms +step:756/1680 train_time:66093ms step_avg:87.43ms +step:757/1680 train_time:66181ms step_avg:87.42ms +step:758/1680 train_time:66267ms step_avg:87.42ms +step:759/1680 train_time:66354ms step_avg:87.42ms +step:760/1680 train_time:66442ms step_avg:87.42ms +step:761/1680 train_time:66530ms step_avg:87.42ms +step:762/1680 train_time:66618ms step_avg:87.42ms +step:763/1680 train_time:66708ms step_avg:87.43ms +step:764/1680 train_time:66798ms step_avg:87.43ms +step:765/1680 train_time:66887ms step_avg:87.43ms +step:766/1680 train_time:66975ms step_avg:87.43ms +step:767/1680 train_time:67063ms step_avg:87.44ms +step:768/1680 train_time:67151ms step_avg:87.44ms +step:769/1680 train_time:67239ms step_avg:87.44ms +step:770/1680 train_time:67327ms step_avg:87.44ms +step:771/1680 train_time:67414ms step_avg:87.44ms +step:772/1680 train_time:67501ms step_avg:87.44ms +step:773/1680 train_time:67589ms step_avg:87.44ms +step:774/1680 train_time:67677ms step_avg:87.44ms +step:775/1680 train_time:67766ms step_avg:87.44ms +step:776/1680 train_time:67856ms step_avg:87.44ms +step:777/1680 train_time:67944ms step_avg:87.44ms +step:778/1680 train_time:68033ms step_avg:87.45ms +step:779/1680 train_time:68121ms step_avg:87.45ms +step:780/1680 train_time:68209ms step_avg:87.45ms +step:781/1680 train_time:68296ms step_avg:87.45ms +step:782/1680 train_time:68385ms step_avg:87.45ms +step:783/1680 train_time:68472ms step_avg:87.45ms +step:784/1680 train_time:68560ms step_avg:87.45ms +step:785/1680 train_time:68648ms step_avg:87.45ms +step:786/1680 train_time:68736ms step_avg:87.45ms +step:787/1680 train_time:68827ms step_avg:87.45ms +step:788/1680 train_time:68917ms step_avg:87.46ms +step:789/1680 train_time:69005ms step_avg:87.46ms +step:790/1680 train_time:69094ms step_avg:87.46ms +step:791/1680 train_time:69182ms step_avg:87.46ms +step:792/1680 train_time:69270ms step_avg:87.46ms +step:793/1680 train_time:69357ms step_avg:87.46ms +step:794/1680 train_time:69445ms step_avg:87.46ms +step:795/1680 train_time:69533ms step_avg:87.46ms +step:796/1680 train_time:69621ms step_avg:87.46ms +step:797/1680 train_time:69708ms step_avg:87.46ms +step:798/1680 train_time:69797ms step_avg:87.46ms +step:799/1680 train_time:69886ms step_avg:87.47ms +step:800/1680 train_time:69974ms step_avg:87.47ms +step:801/1680 train_time:70062ms step_avg:87.47ms +step:802/1680 train_time:70150ms step_avg:87.47ms +step:803/1680 train_time:70238ms step_avg:87.47ms +step:804/1680 train_time:70326ms step_avg:87.47ms +step:805/1680 train_time:70414ms step_avg:87.47ms +step:806/1680 train_time:70501ms step_avg:87.47ms +step:807/1680 train_time:70589ms step_avg:87.47ms +step:808/1680 train_time:70677ms step_avg:87.47ms +step:809/1680 train_time:70765ms step_avg:87.47ms +step:810/1680 train_time:70853ms step_avg:87.47ms +step:811/1680 train_time:70942ms step_avg:87.47ms +step:812/1680 train_time:71030ms step_avg:87.48ms +step:813/1680 train_time:71118ms step_avg:87.48ms +step:814/1680 train_time:71206ms step_avg:87.48ms +step:815/1680 train_time:71294ms step_avg:87.48ms +step:816/1680 train_time:71382ms step_avg:87.48ms +step:817/1680 train_time:71470ms step_avg:87.48ms +step:818/1680 train_time:71559ms step_avg:87.48ms +step:819/1680 train_time:71646ms step_avg:87.48ms +step:820/1680 train_time:71735ms step_avg:87.48ms +step:821/1680 train_time:71824ms step_avg:87.48ms +step:822/1680 train_time:71911ms step_avg:87.48ms +step:823/1680 train_time:72000ms step_avg:87.48ms +step:824/1680 train_time:72088ms step_avg:87.49ms +step:825/1680 train_time:72177ms step_avg:87.49ms +step:826/1680 train_time:72264ms step_avg:87.49ms +step:827/1680 train_time:72352ms step_avg:87.49ms +step:828/1680 train_time:72439ms step_avg:87.49ms +step:829/1680 train_time:72528ms step_avg:87.49ms +step:830/1680 train_time:72616ms step_avg:87.49ms +step:831/1680 train_time:72703ms step_avg:87.49ms +step:832/1680 train_time:72791ms step_avg:87.49ms +step:833/1680 train_time:72880ms step_avg:87.49ms +step:834/1680 train_time:72968ms step_avg:87.49ms +step:835/1680 train_time:73056ms step_avg:87.49ms +step:836/1680 train_time:73144ms step_avg:87.49ms +step:837/1680 train_time:73232ms step_avg:87.49ms +step:838/1680 train_time:73320ms step_avg:87.49ms +step:839/1680 train_time:73408ms step_avg:87.50ms +step:840/1680 train_time:73496ms step_avg:87.50ms +step:841/1680 train_time:73585ms step_avg:87.50ms +step:842/1680 train_time:73673ms step_avg:87.50ms +step:843/1680 train_time:73761ms step_avg:87.50ms +step:844/1680 train_time:73849ms step_avg:87.50ms +step:845/1680 train_time:73937ms step_avg:87.50ms +step:846/1680 train_time:74026ms step_avg:87.50ms +step:847/1680 train_time:74114ms step_avg:87.50ms +step:848/1680 train_time:74202ms step_avg:87.50ms +step:849/1680 train_time:74290ms step_avg:87.50ms +step:850/1680 train_time:74378ms step_avg:87.50ms +step:851/1680 train_time:74466ms step_avg:87.50ms +step:852/1680 train_time:74554ms step_avg:87.51ms +step:853/1680 train_time:74643ms step_avg:87.51ms +step:854/1680 train_time:74730ms step_avg:87.51ms +step:855/1680 train_time:74818ms step_avg:87.51ms +step:856/1680 train_time:74907ms step_avg:87.51ms +step:857/1680 train_time:74996ms step_avg:87.51ms +step:858/1680 train_time:75085ms step_avg:87.51ms +step:859/1680 train_time:75173ms step_avg:87.51ms +step:860/1680 train_time:75261ms step_avg:87.51ms +step:861/1680 train_time:75349ms step_avg:87.51ms +step:862/1680 train_time:75438ms step_avg:87.51ms +step:863/1680 train_time:75526ms step_avg:87.52ms +step:864/1680 train_time:75614ms step_avg:87.52ms +step:865/1680 train_time:75701ms step_avg:87.52ms +step:866/1680 train_time:75790ms step_avg:87.52ms +step:867/1680 train_time:75878ms step_avg:87.52ms +step:868/1680 train_time:75966ms step_avg:87.52ms +step:869/1680 train_time:76055ms step_avg:87.52ms +step:870/1680 train_time:76143ms step_avg:87.52ms +step:871/1680 train_time:76231ms step_avg:87.52ms +step:872/1680 train_time:76318ms step_avg:87.52ms +step:873/1680 train_time:76407ms step_avg:87.52ms +step:874/1680 train_time:76495ms step_avg:87.52ms +step:875/1680 train_time:76583ms step_avg:87.52ms +step:875/1680 val_loss:3.5223 train_time:76672ms step_avg:87.63ms +step:876/1680 train_time:76691ms step_avg:87.55ms +step:877/1680 train_time:76763ms step_avg:87.53ms +step:878/1680 train_time:76855ms step_avg:87.53ms +step:879/1680 train_time:76943ms step_avg:87.54ms +step:880/1680 train_time:77031ms step_avg:87.53ms +step:881/1680 train_time:77117ms step_avg:87.53ms +step:882/1680 train_time:77204ms step_avg:87.53ms +step:883/1680 train_time:77291ms step_avg:87.53ms +step:884/1680 train_time:77378ms step_avg:87.53ms +step:885/1680 train_time:77466ms step_avg:87.53ms +step:886/1680 train_time:77554ms step_avg:87.53ms +step:887/1680 train_time:77643ms step_avg:87.53ms +step:888/1680 train_time:77733ms step_avg:87.54ms +step:889/1680 train_time:77823ms step_avg:87.54ms +step:890/1680 train_time:77911ms step_avg:87.54ms +step:891/1680 train_time:77999ms step_avg:87.54ms +step:892/1680 train_time:78087ms step_avg:87.54ms +step:893/1680 train_time:78174ms step_avg:87.54ms +step:894/1680 train_time:78261ms step_avg:87.54ms +step:895/1680 train_time:78349ms step_avg:87.54ms +step:896/1680 train_time:78437ms step_avg:87.54ms +step:897/1680 train_time:78525ms step_avg:87.54ms +step:898/1680 train_time:78614ms step_avg:87.54ms +step:899/1680 train_time:78703ms step_avg:87.54ms +step:900/1680 train_time:78791ms step_avg:87.55ms +step:901/1680 train_time:78880ms step_avg:87.55ms +step:902/1680 train_time:78969ms step_avg:87.55ms +step:903/1680 train_time:79057ms step_avg:87.55ms +step:904/1680 train_time:79145ms step_avg:87.55ms +step:905/1680 train_time:79232ms step_avg:87.55ms +step:906/1680 train_time:79319ms step_avg:87.55ms +step:907/1680 train_time:79406ms step_avg:87.55ms +step:908/1680 train_time:79495ms step_avg:87.55ms +step:909/1680 train_time:79582ms step_avg:87.55ms +step:910/1680 train_time:79672ms step_avg:87.55ms +step:911/1680 train_time:79760ms step_avg:87.55ms +step:912/1680 train_time:79849ms step_avg:87.55ms +step:913/1680 train_time:79937ms step_avg:87.55ms +step:914/1680 train_time:80025ms step_avg:87.56ms +step:915/1680 train_time:80113ms step_avg:87.56ms +step:916/1680 train_time:80201ms step_avg:87.56ms +step:917/1680 train_time:80289ms step_avg:87.56ms +step:918/1680 train_time:80376ms step_avg:87.56ms +step:919/1680 train_time:80464ms step_avg:87.56ms +step:920/1680 train_time:80553ms step_avg:87.56ms +step:921/1680 train_time:80641ms step_avg:87.56ms +step:922/1680 train_time:80730ms step_avg:87.56ms +step:923/1680 train_time:80819ms step_avg:87.56ms +step:924/1680 train_time:80907ms step_avg:87.56ms +step:925/1680 train_time:80996ms step_avg:87.56ms +step:926/1680 train_time:81084ms step_avg:87.56ms +step:927/1680 train_time:81172ms step_avg:87.56ms +step:928/1680 train_time:81260ms step_avg:87.56ms +step:929/1680 train_time:81348ms step_avg:87.56ms +step:930/1680 train_time:81436ms step_avg:87.57ms +step:931/1680 train_time:81524ms step_avg:87.57ms +step:932/1680 train_time:81612ms step_avg:87.57ms +step:933/1680 train_time:81701ms step_avg:87.57ms +step:934/1680 train_time:81789ms step_avg:87.57ms +step:935/1680 train_time:81877ms step_avg:87.57ms +step:936/1680 train_time:81966ms step_avg:87.57ms +step:937/1680 train_time:82054ms step_avg:87.57ms +step:938/1680 train_time:82142ms step_avg:87.57ms +step:939/1680 train_time:82230ms step_avg:87.57ms +step:940/1680 train_time:82318ms step_avg:87.57ms +step:941/1680 train_time:82406ms step_avg:87.57ms +step:942/1680 train_time:82494ms step_avg:87.57ms +step:943/1680 train_time:82582ms step_avg:87.57ms +step:944/1680 train_time:82670ms step_avg:87.57ms +step:945/1680 train_time:82758ms step_avg:87.57ms +step:946/1680 train_time:82846ms step_avg:87.57ms +step:947/1680 train_time:82935ms step_avg:87.58ms +step:948/1680 train_time:83023ms step_avg:87.58ms +step:949/1680 train_time:83111ms step_avg:87.58ms +step:950/1680 train_time:83198ms step_avg:87.58ms +step:951/1680 train_time:83287ms step_avg:87.58ms +step:952/1680 train_time:83375ms step_avg:87.58ms +step:953/1680 train_time:83462ms step_avg:87.58ms +step:954/1680 train_time:83550ms step_avg:87.58ms +step:955/1680 train_time:83638ms step_avg:87.58ms +step:956/1680 train_time:83726ms step_avg:87.58ms +step:957/1680 train_time:83815ms step_avg:87.58ms +step:958/1680 train_time:83904ms step_avg:87.58ms +step:959/1680 train_time:83991ms step_avg:87.58ms +step:960/1680 train_time:84079ms step_avg:87.58ms +step:961/1680 train_time:84167ms step_avg:87.58ms +step:962/1680 train_time:84255ms step_avg:87.58ms +step:963/1680 train_time:84343ms step_avg:87.58ms +step:964/1680 train_time:84431ms step_avg:87.58ms +step:965/1680 train_time:84519ms step_avg:87.58ms +step:966/1680 train_time:84607ms step_avg:87.59ms +step:967/1680 train_time:84696ms step_avg:87.59ms +step:968/1680 train_time:84784ms step_avg:87.59ms +step:969/1680 train_time:84873ms step_avg:87.59ms +step:970/1680 train_time:84961ms step_avg:87.59ms +step:971/1680 train_time:85049ms step_avg:87.59ms +step:972/1680 train_time:85137ms step_avg:87.59ms +step:973/1680 train_time:85225ms step_avg:87.59ms +step:974/1680 train_time:85313ms step_avg:87.59ms +step:975/1680 train_time:85401ms step_avg:87.59ms +step:976/1680 train_time:85489ms step_avg:87.59ms +step:977/1680 train_time:85577ms step_avg:87.59ms +step:978/1680 train_time:85665ms step_avg:87.59ms +step:979/1680 train_time:85754ms step_avg:87.59ms +step:980/1680 train_time:85842ms step_avg:87.59ms +step:981/1680 train_time:85931ms step_avg:87.60ms +step:982/1680 train_time:86019ms step_avg:87.60ms +step:983/1680 train_time:86108ms step_avg:87.60ms +step:984/1680 train_time:86196ms step_avg:87.60ms +step:985/1680 train_time:86285ms step_avg:87.60ms +step:986/1680 train_time:86373ms step_avg:87.60ms +step:987/1680 train_time:86461ms step_avg:87.60ms +step:988/1680 train_time:86548ms step_avg:87.60ms +step:989/1680 train_time:86636ms step_avg:87.60ms +step:990/1680 train_time:86724ms step_avg:87.60ms +step:991/1680 train_time:86812ms step_avg:87.60ms +step:992/1680 train_time:86901ms step_avg:87.60ms +step:993/1680 train_time:86989ms step_avg:87.60ms +step:994/1680 train_time:87077ms step_avg:87.60ms +step:995/1680 train_time:87166ms step_avg:87.60ms +step:996/1680 train_time:87254ms step_avg:87.60ms +step:997/1680 train_time:87342ms step_avg:87.60ms +step:998/1680 train_time:87430ms step_avg:87.61ms +step:999/1680 train_time:87518ms step_avg:87.61ms +step:1000/1680 train_time:87605ms step_avg:87.61ms +step:1000/1680 val_loss:3.4714 train_time:87695ms step_avg:87.69ms +step:1001/1680 train_time:87713ms step_avg:87.63ms +step:1002/1680 train_time:87784ms step_avg:87.61ms +step:1003/1680 train_time:87878ms step_avg:87.61ms +step:1004/1680 train_time:87967ms step_avg:87.62ms +step:1005/1680 train_time:88054ms step_avg:87.62ms +step:1006/1680 train_time:88142ms step_avg:87.62ms +step:1007/1680 train_time:88229ms step_avg:87.62ms +step:1008/1680 train_time:88317ms step_avg:87.62ms +step:1009/1680 train_time:88405ms step_avg:87.62ms +step:1010/1680 train_time:88492ms step_avg:87.62ms +step:1011/1680 train_time:88579ms step_avg:87.62ms +step:1012/1680 train_time:88668ms step_avg:87.62ms +step:1013/1680 train_time:88757ms step_avg:87.62ms +step:1014/1680 train_time:88848ms step_avg:87.62ms +step:1015/1680 train_time:88938ms step_avg:87.62ms +step:1016/1680 train_time:89026ms step_avg:87.62ms +step:1017/1680 train_time:89115ms step_avg:87.63ms +step:1018/1680 train_time:89202ms step_avg:87.62ms +step:1019/1680 train_time:89290ms step_avg:87.62ms +step:1020/1680 train_time:89377ms step_avg:87.62ms +step:1021/1680 train_time:89465ms step_avg:87.63ms +step:1022/1680 train_time:89553ms step_avg:87.62ms +step:1023/1680 train_time:89641ms step_avg:87.63ms +step:1024/1680 train_time:89730ms step_avg:87.63ms +step:1025/1680 train_time:89819ms step_avg:87.63ms +step:1026/1680 train_time:89909ms step_avg:87.63ms +step:1027/1680 train_time:89998ms step_avg:87.63ms +step:1028/1680 train_time:90087ms step_avg:87.63ms +step:1029/1680 train_time:90175ms step_avg:87.63ms +step:1030/1680 train_time:90262ms step_avg:87.63ms +step:1031/1680 train_time:90349ms step_avg:87.63ms +step:1032/1680 train_time:90438ms step_avg:87.63ms +step:1033/1680 train_time:90525ms step_avg:87.63ms +step:1034/1680 train_time:90613ms step_avg:87.63ms +step:1035/1680 train_time:90701ms step_avg:87.63ms +step:1036/1680 train_time:90790ms step_avg:87.64ms +step:1037/1680 train_time:90879ms step_avg:87.64ms +step:1038/1680 train_time:90967ms step_avg:87.64ms +step:1039/1680 train_time:91056ms step_avg:87.64ms +step:1040/1680 train_time:91144ms step_avg:87.64ms +step:1041/1680 train_time:91232ms step_avg:87.64ms +step:1042/1680 train_time:91320ms step_avg:87.64ms +step:1043/1680 train_time:91408ms step_avg:87.64ms +step:1044/1680 train_time:91496ms step_avg:87.64ms +step:1045/1680 train_time:91584ms step_avg:87.64ms +step:1046/1680 train_time:91672ms step_avg:87.64ms +step:1047/1680 train_time:91761ms step_avg:87.64ms +step:1048/1680 train_time:91849ms step_avg:87.64ms +step:1049/1680 train_time:91937ms step_avg:87.64ms +step:1050/1680 train_time:92026ms step_avg:87.64ms +step:1051/1680 train_time:92114ms step_avg:87.64ms +step:1052/1680 train_time:92202ms step_avg:87.64ms +step:1053/1680 train_time:92290ms step_avg:87.64ms +step:1054/1680 train_time:92378ms step_avg:87.65ms +step:1055/1680 train_time:92466ms step_avg:87.65ms +step:1056/1680 train_time:92554ms step_avg:87.65ms +step:1057/1680 train_time:92643ms step_avg:87.65ms +step:1058/1680 train_time:92731ms step_avg:87.65ms +step:1059/1680 train_time:92819ms step_avg:87.65ms +step:1060/1680 train_time:92908ms step_avg:87.65ms +step:1061/1680 train_time:92996ms step_avg:87.65ms +step:1062/1680 train_time:93085ms step_avg:87.65ms +step:1063/1680 train_time:93173ms step_avg:87.65ms +step:1064/1680 train_time:93261ms step_avg:87.65ms +step:1065/1680 train_time:93349ms step_avg:87.65ms +step:1066/1680 train_time:93438ms step_avg:87.65ms +step:1067/1680 train_time:93526ms step_avg:87.65ms +step:1068/1680 train_time:93613ms step_avg:87.65ms +step:1069/1680 train_time:93702ms step_avg:87.65ms +step:1070/1680 train_time:93790ms step_avg:87.65ms +step:1071/1680 train_time:93878ms step_avg:87.65ms +step:1072/1680 train_time:93966ms step_avg:87.66ms +step:1073/1680 train_time:94056ms step_avg:87.66ms +step:1074/1680 train_time:94143ms step_avg:87.66ms +step:1075/1680 train_time:94231ms step_avg:87.66ms +step:1076/1680 train_time:94320ms step_avg:87.66ms +step:1077/1680 train_time:94409ms step_avg:87.66ms +step:1078/1680 train_time:94497ms step_avg:87.66ms +step:1079/1680 train_time:94586ms step_avg:87.66ms +step:1080/1680 train_time:94674ms step_avg:87.66ms +step:1081/1680 train_time:94762ms step_avg:87.66ms +step:1082/1680 train_time:94851ms step_avg:87.66ms +step:1083/1680 train_time:94939ms step_avg:87.66ms +step:1084/1680 train_time:95028ms step_avg:87.66ms +step:1085/1680 train_time:95116ms step_avg:87.66ms +step:1086/1680 train_time:95204ms step_avg:87.67ms +step:1087/1680 train_time:95293ms step_avg:87.67ms +step:1088/1680 train_time:95381ms step_avg:87.67ms +step:1089/1680 train_time:95470ms step_avg:87.67ms +step:1090/1680 train_time:95558ms step_avg:87.67ms +step:1091/1680 train_time:95646ms step_avg:87.67ms +step:1092/1680 train_time:95734ms step_avg:87.67ms +step:1093/1680 train_time:95822ms step_avg:87.67ms +step:1094/1680 train_time:95910ms step_avg:87.67ms +step:1095/1680 train_time:95999ms step_avg:87.67ms +step:1096/1680 train_time:96088ms step_avg:87.67ms +step:1097/1680 train_time:96176ms step_avg:87.67ms +step:1098/1680 train_time:96265ms step_avg:87.67ms +step:1099/1680 train_time:96355ms step_avg:87.67ms +step:1100/1680 train_time:96443ms step_avg:87.68ms +step:1101/1680 train_time:96533ms step_avg:87.68ms +step:1102/1680 train_time:96622ms step_avg:87.68ms +step:1103/1680 train_time:96710ms step_avg:87.68ms +step:1104/1680 train_time:96799ms step_avg:87.68ms +step:1105/1680 train_time:96889ms step_avg:87.68ms +step:1106/1680 train_time:96977ms step_avg:87.68ms +step:1107/1680 train_time:97066ms step_avg:87.68ms +step:1108/1680 train_time:97154ms step_avg:87.68ms +step:1109/1680 train_time:97244ms step_avg:87.69ms +step:1110/1680 train_time:97332ms step_avg:87.69ms +step:1111/1680 train_time:97422ms step_avg:87.69ms +step:1112/1680 train_time:97511ms step_avg:87.69ms +step:1113/1680 train_time:97600ms step_avg:87.69ms +step:1114/1680 train_time:97690ms step_avg:87.69ms +step:1115/1680 train_time:97778ms step_avg:87.69ms +step:1116/1680 train_time:97867ms step_avg:87.69ms +step:1117/1680 train_time:97956ms step_avg:87.70ms +step:1118/1680 train_time:98044ms step_avg:87.70ms +step:1119/1680 train_time:98133ms step_avg:87.70ms +step:1120/1680 train_time:98222ms step_avg:87.70ms +step:1121/1680 train_time:98311ms step_avg:87.70ms +step:1122/1680 train_time:98401ms step_avg:87.70ms +step:1123/1680 train_time:98490ms step_avg:87.70ms +step:1124/1680 train_time:98579ms step_avg:87.70ms +step:1125/1680 train_time:98668ms step_avg:87.70ms +step:1125/1680 val_loss:3.4178 train_time:98758ms step_avg:87.78ms +step:1126/1680 train_time:98777ms step_avg:87.72ms +step:1127/1680 train_time:98849ms step_avg:87.71ms +step:1128/1680 train_time:98939ms step_avg:87.71ms +step:1129/1680 train_time:99031ms step_avg:87.72ms +step:1130/1680 train_time:99119ms step_avg:87.72ms +step:1131/1680 train_time:99207ms step_avg:87.72ms +step:1132/1680 train_time:99294ms step_avg:87.72ms +step:1133/1680 train_time:99382ms step_avg:87.72ms +step:1134/1680 train_time:99469ms step_avg:87.72ms +step:1135/1680 train_time:99557ms step_avg:87.72ms +step:1136/1680 train_time:99647ms step_avg:87.72ms +step:1137/1680 train_time:99739ms step_avg:87.72ms +step:1138/1680 train_time:99829ms step_avg:87.72ms +step:1139/1680 train_time:99920ms step_avg:87.73ms +step:1140/1680 train_time:100010ms step_avg:87.73ms +step:1141/1680 train_time:100099ms step_avg:87.73ms +step:1142/1680 train_time:100187ms step_avg:87.73ms +step:1143/1680 train_time:100276ms step_avg:87.73ms +step:1144/1680 train_time:100364ms step_avg:87.73ms +step:1145/1680 train_time:100453ms step_avg:87.73ms +step:1146/1680 train_time:100541ms step_avg:87.73ms +step:1147/1680 train_time:100630ms step_avg:87.73ms +step:1148/1680 train_time:100719ms step_avg:87.73ms +step:1149/1680 train_time:100809ms step_avg:87.74ms +step:1150/1680 train_time:100898ms step_avg:87.74ms +step:1151/1680 train_time:100988ms step_avg:87.74ms +step:1152/1680 train_time:101077ms step_avg:87.74ms +step:1153/1680 train_time:101166ms step_avg:87.74ms +step:1154/1680 train_time:101255ms step_avg:87.74ms +step:1155/1680 train_time:101344ms step_avg:87.74ms +step:1156/1680 train_time:101432ms step_avg:87.74ms +step:1157/1680 train_time:101520ms step_avg:87.74ms +step:1158/1680 train_time:101608ms step_avg:87.74ms +step:1159/1680 train_time:101697ms step_avg:87.75ms +step:1160/1680 train_time:101787ms step_avg:87.75ms +step:1161/1680 train_time:101876ms step_avg:87.75ms +step:1162/1680 train_time:101966ms step_avg:87.75ms +step:1163/1680 train_time:102055ms step_avg:87.75ms +step:1164/1680 train_time:102145ms step_avg:87.75ms +step:1165/1680 train_time:102234ms step_avg:87.75ms +step:1166/1680 train_time:102323ms step_avg:87.76ms +step:1167/1680 train_time:102411ms step_avg:87.76ms +step:1168/1680 train_time:102500ms step_avg:87.76ms +step:1169/1680 train_time:102589ms step_avg:87.76ms +step:1170/1680 train_time:102677ms step_avg:87.76ms +step:1171/1680 train_time:102766ms step_avg:87.76ms +step:1172/1680 train_time:102854ms step_avg:87.76ms +step:1173/1680 train_time:102945ms step_avg:87.76ms +step:1174/1680 train_time:103033ms step_avg:87.76ms +step:1175/1680 train_time:103122ms step_avg:87.76ms +step:1176/1680 train_time:103211ms step_avg:87.76ms +step:1177/1680 train_time:103300ms step_avg:87.77ms +step:1178/1680 train_time:103388ms step_avg:87.77ms +step:1179/1680 train_time:103476ms step_avg:87.77ms +step:1180/1680 train_time:103565ms step_avg:87.77ms +step:1181/1680 train_time:103654ms step_avg:87.77ms +step:1182/1680 train_time:103744ms step_avg:87.77ms +step:1183/1680 train_time:103834ms step_avg:87.77ms +step:1184/1680 train_time:103923ms step_avg:87.77ms +step:1185/1680 train_time:104012ms step_avg:87.77ms +step:1186/1680 train_time:104102ms step_avg:87.78ms +step:1187/1680 train_time:104190ms step_avg:87.78ms +step:1188/1680 train_time:104279ms step_avg:87.78ms +step:1189/1680 train_time:104368ms step_avg:87.78ms +step:1190/1680 train_time:104457ms step_avg:87.78ms +step:1191/1680 train_time:104545ms step_avg:87.78ms +step:1192/1680 train_time:104634ms step_avg:87.78ms +step:1193/1680 train_time:104722ms step_avg:87.78ms +step:1194/1680 train_time:104812ms step_avg:87.78ms +step:1195/1680 train_time:104901ms step_avg:87.78ms +step:1196/1680 train_time:104990ms step_avg:87.78ms +step:1197/1680 train_time:105079ms step_avg:87.78ms +step:1198/1680 train_time:105168ms step_avg:87.79ms +step:1199/1680 train_time:105257ms step_avg:87.79ms +step:1200/1680 train_time:105345ms step_avg:87.79ms +step:1201/1680 train_time:105434ms step_avg:87.79ms +step:1202/1680 train_time:105523ms step_avg:87.79ms +step:1203/1680 train_time:105612ms step_avg:87.79ms +step:1204/1680 train_time:105702ms step_avg:87.79ms +step:1205/1680 train_time:105790ms step_avg:87.79ms +step:1206/1680 train_time:105879ms step_avg:87.79ms +step:1207/1680 train_time:105968ms step_avg:87.79ms +step:1208/1680 train_time:106057ms step_avg:87.80ms +step:1209/1680 train_time:106146ms step_avg:87.80ms +step:1210/1680 train_time:106235ms step_avg:87.80ms +step:1211/1680 train_time:106324ms step_avg:87.80ms +step:1212/1680 train_time:106413ms step_avg:87.80ms +step:1213/1680 train_time:106502ms step_avg:87.80ms +step:1214/1680 train_time:106591ms step_avg:87.80ms +step:1215/1680 train_time:106680ms step_avg:87.80ms +step:1216/1680 train_time:106769ms step_avg:87.80ms +step:1217/1680 train_time:106858ms step_avg:87.80ms +step:1218/1680 train_time:106948ms step_avg:87.81ms +step:1219/1680 train_time:107036ms step_avg:87.81ms +step:1220/1680 train_time:107125ms step_avg:87.81ms +step:1221/1680 train_time:107213ms step_avg:87.81ms +step:1222/1680 train_time:107302ms step_avg:87.81ms +step:1223/1680 train_time:107391ms step_avg:87.81ms +step:1224/1680 train_time:107480ms step_avg:87.81ms +step:1225/1680 train_time:107569ms step_avg:87.81ms +step:1226/1680 train_time:107658ms step_avg:87.81ms +step:1227/1680 train_time:107747ms step_avg:87.81ms +step:1228/1680 train_time:107836ms step_avg:87.81ms +step:1229/1680 train_time:107925ms step_avg:87.81ms +step:1230/1680 train_time:108013ms step_avg:87.82ms +step:1231/1680 train_time:108102ms step_avg:87.82ms +step:1232/1680 train_time:108190ms step_avg:87.82ms +step:1233/1680 train_time:108280ms step_avg:87.82ms +step:1234/1680 train_time:108368ms step_avg:87.82ms +step:1235/1680 train_time:108457ms step_avg:87.82ms +step:1236/1680 train_time:108546ms step_avg:87.82ms +step:1237/1680 train_time:108635ms step_avg:87.82ms +step:1238/1680 train_time:108725ms step_avg:87.82ms +step:1239/1680 train_time:108814ms step_avg:87.82ms +step:1240/1680 train_time:108903ms step_avg:87.82ms +step:1241/1680 train_time:108991ms step_avg:87.82ms +step:1242/1680 train_time:109079ms step_avg:87.83ms +step:1243/1680 train_time:109168ms step_avg:87.83ms +step:1244/1680 train_time:109257ms step_avg:87.83ms +step:1245/1680 train_time:109345ms step_avg:87.83ms +step:1246/1680 train_time:109434ms step_avg:87.83ms +step:1247/1680 train_time:109523ms step_avg:87.83ms +step:1248/1680 train_time:109612ms step_avg:87.83ms +step:1249/1680 train_time:109701ms step_avg:87.83ms +step:1250/1680 train_time:109790ms step_avg:87.83ms +step:1250/1680 val_loss:3.3795 train_time:109880ms step_avg:87.90ms +step:1251/1680 train_time:109899ms step_avg:87.85ms +step:1252/1680 train_time:109971ms step_avg:87.84ms +step:1253/1680 train_time:110063ms step_avg:87.84ms +step:1254/1680 train_time:110153ms step_avg:87.84ms +step:1255/1680 train_time:110241ms step_avg:87.84ms +step:1256/1680 train_time:110329ms step_avg:87.84ms +step:1257/1680 train_time:110417ms step_avg:87.84ms +step:1258/1680 train_time:110504ms step_avg:87.84ms +step:1259/1680 train_time:110592ms step_avg:87.84ms +step:1260/1680 train_time:110680ms step_avg:87.84ms +step:1261/1680 train_time:110768ms step_avg:87.84ms +step:1262/1680 train_time:110858ms step_avg:87.84ms +step:1263/1680 train_time:110949ms step_avg:87.85ms +step:1264/1680 train_time:111039ms step_avg:87.85ms +step:1265/1680 train_time:111129ms step_avg:87.85ms +step:1266/1680 train_time:111218ms step_avg:87.85ms +step:1267/1680 train_time:111307ms step_avg:87.85ms +step:1268/1680 train_time:111395ms step_avg:87.85ms +step:1269/1680 train_time:111485ms step_avg:87.85ms +step:1270/1680 train_time:111572ms step_avg:87.85ms +step:1271/1680 train_time:111660ms step_avg:87.85ms +step:1272/1680 train_time:111748ms step_avg:87.85ms +step:1273/1680 train_time:111838ms step_avg:87.85ms +step:1274/1680 train_time:111927ms step_avg:87.86ms +step:1275/1680 train_time:112018ms step_avg:87.86ms +step:1276/1680 train_time:112107ms step_avg:87.86ms +step:1277/1680 train_time:112196ms step_avg:87.86ms +step:1278/1680 train_time:112285ms step_avg:87.86ms +step:1279/1680 train_time:112373ms step_avg:87.86ms +step:1280/1680 train_time:112461ms step_avg:87.86ms +step:1281/1680 train_time:112549ms step_avg:87.86ms +step:1282/1680 train_time:112638ms step_avg:87.86ms +step:1283/1680 train_time:112728ms step_avg:87.86ms +step:1284/1680 train_time:112817ms step_avg:87.86ms +step:1285/1680 train_time:112906ms step_avg:87.86ms +step:1286/1680 train_time:112996ms step_avg:87.87ms +step:1287/1680 train_time:113086ms step_avg:87.87ms +step:1288/1680 train_time:113175ms step_avg:87.87ms +step:1289/1680 train_time:113264ms step_avg:87.87ms +step:1290/1680 train_time:113353ms step_avg:87.87ms +step:1291/1680 train_time:113442ms step_avg:87.87ms +step:1292/1680 train_time:113530ms step_avg:87.87ms +step:1293/1680 train_time:113618ms step_avg:87.87ms +step:1294/1680 train_time:113707ms step_avg:87.87ms +step:1295/1680 train_time:113796ms step_avg:87.87ms +step:1296/1680 train_time:113885ms step_avg:87.87ms +step:1297/1680 train_time:113975ms step_avg:87.88ms +step:1298/1680 train_time:114066ms step_avg:87.88ms +step:1299/1680 train_time:114155ms step_avg:87.88ms +step:1300/1680 train_time:114243ms step_avg:87.88ms +step:1301/1680 train_time:114332ms step_avg:87.88ms +step:1302/1680 train_time:114421ms step_avg:87.88ms +step:1303/1680 train_time:114510ms step_avg:87.88ms +step:1304/1680 train_time:114599ms step_avg:87.88ms +step:1305/1680 train_time:114689ms step_avg:87.88ms +step:1306/1680 train_time:114777ms step_avg:87.88ms +step:1307/1680 train_time:114866ms step_avg:87.89ms +step:1308/1680 train_time:114955ms step_avg:87.89ms +step:1309/1680 train_time:115045ms step_avg:87.89ms +step:1310/1680 train_time:115134ms step_avg:87.89ms +step:1311/1680 train_time:115224ms step_avg:87.89ms +step:1312/1680 train_time:115314ms step_avg:87.89ms +step:1313/1680 train_time:115403ms step_avg:87.89ms +step:1314/1680 train_time:115492ms step_avg:87.89ms +step:1315/1680 train_time:115581ms step_avg:87.89ms +step:1316/1680 train_time:115670ms step_avg:87.90ms +step:1317/1680 train_time:115759ms step_avg:87.90ms +step:1318/1680 train_time:115847ms step_avg:87.90ms +step:1319/1680 train_time:115936ms step_avg:87.90ms +step:1320/1680 train_time:116025ms step_avg:87.90ms +step:1321/1680 train_time:116114ms step_avg:87.90ms +step:1322/1680 train_time:116204ms step_avg:87.90ms +step:1323/1680 train_time:116294ms step_avg:87.90ms +step:1324/1680 train_time:116382ms step_avg:87.90ms +step:1325/1680 train_time:116471ms step_avg:87.90ms +step:1326/1680 train_time:116560ms step_avg:87.90ms +step:1327/1680 train_time:116649ms step_avg:87.90ms +step:1328/1680 train_time:116737ms step_avg:87.90ms +step:1329/1680 train_time:116826ms step_avg:87.91ms +step:1330/1680 train_time:116915ms step_avg:87.91ms +step:1331/1680 train_time:117004ms step_avg:87.91ms +step:1332/1680 train_time:117094ms step_avg:87.91ms +step:1333/1680 train_time:117183ms step_avg:87.91ms +step:1334/1680 train_time:117272ms step_avg:87.91ms +step:1335/1680 train_time:117360ms step_avg:87.91ms +step:1336/1680 train_time:117450ms step_avg:87.91ms +step:1337/1680 train_time:117539ms step_avg:87.91ms +step:1338/1680 train_time:117628ms step_avg:87.91ms +step:1339/1680 train_time:117716ms step_avg:87.91ms +step:1340/1680 train_time:117806ms step_avg:87.91ms +step:1341/1680 train_time:117894ms step_avg:87.92ms +step:1342/1680 train_time:117982ms step_avg:87.92ms +step:1343/1680 train_time:118071ms step_avg:87.92ms +step:1344/1680 train_time:118161ms step_avg:87.92ms +step:1345/1680 train_time:118250ms step_avg:87.92ms +step:1346/1680 train_time:118338ms step_avg:87.92ms +step:1347/1680 train_time:118428ms step_avg:87.92ms +step:1348/1680 train_time:118517ms step_avg:87.92ms +step:1349/1680 train_time:118606ms step_avg:87.92ms +step:1350/1680 train_time:118695ms step_avg:87.92ms +step:1351/1680 train_time:118785ms step_avg:87.92ms +step:1352/1680 train_time:118874ms step_avg:87.92ms +step:1353/1680 train_time:118964ms step_avg:87.93ms +step:1354/1680 train_time:119053ms step_avg:87.93ms +step:1355/1680 train_time:119142ms step_avg:87.93ms +step:1356/1680 train_time:119232ms step_avg:87.93ms +step:1357/1680 train_time:119321ms step_avg:87.93ms +step:1358/1680 train_time:119410ms step_avg:87.93ms +step:1359/1680 train_time:119498ms step_avg:87.93ms +step:1360/1680 train_time:119588ms step_avg:87.93ms +step:1361/1680 train_time:119676ms step_avg:87.93ms +step:1362/1680 train_time:119765ms step_avg:87.93ms +step:1363/1680 train_time:119854ms step_avg:87.93ms +step:1364/1680 train_time:119942ms step_avg:87.93ms +step:1365/1680 train_time:120031ms step_avg:87.93ms +step:1366/1680 train_time:120120ms step_avg:87.94ms +step:1367/1680 train_time:120209ms step_avg:87.94ms +step:1368/1680 train_time:120298ms step_avg:87.94ms +step:1369/1680 train_time:120388ms step_avg:87.94ms +step:1370/1680 train_time:120476ms step_avg:87.94ms +step:1371/1680 train_time:120565ms step_avg:87.94ms +step:1372/1680 train_time:120654ms step_avg:87.94ms +step:1373/1680 train_time:120744ms step_avg:87.94ms +step:1374/1680 train_time:120833ms step_avg:87.94ms +step:1375/1680 train_time:120923ms step_avg:87.94ms +step:1375/1680 val_loss:3.3446 train_time:121013ms step_avg:88.01ms +step:1376/1680 train_time:121032ms step_avg:87.96ms +step:1377/1680 train_time:121105ms step_avg:87.95ms +step:1378/1680 train_time:121196ms step_avg:87.95ms +step:1379/1680 train_time:121285ms step_avg:87.95ms +step:1380/1680 train_time:121373ms step_avg:87.95ms +step:1381/1680 train_time:121461ms step_avg:87.95ms +step:1382/1680 train_time:121548ms step_avg:87.95ms +step:1383/1680 train_time:121636ms step_avg:87.95ms +step:1384/1680 train_time:121724ms step_avg:87.95ms +step:1385/1680 train_time:121812ms step_avg:87.95ms +step:1386/1680 train_time:121901ms step_avg:87.95ms +step:1387/1680 train_time:121990ms step_avg:87.95ms +step:1388/1680 train_time:122081ms step_avg:87.95ms +step:1389/1680 train_time:122172ms step_avg:87.96ms +step:1390/1680 train_time:122261ms step_avg:87.96ms +step:1391/1680 train_time:122350ms step_avg:87.96ms +step:1392/1680 train_time:122439ms step_avg:87.96ms +step:1393/1680 train_time:122527ms step_avg:87.96ms +step:1394/1680 train_time:122615ms step_avg:87.96ms +step:1395/1680 train_time:122703ms step_avg:87.96ms +step:1396/1680 train_time:122790ms step_avg:87.96ms +step:1397/1680 train_time:122878ms step_avg:87.96ms +step:1398/1680 train_time:122968ms step_avg:87.96ms +step:1399/1680 train_time:123058ms step_avg:87.96ms +step:1400/1680 train_time:123148ms step_avg:87.96ms +step:1401/1680 train_time:123238ms step_avg:87.96ms +step:1402/1680 train_time:123327ms step_avg:87.96ms +step:1403/1680 train_time:123417ms step_avg:87.97ms +step:1404/1680 train_time:123506ms step_avg:87.97ms +step:1405/1680 train_time:123594ms step_avg:87.97ms +step:1406/1680 train_time:123682ms step_avg:87.97ms +step:1407/1680 train_time:123770ms step_avg:87.97ms +step:1408/1680 train_time:123859ms step_avg:87.97ms +step:1409/1680 train_time:123948ms step_avg:87.97ms +step:1410/1680 train_time:124037ms step_avg:87.97ms +step:1411/1680 train_time:124126ms step_avg:87.97ms +step:1412/1680 train_time:124216ms step_avg:87.97ms +step:1413/1680 train_time:124306ms step_avg:87.97ms +step:1414/1680 train_time:124395ms step_avg:87.97ms +step:1415/1680 train_time:124484ms step_avg:87.97ms +step:1416/1680 train_time:124573ms step_avg:87.98ms +step:1417/1680 train_time:124663ms step_avg:87.98ms +step:1418/1680 train_time:124751ms step_avg:87.98ms +step:1419/1680 train_time:124839ms step_avg:87.98ms +step:1420/1680 train_time:124928ms step_avg:87.98ms +step:1421/1680 train_time:125017ms step_avg:87.98ms +step:1422/1680 train_time:125105ms step_avg:87.98ms +step:1423/1680 train_time:125195ms step_avg:87.98ms +step:1424/1680 train_time:125284ms step_avg:87.98ms +step:1425/1680 train_time:125373ms step_avg:87.98ms +step:1426/1680 train_time:125462ms step_avg:87.98ms +step:1427/1680 train_time:125551ms step_avg:87.98ms +step:1428/1680 train_time:125640ms step_avg:87.98ms +step:1429/1680 train_time:125729ms step_avg:87.98ms +step:1430/1680 train_time:125818ms step_avg:87.98ms +step:1431/1680 train_time:125907ms step_avg:87.99ms +step:1432/1680 train_time:125995ms step_avg:87.99ms +step:1433/1680 train_time:126085ms step_avg:87.99ms +step:1434/1680 train_time:126174ms step_avg:87.99ms +step:1435/1680 train_time:126264ms step_avg:87.99ms +step:1436/1680 train_time:126354ms step_avg:87.99ms +step:1437/1680 train_time:126444ms step_avg:87.99ms +step:1438/1680 train_time:126534ms step_avg:87.99ms +step:1439/1680 train_time:126623ms step_avg:87.99ms +step:1440/1680 train_time:126712ms step_avg:87.99ms +step:1441/1680 train_time:126800ms step_avg:87.99ms +step:1442/1680 train_time:126889ms step_avg:88.00ms +step:1443/1680 train_time:126977ms step_avg:88.00ms +step:1444/1680 train_time:127067ms step_avg:88.00ms +step:1445/1680 train_time:127156ms step_avg:88.00ms +step:1446/1680 train_time:127246ms step_avg:88.00ms +step:1447/1680 train_time:127335ms step_avg:88.00ms +step:1448/1680 train_time:127424ms step_avg:88.00ms +step:1449/1680 train_time:127514ms step_avg:88.00ms +step:1450/1680 train_time:127603ms step_avg:88.00ms +step:1451/1680 train_time:127692ms step_avg:88.00ms +step:1452/1680 train_time:127780ms step_avg:88.00ms +step:1453/1680 train_time:127869ms step_avg:88.00ms +step:1454/1680 train_time:127958ms step_avg:88.00ms +step:1455/1680 train_time:128046ms step_avg:88.00ms +step:1456/1680 train_time:128136ms step_avg:88.01ms +step:1457/1680 train_time:128225ms step_avg:88.01ms +step:1458/1680 train_time:128314ms step_avg:88.01ms +step:1459/1680 train_time:128404ms step_avg:88.01ms +step:1460/1680 train_time:128493ms step_avg:88.01ms +step:1461/1680 train_time:128582ms step_avg:88.01ms +step:1462/1680 train_time:128671ms step_avg:88.01ms +step:1463/1680 train_time:128759ms step_avg:88.01ms +step:1464/1680 train_time:128848ms step_avg:88.01ms +step:1465/1680 train_time:128936ms step_avg:88.01ms +step:1466/1680 train_time:129026ms step_avg:88.01ms +step:1467/1680 train_time:129114ms step_avg:88.01ms +step:1468/1680 train_time:129203ms step_avg:88.01ms +step:1469/1680 train_time:129292ms step_avg:88.01ms +step:1470/1680 train_time:129381ms step_avg:88.01ms +step:1471/1680 train_time:129470ms step_avg:88.01ms +step:1472/1680 train_time:129559ms step_avg:88.02ms +step:1473/1680 train_time:129648ms step_avg:88.02ms +step:1474/1680 train_time:129737ms step_avg:88.02ms +step:1475/1680 train_time:129826ms step_avg:88.02ms +step:1476/1680 train_time:129915ms step_avg:88.02ms +step:1477/1680 train_time:130003ms step_avg:88.02ms +step:1478/1680 train_time:130093ms step_avg:88.02ms +step:1479/1680 train_time:130182ms step_avg:88.02ms +step:1480/1680 train_time:130272ms step_avg:88.02ms +step:1481/1680 train_time:130361ms step_avg:88.02ms +step:1482/1680 train_time:130451ms step_avg:88.02ms +step:1483/1680 train_time:130540ms step_avg:88.02ms +step:1484/1680 train_time:130629ms step_avg:88.02ms +step:1485/1680 train_time:130718ms step_avg:88.03ms +step:1486/1680 train_time:130807ms step_avg:88.03ms +step:1487/1680 train_time:130896ms step_avg:88.03ms +step:1488/1680 train_time:130984ms step_avg:88.03ms +step:1489/1680 train_time:131073ms step_avg:88.03ms +step:1490/1680 train_time:131162ms step_avg:88.03ms +step:1491/1680 train_time:131252ms step_avg:88.03ms +step:1492/1680 train_time:131341ms step_avg:88.03ms +step:1493/1680 train_time:131430ms step_avg:88.03ms +step:1494/1680 train_time:131519ms step_avg:88.03ms +step:1495/1680 train_time:131608ms step_avg:88.03ms +step:1496/1680 train_time:131697ms step_avg:88.03ms +step:1497/1680 train_time:131786ms step_avg:88.03ms +step:1498/1680 train_time:131875ms step_avg:88.03ms +step:1499/1680 train_time:131963ms step_avg:88.03ms +step:1500/1680 train_time:132052ms step_avg:88.03ms +step:1500/1680 val_loss:3.3147 train_time:132142ms step_avg:88.09ms +step:1501/1680 train_time:132162ms step_avg:88.05ms +step:1502/1680 train_time:132235ms step_avg:88.04ms +step:1503/1680 train_time:132326ms step_avg:88.04ms +step:1504/1680 train_time:132417ms step_avg:88.04ms +step:1505/1680 train_time:132505ms step_avg:88.04ms +step:1506/1680 train_time:132593ms step_avg:88.04ms +step:1507/1680 train_time:132681ms step_avg:88.04ms +step:1508/1680 train_time:132769ms step_avg:88.04ms +step:1509/1680 train_time:132856ms step_avg:88.04ms +step:1510/1680 train_time:132944ms step_avg:88.04ms +step:1511/1680 train_time:133033ms step_avg:88.04ms +step:1512/1680 train_time:133123ms step_avg:88.04ms +step:1513/1680 train_time:133214ms step_avg:88.05ms +step:1514/1680 train_time:133305ms step_avg:88.05ms +step:1515/1680 train_time:133395ms step_avg:88.05ms +step:1516/1680 train_time:133485ms step_avg:88.05ms +step:1517/1680 train_time:133574ms step_avg:88.05ms +step:1518/1680 train_time:133662ms step_avg:88.05ms +step:1519/1680 train_time:133750ms step_avg:88.05ms +step:1520/1680 train_time:133838ms step_avg:88.05ms +step:1521/1680 train_time:133927ms step_avg:88.05ms +step:1522/1680 train_time:134015ms step_avg:88.05ms +step:1523/1680 train_time:134105ms step_avg:88.05ms +step:1524/1680 train_time:134194ms step_avg:88.05ms +step:1525/1680 train_time:134284ms step_avg:88.06ms +step:1526/1680 train_time:134375ms step_avg:88.06ms +step:1527/1680 train_time:134464ms step_avg:88.06ms +step:1528/1680 train_time:134553ms step_avg:88.06ms +step:1529/1680 train_time:134641ms step_avg:88.06ms +step:1530/1680 train_time:134730ms step_avg:88.06ms +step:1531/1680 train_time:134819ms step_avg:88.06ms +step:1532/1680 train_time:134908ms step_avg:88.06ms +step:1533/1680 train_time:134996ms step_avg:88.06ms +step:1534/1680 train_time:135084ms step_avg:88.06ms +step:1535/1680 train_time:135174ms step_avg:88.06ms +step:1536/1680 train_time:135263ms step_avg:88.06ms +step:1537/1680 train_time:135353ms step_avg:88.06ms +step:1538/1680 train_time:135443ms step_avg:88.06ms +step:1539/1680 train_time:135532ms step_avg:88.06ms +step:1540/1680 train_time:135621ms step_avg:88.07ms +step:1541/1680 train_time:135710ms step_avg:88.07ms +step:1542/1680 train_time:135798ms step_avg:88.07ms +step:1543/1680 train_time:135887ms step_avg:88.07ms +step:1544/1680 train_time:135976ms step_avg:88.07ms +step:1545/1680 train_time:136064ms step_avg:88.07ms +step:1546/1680 train_time:136153ms step_avg:88.07ms +step:1547/1680 train_time:136243ms step_avg:88.07ms +step:1548/1680 train_time:136332ms step_avg:88.07ms +step:1549/1680 train_time:136421ms step_avg:88.07ms +step:1550/1680 train_time:136511ms step_avg:88.07ms +step:1551/1680 train_time:136599ms step_avg:88.07ms +step:1552/1680 train_time:136688ms step_avg:88.07ms +step:1553/1680 train_time:136777ms step_avg:88.07ms +step:1554/1680 train_time:136865ms step_avg:88.07ms +step:1555/1680 train_time:136954ms step_avg:88.07ms +step:1556/1680 train_time:137043ms step_avg:88.07ms +step:1557/1680 train_time:137133ms step_avg:88.07ms +step:1558/1680 train_time:137222ms step_avg:88.08ms +step:1559/1680 train_time:137311ms step_avg:88.08ms +step:1560/1680 train_time:137401ms step_avg:88.08ms +step:1561/1680 train_time:137493ms step_avg:88.08ms +step:1562/1680 train_time:137582ms step_avg:88.08ms +step:1563/1680 train_time:137670ms step_avg:88.08ms +step:1564/1680 train_time:137759ms step_avg:88.08ms +step:1565/1680 train_time:137848ms step_avg:88.08ms +step:1566/1680 train_time:137936ms step_avg:88.08ms +step:1567/1680 train_time:138025ms step_avg:88.08ms +step:1568/1680 train_time:138114ms step_avg:88.08ms +step:1569/1680 train_time:138203ms step_avg:88.08ms +step:1570/1680 train_time:138292ms step_avg:88.08ms +step:1571/1680 train_time:138381ms step_avg:88.08ms +step:1572/1680 train_time:138470ms step_avg:88.09ms +step:1573/1680 train_time:138559ms step_avg:88.09ms +step:1574/1680 train_time:138648ms step_avg:88.09ms +step:1575/1680 train_time:138737ms step_avg:88.09ms +step:1576/1680 train_time:138826ms step_avg:88.09ms +step:1577/1680 train_time:138914ms step_avg:88.09ms +step:1578/1680 train_time:139003ms step_avg:88.09ms +step:1579/1680 train_time:139092ms step_avg:88.09ms +step:1580/1680 train_time:139181ms step_avg:88.09ms +step:1581/1680 train_time:139270ms step_avg:88.09ms +step:1582/1680 train_time:139360ms step_avg:88.09ms +step:1583/1680 train_time:139450ms step_avg:88.09ms +step:1584/1680 train_time:139539ms step_avg:88.09ms +step:1585/1680 train_time:139627ms step_avg:88.09ms +step:1586/1680 train_time:139717ms step_avg:88.09ms +step:1587/1680 train_time:139805ms step_avg:88.09ms +step:1588/1680 train_time:139894ms step_avg:88.09ms +step:1589/1680 train_time:139984ms step_avg:88.10ms +step:1590/1680 train_time:140073ms step_avg:88.10ms +step:1591/1680 train_time:140162ms step_avg:88.10ms +step:1592/1680 train_time:140250ms step_avg:88.10ms +step:1593/1680 train_time:140339ms step_avg:88.10ms +step:1594/1680 train_time:140429ms step_avg:88.10ms +step:1595/1680 train_time:140518ms step_avg:88.10ms +step:1596/1680 train_time:140607ms step_avg:88.10ms +step:1597/1680 train_time:140695ms step_avg:88.10ms +step:1598/1680 train_time:140784ms step_avg:88.10ms +step:1599/1680 train_time:140873ms step_avg:88.10ms +step:1600/1680 train_time:140961ms step_avg:88.10ms +step:1601/1680 train_time:141051ms step_avg:88.10ms +step:1602/1680 train_time:141140ms step_avg:88.10ms +step:1603/1680 train_time:141228ms step_avg:88.10ms +step:1604/1680 train_time:141318ms step_avg:88.10ms +step:1605/1680 train_time:141407ms step_avg:88.10ms +step:1606/1680 train_time:141496ms step_avg:88.10ms +step:1607/1680 train_time:141585ms step_avg:88.11ms +step:1608/1680 train_time:141674ms step_avg:88.11ms +step:1609/1680 train_time:141763ms step_avg:88.11ms +step:1610/1680 train_time:141852ms step_avg:88.11ms +step:1611/1680 train_time:141941ms step_avg:88.11ms +step:1612/1680 train_time:142030ms step_avg:88.11ms +step:1613/1680 train_time:142120ms step_avg:88.11ms +step:1614/1680 train_time:142209ms step_avg:88.11ms +step:1615/1680 train_time:142297ms step_avg:88.11ms +step:1616/1680 train_time:142386ms step_avg:88.11ms +step:1617/1680 train_time:142475ms step_avg:88.11ms +step:1618/1680 train_time:142564ms step_avg:88.11ms +step:1619/1680 train_time:142653ms step_avg:88.11ms +step:1620/1680 train_time:142742ms step_avg:88.11ms +step:1621/1680 train_time:142832ms step_avg:88.11ms +step:1622/1680 train_time:142921ms step_avg:88.11ms +step:1623/1680 train_time:143010ms step_avg:88.11ms +step:1624/1680 train_time:143100ms step_avg:88.12ms +step:1625/1680 train_time:143189ms step_avg:88.12ms +step:1625/1680 val_loss:3.2910 train_time:143279ms step_avg:88.17ms +step:1626/1680 train_time:143299ms step_avg:88.13ms +step:1627/1680 train_time:143370ms step_avg:88.12ms +step:1628/1680 train_time:143464ms step_avg:88.12ms +step:1629/1680 train_time:143555ms step_avg:88.12ms +step:1630/1680 train_time:143643ms step_avg:88.12ms +step:1631/1680 train_time:143732ms step_avg:88.13ms +step:1632/1680 train_time:143820ms step_avg:88.13ms +step:1633/1680 train_time:143908ms step_avg:88.12ms +step:1634/1680 train_time:143996ms step_avg:88.12ms +step:1635/1680 train_time:144084ms step_avg:88.12ms +step:1636/1680 train_time:144171ms step_avg:88.12ms +step:1637/1680 train_time:144262ms step_avg:88.13ms +step:1638/1680 train_time:144353ms step_avg:88.13ms +step:1639/1680 train_time:144444ms step_avg:88.13ms +step:1640/1680 train_time:144534ms step_avg:88.13ms +step:1641/1680 train_time:144624ms step_avg:88.13ms +step:1642/1680 train_time:144714ms step_avg:88.13ms +step:1643/1680 train_time:144803ms step_avg:88.13ms +step:1644/1680 train_time:144891ms step_avg:88.13ms +step:1645/1680 train_time:144980ms step_avg:88.13ms +step:1646/1680 train_time:145068ms step_avg:88.13ms +step:1647/1680 train_time:145156ms step_avg:88.13ms +step:1648/1680 train_time:145246ms step_avg:88.13ms +step:1649/1680 train_time:145336ms step_avg:88.14ms +step:1650/1680 train_time:145426ms step_avg:88.14ms +step:1651/1680 train_time:145515ms step_avg:88.14ms +step:1652/1680 train_time:145605ms step_avg:88.14ms +step:1653/1680 train_time:145694ms step_avg:88.14ms +step:1654/1680 train_time:145783ms step_avg:88.14ms +step:1655/1680 train_time:145871ms step_avg:88.14ms +step:1656/1680 train_time:145960ms step_avg:88.14ms +step:1657/1680 train_time:146048ms step_avg:88.14ms +step:1658/1680 train_time:146136ms step_avg:88.14ms +step:1659/1680 train_time:146225ms step_avg:88.14ms +step:1660/1680 train_time:146315ms step_avg:88.14ms +step:1661/1680 train_time:146404ms step_avg:88.14ms +step:1662/1680 train_time:146494ms step_avg:88.14ms +step:1663/1680 train_time:146583ms step_avg:88.14ms +step:1664/1680 train_time:146673ms step_avg:88.14ms +step:1665/1680 train_time:146762ms step_avg:88.15ms +step:1666/1680 train_time:146851ms step_avg:88.15ms +step:1667/1680 train_time:146939ms step_avg:88.15ms +step:1668/1680 train_time:147027ms step_avg:88.15ms +step:1669/1680 train_time:147116ms step_avg:88.15ms +step:1670/1680 train_time:147205ms step_avg:88.15ms +step:1671/1680 train_time:147294ms step_avg:88.15ms +step:1672/1680 train_time:147383ms step_avg:88.15ms +step:1673/1680 train_time:147473ms step_avg:88.15ms +step:1674/1680 train_time:147563ms step_avg:88.15ms +step:1675/1680 train_time:147652ms step_avg:88.15ms +step:1676/1680 train_time:147743ms step_avg:88.15ms +step:1677/1680 train_time:147832ms step_avg:88.15ms +step:1678/1680 train_time:147922ms step_avg:88.15ms +step:1679/1680 train_time:148011ms step_avg:88.15ms +step:1680/1680 train_time:148100ms step_avg:88.15ms +step:1680/1680 val_loss:3.2799 train_time:148190ms step_avg:88.21ms +peak memory allocated: 30760 MiB reserved: 45934 MiB diff --git a/records/092725_BF16CE/11dd9171-d060-4279-a6e5-5ba91fb7758e.txt b/records/092725_BF16CE/11dd9171-d060-4279-a6e5-5ba91fb7758e.txt new file mode 100644 index 000000000..bd49077fe --- /dev/null +++ b/records/092725_BF16CE/11dd9171-d060-4279-a6e5-5ba91fb7758e.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:28:19 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 174977 C /usr/bin/python 0MiB | +| 0 N/A N/A 174978 C /usr/bin/python 0MiB | +| 0 N/A N/A 174979 C /usr/bin/python 0MiB | +| 0 N/A N/A 174980 C /usr/bin/python 0MiB | +| 0 N/A N/A 174981 C /usr/bin/python 0MiB | +| 0 N/A N/A 174982 C /usr/bin/python 0MiB | +| 0 N/A N/A 174983 C /usr/bin/python 0MiB | +| 0 N/A N/A 174984 C /usr/bin/python 0MiB | +| 1 N/A N/A 174978 C /usr/bin/python 0MiB | +| 2 N/A N/A 174979 C /usr/bin/python 0MiB | +| 3 N/A N/A 174980 C /usr/bin/python 0MiB | +| 4 N/A N/A 174981 C /usr/bin/python 0MiB | +| 5 N/A N/A 174982 C /usr/bin/python 0MiB | +| 6 N/A N/A 174983 C /usr/bin/python 0MiB | +| 7 N/A N/A 174984 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:141ms step_avg:140.61ms +step:2/1680 train_time:160ms step_avg:79.95ms +step:3/1680 train_time:224ms step_avg:74.73ms +step:4/1680 train_time:309ms step_avg:77.26ms +step:5/1680 train_time:396ms step_avg:79.14ms +step:6/1680 train_time:482ms step_avg:80.31ms +step:7/1680 train_time:568ms step_avg:81.13ms +step:8/1680 train_time:654ms step_avg:81.73ms +step:9/1680 train_time:740ms step_avg:82.18ms +step:10/1680 train_time:826ms step_avg:82.65ms +step:11/1680 train_time:913ms step_avg:82.96ms +step:12/1680 train_time:1001ms step_avg:83.42ms +step:13/1680 train_time:1093ms step_avg:84.04ms +step:14/1680 train_time:1183ms step_avg:84.51ms +step:15/1680 train_time:1271ms step_avg:84.71ms +step:16/1680 train_time:1359ms step_avg:84.93ms +step:17/1680 train_time:1445ms step_avg:85.02ms +step:18/1680 train_time:1532ms step_avg:85.12ms +step:19/1680 train_time:1619ms step_avg:85.20ms +step:20/1680 train_time:1705ms step_avg:85.25ms +step:21/1680 train_time:1791ms step_avg:85.30ms +step:22/1680 train_time:1877ms step_avg:85.34ms +step:23/1680 train_time:1965ms step_avg:85.44ms +step:24/1680 train_time:2053ms step_avg:85.54ms +step:25/1680 train_time:2142ms step_avg:85.68ms +step:26/1680 train_time:2232ms step_avg:85.83ms +step:27/1680 train_time:2319ms step_avg:85.90ms +step:28/1680 train_time:2406ms step_avg:85.93ms +step:29/1680 train_time:2494ms step_avg:85.99ms +step:30/1680 train_time:2581ms step_avg:86.04ms +step:31/1680 train_time:2667ms step_avg:86.04ms +step:32/1680 train_time:2754ms step_avg:86.07ms +step:33/1680 train_time:2841ms step_avg:86.09ms +step:34/1680 train_time:2929ms step_avg:86.14ms +step:35/1680 train_time:3016ms step_avg:86.18ms +step:36/1680 train_time:3104ms step_avg:86.22ms +step:37/1680 train_time:3192ms step_avg:86.26ms +step:38/1680 train_time:3280ms step_avg:86.31ms +step:39/1680 train_time:3367ms step_avg:86.34ms +step:40/1680 train_time:3454ms step_avg:86.35ms +step:41/1680 train_time:3541ms step_avg:86.37ms +step:42/1680 train_time:3628ms step_avg:86.38ms +step:43/1680 train_time:3715ms step_avg:86.39ms +step:44/1680 train_time:3801ms step_avg:86.39ms +step:45/1680 train_time:3888ms step_avg:86.41ms +step:46/1680 train_time:3976ms step_avg:86.43ms +step:47/1680 train_time:4062ms step_avg:86.43ms +step:48/1680 train_time:4150ms step_avg:86.45ms +step:49/1680 train_time:4238ms step_avg:86.49ms +step:50/1680 train_time:4327ms step_avg:86.54ms +step:51/1680 train_time:4415ms step_avg:86.56ms +step:52/1680 train_time:4502ms step_avg:86.57ms +step:53/1680 train_time:4589ms step_avg:86.58ms +step:54/1680 train_time:4675ms step_avg:86.58ms +step:55/1680 train_time:4762ms step_avg:86.59ms +step:56/1680 train_time:4849ms step_avg:86.59ms +step:57/1680 train_time:4936ms step_avg:86.59ms +step:58/1680 train_time:5023ms step_avg:86.60ms +step:59/1680 train_time:5111ms step_avg:86.63ms +step:60/1680 train_time:5199ms step_avg:86.64ms +step:61/1680 train_time:5288ms step_avg:86.69ms +step:62/1680 train_time:5376ms step_avg:86.70ms +step:63/1680 train_time:5463ms step_avg:86.71ms +step:64/1680 train_time:5550ms step_avg:86.72ms +step:65/1680 train_time:5637ms step_avg:86.72ms +step:66/1680 train_time:5724ms step_avg:86.73ms +step:67/1680 train_time:5811ms step_avg:86.73ms +step:68/1680 train_time:5898ms step_avg:86.74ms +step:69/1680 train_time:5986ms step_avg:86.75ms +step:70/1680 train_time:6073ms step_avg:86.75ms +step:71/1680 train_time:6160ms step_avg:86.76ms +step:72/1680 train_time:6248ms step_avg:86.78ms +step:73/1680 train_time:6336ms step_avg:86.79ms +step:74/1680 train_time:6422ms step_avg:86.79ms +step:75/1680 train_time:6510ms step_avg:86.80ms +step:76/1680 train_time:6597ms step_avg:86.80ms +step:77/1680 train_time:6684ms step_avg:86.81ms +step:78/1680 train_time:6771ms step_avg:86.80ms +step:79/1680 train_time:6858ms step_avg:86.81ms +step:80/1680 train_time:6945ms step_avg:86.82ms +step:81/1680 train_time:7033ms step_avg:86.82ms +step:82/1680 train_time:7120ms step_avg:86.83ms +step:83/1680 train_time:7208ms step_avg:86.84ms +step:84/1680 train_time:7295ms step_avg:86.84ms +step:85/1680 train_time:7382ms step_avg:86.85ms +step:86/1680 train_time:7469ms step_avg:86.85ms +step:87/1680 train_time:7556ms step_avg:86.85ms +step:88/1680 train_time:7643ms step_avg:86.85ms +step:89/1680 train_time:7731ms step_avg:86.86ms +step:90/1680 train_time:7817ms step_avg:86.86ms +step:91/1680 train_time:7905ms step_avg:86.86ms +step:92/1680 train_time:7992ms step_avg:86.87ms +step:93/1680 train_time:8079ms step_avg:86.87ms +step:94/1680 train_time:8166ms step_avg:86.88ms +step:95/1680 train_time:8253ms step_avg:86.88ms +step:96/1680 train_time:8341ms step_avg:86.88ms +step:97/1680 train_time:8428ms step_avg:86.89ms +step:98/1680 train_time:8515ms step_avg:86.89ms +step:99/1680 train_time:8602ms step_avg:86.89ms +step:100/1680 train_time:8689ms step_avg:86.89ms +step:101/1680 train_time:8777ms step_avg:86.90ms +step:102/1680 train_time:8864ms step_avg:86.90ms +step:103/1680 train_time:8951ms step_avg:86.90ms +step:104/1680 train_time:9038ms step_avg:86.90ms +step:105/1680 train_time:9126ms step_avg:86.91ms +step:106/1680 train_time:9212ms step_avg:86.91ms +step:107/1680 train_time:9300ms step_avg:86.91ms +step:108/1680 train_time:9388ms step_avg:86.92ms +step:109/1680 train_time:9475ms step_avg:86.92ms +step:110/1680 train_time:9562ms step_avg:86.93ms +step:111/1680 train_time:9649ms step_avg:86.93ms +step:112/1680 train_time:9736ms step_avg:86.92ms +step:113/1680 train_time:9822ms step_avg:86.92ms +step:114/1680 train_time:9910ms step_avg:86.93ms +step:115/1680 train_time:9997ms step_avg:86.93ms +step:116/1680 train_time:10084ms step_avg:86.93ms +step:117/1680 train_time:10171ms step_avg:86.93ms +step:118/1680 train_time:10258ms step_avg:86.93ms +step:119/1680 train_time:10346ms step_avg:86.94ms +step:120/1680 train_time:10433ms step_avg:86.94ms +step:121/1680 train_time:10521ms step_avg:86.95ms +step:122/1680 train_time:10609ms step_avg:86.96ms +step:123/1680 train_time:10695ms step_avg:86.95ms +step:124/1680 train_time:10783ms step_avg:86.96ms +step:125/1680 train_time:10869ms step_avg:86.96ms +step:125/1680 val_loss:4.2915 train_time:10958ms step_avg:87.66ms +step:126/1680 train_time:10977ms step_avg:87.12ms +step:127/1680 train_time:11046ms step_avg:86.98ms +step:128/1680 train_time:11144ms step_avg:87.06ms +step:129/1680 train_time:11238ms step_avg:87.12ms +step:130/1680 train_time:11326ms step_avg:87.12ms +step:131/1680 train_time:11413ms step_avg:87.12ms +step:132/1680 train_time:11499ms step_avg:87.12ms +step:133/1680 train_time:11585ms step_avg:87.11ms +step:134/1680 train_time:11671ms step_avg:87.10ms +step:135/1680 train_time:11757ms step_avg:87.09ms +step:136/1680 train_time:11843ms step_avg:87.08ms +step:137/1680 train_time:11929ms step_avg:87.07ms +step:138/1680 train_time:12015ms step_avg:87.07ms +step:139/1680 train_time:12104ms step_avg:87.08ms +step:140/1680 train_time:12194ms step_avg:87.10ms +step:141/1680 train_time:12282ms step_avg:87.10ms +step:142/1680 train_time:12369ms step_avg:87.11ms +step:143/1680 train_time:12456ms step_avg:87.11ms +step:144/1680 train_time:12543ms step_avg:87.10ms +step:145/1680 train_time:12629ms step_avg:87.10ms +step:146/1680 train_time:12716ms step_avg:87.10ms +step:147/1680 train_time:12802ms step_avg:87.09ms +step:148/1680 train_time:12888ms step_avg:87.08ms +step:149/1680 train_time:12976ms step_avg:87.08ms +step:150/1680 train_time:13063ms step_avg:87.09ms +step:151/1680 train_time:13152ms step_avg:87.10ms +step:152/1680 train_time:13240ms step_avg:87.10ms +step:153/1680 train_time:13327ms step_avg:87.10ms +step:154/1680 train_time:13414ms step_avg:87.11ms +step:155/1680 train_time:13501ms step_avg:87.10ms +step:156/1680 train_time:13587ms step_avg:87.10ms +step:157/1680 train_time:13674ms step_avg:87.10ms +step:158/1680 train_time:13761ms step_avg:87.09ms +step:159/1680 train_time:13847ms step_avg:87.09ms +step:160/1680 train_time:13935ms step_avg:87.09ms +step:161/1680 train_time:14021ms step_avg:87.09ms +step:162/1680 train_time:14109ms step_avg:87.09ms +step:163/1680 train_time:14196ms step_avg:87.09ms +step:164/1680 train_time:14283ms step_avg:87.09ms +step:165/1680 train_time:14371ms step_avg:87.09ms +step:166/1680 train_time:14458ms step_avg:87.10ms +step:167/1680 train_time:14545ms step_avg:87.10ms +step:168/1680 train_time:14633ms step_avg:87.10ms +step:169/1680 train_time:14719ms step_avg:87.09ms +step:170/1680 train_time:14805ms step_avg:87.09ms +step:171/1680 train_time:14892ms step_avg:87.09ms +step:172/1680 train_time:14979ms step_avg:87.09ms +step:173/1680 train_time:15067ms step_avg:87.09ms +step:174/1680 train_time:15154ms step_avg:87.09ms +step:175/1680 train_time:15242ms step_avg:87.09ms +step:176/1680 train_time:15330ms step_avg:87.10ms +step:177/1680 train_time:15417ms step_avg:87.10ms +step:178/1680 train_time:15504ms step_avg:87.10ms +step:179/1680 train_time:15593ms step_avg:87.11ms +step:180/1680 train_time:15680ms step_avg:87.11ms +step:181/1680 train_time:15767ms step_avg:87.11ms +step:182/1680 train_time:15853ms step_avg:87.11ms +step:183/1680 train_time:15941ms step_avg:87.11ms +step:184/1680 train_time:16028ms step_avg:87.11ms +step:185/1680 train_time:16115ms step_avg:87.11ms +step:186/1680 train_time:16203ms step_avg:87.11ms +step:187/1680 train_time:16290ms step_avg:87.11ms +step:188/1680 train_time:16377ms step_avg:87.11ms +step:189/1680 train_time:16464ms step_avg:87.11ms +step:190/1680 train_time:16552ms step_avg:87.11ms +step:191/1680 train_time:16639ms step_avg:87.12ms +step:192/1680 train_time:16726ms step_avg:87.11ms +step:193/1680 train_time:16813ms step_avg:87.11ms +step:194/1680 train_time:16900ms step_avg:87.11ms +step:195/1680 train_time:16986ms step_avg:87.11ms +step:196/1680 train_time:17074ms step_avg:87.11ms +step:197/1680 train_time:17161ms step_avg:87.11ms +step:198/1680 train_time:17249ms step_avg:87.12ms +step:199/1680 train_time:17337ms step_avg:87.12ms +step:200/1680 train_time:17424ms step_avg:87.12ms +step:201/1680 train_time:17511ms step_avg:87.12ms +step:202/1680 train_time:17598ms step_avg:87.12ms +step:203/1680 train_time:17685ms step_avg:87.12ms +step:204/1680 train_time:17771ms step_avg:87.11ms +step:205/1680 train_time:17859ms step_avg:87.12ms +step:206/1680 train_time:17945ms step_avg:87.11ms +step:207/1680 train_time:18032ms step_avg:87.11ms +step:208/1680 train_time:18119ms step_avg:87.11ms +step:209/1680 train_time:18207ms step_avg:87.11ms +step:210/1680 train_time:18294ms step_avg:87.12ms +step:211/1680 train_time:18381ms step_avg:87.12ms +step:212/1680 train_time:18469ms step_avg:87.12ms +step:213/1680 train_time:18556ms step_avg:87.12ms +step:214/1680 train_time:18643ms step_avg:87.12ms +step:215/1680 train_time:18730ms step_avg:87.12ms +step:216/1680 train_time:18817ms step_avg:87.12ms +step:217/1680 train_time:18904ms step_avg:87.12ms +step:218/1680 train_time:18991ms step_avg:87.11ms +step:219/1680 train_time:19077ms step_avg:87.11ms +step:220/1680 train_time:19165ms step_avg:87.11ms +step:221/1680 train_time:19252ms step_avg:87.11ms +step:222/1680 train_time:19339ms step_avg:87.11ms +step:223/1680 train_time:19426ms step_avg:87.11ms +step:224/1680 train_time:19514ms step_avg:87.12ms +step:225/1680 train_time:19601ms step_avg:87.11ms +step:226/1680 train_time:19687ms step_avg:87.11ms +step:227/1680 train_time:19776ms step_avg:87.12ms +step:228/1680 train_time:19862ms step_avg:87.12ms +step:229/1680 train_time:19950ms step_avg:87.12ms +step:230/1680 train_time:20037ms step_avg:87.12ms +step:231/1680 train_time:20123ms step_avg:87.11ms +step:232/1680 train_time:20211ms step_avg:87.11ms +step:233/1680 train_time:20298ms step_avg:87.12ms +step:234/1680 train_time:20386ms step_avg:87.12ms +step:235/1680 train_time:20473ms step_avg:87.12ms +step:236/1680 train_time:20560ms step_avg:87.12ms +step:237/1680 train_time:20648ms step_avg:87.12ms +step:238/1680 train_time:20735ms step_avg:87.12ms +step:239/1680 train_time:20822ms step_avg:87.12ms +step:240/1680 train_time:20909ms step_avg:87.12ms +step:241/1680 train_time:20996ms step_avg:87.12ms +step:242/1680 train_time:21083ms step_avg:87.12ms +step:243/1680 train_time:21170ms step_avg:87.12ms +step:244/1680 train_time:21258ms step_avg:87.12ms +step:245/1680 train_time:21345ms step_avg:87.12ms +step:246/1680 train_time:21433ms step_avg:87.12ms +step:247/1680 train_time:21520ms step_avg:87.12ms +step:248/1680 train_time:21607ms step_avg:87.13ms +step:249/1680 train_time:21694ms step_avg:87.13ms +step:250/1680 train_time:21782ms step_avg:87.13ms +step:250/1680 val_loss:3.9644 train_time:21870ms step_avg:87.48ms +step:251/1680 train_time:21888ms step_avg:87.20ms +step:252/1680 train_time:21959ms step_avg:87.14ms +step:253/1680 train_time:22049ms step_avg:87.15ms +step:254/1680 train_time:22137ms step_avg:87.15ms +step:255/1680 train_time:22224ms step_avg:87.15ms +step:256/1680 train_time:22311ms step_avg:87.15ms +step:257/1680 train_time:22398ms step_avg:87.15ms +step:258/1680 train_time:22484ms step_avg:87.15ms +step:259/1680 train_time:22571ms step_avg:87.15ms +step:260/1680 train_time:22657ms step_avg:87.14ms +step:261/1680 train_time:22743ms step_avg:87.14ms +step:262/1680 train_time:22831ms step_avg:87.14ms +step:263/1680 train_time:22919ms step_avg:87.14ms +step:264/1680 train_time:23008ms step_avg:87.15ms +step:265/1680 train_time:23096ms step_avg:87.15ms +step:266/1680 train_time:23182ms step_avg:87.15ms +step:267/1680 train_time:23269ms step_avg:87.15ms +step:268/1680 train_time:23356ms step_avg:87.15ms +step:269/1680 train_time:23443ms step_avg:87.15ms +step:270/1680 train_time:23529ms step_avg:87.14ms +step:271/1680 train_time:23616ms step_avg:87.14ms +step:272/1680 train_time:23702ms step_avg:87.14ms +step:273/1680 train_time:23790ms step_avg:87.14ms +step:274/1680 train_time:23878ms step_avg:87.14ms +step:275/1680 train_time:23965ms step_avg:87.14ms +step:276/1680 train_time:24053ms step_avg:87.15ms +step:277/1680 train_time:24140ms step_avg:87.15ms +step:278/1680 train_time:24227ms step_avg:87.15ms +step:279/1680 train_time:24314ms step_avg:87.15ms +step:280/1680 train_time:24400ms step_avg:87.14ms +step:281/1680 train_time:24488ms step_avg:87.14ms +step:282/1680 train_time:24574ms step_avg:87.14ms +step:283/1680 train_time:24660ms step_avg:87.14ms +step:284/1680 train_time:24747ms step_avg:87.14ms +step:285/1680 train_time:24834ms step_avg:87.14ms +step:286/1680 train_time:24922ms step_avg:87.14ms +step:287/1680 train_time:25009ms step_avg:87.14ms +step:288/1680 train_time:25096ms step_avg:87.14ms +step:289/1680 train_time:25183ms step_avg:87.14ms +step:290/1680 train_time:25270ms step_avg:87.14ms +step:291/1680 train_time:25357ms step_avg:87.14ms +step:292/1680 train_time:25445ms step_avg:87.14ms +step:293/1680 train_time:25532ms step_avg:87.14ms +step:294/1680 train_time:25618ms step_avg:87.14ms +step:295/1680 train_time:25705ms step_avg:87.14ms +step:296/1680 train_time:25792ms step_avg:87.13ms +step:297/1680 train_time:25879ms step_avg:87.13ms +step:298/1680 train_time:25966ms step_avg:87.13ms +step:299/1680 train_time:26054ms step_avg:87.14ms +step:300/1680 train_time:26140ms step_avg:87.13ms +step:301/1680 train_time:26228ms step_avg:87.14ms +step:302/1680 train_time:26315ms step_avg:87.14ms +step:303/1680 train_time:26402ms step_avg:87.14ms +step:304/1680 train_time:26489ms step_avg:87.13ms +step:305/1680 train_time:26576ms step_avg:87.13ms +step:306/1680 train_time:26662ms step_avg:87.13ms +step:307/1680 train_time:26749ms step_avg:87.13ms +step:308/1680 train_time:26836ms step_avg:87.13ms +step:309/1680 train_time:26923ms step_avg:87.13ms +step:310/1680 train_time:27011ms step_avg:87.13ms +step:311/1680 train_time:27100ms step_avg:87.14ms +step:312/1680 train_time:27187ms step_avg:87.14ms +step:313/1680 train_time:27274ms step_avg:87.14ms +step:314/1680 train_time:27360ms step_avg:87.13ms +step:315/1680 train_time:27447ms step_avg:87.13ms +step:316/1680 train_time:27535ms step_avg:87.13ms +step:317/1680 train_time:27621ms step_avg:87.13ms +step:318/1680 train_time:27708ms step_avg:87.13ms +step:319/1680 train_time:27796ms step_avg:87.13ms +step:320/1680 train_time:27883ms step_avg:87.13ms +step:321/1680 train_time:27971ms step_avg:87.14ms +step:322/1680 train_time:28058ms step_avg:87.14ms +step:323/1680 train_time:28144ms step_avg:87.13ms +step:324/1680 train_time:28232ms step_avg:87.14ms +step:325/1680 train_time:28319ms step_avg:87.13ms +step:326/1680 train_time:28406ms step_avg:87.13ms +step:327/1680 train_time:28494ms step_avg:87.14ms +step:328/1680 train_time:28581ms step_avg:87.14ms +step:329/1680 train_time:28668ms step_avg:87.14ms +step:330/1680 train_time:28755ms step_avg:87.14ms +step:331/1680 train_time:28842ms step_avg:87.14ms +step:332/1680 train_time:28930ms step_avg:87.14ms +step:333/1680 train_time:29017ms step_avg:87.14ms +step:334/1680 train_time:29104ms step_avg:87.14ms +step:335/1680 train_time:29191ms step_avg:87.14ms +step:336/1680 train_time:29278ms step_avg:87.14ms +step:337/1680 train_time:29365ms step_avg:87.14ms +step:338/1680 train_time:29453ms step_avg:87.14ms +step:339/1680 train_time:29540ms step_avg:87.14ms +step:340/1680 train_time:29627ms step_avg:87.14ms +step:341/1680 train_time:29714ms step_avg:87.14ms +step:342/1680 train_time:29801ms step_avg:87.14ms +step:343/1680 train_time:29889ms step_avg:87.14ms +step:344/1680 train_time:29976ms step_avg:87.14ms +step:345/1680 train_time:30063ms step_avg:87.14ms +step:346/1680 train_time:30151ms step_avg:87.14ms +step:347/1680 train_time:30238ms step_avg:87.14ms +step:348/1680 train_time:30325ms step_avg:87.14ms +step:349/1680 train_time:30412ms step_avg:87.14ms +step:350/1680 train_time:30499ms step_avg:87.14ms +step:351/1680 train_time:30586ms step_avg:87.14ms +step:352/1680 train_time:30673ms step_avg:87.14ms +step:353/1680 train_time:30760ms step_avg:87.14ms +step:354/1680 train_time:30847ms step_avg:87.14ms +step:355/1680 train_time:30935ms step_avg:87.14ms +step:356/1680 train_time:31022ms step_avg:87.14ms +step:357/1680 train_time:31110ms step_avg:87.14ms +step:358/1680 train_time:31197ms step_avg:87.14ms +step:359/1680 train_time:31283ms step_avg:87.14ms +step:360/1680 train_time:31370ms step_avg:87.14ms +step:361/1680 train_time:31457ms step_avg:87.14ms +step:362/1680 train_time:31544ms step_avg:87.14ms +step:363/1680 train_time:31631ms step_avg:87.14ms +step:364/1680 train_time:31718ms step_avg:87.14ms +step:365/1680 train_time:31805ms step_avg:87.14ms +step:366/1680 train_time:31893ms step_avg:87.14ms +step:367/1680 train_time:31979ms step_avg:87.14ms +step:368/1680 train_time:32067ms step_avg:87.14ms +step:369/1680 train_time:32154ms step_avg:87.14ms +step:370/1680 train_time:32242ms step_avg:87.14ms +step:371/1680 train_time:32329ms step_avg:87.14ms +step:372/1680 train_time:32417ms step_avg:87.14ms +step:373/1680 train_time:32504ms step_avg:87.14ms +step:374/1680 train_time:32591ms step_avg:87.14ms +step:375/1680 train_time:32678ms step_avg:87.14ms +step:375/1680 val_loss:3.8164 train_time:32767ms step_avg:87.38ms +step:376/1680 train_time:32788ms step_avg:87.20ms +step:377/1680 train_time:32855ms step_avg:87.15ms +step:378/1680 train_time:32947ms step_avg:87.16ms +step:379/1680 train_time:33035ms step_avg:87.16ms +step:380/1680 train_time:33122ms step_avg:87.16ms +step:381/1680 train_time:33209ms step_avg:87.16ms +step:382/1680 train_time:33296ms step_avg:87.16ms +step:383/1680 train_time:33382ms step_avg:87.16ms +step:384/1680 train_time:33468ms step_avg:87.16ms +step:385/1680 train_time:33555ms step_avg:87.15ms +step:386/1680 train_time:33640ms step_avg:87.15ms +step:387/1680 train_time:33727ms step_avg:87.15ms +step:388/1680 train_time:33815ms step_avg:87.15ms +step:389/1680 train_time:33904ms step_avg:87.16ms +step:390/1680 train_time:33994ms step_avg:87.16ms +step:391/1680 train_time:34082ms step_avg:87.17ms +step:392/1680 train_time:34169ms step_avg:87.17ms +step:393/1680 train_time:34256ms step_avg:87.16ms +step:394/1680 train_time:34342ms step_avg:87.16ms +step:395/1680 train_time:34429ms step_avg:87.16ms +step:396/1680 train_time:34516ms step_avg:87.16ms +step:397/1680 train_time:34602ms step_avg:87.16ms +step:398/1680 train_time:34689ms step_avg:87.16ms +step:399/1680 train_time:34776ms step_avg:87.16ms +step:400/1680 train_time:34865ms step_avg:87.16ms +step:401/1680 train_time:34954ms step_avg:87.17ms +step:402/1680 train_time:35041ms step_avg:87.17ms +step:403/1680 train_time:35128ms step_avg:87.17ms +step:404/1680 train_time:35215ms step_avg:87.17ms +step:405/1680 train_time:35302ms step_avg:87.17ms +step:406/1680 train_time:35388ms step_avg:87.16ms +step:407/1680 train_time:35475ms step_avg:87.16ms +step:408/1680 train_time:35562ms step_avg:87.16ms +step:409/1680 train_time:35649ms step_avg:87.16ms +step:410/1680 train_time:35735ms step_avg:87.16ms +step:411/1680 train_time:35823ms step_avg:87.16ms +step:412/1680 train_time:35911ms step_avg:87.16ms +step:413/1680 train_time:35999ms step_avg:87.16ms +step:414/1680 train_time:36086ms step_avg:87.16ms +step:415/1680 train_time:36173ms step_avg:87.16ms +step:416/1680 train_time:36260ms step_avg:87.16ms +step:417/1680 train_time:36347ms step_avg:87.16ms +step:418/1680 train_time:36434ms step_avg:87.16ms +step:419/1680 train_time:36520ms step_avg:87.16ms +step:420/1680 train_time:36607ms step_avg:87.16ms +step:421/1680 train_time:36694ms step_avg:87.16ms +step:422/1680 train_time:36781ms step_avg:87.16ms +step:423/1680 train_time:36869ms step_avg:87.16ms +step:424/1680 train_time:36957ms step_avg:87.16ms +step:425/1680 train_time:37045ms step_avg:87.16ms +step:426/1680 train_time:37132ms step_avg:87.17ms +step:427/1680 train_time:37219ms step_avg:87.16ms +step:428/1680 train_time:37306ms step_avg:87.16ms +step:429/1680 train_time:37393ms step_avg:87.16ms +step:430/1680 train_time:37480ms step_avg:87.16ms +step:431/1680 train_time:37566ms step_avg:87.16ms +step:432/1680 train_time:37653ms step_avg:87.16ms +step:433/1680 train_time:37740ms step_avg:87.16ms +step:434/1680 train_time:37827ms step_avg:87.16ms +step:435/1680 train_time:37914ms step_avg:87.16ms +step:436/1680 train_time:38001ms step_avg:87.16ms +step:437/1680 train_time:38090ms step_avg:87.16ms +step:438/1680 train_time:38176ms step_avg:87.16ms +step:439/1680 train_time:38263ms step_avg:87.16ms +step:440/1680 train_time:38350ms step_avg:87.16ms +step:441/1680 train_time:38437ms step_avg:87.16ms +step:442/1680 train_time:38523ms step_avg:87.16ms +step:443/1680 train_time:38610ms step_avg:87.16ms +step:444/1680 train_time:38697ms step_avg:87.15ms +step:445/1680 train_time:38784ms step_avg:87.15ms +step:446/1680 train_time:38871ms step_avg:87.16ms +step:447/1680 train_time:38959ms step_avg:87.16ms +step:448/1680 train_time:39047ms step_avg:87.16ms +step:449/1680 train_time:39134ms step_avg:87.16ms +step:450/1680 train_time:39221ms step_avg:87.16ms +step:451/1680 train_time:39308ms step_avg:87.16ms +step:452/1680 train_time:39395ms step_avg:87.16ms +step:453/1680 train_time:39482ms step_avg:87.16ms +step:454/1680 train_time:39569ms step_avg:87.16ms +step:455/1680 train_time:39656ms step_avg:87.16ms +step:456/1680 train_time:39743ms step_avg:87.16ms +step:457/1680 train_time:39830ms step_avg:87.16ms +step:458/1680 train_time:39917ms step_avg:87.16ms +step:459/1680 train_time:40005ms step_avg:87.16ms +step:460/1680 train_time:40092ms step_avg:87.16ms +step:461/1680 train_time:40179ms step_avg:87.16ms +step:462/1680 train_time:40267ms step_avg:87.16ms +step:463/1680 train_time:40354ms step_avg:87.16ms +step:464/1680 train_time:40441ms step_avg:87.16ms +step:465/1680 train_time:40528ms step_avg:87.16ms +step:466/1680 train_time:40615ms step_avg:87.16ms +step:467/1680 train_time:40702ms step_avg:87.16ms +step:468/1680 train_time:40789ms step_avg:87.16ms +step:469/1680 train_time:40876ms step_avg:87.15ms +step:470/1680 train_time:40963ms step_avg:87.15ms +step:471/1680 train_time:41050ms step_avg:87.15ms +step:472/1680 train_time:41136ms step_avg:87.15ms +step:473/1680 train_time:41224ms step_avg:87.16ms +step:474/1680 train_time:41312ms step_avg:87.16ms +step:475/1680 train_time:41399ms step_avg:87.16ms +step:476/1680 train_time:41486ms step_avg:87.16ms +step:477/1680 train_time:41573ms step_avg:87.15ms +step:478/1680 train_time:41660ms step_avg:87.15ms +step:479/1680 train_time:41747ms step_avg:87.15ms +step:480/1680 train_time:41834ms step_avg:87.15ms +step:481/1680 train_time:41921ms step_avg:87.15ms +step:482/1680 train_time:42008ms step_avg:87.15ms +step:483/1680 train_time:42095ms step_avg:87.15ms +step:484/1680 train_time:42182ms step_avg:87.15ms +step:485/1680 train_time:42270ms step_avg:87.15ms +step:486/1680 train_time:42356ms step_avg:87.15ms +step:487/1680 train_time:42444ms step_avg:87.15ms +step:488/1680 train_time:42531ms step_avg:87.15ms +step:489/1680 train_time:42617ms step_avg:87.15ms +step:490/1680 train_time:42705ms step_avg:87.15ms +step:491/1680 train_time:42793ms step_avg:87.15ms +step:492/1680 train_time:42880ms step_avg:87.15ms +step:493/1680 train_time:42967ms step_avg:87.15ms +step:494/1680 train_time:43054ms step_avg:87.15ms +step:495/1680 train_time:43141ms step_avg:87.15ms +step:496/1680 train_time:43228ms step_avg:87.15ms +step:497/1680 train_time:43315ms step_avg:87.15ms +step:498/1680 train_time:43402ms step_avg:87.15ms +step:499/1680 train_time:43490ms step_avg:87.15ms +step:500/1680 train_time:43576ms step_avg:87.15ms +step:500/1680 val_loss:3.7161 train_time:43665ms step_avg:87.33ms +step:501/1680 train_time:43684ms step_avg:87.19ms +step:502/1680 train_time:43754ms step_avg:87.16ms +step:503/1680 train_time:43849ms step_avg:87.17ms +step:504/1680 train_time:43937ms step_avg:87.18ms +step:505/1680 train_time:44023ms step_avg:87.18ms +step:506/1680 train_time:44110ms step_avg:87.17ms +step:507/1680 train_time:44196ms step_avg:87.17ms +step:508/1680 train_time:44282ms step_avg:87.17ms +step:509/1680 train_time:44367ms step_avg:87.17ms +step:510/1680 train_time:44454ms step_avg:87.17ms +step:511/1680 train_time:44541ms step_avg:87.16ms +step:512/1680 train_time:44628ms step_avg:87.16ms +step:513/1680 train_time:44716ms step_avg:87.17ms +step:514/1680 train_time:44806ms step_avg:87.17ms +step:515/1680 train_time:44894ms step_avg:87.17ms +step:516/1680 train_time:44982ms step_avg:87.17ms +step:517/1680 train_time:45069ms step_avg:87.17ms +step:518/1680 train_time:45155ms step_avg:87.17ms +step:519/1680 train_time:45242ms step_avg:87.17ms +step:520/1680 train_time:45328ms step_avg:87.17ms +step:521/1680 train_time:45414ms step_avg:87.17ms +step:522/1680 train_time:45500ms step_avg:87.17ms +step:523/1680 train_time:45587ms step_avg:87.16ms +step:524/1680 train_time:45675ms step_avg:87.17ms +step:525/1680 train_time:45763ms step_avg:87.17ms +step:526/1680 train_time:45851ms step_avg:87.17ms +step:527/1680 train_time:45939ms step_avg:87.17ms +step:528/1680 train_time:46027ms step_avg:87.17ms +step:529/1680 train_time:46114ms step_avg:87.17ms +step:530/1680 train_time:46201ms step_avg:87.17ms +step:531/1680 train_time:46287ms step_avg:87.17ms +step:532/1680 train_time:46373ms step_avg:87.17ms +step:533/1680 train_time:46459ms step_avg:87.17ms +step:534/1680 train_time:46547ms step_avg:87.17ms +step:535/1680 train_time:46633ms step_avg:87.16ms +step:536/1680 train_time:46721ms step_avg:87.17ms +step:537/1680 train_time:46809ms step_avg:87.17ms +step:538/1680 train_time:46897ms step_avg:87.17ms +step:539/1680 train_time:46985ms step_avg:87.17ms +step:540/1680 train_time:47072ms step_avg:87.17ms +step:541/1680 train_time:47160ms step_avg:87.17ms +step:542/1680 train_time:47247ms step_avg:87.17ms +step:543/1680 train_time:47333ms step_avg:87.17ms +step:544/1680 train_time:47419ms step_avg:87.17ms +step:545/1680 train_time:47506ms step_avg:87.17ms +step:546/1680 train_time:47593ms step_avg:87.17ms +step:547/1680 train_time:47680ms step_avg:87.17ms +step:548/1680 train_time:47768ms step_avg:87.17ms +step:549/1680 train_time:47857ms step_avg:87.17ms +step:550/1680 train_time:47947ms step_avg:87.18ms +step:551/1680 train_time:48035ms step_avg:87.18ms +step:552/1680 train_time:48123ms step_avg:87.18ms +step:553/1680 train_time:48211ms step_avg:87.18ms +step:554/1680 train_time:48299ms step_avg:87.18ms +step:555/1680 train_time:48387ms step_avg:87.18ms +step:556/1680 train_time:48475ms step_avg:87.19ms +step:557/1680 train_time:48562ms step_avg:87.19ms +step:558/1680 train_time:48650ms step_avg:87.19ms +step:559/1680 train_time:48739ms step_avg:87.19ms +step:560/1680 train_time:48828ms step_avg:87.19ms +step:561/1680 train_time:48916ms step_avg:87.19ms +step:562/1680 train_time:49005ms step_avg:87.20ms +step:563/1680 train_time:49093ms step_avg:87.20ms +step:564/1680 train_time:49180ms step_avg:87.20ms +step:565/1680 train_time:49269ms step_avg:87.20ms +step:566/1680 train_time:49357ms step_avg:87.20ms +step:567/1680 train_time:49446ms step_avg:87.21ms +step:568/1680 train_time:49533ms step_avg:87.21ms +step:569/1680 train_time:49622ms step_avg:87.21ms +step:570/1680 train_time:49711ms step_avg:87.21ms +step:571/1680 train_time:49799ms step_avg:87.21ms +step:572/1680 train_time:49888ms step_avg:87.22ms +step:573/1680 train_time:49976ms step_avg:87.22ms +step:574/1680 train_time:50064ms step_avg:87.22ms +step:575/1680 train_time:50152ms step_avg:87.22ms +step:576/1680 train_time:50241ms step_avg:87.22ms +step:577/1680 train_time:50330ms step_avg:87.23ms +step:578/1680 train_time:50418ms step_avg:87.23ms +step:579/1680 train_time:50506ms step_avg:87.23ms +step:580/1680 train_time:50594ms step_avg:87.23ms +step:581/1680 train_time:50682ms step_avg:87.23ms +step:582/1680 train_time:50771ms step_avg:87.24ms +step:583/1680 train_time:50860ms step_avg:87.24ms +step:584/1680 train_time:50949ms step_avg:87.24ms +step:585/1680 train_time:51037ms step_avg:87.24ms +step:586/1680 train_time:51125ms step_avg:87.24ms +step:587/1680 train_time:51213ms step_avg:87.25ms +step:588/1680 train_time:51302ms step_avg:87.25ms +step:589/1680 train_time:51389ms step_avg:87.25ms +step:590/1680 train_time:51477ms step_avg:87.25ms +step:591/1680 train_time:51565ms step_avg:87.25ms +step:592/1680 train_time:51653ms step_avg:87.25ms +step:593/1680 train_time:51741ms step_avg:87.25ms +step:594/1680 train_time:51830ms step_avg:87.26ms +step:595/1680 train_time:51918ms step_avg:87.26ms +step:596/1680 train_time:52008ms step_avg:87.26ms +step:597/1680 train_time:52096ms step_avg:87.26ms +step:598/1680 train_time:52185ms step_avg:87.27ms +step:599/1680 train_time:52273ms step_avg:87.27ms +step:600/1680 train_time:52361ms step_avg:87.27ms +step:601/1680 train_time:52450ms step_avg:87.27ms +step:602/1680 train_time:52538ms step_avg:87.27ms +step:603/1680 train_time:52627ms step_avg:87.27ms +step:604/1680 train_time:52715ms step_avg:87.28ms +step:605/1680 train_time:52803ms step_avg:87.28ms +step:606/1680 train_time:52892ms step_avg:87.28ms +step:607/1680 train_time:52980ms step_avg:87.28ms +step:608/1680 train_time:53069ms step_avg:87.28ms +step:609/1680 train_time:53157ms step_avg:87.29ms +step:610/1680 train_time:53245ms step_avg:87.29ms +step:611/1680 train_time:53333ms step_avg:87.29ms +step:612/1680 train_time:53421ms step_avg:87.29ms +step:613/1680 train_time:53510ms step_avg:87.29ms +step:614/1680 train_time:53598ms step_avg:87.29ms +step:615/1680 train_time:53685ms step_avg:87.29ms +step:616/1680 train_time:53773ms step_avg:87.29ms +step:617/1680 train_time:53862ms step_avg:87.30ms +step:618/1680 train_time:53951ms step_avg:87.30ms +step:619/1680 train_time:54038ms step_avg:87.30ms +step:620/1680 train_time:54127ms step_avg:87.30ms +step:621/1680 train_time:54215ms step_avg:87.30ms +step:622/1680 train_time:54302ms step_avg:87.30ms +step:623/1680 train_time:54390ms step_avg:87.30ms +step:624/1680 train_time:54479ms step_avg:87.31ms +step:625/1680 train_time:54567ms step_avg:87.31ms +step:625/1680 val_loss:3.6134 train_time:54656ms step_avg:87.45ms +step:626/1680 train_time:54677ms step_avg:87.34ms +step:627/1680 train_time:54746ms step_avg:87.31ms +step:628/1680 train_time:54836ms step_avg:87.32ms +step:629/1680 train_time:54928ms step_avg:87.33ms +step:630/1680 train_time:55018ms step_avg:87.33ms +step:631/1680 train_time:55105ms step_avg:87.33ms +step:632/1680 train_time:55192ms step_avg:87.33ms +step:633/1680 train_time:55279ms step_avg:87.33ms +step:634/1680 train_time:55366ms step_avg:87.33ms +step:635/1680 train_time:55453ms step_avg:87.33ms +step:636/1680 train_time:55540ms step_avg:87.33ms +step:637/1680 train_time:55633ms step_avg:87.34ms +step:638/1680 train_time:55723ms step_avg:87.34ms +step:639/1680 train_time:55811ms step_avg:87.34ms +step:640/1680 train_time:55902ms step_avg:87.35ms +step:641/1680 train_time:55990ms step_avg:87.35ms +step:642/1680 train_time:56078ms step_avg:87.35ms +step:643/1680 train_time:56165ms step_avg:87.35ms +step:644/1680 train_time:56253ms step_avg:87.35ms +step:645/1680 train_time:56340ms step_avg:87.35ms +step:646/1680 train_time:56427ms step_avg:87.35ms +step:647/1680 train_time:56515ms step_avg:87.35ms +step:648/1680 train_time:56603ms step_avg:87.35ms +step:649/1680 train_time:56692ms step_avg:87.35ms +step:650/1680 train_time:56781ms step_avg:87.36ms +step:651/1680 train_time:56870ms step_avg:87.36ms +step:652/1680 train_time:56959ms step_avg:87.36ms +step:653/1680 train_time:57048ms step_avg:87.36ms +step:654/1680 train_time:57137ms step_avg:87.36ms +step:655/1680 train_time:57225ms step_avg:87.37ms +step:656/1680 train_time:57313ms step_avg:87.37ms +step:657/1680 train_time:57402ms step_avg:87.37ms +step:658/1680 train_time:57489ms step_avg:87.37ms +step:659/1680 train_time:57579ms step_avg:87.37ms +step:660/1680 train_time:57666ms step_avg:87.37ms +step:661/1680 train_time:57754ms step_avg:87.37ms +step:662/1680 train_time:57844ms step_avg:87.38ms +step:663/1680 train_time:57932ms step_avg:87.38ms +step:664/1680 train_time:58021ms step_avg:87.38ms +step:665/1680 train_time:58109ms step_avg:87.38ms +step:666/1680 train_time:58197ms step_avg:87.38ms +step:667/1680 train_time:58285ms step_avg:87.38ms +step:668/1680 train_time:58373ms step_avg:87.38ms +step:669/1680 train_time:58460ms step_avg:87.38ms +step:670/1680 train_time:58548ms step_avg:87.39ms +step:671/1680 train_time:58636ms step_avg:87.39ms +step:672/1680 train_time:58724ms step_avg:87.39ms +step:673/1680 train_time:58812ms step_avg:87.39ms +step:674/1680 train_time:58901ms step_avg:87.39ms +step:675/1680 train_time:58989ms step_avg:87.39ms +step:676/1680 train_time:59078ms step_avg:87.39ms +step:677/1680 train_time:59166ms step_avg:87.39ms +step:678/1680 train_time:59254ms step_avg:87.40ms +step:679/1680 train_time:59342ms step_avg:87.40ms +step:680/1680 train_time:59430ms step_avg:87.40ms +step:681/1680 train_time:59518ms step_avg:87.40ms +step:682/1680 train_time:59607ms step_avg:87.40ms +step:683/1680 train_time:59695ms step_avg:87.40ms +step:684/1680 train_time:59784ms step_avg:87.40ms +step:685/1680 train_time:59872ms step_avg:87.40ms +step:686/1680 train_time:59959ms step_avg:87.40ms +step:687/1680 train_time:60048ms step_avg:87.41ms +step:688/1680 train_time:60136ms step_avg:87.41ms +step:689/1680 train_time:60224ms step_avg:87.41ms +step:690/1680 train_time:60312ms step_avg:87.41ms +step:691/1680 train_time:60400ms step_avg:87.41ms +step:692/1680 train_time:60488ms step_avg:87.41ms +step:693/1680 train_time:60576ms step_avg:87.41ms +step:694/1680 train_time:60664ms step_avg:87.41ms +step:695/1680 train_time:60752ms step_avg:87.41ms +step:696/1680 train_time:60840ms step_avg:87.41ms +step:697/1680 train_time:60928ms step_avg:87.41ms +step:698/1680 train_time:61017ms step_avg:87.42ms +step:699/1680 train_time:61105ms step_avg:87.42ms +step:700/1680 train_time:61194ms step_avg:87.42ms +step:701/1680 train_time:61284ms step_avg:87.42ms +step:702/1680 train_time:61372ms step_avg:87.42ms +step:703/1680 train_time:61459ms step_avg:87.42ms +step:704/1680 train_time:61547ms step_avg:87.42ms +step:705/1680 train_time:61635ms step_avg:87.43ms +step:706/1680 train_time:61724ms step_avg:87.43ms +step:707/1680 train_time:61812ms step_avg:87.43ms +step:708/1680 train_time:61900ms step_avg:87.43ms +step:709/1680 train_time:61988ms step_avg:87.43ms +step:710/1680 train_time:62076ms step_avg:87.43ms +step:711/1680 train_time:62164ms step_avg:87.43ms +step:712/1680 train_time:62252ms step_avg:87.43ms +step:713/1680 train_time:62341ms step_avg:87.43ms +step:714/1680 train_time:62428ms step_avg:87.43ms +step:715/1680 train_time:62517ms step_avg:87.44ms +step:716/1680 train_time:62605ms step_avg:87.44ms +step:717/1680 train_time:62693ms step_avg:87.44ms +step:718/1680 train_time:62782ms step_avg:87.44ms +step:719/1680 train_time:62869ms step_avg:87.44ms +step:720/1680 train_time:62957ms step_avg:87.44ms +step:721/1680 train_time:63045ms step_avg:87.44ms +step:722/1680 train_time:63133ms step_avg:87.44ms +step:723/1680 train_time:63222ms step_avg:87.44ms +step:724/1680 train_time:63310ms step_avg:87.44ms +step:725/1680 train_time:63398ms step_avg:87.45ms +step:726/1680 train_time:63486ms step_avg:87.45ms +step:727/1680 train_time:63575ms step_avg:87.45ms +step:728/1680 train_time:63662ms step_avg:87.45ms +step:729/1680 train_time:63750ms step_avg:87.45ms +step:730/1680 train_time:63838ms step_avg:87.45ms +step:731/1680 train_time:63927ms step_avg:87.45ms +step:732/1680 train_time:64015ms step_avg:87.45ms +step:733/1680 train_time:64104ms step_avg:87.45ms +step:734/1680 train_time:64192ms step_avg:87.45ms +step:735/1680 train_time:64280ms step_avg:87.46ms +step:736/1680 train_time:64368ms step_avg:87.46ms +step:737/1680 train_time:64456ms step_avg:87.46ms +step:738/1680 train_time:64544ms step_avg:87.46ms +step:739/1680 train_time:64633ms step_avg:87.46ms +step:740/1680 train_time:64720ms step_avg:87.46ms +step:741/1680 train_time:64808ms step_avg:87.46ms +step:742/1680 train_time:64896ms step_avg:87.46ms +step:743/1680 train_time:64985ms step_avg:87.46ms +step:744/1680 train_time:65073ms step_avg:87.46ms +step:745/1680 train_time:65161ms step_avg:87.46ms +step:746/1680 train_time:65248ms step_avg:87.46ms +step:747/1680 train_time:65337ms step_avg:87.47ms +step:748/1680 train_time:65427ms step_avg:87.47ms +step:749/1680 train_time:65515ms step_avg:87.47ms +step:750/1680 train_time:65603ms step_avg:87.47ms +step:750/1680 val_loss:3.5645 train_time:65694ms step_avg:87.59ms +step:751/1680 train_time:65712ms step_avg:87.50ms +step:752/1680 train_time:65784ms step_avg:87.48ms +step:753/1680 train_time:65881ms step_avg:87.49ms +step:754/1680 train_time:65971ms step_avg:87.50ms +step:755/1680 train_time:66059ms step_avg:87.50ms +step:756/1680 train_time:66147ms step_avg:87.50ms +step:757/1680 train_time:66234ms step_avg:87.50ms +step:758/1680 train_time:66321ms step_avg:87.50ms +step:759/1680 train_time:66408ms step_avg:87.49ms +step:760/1680 train_time:66495ms step_avg:87.49ms +step:761/1680 train_time:66582ms step_avg:87.49ms +step:762/1680 train_time:66670ms step_avg:87.49ms +step:763/1680 train_time:66759ms step_avg:87.50ms +step:764/1680 train_time:66849ms step_avg:87.50ms +step:765/1680 train_time:66939ms step_avg:87.50ms +step:766/1680 train_time:67029ms step_avg:87.50ms +step:767/1680 train_time:67118ms step_avg:87.51ms +step:768/1680 train_time:67206ms step_avg:87.51ms +step:769/1680 train_time:67294ms step_avg:87.51ms +step:770/1680 train_time:67382ms step_avg:87.51ms +step:771/1680 train_time:67468ms step_avg:87.51ms +step:772/1680 train_time:67556ms step_avg:87.51ms +step:773/1680 train_time:67644ms step_avg:87.51ms +step:774/1680 train_time:67732ms step_avg:87.51ms +step:775/1680 train_time:67821ms step_avg:87.51ms +step:776/1680 train_time:67911ms step_avg:87.51ms +step:777/1680 train_time:68000ms step_avg:87.52ms +step:778/1680 train_time:68089ms step_avg:87.52ms +step:779/1680 train_time:68178ms step_avg:87.52ms +step:780/1680 train_time:68266ms step_avg:87.52ms +step:781/1680 train_time:68354ms step_avg:87.52ms +step:782/1680 train_time:68442ms step_avg:87.52ms +step:783/1680 train_time:68530ms step_avg:87.52ms +step:784/1680 train_time:68618ms step_avg:87.52ms +step:785/1680 train_time:68706ms step_avg:87.52ms +step:786/1680 train_time:68794ms step_avg:87.52ms +step:787/1680 train_time:68884ms step_avg:87.53ms +step:788/1680 train_time:68973ms step_avg:87.53ms +step:789/1680 train_time:69061ms step_avg:87.53ms +step:790/1680 train_time:69149ms step_avg:87.53ms +step:791/1680 train_time:69237ms step_avg:87.53ms +step:792/1680 train_time:69325ms step_avg:87.53ms +step:793/1680 train_time:69413ms step_avg:87.53ms +step:794/1680 train_time:69501ms step_avg:87.53ms +step:795/1680 train_time:69589ms step_avg:87.53ms +step:796/1680 train_time:69677ms step_avg:87.53ms +step:797/1680 train_time:69766ms step_avg:87.54ms +step:798/1680 train_time:69854ms step_avg:87.54ms +step:799/1680 train_time:69942ms step_avg:87.54ms +step:800/1680 train_time:70031ms step_avg:87.54ms +step:801/1680 train_time:70119ms step_avg:87.54ms +step:802/1680 train_time:70208ms step_avg:87.54ms +step:803/1680 train_time:70296ms step_avg:87.54ms +step:804/1680 train_time:70383ms step_avg:87.54ms +step:805/1680 train_time:70472ms step_avg:87.54ms +step:806/1680 train_time:70559ms step_avg:87.54ms +step:807/1680 train_time:70647ms step_avg:87.54ms +step:808/1680 train_time:70736ms step_avg:87.54ms +step:809/1680 train_time:70824ms step_avg:87.55ms +step:810/1680 train_time:70913ms step_avg:87.55ms +step:811/1680 train_time:71001ms step_avg:87.55ms +step:812/1680 train_time:71089ms step_avg:87.55ms +step:813/1680 train_time:71178ms step_avg:87.55ms +step:814/1680 train_time:71266ms step_avg:87.55ms +step:815/1680 train_time:71354ms step_avg:87.55ms +step:816/1680 train_time:71443ms step_avg:87.55ms +step:817/1680 train_time:71530ms step_avg:87.55ms +step:818/1680 train_time:71618ms step_avg:87.55ms +step:819/1680 train_time:71707ms step_avg:87.55ms +step:820/1680 train_time:71795ms step_avg:87.56ms +step:821/1680 train_time:71884ms step_avg:87.56ms +step:822/1680 train_time:71972ms step_avg:87.56ms +step:823/1680 train_time:72060ms step_avg:87.56ms +step:824/1680 train_time:72149ms step_avg:87.56ms +step:825/1680 train_time:72237ms step_avg:87.56ms +step:826/1680 train_time:72325ms step_avg:87.56ms +step:827/1680 train_time:72413ms step_avg:87.56ms +step:828/1680 train_time:72501ms step_avg:87.56ms +step:829/1680 train_time:72590ms step_avg:87.56ms +step:830/1680 train_time:72678ms step_avg:87.56ms +step:831/1680 train_time:72766ms step_avg:87.56ms +step:832/1680 train_time:72854ms step_avg:87.56ms +step:833/1680 train_time:72942ms step_avg:87.57ms +step:834/1680 train_time:73030ms step_avg:87.57ms +step:835/1680 train_time:73118ms step_avg:87.57ms +step:836/1680 train_time:73207ms step_avg:87.57ms +step:837/1680 train_time:73295ms step_avg:87.57ms +step:838/1680 train_time:73383ms step_avg:87.57ms +step:839/1680 train_time:73471ms step_avg:87.57ms +step:840/1680 train_time:73559ms step_avg:87.57ms +step:841/1680 train_time:73648ms step_avg:87.57ms +step:842/1680 train_time:73736ms step_avg:87.57ms +step:843/1680 train_time:73825ms step_avg:87.57ms +step:844/1680 train_time:73913ms step_avg:87.57ms +step:845/1680 train_time:74000ms step_avg:87.57ms +step:846/1680 train_time:74088ms step_avg:87.58ms +step:847/1680 train_time:74177ms step_avg:87.58ms +step:848/1680 train_time:74266ms step_avg:87.58ms +step:849/1680 train_time:74354ms step_avg:87.58ms +step:850/1680 train_time:74442ms step_avg:87.58ms +step:851/1680 train_time:74530ms step_avg:87.58ms +step:852/1680 train_time:74618ms step_avg:87.58ms +step:853/1680 train_time:74706ms step_avg:87.58ms +step:854/1680 train_time:74794ms step_avg:87.58ms +step:855/1680 train_time:74881ms step_avg:87.58ms +step:856/1680 train_time:74971ms step_avg:87.58ms +step:857/1680 train_time:75058ms step_avg:87.58ms +step:858/1680 train_time:75147ms step_avg:87.58ms +step:859/1680 train_time:75236ms step_avg:87.59ms +step:860/1680 train_time:75325ms step_avg:87.59ms +step:861/1680 train_time:75412ms step_avg:87.59ms +step:862/1680 train_time:75501ms step_avg:87.59ms +step:863/1680 train_time:75589ms step_avg:87.59ms +step:864/1680 train_time:75677ms step_avg:87.59ms +step:865/1680 train_time:75765ms step_avg:87.59ms +step:866/1680 train_time:75854ms step_avg:87.59ms +step:867/1680 train_time:75942ms step_avg:87.59ms +step:868/1680 train_time:76029ms step_avg:87.59ms +step:869/1680 train_time:76117ms step_avg:87.59ms +step:870/1680 train_time:76206ms step_avg:87.59ms +step:871/1680 train_time:76294ms step_avg:87.59ms +step:872/1680 train_time:76382ms step_avg:87.59ms +step:873/1680 train_time:76470ms step_avg:87.59ms +step:874/1680 train_time:76559ms step_avg:87.60ms +step:875/1680 train_time:76647ms step_avg:87.60ms +step:875/1680 val_loss:3.5174 train_time:76737ms step_avg:87.70ms +step:876/1680 train_time:76755ms step_avg:87.62ms +step:877/1680 train_time:76829ms step_avg:87.60ms +step:878/1680 train_time:76922ms step_avg:87.61ms +step:879/1680 train_time:77011ms step_avg:87.61ms +step:880/1680 train_time:77099ms step_avg:87.61ms +step:881/1680 train_time:77186ms step_avg:87.61ms +step:882/1680 train_time:77273ms step_avg:87.61ms +step:883/1680 train_time:77360ms step_avg:87.61ms +step:884/1680 train_time:77447ms step_avg:87.61ms +step:885/1680 train_time:77534ms step_avg:87.61ms +step:886/1680 train_time:77621ms step_avg:87.61ms +step:887/1680 train_time:77710ms step_avg:87.61ms +step:888/1680 train_time:77800ms step_avg:87.61ms +step:889/1680 train_time:77892ms step_avg:87.62ms +step:890/1680 train_time:77981ms step_avg:87.62ms +step:891/1680 train_time:78069ms step_avg:87.62ms +step:892/1680 train_time:78157ms step_avg:87.62ms +step:893/1680 train_time:78244ms step_avg:87.62ms +step:894/1680 train_time:78332ms step_avg:87.62ms +step:895/1680 train_time:78419ms step_avg:87.62ms +step:896/1680 train_time:78506ms step_avg:87.62ms +step:897/1680 train_time:78594ms step_avg:87.62ms +step:898/1680 train_time:78682ms step_avg:87.62ms +step:899/1680 train_time:78772ms step_avg:87.62ms +step:900/1680 train_time:78866ms step_avg:87.63ms +step:901/1680 train_time:78955ms step_avg:87.63ms +step:902/1680 train_time:79044ms step_avg:87.63ms +step:903/1680 train_time:79131ms step_avg:87.63ms +step:904/1680 train_time:79220ms step_avg:87.63ms +step:905/1680 train_time:79307ms step_avg:87.63ms +step:906/1680 train_time:79395ms step_avg:87.63ms +step:907/1680 train_time:79483ms step_avg:87.63ms +step:908/1680 train_time:79570ms step_avg:87.63ms +step:909/1680 train_time:79657ms step_avg:87.63ms +step:910/1680 train_time:79746ms step_avg:87.63ms +step:911/1680 train_time:79835ms step_avg:87.63ms +step:912/1680 train_time:79924ms step_avg:87.64ms +step:913/1680 train_time:80013ms step_avg:87.64ms +step:914/1680 train_time:80102ms step_avg:87.64ms +step:915/1680 train_time:80190ms step_avg:87.64ms +step:916/1680 train_time:80278ms step_avg:87.64ms +step:917/1680 train_time:80366ms step_avg:87.64ms +step:918/1680 train_time:80454ms step_avg:87.64ms +step:919/1680 train_time:80542ms step_avg:87.64ms +step:920/1680 train_time:80630ms step_avg:87.64ms +step:921/1680 train_time:80718ms step_avg:87.64ms +step:922/1680 train_time:80806ms step_avg:87.64ms +step:923/1680 train_time:80895ms step_avg:87.64ms +step:924/1680 train_time:80985ms step_avg:87.65ms +step:925/1680 train_time:81073ms step_avg:87.65ms +step:926/1680 train_time:81161ms step_avg:87.65ms +step:927/1680 train_time:81250ms step_avg:87.65ms +step:928/1680 train_time:81338ms step_avg:87.65ms +step:929/1680 train_time:81427ms step_avg:87.65ms +step:930/1680 train_time:81516ms step_avg:87.65ms +step:931/1680 train_time:81603ms step_avg:87.65ms +step:932/1680 train_time:81692ms step_avg:87.65ms +step:933/1680 train_time:81780ms step_avg:87.65ms +step:934/1680 train_time:81868ms step_avg:87.65ms +step:935/1680 train_time:81957ms step_avg:87.65ms +step:936/1680 train_time:82045ms step_avg:87.65ms +step:937/1680 train_time:82133ms step_avg:87.66ms +step:938/1680 train_time:82222ms step_avg:87.66ms +step:939/1680 train_time:82310ms step_avg:87.66ms +step:940/1680 train_time:82399ms step_avg:87.66ms +step:941/1680 train_time:82487ms step_avg:87.66ms +step:942/1680 train_time:82576ms step_avg:87.66ms +step:943/1680 train_time:82664ms step_avg:87.66ms +step:944/1680 train_time:82752ms step_avg:87.66ms +step:945/1680 train_time:82840ms step_avg:87.66ms +step:946/1680 train_time:82928ms step_avg:87.66ms +step:947/1680 train_time:83017ms step_avg:87.66ms +step:948/1680 train_time:83105ms step_avg:87.66ms +step:949/1680 train_time:83193ms step_avg:87.66ms +step:950/1680 train_time:83281ms step_avg:87.66ms +step:951/1680 train_time:83369ms step_avg:87.66ms +step:952/1680 train_time:83458ms step_avg:87.67ms +step:953/1680 train_time:83546ms step_avg:87.67ms +step:954/1680 train_time:83634ms step_avg:87.67ms +step:955/1680 train_time:83723ms step_avg:87.67ms +step:956/1680 train_time:83811ms step_avg:87.67ms +step:957/1680 train_time:83899ms step_avg:87.67ms +step:958/1680 train_time:83986ms step_avg:87.67ms +step:959/1680 train_time:84074ms step_avg:87.67ms +step:960/1680 train_time:84163ms step_avg:87.67ms +step:961/1680 train_time:84251ms step_avg:87.67ms +step:962/1680 train_time:84340ms step_avg:87.67ms +step:963/1680 train_time:84428ms step_avg:87.67ms +step:964/1680 train_time:84515ms step_avg:87.67ms +step:965/1680 train_time:84604ms step_avg:87.67ms +step:966/1680 train_time:84692ms step_avg:87.67ms +step:967/1680 train_time:84781ms step_avg:87.67ms +step:968/1680 train_time:84869ms step_avg:87.67ms +step:969/1680 train_time:84957ms step_avg:87.67ms +step:970/1680 train_time:85045ms step_avg:87.68ms +step:971/1680 train_time:85134ms step_avg:87.68ms +step:972/1680 train_time:85222ms step_avg:87.68ms +step:973/1680 train_time:85311ms step_avg:87.68ms +step:974/1680 train_time:85398ms step_avg:87.68ms +step:975/1680 train_time:85486ms step_avg:87.68ms +step:976/1680 train_time:85574ms step_avg:87.68ms +step:977/1680 train_time:85662ms step_avg:87.68ms +step:978/1680 train_time:85750ms step_avg:87.68ms +step:979/1680 train_time:85838ms step_avg:87.68ms +step:980/1680 train_time:85926ms step_avg:87.68ms +step:981/1680 train_time:86015ms step_avg:87.68ms +step:982/1680 train_time:86103ms step_avg:87.68ms +step:983/1680 train_time:86192ms step_avg:87.68ms +step:984/1680 train_time:86280ms step_avg:87.68ms +step:985/1680 train_time:86368ms step_avg:87.68ms +step:986/1680 train_time:86456ms step_avg:87.68ms +step:987/1680 train_time:86544ms step_avg:87.68ms +step:988/1680 train_time:86632ms step_avg:87.68ms +step:989/1680 train_time:86721ms step_avg:87.69ms +step:990/1680 train_time:86809ms step_avg:87.69ms +step:991/1680 train_time:86897ms step_avg:87.69ms +step:992/1680 train_time:86986ms step_avg:87.69ms +step:993/1680 train_time:87075ms step_avg:87.69ms +step:994/1680 train_time:87163ms step_avg:87.69ms +step:995/1680 train_time:87251ms step_avg:87.69ms +step:996/1680 train_time:87340ms step_avg:87.69ms +step:997/1680 train_time:87427ms step_avg:87.69ms +step:998/1680 train_time:87515ms step_avg:87.69ms +step:999/1680 train_time:87603ms step_avg:87.69ms +step:1000/1680 train_time:87691ms step_avg:87.69ms +step:1000/1680 val_loss:3.4686 train_time:87780ms step_avg:87.78ms +step:1001/1680 train_time:87799ms step_avg:87.71ms +step:1002/1680 train_time:87871ms step_avg:87.70ms +step:1003/1680 train_time:87964ms step_avg:87.70ms +step:1004/1680 train_time:88054ms step_avg:87.70ms +step:1005/1680 train_time:88142ms step_avg:87.70ms +step:1006/1680 train_time:88229ms step_avg:87.70ms +step:1007/1680 train_time:88316ms step_avg:87.70ms +step:1008/1680 train_time:88404ms step_avg:87.70ms +step:1009/1680 train_time:88491ms step_avg:87.70ms +step:1010/1680 train_time:88579ms step_avg:87.70ms +step:1011/1680 train_time:88666ms step_avg:87.70ms +step:1012/1680 train_time:88754ms step_avg:87.70ms +step:1013/1680 train_time:88844ms step_avg:87.70ms +step:1014/1680 train_time:88935ms step_avg:87.71ms +step:1015/1680 train_time:89025ms step_avg:87.71ms +step:1016/1680 train_time:89114ms step_avg:87.71ms +step:1017/1680 train_time:89202ms step_avg:87.71ms +step:1018/1680 train_time:89290ms step_avg:87.71ms +step:1019/1680 train_time:89377ms step_avg:87.71ms +step:1020/1680 train_time:89465ms step_avg:87.71ms +step:1021/1680 train_time:89554ms step_avg:87.71ms +step:1022/1680 train_time:89641ms step_avg:87.71ms +step:1023/1680 train_time:89730ms step_avg:87.71ms +step:1024/1680 train_time:89819ms step_avg:87.71ms +step:1025/1680 train_time:89908ms step_avg:87.71ms +step:1026/1680 train_time:89996ms step_avg:87.72ms +step:1027/1680 train_time:90085ms step_avg:87.72ms +step:1028/1680 train_time:90174ms step_avg:87.72ms +step:1029/1680 train_time:90262ms step_avg:87.72ms +step:1030/1680 train_time:90350ms step_avg:87.72ms +step:1031/1680 train_time:90437ms step_avg:87.72ms +step:1032/1680 train_time:90526ms step_avg:87.72ms +step:1033/1680 train_time:90615ms step_avg:87.72ms +step:1034/1680 train_time:90704ms step_avg:87.72ms +step:1035/1680 train_time:90793ms step_avg:87.72ms +step:1036/1680 train_time:90881ms step_avg:87.72ms +step:1037/1680 train_time:90969ms step_avg:87.72ms +step:1038/1680 train_time:91058ms step_avg:87.72ms +step:1039/1680 train_time:91146ms step_avg:87.72ms +step:1040/1680 train_time:91235ms step_avg:87.73ms +step:1041/1680 train_time:91323ms step_avg:87.73ms +step:1042/1680 train_time:91411ms step_avg:87.73ms +step:1043/1680 train_time:91498ms step_avg:87.73ms +step:1044/1680 train_time:91586ms step_avg:87.73ms +step:1045/1680 train_time:91674ms step_avg:87.73ms +step:1046/1680 train_time:91763ms step_avg:87.73ms +step:1047/1680 train_time:91851ms step_avg:87.73ms +step:1048/1680 train_time:91939ms step_avg:87.73ms +step:1049/1680 train_time:92029ms step_avg:87.73ms +step:1050/1680 train_time:92117ms step_avg:87.73ms +step:1051/1680 train_time:92205ms step_avg:87.73ms +step:1052/1680 train_time:92293ms step_avg:87.73ms +step:1053/1680 train_time:92381ms step_avg:87.73ms +step:1054/1680 train_time:92469ms step_avg:87.73ms +step:1055/1680 train_time:92557ms step_avg:87.73ms +step:1056/1680 train_time:92645ms step_avg:87.73ms +step:1057/1680 train_time:92734ms step_avg:87.73ms +step:1058/1680 train_time:92822ms step_avg:87.73ms +step:1059/1680 train_time:92910ms step_avg:87.73ms +step:1060/1680 train_time:92998ms step_avg:87.73ms +step:1061/1680 train_time:93087ms step_avg:87.74ms +step:1062/1680 train_time:93176ms step_avg:87.74ms +step:1063/1680 train_time:93264ms step_avg:87.74ms +step:1064/1680 train_time:93352ms step_avg:87.74ms +step:1065/1680 train_time:93440ms step_avg:87.74ms +step:1066/1680 train_time:93528ms step_avg:87.74ms +step:1067/1680 train_time:93616ms step_avg:87.74ms +step:1068/1680 train_time:93704ms step_avg:87.74ms +step:1069/1680 train_time:93792ms step_avg:87.74ms +step:1070/1680 train_time:93881ms step_avg:87.74ms +step:1071/1680 train_time:93969ms step_avg:87.74ms +step:1072/1680 train_time:94057ms step_avg:87.74ms +step:1073/1680 train_time:94145ms step_avg:87.74ms +step:1074/1680 train_time:94234ms step_avg:87.74ms +step:1075/1680 train_time:94323ms step_avg:87.74ms +step:1076/1680 train_time:94411ms step_avg:87.74ms +step:1077/1680 train_time:94499ms step_avg:87.74ms +step:1078/1680 train_time:94587ms step_avg:87.74ms +step:1079/1680 train_time:94675ms step_avg:87.74ms +step:1080/1680 train_time:94763ms step_avg:87.74ms +step:1081/1680 train_time:94851ms step_avg:87.74ms +step:1082/1680 train_time:94939ms step_avg:87.74ms +step:1083/1680 train_time:95028ms step_avg:87.75ms +step:1084/1680 train_time:95117ms step_avg:87.75ms +step:1085/1680 train_time:95206ms step_avg:87.75ms +step:1086/1680 train_time:95294ms step_avg:87.75ms +step:1087/1680 train_time:95382ms step_avg:87.75ms +step:1088/1680 train_time:95472ms step_avg:87.75ms +step:1089/1680 train_time:95559ms step_avg:87.75ms +step:1090/1680 train_time:95648ms step_avg:87.75ms +step:1091/1680 train_time:95736ms step_avg:87.75ms +step:1092/1680 train_time:95824ms step_avg:87.75ms +step:1093/1680 train_time:95912ms step_avg:87.75ms +step:1094/1680 train_time:95999ms step_avg:87.75ms +step:1095/1680 train_time:96088ms step_avg:87.75ms +step:1096/1680 train_time:96177ms step_avg:87.75ms +step:1097/1680 train_time:96266ms step_avg:87.75ms +step:1098/1680 train_time:96355ms step_avg:87.75ms +step:1099/1680 train_time:96443ms step_avg:87.76ms +step:1100/1680 train_time:96533ms step_avg:87.76ms +step:1101/1680 train_time:96621ms step_avg:87.76ms +step:1102/1680 train_time:96710ms step_avg:87.76ms +step:1103/1680 train_time:96798ms step_avg:87.76ms +step:1104/1680 train_time:96887ms step_avg:87.76ms +step:1105/1680 train_time:96976ms step_avg:87.76ms +step:1106/1680 train_time:97066ms step_avg:87.76ms +step:1107/1680 train_time:97155ms step_avg:87.76ms +step:1108/1680 train_time:97245ms step_avg:87.77ms +step:1109/1680 train_time:97335ms step_avg:87.77ms +step:1110/1680 train_time:97424ms step_avg:87.77ms +step:1111/1680 train_time:97513ms step_avg:87.77ms +step:1112/1680 train_time:97602ms step_avg:87.77ms +step:1113/1680 train_time:97691ms step_avg:87.77ms +step:1114/1680 train_time:97779ms step_avg:87.77ms +step:1115/1680 train_time:97869ms step_avg:87.77ms +step:1116/1680 train_time:97957ms step_avg:87.78ms +step:1117/1680 train_time:98046ms step_avg:87.78ms +step:1118/1680 train_time:98135ms step_avg:87.78ms +step:1119/1680 train_time:98224ms step_avg:87.78ms +step:1120/1680 train_time:98313ms step_avg:87.78ms +step:1121/1680 train_time:98402ms step_avg:87.78ms +step:1122/1680 train_time:98491ms step_avg:87.78ms +step:1123/1680 train_time:98580ms step_avg:87.78ms +step:1124/1680 train_time:98669ms step_avg:87.78ms +step:1125/1680 train_time:98758ms step_avg:87.78ms +step:1125/1680 val_loss:3.4155 train_time:98849ms step_avg:87.87ms +step:1126/1680 train_time:98867ms step_avg:87.80ms +step:1127/1680 train_time:98939ms step_avg:87.79ms +step:1128/1680 train_time:99030ms step_avg:87.79ms +step:1129/1680 train_time:99122ms step_avg:87.80ms +step:1130/1680 train_time:99213ms step_avg:87.80ms +step:1131/1680 train_time:99301ms step_avg:87.80ms +step:1132/1680 train_time:99389ms step_avg:87.80ms +step:1133/1680 train_time:99477ms step_avg:87.80ms +step:1134/1680 train_time:99564ms step_avg:87.80ms +step:1135/1680 train_time:99652ms step_avg:87.80ms +step:1136/1680 train_time:99740ms step_avg:87.80ms +step:1137/1680 train_time:99831ms step_avg:87.80ms +step:1138/1680 train_time:99921ms step_avg:87.80ms +step:1139/1680 train_time:100012ms step_avg:87.81ms +step:1140/1680 train_time:100103ms step_avg:87.81ms +step:1141/1680 train_time:100193ms step_avg:87.81ms +step:1142/1680 train_time:100282ms step_avg:87.81ms +step:1143/1680 train_time:100370ms step_avg:87.81ms +step:1144/1680 train_time:100458ms step_avg:87.81ms +step:1145/1680 train_time:100546ms step_avg:87.81ms +step:1146/1680 train_time:100634ms step_avg:87.81ms +step:1147/1680 train_time:100722ms step_avg:87.81ms +step:1148/1680 train_time:100811ms step_avg:87.81ms +step:1149/1680 train_time:100902ms step_avg:87.82ms +step:1150/1680 train_time:100991ms step_avg:87.82ms +step:1151/1680 train_time:101081ms step_avg:87.82ms +step:1152/1680 train_time:101171ms step_avg:87.82ms +step:1153/1680 train_time:101260ms step_avg:87.82ms +step:1154/1680 train_time:101348ms step_avg:87.82ms +step:1155/1680 train_time:101438ms step_avg:87.83ms +step:1156/1680 train_time:101526ms step_avg:87.82ms +step:1157/1680 train_time:101614ms step_avg:87.83ms +step:1158/1680 train_time:101702ms step_avg:87.83ms +step:1159/1680 train_time:101791ms step_avg:87.83ms +step:1160/1680 train_time:101880ms step_avg:87.83ms +step:1161/1680 train_time:101970ms step_avg:87.83ms +step:1162/1680 train_time:102060ms step_avg:87.83ms +step:1163/1680 train_time:102150ms step_avg:87.83ms +step:1164/1680 train_time:102239ms step_avg:87.83ms +step:1165/1680 train_time:102328ms step_avg:87.83ms +step:1166/1680 train_time:102417ms step_avg:87.84ms +step:1167/1680 train_time:102505ms step_avg:87.84ms +step:1168/1680 train_time:102594ms step_avg:87.84ms +step:1169/1680 train_time:102682ms step_avg:87.84ms +step:1170/1680 train_time:102770ms step_avg:87.84ms +step:1171/1680 train_time:102859ms step_avg:87.84ms +step:1172/1680 train_time:102948ms step_avg:87.84ms +step:1173/1680 train_time:103037ms step_avg:87.84ms +step:1174/1680 train_time:103127ms step_avg:87.84ms +step:1175/1680 train_time:103216ms step_avg:87.84ms +step:1176/1680 train_time:103305ms step_avg:87.84ms +step:1177/1680 train_time:103393ms step_avg:87.84ms +step:1178/1680 train_time:103481ms step_avg:87.84ms +step:1179/1680 train_time:103570ms step_avg:87.85ms +step:1180/1680 train_time:103658ms step_avg:87.85ms +step:1181/1680 train_time:103747ms step_avg:87.85ms +step:1182/1680 train_time:103837ms step_avg:87.85ms +step:1183/1680 train_time:103925ms step_avg:87.85ms +step:1184/1680 train_time:104015ms step_avg:87.85ms +step:1185/1680 train_time:104104ms step_avg:87.85ms +step:1186/1680 train_time:104193ms step_avg:87.85ms +step:1187/1680 train_time:104282ms step_avg:87.85ms +step:1188/1680 train_time:104371ms step_avg:87.85ms +step:1189/1680 train_time:104460ms step_avg:87.85ms +step:1190/1680 train_time:104548ms step_avg:87.86ms +step:1191/1680 train_time:104638ms step_avg:87.86ms +step:1192/1680 train_time:104727ms step_avg:87.86ms +step:1193/1680 train_time:104816ms step_avg:87.86ms +step:1194/1680 train_time:104905ms step_avg:87.86ms +step:1195/1680 train_time:104994ms step_avg:87.86ms +step:1196/1680 train_time:105082ms step_avg:87.86ms +step:1197/1680 train_time:105171ms step_avg:87.86ms +step:1198/1680 train_time:105260ms step_avg:87.86ms +step:1199/1680 train_time:105349ms step_avg:87.86ms +step:1200/1680 train_time:105438ms step_avg:87.87ms +step:1201/1680 train_time:105527ms step_avg:87.87ms +step:1202/1680 train_time:105616ms step_avg:87.87ms +step:1203/1680 train_time:105704ms step_avg:87.87ms +step:1204/1680 train_time:105793ms step_avg:87.87ms +step:1205/1680 train_time:105882ms step_avg:87.87ms +step:1206/1680 train_time:105970ms step_avg:87.87ms +step:1207/1680 train_time:106059ms step_avg:87.87ms +step:1208/1680 train_time:106148ms step_avg:87.87ms +step:1209/1680 train_time:106237ms step_avg:87.87ms +step:1210/1680 train_time:106326ms step_avg:87.87ms +step:1211/1680 train_time:106415ms step_avg:87.87ms +step:1212/1680 train_time:106504ms step_avg:87.87ms +step:1213/1680 train_time:106593ms step_avg:87.88ms +step:1214/1680 train_time:106683ms step_avg:87.88ms +step:1215/1680 train_time:106772ms step_avg:87.88ms +step:1216/1680 train_time:106860ms step_avg:87.88ms +step:1217/1680 train_time:106948ms step_avg:87.88ms +step:1218/1680 train_time:107038ms step_avg:87.88ms +step:1219/1680 train_time:107128ms step_avg:87.88ms +step:1220/1680 train_time:107216ms step_avg:87.88ms +step:1221/1680 train_time:107304ms step_avg:87.88ms +step:1222/1680 train_time:107393ms step_avg:87.88ms +step:1223/1680 train_time:107482ms step_avg:87.88ms +step:1224/1680 train_time:107570ms step_avg:87.88ms +step:1225/1680 train_time:107659ms step_avg:87.89ms +step:1226/1680 train_time:107749ms step_avg:87.89ms +step:1227/1680 train_time:107838ms step_avg:87.89ms +step:1228/1680 train_time:107928ms step_avg:87.89ms +step:1229/1680 train_time:108017ms step_avg:87.89ms +step:1230/1680 train_time:108105ms step_avg:87.89ms +step:1231/1680 train_time:108195ms step_avg:87.89ms +step:1232/1680 train_time:108283ms step_avg:87.89ms +step:1233/1680 train_time:108372ms step_avg:87.89ms +step:1234/1680 train_time:108460ms step_avg:87.89ms +step:1235/1680 train_time:108549ms step_avg:87.89ms +step:1236/1680 train_time:108638ms step_avg:87.90ms +step:1237/1680 train_time:108728ms step_avg:87.90ms +step:1238/1680 train_time:108817ms step_avg:87.90ms +step:1239/1680 train_time:108906ms step_avg:87.90ms +step:1240/1680 train_time:108995ms step_avg:87.90ms +step:1241/1680 train_time:109083ms step_avg:87.90ms +step:1242/1680 train_time:109173ms step_avg:87.90ms +step:1243/1680 train_time:109261ms step_avg:87.90ms +step:1244/1680 train_time:109350ms step_avg:87.90ms +step:1245/1680 train_time:109440ms step_avg:87.90ms +step:1246/1680 train_time:109528ms step_avg:87.90ms +step:1247/1680 train_time:109618ms step_avg:87.91ms +step:1248/1680 train_time:109707ms step_avg:87.91ms +step:1249/1680 train_time:109797ms step_avg:87.91ms +step:1250/1680 train_time:109885ms step_avg:87.91ms +step:1250/1680 val_loss:3.3775 train_time:109975ms step_avg:87.98ms +step:1251/1680 train_time:109995ms step_avg:87.93ms +step:1252/1680 train_time:110066ms step_avg:87.91ms +step:1253/1680 train_time:110160ms step_avg:87.92ms +step:1254/1680 train_time:110250ms step_avg:87.92ms +step:1255/1680 train_time:110339ms step_avg:87.92ms +step:1256/1680 train_time:110427ms step_avg:87.92ms +step:1257/1680 train_time:110515ms step_avg:87.92ms +step:1258/1680 train_time:110603ms step_avg:87.92ms +step:1259/1680 train_time:110691ms step_avg:87.92ms +step:1260/1680 train_time:110779ms step_avg:87.92ms +step:1261/1680 train_time:110867ms step_avg:87.92ms +step:1262/1680 train_time:110957ms step_avg:87.92ms +step:1263/1680 train_time:111047ms step_avg:87.92ms +step:1264/1680 train_time:111138ms step_avg:87.93ms +step:1265/1680 train_time:111227ms step_avg:87.93ms +step:1266/1680 train_time:111317ms step_avg:87.93ms +step:1267/1680 train_time:111405ms step_avg:87.93ms +step:1268/1680 train_time:111493ms step_avg:87.93ms +step:1269/1680 train_time:111581ms step_avg:87.93ms +step:1270/1680 train_time:111670ms step_avg:87.93ms +step:1271/1680 train_time:111758ms step_avg:87.93ms +step:1272/1680 train_time:111848ms step_avg:87.93ms +step:1273/1680 train_time:111936ms step_avg:87.93ms +step:1274/1680 train_time:112025ms step_avg:87.93ms +step:1275/1680 train_time:112114ms step_avg:87.93ms +step:1276/1680 train_time:112203ms step_avg:87.93ms +step:1277/1680 train_time:112293ms step_avg:87.94ms +step:1278/1680 train_time:112381ms step_avg:87.94ms +step:1279/1680 train_time:112471ms step_avg:87.94ms +step:1280/1680 train_time:112560ms step_avg:87.94ms +step:1281/1680 train_time:112648ms step_avg:87.94ms +step:1282/1680 train_time:112737ms step_avg:87.94ms +step:1283/1680 train_time:112826ms step_avg:87.94ms +step:1284/1680 train_time:112914ms step_avg:87.94ms +step:1285/1680 train_time:113003ms step_avg:87.94ms +step:1286/1680 train_time:113091ms step_avg:87.94ms +step:1287/1680 train_time:113181ms step_avg:87.94ms +step:1288/1680 train_time:113270ms step_avg:87.94ms +step:1289/1680 train_time:113359ms step_avg:87.94ms +step:1290/1680 train_time:113450ms step_avg:87.95ms +step:1291/1680 train_time:113540ms step_avg:87.95ms +step:1292/1680 train_time:113628ms step_avg:87.95ms +step:1293/1680 train_time:113717ms step_avg:87.95ms +step:1294/1680 train_time:113805ms step_avg:87.95ms +step:1295/1680 train_time:113895ms step_avg:87.95ms +step:1296/1680 train_time:113983ms step_avg:87.95ms +step:1297/1680 train_time:114072ms step_avg:87.95ms +step:1298/1680 train_time:114161ms step_avg:87.95ms +step:1299/1680 train_time:114251ms step_avg:87.95ms +step:1300/1680 train_time:114340ms step_avg:87.95ms +step:1301/1680 train_time:114429ms step_avg:87.95ms +step:1302/1680 train_time:114517ms step_avg:87.96ms +step:1303/1680 train_time:114607ms step_avg:87.96ms +step:1304/1680 train_time:114695ms step_avg:87.96ms +step:1305/1680 train_time:114784ms step_avg:87.96ms +step:1306/1680 train_time:114873ms step_avg:87.96ms +step:1307/1680 train_time:114961ms step_avg:87.96ms +step:1308/1680 train_time:115050ms step_avg:87.96ms +step:1309/1680 train_time:115140ms step_avg:87.96ms +step:1310/1680 train_time:115229ms step_avg:87.96ms +step:1311/1680 train_time:115318ms step_avg:87.96ms +step:1312/1680 train_time:115408ms step_avg:87.96ms +step:1313/1680 train_time:115497ms step_avg:87.96ms +step:1314/1680 train_time:115585ms step_avg:87.96ms +step:1315/1680 train_time:115674ms step_avg:87.96ms +step:1316/1680 train_time:115762ms step_avg:87.97ms +step:1317/1680 train_time:115851ms step_avg:87.97ms +step:1318/1680 train_time:115939ms step_avg:87.97ms +step:1319/1680 train_time:116028ms step_avg:87.97ms +step:1320/1680 train_time:116117ms step_avg:87.97ms +step:1321/1680 train_time:116207ms step_avg:87.97ms +step:1322/1680 train_time:116297ms step_avg:87.97ms +step:1323/1680 train_time:116386ms step_avg:87.97ms +step:1324/1680 train_time:116474ms step_avg:87.97ms +step:1325/1680 train_time:116563ms step_avg:87.97ms +step:1326/1680 train_time:116652ms step_avg:87.97ms +step:1327/1680 train_time:116740ms step_avg:87.97ms +step:1328/1680 train_time:116828ms step_avg:87.97ms +step:1329/1680 train_time:116918ms step_avg:87.97ms +step:1330/1680 train_time:117007ms step_avg:87.97ms +step:1331/1680 train_time:117096ms step_avg:87.98ms +step:1332/1680 train_time:117185ms step_avg:87.98ms +step:1333/1680 train_time:117274ms step_avg:87.98ms +step:1334/1680 train_time:117363ms step_avg:87.98ms +step:1335/1680 train_time:117453ms step_avg:87.98ms +step:1336/1680 train_time:117542ms step_avg:87.98ms +step:1337/1680 train_time:117632ms step_avg:87.98ms +step:1338/1680 train_time:117721ms step_avg:87.98ms +step:1339/1680 train_time:117809ms step_avg:87.98ms +step:1340/1680 train_time:117898ms step_avg:87.98ms +step:1341/1680 train_time:117987ms step_avg:87.98ms +step:1342/1680 train_time:118076ms step_avg:87.98ms +step:1343/1680 train_time:118165ms step_avg:87.99ms +step:1344/1680 train_time:118254ms step_avg:87.99ms +step:1345/1680 train_time:118343ms step_avg:87.99ms +step:1346/1680 train_time:118433ms step_avg:87.99ms +step:1347/1680 train_time:118522ms step_avg:87.99ms +step:1348/1680 train_time:118610ms step_avg:87.99ms +step:1349/1680 train_time:118699ms step_avg:87.99ms +step:1350/1680 train_time:118787ms step_avg:87.99ms +step:1351/1680 train_time:118876ms step_avg:87.99ms +step:1352/1680 train_time:118965ms step_avg:87.99ms +step:1353/1680 train_time:119056ms step_avg:87.99ms +step:1354/1680 train_time:119145ms step_avg:87.99ms +step:1355/1680 train_time:119233ms step_avg:87.99ms +step:1356/1680 train_time:119322ms step_avg:88.00ms +step:1357/1680 train_time:119411ms step_avg:88.00ms +step:1358/1680 train_time:119500ms step_avg:88.00ms +step:1359/1680 train_time:119589ms step_avg:88.00ms +step:1360/1680 train_time:119678ms step_avg:88.00ms +step:1361/1680 train_time:119766ms step_avg:88.00ms +step:1362/1680 train_time:119855ms step_avg:88.00ms +step:1363/1680 train_time:119943ms step_avg:88.00ms +step:1364/1680 train_time:120032ms step_avg:88.00ms +step:1365/1680 train_time:120121ms step_avg:88.00ms +step:1366/1680 train_time:120211ms step_avg:88.00ms +step:1367/1680 train_time:120299ms step_avg:88.00ms +step:1368/1680 train_time:120388ms step_avg:88.00ms +step:1369/1680 train_time:120478ms step_avg:88.00ms +step:1370/1680 train_time:120566ms step_avg:88.00ms +step:1371/1680 train_time:120655ms step_avg:88.01ms +step:1372/1680 train_time:120744ms step_avg:88.01ms +step:1373/1680 train_time:120834ms step_avg:88.01ms +step:1374/1680 train_time:120922ms step_avg:88.01ms +step:1375/1680 train_time:121011ms step_avg:88.01ms +step:1375/1680 val_loss:3.3427 train_time:121101ms step_avg:88.07ms +step:1376/1680 train_time:121119ms step_avg:88.02ms +step:1377/1680 train_time:121194ms step_avg:88.01ms +step:1378/1680 train_time:121288ms step_avg:88.02ms +step:1379/1680 train_time:121378ms step_avg:88.02ms +step:1380/1680 train_time:121466ms step_avg:88.02ms +step:1381/1680 train_time:121554ms step_avg:88.02ms +step:1382/1680 train_time:121642ms step_avg:88.02ms +step:1383/1680 train_time:121729ms step_avg:88.02ms +step:1384/1680 train_time:121817ms step_avg:88.02ms +step:1385/1680 train_time:121905ms step_avg:88.02ms +step:1386/1680 train_time:121993ms step_avg:88.02ms +step:1387/1680 train_time:122084ms step_avg:88.02ms +step:1388/1680 train_time:122176ms step_avg:88.02ms +step:1389/1680 train_time:122266ms step_avg:88.02ms +step:1390/1680 train_time:122356ms step_avg:88.03ms +step:1391/1680 train_time:122445ms step_avg:88.03ms +step:1392/1680 train_time:122535ms step_avg:88.03ms +step:1393/1680 train_time:122622ms step_avg:88.03ms +step:1394/1680 train_time:122711ms step_avg:88.03ms +step:1395/1680 train_time:122799ms step_avg:88.03ms +step:1396/1680 train_time:122887ms step_avg:88.03ms +step:1397/1680 train_time:122976ms step_avg:88.03ms +step:1398/1680 train_time:123064ms step_avg:88.03ms +step:1399/1680 train_time:123153ms step_avg:88.03ms +step:1400/1680 train_time:123242ms step_avg:88.03ms +step:1401/1680 train_time:123333ms step_avg:88.03ms +step:1402/1680 train_time:123424ms step_avg:88.03ms +step:1403/1680 train_time:123513ms step_avg:88.04ms +step:1404/1680 train_time:123602ms step_avg:88.04ms +step:1405/1680 train_time:123691ms step_avg:88.04ms +step:1406/1680 train_time:123779ms step_avg:88.04ms +step:1407/1680 train_time:123867ms step_avg:88.04ms +step:1408/1680 train_time:123955ms step_avg:88.04ms +step:1409/1680 train_time:124044ms step_avg:88.04ms +step:1410/1680 train_time:124134ms step_avg:88.04ms +step:1411/1680 train_time:124224ms step_avg:88.04ms +step:1412/1680 train_time:124313ms step_avg:88.04ms +step:1413/1680 train_time:124403ms step_avg:88.04ms +step:1414/1680 train_time:124492ms step_avg:88.04ms +step:1415/1680 train_time:124581ms step_avg:88.04ms +step:1416/1680 train_time:124670ms step_avg:88.04ms +step:1417/1680 train_time:124758ms step_avg:88.04ms +step:1418/1680 train_time:124847ms step_avg:88.04ms +step:1419/1680 train_time:124935ms step_avg:88.04ms +step:1420/1680 train_time:125023ms step_avg:88.04ms +step:1421/1680 train_time:125113ms step_avg:88.05ms +step:1422/1680 train_time:125201ms step_avg:88.05ms +step:1423/1680 train_time:125290ms step_avg:88.05ms +step:1424/1680 train_time:125380ms step_avg:88.05ms +step:1425/1680 train_time:125469ms step_avg:88.05ms +step:1426/1680 train_time:125559ms step_avg:88.05ms +step:1427/1680 train_time:125648ms step_avg:88.05ms +step:1428/1680 train_time:125738ms step_avg:88.05ms +step:1429/1680 train_time:125827ms step_avg:88.05ms +step:1430/1680 train_time:125915ms step_avg:88.05ms +step:1431/1680 train_time:126004ms step_avg:88.05ms +step:1432/1680 train_time:126093ms step_avg:88.05ms +step:1433/1680 train_time:126182ms step_avg:88.05ms +step:1434/1680 train_time:126272ms step_avg:88.06ms +step:1435/1680 train_time:126361ms step_avg:88.06ms +step:1436/1680 train_time:126450ms step_avg:88.06ms +step:1437/1680 train_time:126539ms step_avg:88.06ms +step:1438/1680 train_time:126627ms step_avg:88.06ms +step:1439/1680 train_time:126718ms step_avg:88.06ms +step:1440/1680 train_time:126807ms step_avg:88.06ms +step:1441/1680 train_time:126897ms step_avg:88.06ms +step:1442/1680 train_time:126986ms step_avg:88.06ms +step:1443/1680 train_time:127076ms step_avg:88.06ms +step:1444/1680 train_time:127165ms step_avg:88.06ms +step:1445/1680 train_time:127254ms step_avg:88.07ms +step:1446/1680 train_time:127343ms step_avg:88.07ms +step:1447/1680 train_time:127431ms step_avg:88.07ms +step:1448/1680 train_time:127521ms step_avg:88.07ms +step:1449/1680 train_time:127610ms step_avg:88.07ms +step:1450/1680 train_time:127699ms step_avg:88.07ms +step:1451/1680 train_time:127788ms step_avg:88.07ms +step:1452/1680 train_time:127877ms step_avg:88.07ms +step:1453/1680 train_time:127966ms step_avg:88.07ms +step:1454/1680 train_time:128056ms step_avg:88.07ms +step:1455/1680 train_time:128144ms step_avg:88.07ms +step:1456/1680 train_time:128233ms step_avg:88.07ms +step:1457/1680 train_time:128322ms step_avg:88.07ms +step:1458/1680 train_time:128411ms step_avg:88.07ms +step:1459/1680 train_time:128499ms step_avg:88.07ms +step:1460/1680 train_time:128588ms step_avg:88.07ms +step:1461/1680 train_time:128678ms step_avg:88.08ms +step:1462/1680 train_time:128767ms step_avg:88.08ms +step:1463/1680 train_time:128856ms step_avg:88.08ms +step:1464/1680 train_time:128945ms step_avg:88.08ms +step:1465/1680 train_time:129034ms step_avg:88.08ms +step:1466/1680 train_time:129122ms step_avg:88.08ms +step:1467/1680 train_time:129213ms step_avg:88.08ms +step:1468/1680 train_time:129301ms step_avg:88.08ms +step:1469/1680 train_time:129391ms step_avg:88.08ms +step:1470/1680 train_time:129479ms step_avg:88.08ms +step:1471/1680 train_time:129567ms step_avg:88.08ms +step:1472/1680 train_time:129657ms step_avg:88.08ms +step:1473/1680 train_time:129746ms step_avg:88.08ms +step:1474/1680 train_time:129835ms step_avg:88.08ms +step:1475/1680 train_time:129924ms step_avg:88.08ms +step:1476/1680 train_time:130014ms step_avg:88.09ms +step:1477/1680 train_time:130103ms step_avg:88.09ms +step:1478/1680 train_time:130192ms step_avg:88.09ms +step:1479/1680 train_time:130281ms step_avg:88.09ms +step:1480/1680 train_time:130371ms step_avg:88.09ms +step:1481/1680 train_time:130460ms step_avg:88.09ms +step:1482/1680 train_time:130548ms step_avg:88.09ms +step:1483/1680 train_time:130638ms step_avg:88.09ms +step:1484/1680 train_time:130727ms step_avg:88.09ms +step:1485/1680 train_time:130816ms step_avg:88.09ms +step:1486/1680 train_time:130905ms step_avg:88.09ms +step:1487/1680 train_time:130993ms step_avg:88.09ms +step:1488/1680 train_time:131082ms step_avg:88.09ms +step:1489/1680 train_time:131172ms step_avg:88.09ms +step:1490/1680 train_time:131260ms step_avg:88.09ms +step:1491/1680 train_time:131349ms step_avg:88.09ms +step:1492/1680 train_time:131438ms step_avg:88.10ms +step:1493/1680 train_time:131527ms step_avg:88.10ms +step:1494/1680 train_time:131615ms step_avg:88.10ms +step:1495/1680 train_time:131703ms step_avg:88.10ms +step:1496/1680 train_time:131792ms step_avg:88.10ms +step:1497/1680 train_time:131881ms step_avg:88.10ms +step:1498/1680 train_time:131970ms step_avg:88.10ms +step:1499/1680 train_time:132059ms step_avg:88.10ms +step:1500/1680 train_time:132148ms step_avg:88.10ms +step:1500/1680 val_loss:3.3131 train_time:132238ms step_avg:88.16ms +step:1501/1680 train_time:132258ms step_avg:88.11ms +step:1502/1680 train_time:132331ms step_avg:88.10ms +step:1503/1680 train_time:132426ms step_avg:88.11ms +step:1504/1680 train_time:132516ms step_avg:88.11ms +step:1505/1680 train_time:132604ms step_avg:88.11ms +step:1506/1680 train_time:132692ms step_avg:88.11ms +step:1507/1680 train_time:132780ms step_avg:88.11ms +step:1508/1680 train_time:132867ms step_avg:88.11ms +step:1509/1680 train_time:132955ms step_avg:88.11ms +step:1510/1680 train_time:133043ms step_avg:88.11ms +step:1511/1680 train_time:133131ms step_avg:88.11ms +step:1512/1680 train_time:133221ms step_avg:88.11ms +step:1513/1680 train_time:133312ms step_avg:88.11ms +step:1514/1680 train_time:133404ms step_avg:88.11ms +step:1515/1680 train_time:133494ms step_avg:88.12ms +step:1516/1680 train_time:133583ms step_avg:88.12ms +step:1517/1680 train_time:133672ms step_avg:88.12ms +step:1518/1680 train_time:133760ms step_avg:88.12ms +step:1519/1680 train_time:133848ms step_avg:88.12ms +step:1520/1680 train_time:133936ms step_avg:88.12ms +step:1521/1680 train_time:134024ms step_avg:88.12ms +step:1522/1680 train_time:134112ms step_avg:88.12ms +step:1523/1680 train_time:134200ms step_avg:88.12ms +step:1524/1680 train_time:134290ms step_avg:88.12ms +step:1525/1680 train_time:134380ms step_avg:88.12ms +step:1526/1680 train_time:134471ms step_avg:88.12ms +step:1527/1680 train_time:134561ms step_avg:88.12ms +step:1528/1680 train_time:134650ms step_avg:88.12ms +step:1529/1680 train_time:134738ms step_avg:88.12ms +step:1530/1680 train_time:134827ms step_avg:88.12ms +step:1531/1680 train_time:134915ms step_avg:88.12ms +step:1532/1680 train_time:135003ms step_avg:88.12ms +step:1533/1680 train_time:135092ms step_avg:88.12ms +step:1534/1680 train_time:135180ms step_avg:88.12ms +step:1535/1680 train_time:135269ms step_avg:88.12ms +step:1536/1680 train_time:135358ms step_avg:88.12ms +step:1537/1680 train_time:135448ms step_avg:88.12ms +step:1538/1680 train_time:135537ms step_avg:88.13ms +step:1539/1680 train_time:135627ms step_avg:88.13ms +step:1540/1680 train_time:135715ms step_avg:88.13ms +step:1541/1680 train_time:135804ms step_avg:88.13ms +step:1542/1680 train_time:135892ms step_avg:88.13ms +step:1543/1680 train_time:135980ms step_avg:88.13ms +step:1544/1680 train_time:136069ms step_avg:88.13ms +step:1545/1680 train_time:136157ms step_avg:88.13ms +step:1546/1680 train_time:136246ms step_avg:88.13ms +step:1547/1680 train_time:136335ms step_avg:88.13ms +step:1548/1680 train_time:136424ms step_avg:88.13ms +step:1549/1680 train_time:136513ms step_avg:88.13ms +step:1550/1680 train_time:136602ms step_avg:88.13ms +step:1551/1680 train_time:136692ms step_avg:88.13ms +step:1552/1680 train_time:136780ms step_avg:88.13ms +step:1553/1680 train_time:136869ms step_avg:88.13ms +step:1554/1680 train_time:136957ms step_avg:88.13ms +step:1555/1680 train_time:137047ms step_avg:88.13ms +step:1556/1680 train_time:137136ms step_avg:88.13ms +step:1557/1680 train_time:137225ms step_avg:88.13ms +step:1558/1680 train_time:137314ms step_avg:88.13ms +step:1559/1680 train_time:137404ms step_avg:88.14ms +step:1560/1680 train_time:137493ms step_avg:88.14ms +step:1561/1680 train_time:137583ms step_avg:88.14ms +step:1562/1680 train_time:137671ms step_avg:88.14ms +step:1563/1680 train_time:137761ms step_avg:88.14ms +step:1564/1680 train_time:137850ms step_avg:88.14ms +step:1565/1680 train_time:137939ms step_avg:88.14ms +step:1566/1680 train_time:138029ms step_avg:88.14ms +step:1567/1680 train_time:138119ms step_avg:88.14ms +step:1568/1680 train_time:138207ms step_avg:88.14ms +step:1569/1680 train_time:138295ms step_avg:88.14ms +step:1570/1680 train_time:138384ms step_avg:88.14ms +step:1571/1680 train_time:138473ms step_avg:88.14ms +step:1572/1680 train_time:138562ms step_avg:88.14ms +step:1573/1680 train_time:138651ms step_avg:88.14ms +step:1574/1680 train_time:138740ms step_avg:88.14ms +step:1575/1680 train_time:138829ms step_avg:88.15ms +step:1576/1680 train_time:138918ms step_avg:88.15ms +step:1577/1680 train_time:139008ms step_avg:88.15ms +step:1578/1680 train_time:139099ms step_avg:88.15ms +step:1579/1680 train_time:139187ms step_avg:88.15ms +step:1580/1680 train_time:139275ms step_avg:88.15ms +step:1581/1680 train_time:139364ms step_avg:88.15ms +step:1582/1680 train_time:139453ms step_avg:88.15ms +step:1583/1680 train_time:139542ms step_avg:88.15ms +step:1584/1680 train_time:139631ms step_avg:88.15ms +step:1585/1680 train_time:139720ms step_avg:88.15ms +step:1586/1680 train_time:139809ms step_avg:88.15ms +step:1587/1680 train_time:139897ms step_avg:88.15ms +step:1588/1680 train_time:139987ms step_avg:88.15ms +step:1589/1680 train_time:140075ms step_avg:88.15ms +step:1590/1680 train_time:140164ms step_avg:88.15ms +step:1591/1680 train_time:140253ms step_avg:88.15ms +step:1592/1680 train_time:140342ms step_avg:88.15ms +step:1593/1680 train_time:140431ms step_avg:88.15ms +step:1594/1680 train_time:140519ms step_avg:88.16ms +step:1595/1680 train_time:140608ms step_avg:88.16ms +step:1596/1680 train_time:140697ms step_avg:88.16ms +step:1597/1680 train_time:140787ms step_avg:88.16ms +step:1598/1680 train_time:140876ms step_avg:88.16ms +step:1599/1680 train_time:140965ms step_avg:88.16ms +step:1600/1680 train_time:141054ms step_avg:88.16ms +step:1601/1680 train_time:141143ms step_avg:88.16ms +step:1602/1680 train_time:141232ms step_avg:88.16ms +step:1603/1680 train_time:141320ms step_avg:88.16ms +step:1604/1680 train_time:141409ms step_avg:88.16ms +step:1605/1680 train_time:141497ms step_avg:88.16ms +step:1606/1680 train_time:141586ms step_avg:88.16ms +step:1607/1680 train_time:141675ms step_avg:88.16ms +step:1608/1680 train_time:141764ms step_avg:88.16ms +step:1609/1680 train_time:141853ms step_avg:88.16ms +step:1610/1680 train_time:141942ms step_avg:88.16ms +step:1611/1680 train_time:142033ms step_avg:88.16ms +step:1612/1680 train_time:142122ms step_avg:88.16ms +step:1613/1680 train_time:142212ms step_avg:88.17ms +step:1614/1680 train_time:142301ms step_avg:88.17ms +step:1615/1680 train_time:142389ms step_avg:88.17ms +step:1616/1680 train_time:142478ms step_avg:88.17ms +step:1617/1680 train_time:142566ms step_avg:88.17ms +step:1618/1680 train_time:142655ms step_avg:88.17ms +step:1619/1680 train_time:142745ms step_avg:88.17ms +step:1620/1680 train_time:142833ms step_avg:88.17ms +step:1621/1680 train_time:142923ms step_avg:88.17ms +step:1622/1680 train_time:143013ms step_avg:88.17ms +step:1623/1680 train_time:143102ms step_avg:88.17ms +step:1624/1680 train_time:143191ms step_avg:88.17ms +step:1625/1680 train_time:143280ms step_avg:88.17ms +step:1625/1680 val_loss:3.2896 train_time:143370ms step_avg:88.23ms +step:1626/1680 train_time:143388ms step_avg:88.18ms +step:1627/1680 train_time:143463ms step_avg:88.18ms +step:1628/1680 train_time:143557ms step_avg:88.18ms +step:1629/1680 train_time:143646ms step_avg:88.18ms +step:1630/1680 train_time:143734ms step_avg:88.18ms +step:1631/1680 train_time:143823ms step_avg:88.18ms +step:1632/1680 train_time:143911ms step_avg:88.18ms +step:1633/1680 train_time:143998ms step_avg:88.18ms +step:1634/1680 train_time:144087ms step_avg:88.18ms +step:1635/1680 train_time:144175ms step_avg:88.18ms +step:1636/1680 train_time:144263ms step_avg:88.18ms +step:1637/1680 train_time:144353ms step_avg:88.18ms +step:1638/1680 train_time:144446ms step_avg:88.18ms +step:1639/1680 train_time:144538ms step_avg:88.19ms +step:1640/1680 train_time:144628ms step_avg:88.19ms +step:1641/1680 train_time:144716ms step_avg:88.19ms +step:1642/1680 train_time:144804ms step_avg:88.19ms +step:1643/1680 train_time:144894ms step_avg:88.19ms +step:1644/1680 train_time:144982ms step_avg:88.19ms +step:1645/1680 train_time:145070ms step_avg:88.19ms +step:1646/1680 train_time:145159ms step_avg:88.19ms +step:1647/1680 train_time:145247ms step_avg:88.19ms +step:1648/1680 train_time:145336ms step_avg:88.19ms +step:1649/1680 train_time:145427ms step_avg:88.19ms +step:1650/1680 train_time:145517ms step_avg:88.19ms +step:1651/1680 train_time:145607ms step_avg:88.19ms +step:1652/1680 train_time:145696ms step_avg:88.19ms +step:1653/1680 train_time:145785ms step_avg:88.19ms +step:1654/1680 train_time:145874ms step_avg:88.19ms +step:1655/1680 train_time:145962ms step_avg:88.19ms +step:1656/1680 train_time:146050ms step_avg:88.19ms +step:1657/1680 train_time:146138ms step_avg:88.19ms +step:1658/1680 train_time:146226ms step_avg:88.19ms +step:1659/1680 train_time:146315ms step_avg:88.19ms +step:1660/1680 train_time:146405ms step_avg:88.20ms +step:1661/1680 train_time:146495ms step_avg:88.20ms +step:1662/1680 train_time:146585ms step_avg:88.20ms +step:1663/1680 train_time:146674ms step_avg:88.20ms +step:1664/1680 train_time:146764ms step_avg:88.20ms +step:1665/1680 train_time:146853ms step_avg:88.20ms +step:1666/1680 train_time:146942ms step_avg:88.20ms +step:1667/1680 train_time:147030ms step_avg:88.20ms +step:1668/1680 train_time:147119ms step_avg:88.20ms +step:1669/1680 train_time:147207ms step_avg:88.20ms +step:1670/1680 train_time:147295ms step_avg:88.20ms +step:1671/1680 train_time:147384ms step_avg:88.20ms +step:1672/1680 train_time:147474ms step_avg:88.20ms +step:1673/1680 train_time:147563ms step_avg:88.20ms +step:1674/1680 train_time:147653ms step_avg:88.20ms +step:1675/1680 train_time:147742ms step_avg:88.20ms +step:1676/1680 train_time:147831ms step_avg:88.20ms +step:1677/1680 train_time:147920ms step_avg:88.21ms +step:1678/1680 train_time:148009ms step_avg:88.21ms +step:1679/1680 train_time:148098ms step_avg:88.21ms +step:1680/1680 train_time:148187ms step_avg:88.21ms +step:1680/1680 val_loss:3.2786 train_time:148277ms step_avg:88.26ms +peak memory allocated: 30760 MiB reserved: 46234 MiB diff --git a/records/092725_BF16CE/1316b79f-d02b-4cd6-b98a-43b48023aedf.txt b/records/092725_BF16CE/1316b79f-d02b-4cd6-b98a-43b48023aedf.txt new file mode 100644 index 000000000..250519b6c --- /dev/null +++ b/records/092725_BF16CE/1316b79f-d02b-4cd6-b98a-43b48023aedf.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:08:42 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 28C P0 123W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 26C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 23C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 25C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 152004 C /usr/bin/python 0MiB | +| 0 N/A N/A 152005 C /usr/bin/python 0MiB | +| 0 N/A N/A 152006 C /usr/bin/python 0MiB | +| 0 N/A N/A 152007 C /usr/bin/python 0MiB | +| 0 N/A N/A 152008 C /usr/bin/python 0MiB | +| 0 N/A N/A 152009 C /usr/bin/python 0MiB | +| 0 N/A N/A 152010 C /usr/bin/python 0MiB | +| 0 N/A N/A 152011 C /usr/bin/python 0MiB | +| 1 N/A N/A 152005 C /usr/bin/python 0MiB | +| 2 N/A N/A 152006 C /usr/bin/python 0MiB | +| 3 N/A N/A 152007 C /usr/bin/python 0MiB | +| 4 N/A N/A 152008 C /usr/bin/python 0MiB | +| 5 N/A N/A 152009 C /usr/bin/python 0MiB | +| 6 N/A N/A 152010 C /usr/bin/python 0MiB | +| 7 N/A N/A 152011 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:137ms step_avg:136.89ms +step:2/1680 train_time:157ms step_avg:78.59ms +step:3/1680 train_time:221ms step_avg:73.73ms +step:4/1680 train_time:307ms step_avg:76.69ms +step:5/1680 train_time:393ms step_avg:78.53ms +step:6/1680 train_time:479ms step_avg:79.85ms +step:7/1680 train_time:565ms step_avg:80.77ms +step:8/1680 train_time:651ms step_avg:81.41ms +step:9/1680 train_time:738ms step_avg:81.96ms +step:10/1680 train_time:824ms step_avg:82.42ms +step:11/1680 train_time:911ms step_avg:82.82ms +step:12/1680 train_time:998ms step_avg:83.13ms +step:13/1680 train_time:1086ms step_avg:83.56ms +step:14/1680 train_time:1178ms step_avg:84.11ms +step:15/1680 train_time:1266ms step_avg:84.42ms +step:16/1680 train_time:1354ms step_avg:84.61ms +step:17/1680 train_time:1441ms step_avg:84.79ms +step:18/1680 train_time:1528ms step_avg:84.90ms +step:19/1680 train_time:1615ms step_avg:85.01ms +step:20/1680 train_time:1703ms step_avg:85.14ms +step:21/1680 train_time:1789ms step_avg:85.19ms +step:22/1680 train_time:1876ms step_avg:85.26ms +step:23/1680 train_time:1963ms step_avg:85.35ms +step:24/1680 train_time:2051ms step_avg:85.48ms +step:25/1680 train_time:2141ms step_avg:85.63ms +step:26/1680 train_time:2230ms step_avg:85.75ms +step:27/1680 train_time:2318ms step_avg:85.86ms +step:28/1680 train_time:2406ms step_avg:85.92ms +step:29/1680 train_time:2495ms step_avg:86.02ms +step:30/1680 train_time:2582ms step_avg:86.05ms +step:31/1680 train_time:2669ms step_avg:86.10ms +step:32/1680 train_time:2755ms step_avg:86.11ms +step:33/1680 train_time:2842ms step_avg:86.13ms +step:34/1680 train_time:2930ms step_avg:86.18ms +step:35/1680 train_time:3018ms step_avg:86.22ms +step:36/1680 train_time:3105ms step_avg:86.26ms +step:37/1680 train_time:3194ms step_avg:86.32ms +step:38/1680 train_time:3282ms step_avg:86.36ms +step:39/1680 train_time:3370ms step_avg:86.40ms +step:40/1680 train_time:3457ms step_avg:86.42ms +step:41/1680 train_time:3544ms step_avg:86.43ms +step:42/1680 train_time:3631ms step_avg:86.46ms +step:43/1680 train_time:3718ms step_avg:86.47ms +step:44/1680 train_time:3805ms step_avg:86.47ms +step:45/1680 train_time:3891ms step_avg:86.47ms +step:46/1680 train_time:3978ms step_avg:86.48ms +step:47/1680 train_time:4066ms step_avg:86.51ms +step:48/1680 train_time:4154ms step_avg:86.54ms +step:49/1680 train_time:4241ms step_avg:86.56ms +step:50/1680 train_time:4329ms step_avg:86.58ms +step:51/1680 train_time:4417ms step_avg:86.61ms +step:52/1680 train_time:4504ms step_avg:86.62ms +step:53/1680 train_time:4591ms step_avg:86.63ms +step:54/1680 train_time:4678ms step_avg:86.64ms +step:55/1680 train_time:4765ms step_avg:86.64ms +step:56/1680 train_time:4852ms step_avg:86.64ms +step:57/1680 train_time:4939ms step_avg:86.64ms +step:58/1680 train_time:5026ms step_avg:86.65ms +step:59/1680 train_time:5114ms step_avg:86.68ms +step:60/1680 train_time:5201ms step_avg:86.68ms +step:61/1680 train_time:5288ms step_avg:86.69ms +step:62/1680 train_time:5376ms step_avg:86.71ms +step:63/1680 train_time:5463ms step_avg:86.72ms +step:64/1680 train_time:5550ms step_avg:86.73ms +step:65/1680 train_time:5638ms step_avg:86.73ms +step:66/1680 train_time:5725ms step_avg:86.74ms +step:67/1680 train_time:5812ms step_avg:86.75ms +step:68/1680 train_time:5899ms step_avg:86.75ms +step:69/1680 train_time:5986ms step_avg:86.75ms +step:70/1680 train_time:6073ms step_avg:86.75ms +step:71/1680 train_time:6160ms step_avg:86.75ms +step:72/1680 train_time:6247ms step_avg:86.76ms +step:73/1680 train_time:6335ms step_avg:86.79ms +step:74/1680 train_time:6422ms step_avg:86.79ms +step:75/1680 train_time:6510ms step_avg:86.81ms +step:76/1680 train_time:6597ms step_avg:86.81ms +step:77/1680 train_time:6685ms step_avg:86.81ms +step:78/1680 train_time:6772ms step_avg:86.82ms +step:79/1680 train_time:6859ms step_avg:86.82ms +step:80/1680 train_time:6946ms step_avg:86.83ms +step:81/1680 train_time:7033ms step_avg:86.83ms +step:82/1680 train_time:7120ms step_avg:86.83ms +step:83/1680 train_time:7208ms step_avg:86.84ms +step:84/1680 train_time:7295ms step_avg:86.85ms +step:85/1680 train_time:7383ms step_avg:86.86ms +step:86/1680 train_time:7470ms step_avg:86.86ms +step:87/1680 train_time:7558ms step_avg:86.87ms +step:88/1680 train_time:7645ms step_avg:86.88ms +step:89/1680 train_time:7732ms step_avg:86.88ms +step:90/1680 train_time:7819ms step_avg:86.88ms +step:91/1680 train_time:7907ms step_avg:86.89ms +step:92/1680 train_time:7993ms step_avg:86.88ms +step:93/1680 train_time:8080ms step_avg:86.89ms +step:94/1680 train_time:8168ms step_avg:86.89ms +step:95/1680 train_time:8256ms step_avg:86.90ms +step:96/1680 train_time:8343ms step_avg:86.90ms +step:97/1680 train_time:8430ms step_avg:86.91ms +step:98/1680 train_time:8518ms step_avg:86.92ms +step:99/1680 train_time:8605ms step_avg:86.92ms +step:100/1680 train_time:8693ms step_avg:86.93ms +step:101/1680 train_time:8780ms step_avg:86.93ms +step:102/1680 train_time:8867ms step_avg:86.93ms +step:103/1680 train_time:8955ms step_avg:86.94ms +step:104/1680 train_time:9042ms step_avg:86.95ms +step:105/1680 train_time:9129ms step_avg:86.95ms +step:106/1680 train_time:9217ms step_avg:86.95ms +step:107/1680 train_time:9304ms step_avg:86.95ms +step:108/1680 train_time:9391ms step_avg:86.96ms +step:109/1680 train_time:9478ms step_avg:86.96ms +step:110/1680 train_time:9566ms step_avg:86.96ms +step:111/1680 train_time:9654ms step_avg:86.97ms +step:112/1680 train_time:9741ms step_avg:86.97ms +step:113/1680 train_time:9828ms step_avg:86.97ms +step:114/1680 train_time:9916ms step_avg:86.99ms +step:115/1680 train_time:10004ms step_avg:86.99ms +step:116/1680 train_time:10091ms step_avg:86.99ms +step:117/1680 train_time:10177ms step_avg:86.99ms +step:118/1680 train_time:10265ms step_avg:86.99ms +step:119/1680 train_time:10352ms step_avg:86.99ms +step:120/1680 train_time:10439ms step_avg:86.99ms +step:121/1680 train_time:10526ms step_avg:86.99ms +step:122/1680 train_time:10615ms step_avg:87.01ms +step:123/1680 train_time:10702ms step_avg:87.01ms +step:124/1680 train_time:10789ms step_avg:87.01ms +step:125/1680 train_time:10876ms step_avg:87.01ms +step:125/1680 val_loss:4.3028 train_time:10965ms step_avg:87.72ms +step:126/1680 train_time:10985ms step_avg:87.19ms +step:127/1680 train_time:11055ms step_avg:87.05ms +step:128/1680 train_time:11152ms step_avg:87.12ms +step:129/1680 train_time:11242ms step_avg:87.15ms +step:130/1680 train_time:11329ms step_avg:87.15ms +step:131/1680 train_time:11415ms step_avg:87.14ms +step:132/1680 train_time:11502ms step_avg:87.14ms +step:133/1680 train_time:11589ms step_avg:87.13ms +step:134/1680 train_time:11674ms step_avg:87.12ms +step:135/1680 train_time:11760ms step_avg:87.11ms +step:136/1680 train_time:11846ms step_avg:87.10ms +step:137/1680 train_time:11933ms step_avg:87.10ms +step:138/1680 train_time:12021ms step_avg:87.11ms +step:139/1680 train_time:12110ms step_avg:87.12ms +step:140/1680 train_time:12199ms step_avg:87.14ms +step:141/1680 train_time:12288ms step_avg:87.15ms +step:142/1680 train_time:12376ms step_avg:87.15ms +step:143/1680 train_time:12462ms step_avg:87.15ms +step:144/1680 train_time:12549ms step_avg:87.14ms +step:145/1680 train_time:12635ms step_avg:87.14ms +step:146/1680 train_time:12721ms step_avg:87.13ms +step:147/1680 train_time:12807ms step_avg:87.12ms +step:148/1680 train_time:12893ms step_avg:87.12ms +step:149/1680 train_time:12980ms step_avg:87.12ms +step:150/1680 train_time:13068ms step_avg:87.12ms +step:151/1680 train_time:13156ms step_avg:87.13ms +step:152/1680 train_time:13245ms step_avg:87.14ms +step:153/1680 train_time:13333ms step_avg:87.14ms +step:154/1680 train_time:13420ms step_avg:87.15ms +step:155/1680 train_time:13508ms step_avg:87.15ms +step:156/1680 train_time:13595ms step_avg:87.15ms +step:157/1680 train_time:13682ms step_avg:87.14ms +step:158/1680 train_time:13768ms step_avg:87.14ms +step:159/1680 train_time:13854ms step_avg:87.13ms +step:160/1680 train_time:13942ms step_avg:87.14ms +step:161/1680 train_time:14029ms step_avg:87.13ms +step:162/1680 train_time:14116ms step_avg:87.14ms +step:163/1680 train_time:14205ms step_avg:87.15ms +step:164/1680 train_time:14293ms step_avg:87.15ms +step:165/1680 train_time:14380ms step_avg:87.15ms +step:166/1680 train_time:14468ms step_avg:87.16ms +step:167/1680 train_time:14555ms step_avg:87.15ms +step:168/1680 train_time:14641ms step_avg:87.15ms +step:169/1680 train_time:14729ms step_avg:87.15ms +step:170/1680 train_time:14815ms step_avg:87.15ms +step:171/1680 train_time:14901ms step_avg:87.14ms +step:172/1680 train_time:14988ms step_avg:87.14ms +step:173/1680 train_time:15075ms step_avg:87.14ms +step:174/1680 train_time:15163ms step_avg:87.14ms +step:175/1680 train_time:15251ms step_avg:87.15ms +step:176/1680 train_time:15338ms step_avg:87.15ms +step:177/1680 train_time:15426ms step_avg:87.15ms +step:178/1680 train_time:15513ms step_avg:87.15ms +step:179/1680 train_time:15600ms step_avg:87.15ms +step:180/1680 train_time:15687ms step_avg:87.15ms +step:181/1680 train_time:15775ms step_avg:87.15ms +step:182/1680 train_time:15862ms step_avg:87.15ms +step:183/1680 train_time:15949ms step_avg:87.15ms +step:184/1680 train_time:16036ms step_avg:87.15ms +step:185/1680 train_time:16123ms step_avg:87.15ms +step:186/1680 train_time:16211ms step_avg:87.15ms +step:187/1680 train_time:16298ms step_avg:87.15ms +step:188/1680 train_time:16385ms step_avg:87.15ms +step:189/1680 train_time:16473ms step_avg:87.16ms +step:190/1680 train_time:16560ms step_avg:87.16ms +step:191/1680 train_time:16647ms step_avg:87.16ms +step:192/1680 train_time:16734ms step_avg:87.16ms +step:193/1680 train_time:16822ms step_avg:87.16ms +step:194/1680 train_time:16909ms step_avg:87.16ms +step:195/1680 train_time:16995ms step_avg:87.16ms +step:196/1680 train_time:17083ms step_avg:87.16ms +step:197/1680 train_time:17171ms step_avg:87.16ms +step:198/1680 train_time:17258ms step_avg:87.16ms +step:199/1680 train_time:17346ms step_avg:87.16ms +step:200/1680 train_time:17433ms step_avg:87.17ms +step:201/1680 train_time:17520ms step_avg:87.17ms +step:202/1680 train_time:17608ms step_avg:87.17ms +step:203/1680 train_time:17695ms step_avg:87.17ms +step:204/1680 train_time:17782ms step_avg:87.16ms +step:205/1680 train_time:17868ms step_avg:87.16ms +step:206/1680 train_time:17955ms step_avg:87.16ms +step:207/1680 train_time:18042ms step_avg:87.16ms +step:208/1680 train_time:18128ms step_avg:87.16ms +step:209/1680 train_time:18215ms step_avg:87.15ms +step:210/1680 train_time:18302ms step_avg:87.15ms +step:211/1680 train_time:18389ms step_avg:87.15ms +step:212/1680 train_time:18476ms step_avg:87.15ms +step:213/1680 train_time:18565ms step_avg:87.16ms +step:214/1680 train_time:18652ms step_avg:87.16ms +step:215/1680 train_time:18739ms step_avg:87.16ms +step:216/1680 train_time:18826ms step_avg:87.16ms +step:217/1680 train_time:18914ms step_avg:87.16ms +step:218/1680 train_time:19001ms step_avg:87.16ms +step:219/1680 train_time:19089ms step_avg:87.16ms +step:220/1680 train_time:19175ms step_avg:87.16ms +step:221/1680 train_time:19263ms step_avg:87.16ms +step:222/1680 train_time:19349ms step_avg:87.16ms +step:223/1680 train_time:19437ms step_avg:87.16ms +step:224/1680 train_time:19524ms step_avg:87.16ms +step:225/1680 train_time:19611ms step_avg:87.16ms +step:226/1680 train_time:19699ms step_avg:87.16ms +step:227/1680 train_time:19786ms step_avg:87.16ms +step:228/1680 train_time:19873ms step_avg:87.16ms +step:229/1680 train_time:19960ms step_avg:87.16ms +step:230/1680 train_time:20047ms step_avg:87.16ms +step:231/1680 train_time:20134ms step_avg:87.16ms +step:232/1680 train_time:20221ms step_avg:87.16ms +step:233/1680 train_time:20308ms step_avg:87.16ms +step:234/1680 train_time:20395ms step_avg:87.16ms +step:235/1680 train_time:20482ms step_avg:87.16ms +step:236/1680 train_time:20570ms step_avg:87.16ms +step:237/1680 train_time:20657ms step_avg:87.16ms +step:238/1680 train_time:20744ms step_avg:87.16ms +step:239/1680 train_time:20831ms step_avg:87.16ms +step:240/1680 train_time:20918ms step_avg:87.16ms +step:241/1680 train_time:21005ms step_avg:87.16ms +step:242/1680 train_time:21093ms step_avg:87.16ms +step:243/1680 train_time:21180ms step_avg:87.16ms +step:244/1680 train_time:21267ms step_avg:87.16ms +step:245/1680 train_time:21354ms step_avg:87.16ms +step:246/1680 train_time:21441ms step_avg:87.16ms +step:247/1680 train_time:21528ms step_avg:87.16ms +step:248/1680 train_time:21615ms step_avg:87.16ms +step:249/1680 train_time:21701ms step_avg:87.15ms +step:250/1680 train_time:21789ms step_avg:87.16ms +step:250/1680 val_loss:3.9662 train_time:21877ms step_avg:87.51ms +step:251/1680 train_time:21896ms step_avg:87.23ms +step:252/1680 train_time:21967ms step_avg:87.17ms +step:253/1680 train_time:22058ms step_avg:87.18ms +step:254/1680 train_time:22146ms step_avg:87.19ms +step:255/1680 train_time:22232ms step_avg:87.18ms +step:256/1680 train_time:22319ms step_avg:87.18ms +step:257/1680 train_time:22405ms step_avg:87.18ms +step:258/1680 train_time:22491ms step_avg:87.17ms +step:259/1680 train_time:22577ms step_avg:87.17ms +step:260/1680 train_time:22663ms step_avg:87.17ms +step:261/1680 train_time:22750ms step_avg:87.16ms +step:262/1680 train_time:22838ms step_avg:87.17ms +step:263/1680 train_time:22926ms step_avg:87.17ms +step:264/1680 train_time:23015ms step_avg:87.18ms +step:265/1680 train_time:23103ms step_avg:87.18ms +step:266/1680 train_time:23191ms step_avg:87.19ms +step:267/1680 train_time:23278ms step_avg:87.18ms +step:268/1680 train_time:23365ms step_avg:87.18ms +step:269/1680 train_time:23452ms step_avg:87.18ms +step:270/1680 train_time:23538ms step_avg:87.18ms +step:271/1680 train_time:23624ms step_avg:87.17ms +step:272/1680 train_time:23711ms step_avg:87.17ms +step:273/1680 train_time:23798ms step_avg:87.17ms +step:274/1680 train_time:23885ms step_avg:87.17ms +step:275/1680 train_time:23973ms step_avg:87.18ms +step:276/1680 train_time:24061ms step_avg:87.18ms +step:277/1680 train_time:24149ms step_avg:87.18ms +step:278/1680 train_time:24237ms step_avg:87.18ms +step:279/1680 train_time:24324ms step_avg:87.18ms +step:280/1680 train_time:24411ms step_avg:87.18ms +step:281/1680 train_time:24497ms step_avg:87.18ms +step:282/1680 train_time:24584ms step_avg:87.18ms +step:283/1680 train_time:24671ms step_avg:87.18ms +step:284/1680 train_time:24757ms step_avg:87.17ms +step:285/1680 train_time:24844ms step_avg:87.17ms +step:286/1680 train_time:24932ms step_avg:87.17ms +step:287/1680 train_time:25019ms step_avg:87.18ms +step:288/1680 train_time:25107ms step_avg:87.18ms +step:289/1680 train_time:25196ms step_avg:87.18ms +step:290/1680 train_time:25283ms step_avg:87.18ms +step:291/1680 train_time:25370ms step_avg:87.18ms +step:292/1680 train_time:25457ms step_avg:87.18ms +step:293/1680 train_time:25544ms step_avg:87.18ms +step:294/1680 train_time:25631ms step_avg:87.18ms +step:295/1680 train_time:25717ms step_avg:87.18ms +step:296/1680 train_time:25804ms step_avg:87.18ms +step:297/1680 train_time:25892ms step_avg:87.18ms +step:298/1680 train_time:25979ms step_avg:87.18ms +step:299/1680 train_time:26066ms step_avg:87.18ms +step:300/1680 train_time:26154ms step_avg:87.18ms +step:301/1680 train_time:26242ms step_avg:87.18ms +step:302/1680 train_time:26329ms step_avg:87.18ms +step:303/1680 train_time:26416ms step_avg:87.18ms +step:304/1680 train_time:26504ms step_avg:87.18ms +step:305/1680 train_time:26591ms step_avg:87.18ms +step:306/1680 train_time:26677ms step_avg:87.18ms +step:307/1680 train_time:26764ms step_avg:87.18ms +step:308/1680 train_time:26852ms step_avg:87.18ms +step:309/1680 train_time:26939ms step_avg:87.18ms +step:310/1680 train_time:27026ms step_avg:87.18ms +step:311/1680 train_time:27113ms step_avg:87.18ms +step:312/1680 train_time:27200ms step_avg:87.18ms +step:313/1680 train_time:27287ms step_avg:87.18ms +step:314/1680 train_time:27374ms step_avg:87.18ms +step:315/1680 train_time:27461ms step_avg:87.18ms +step:316/1680 train_time:27549ms step_avg:87.18ms +step:317/1680 train_time:27636ms step_avg:87.18ms +step:318/1680 train_time:27723ms step_avg:87.18ms +step:319/1680 train_time:27810ms step_avg:87.18ms +step:320/1680 train_time:27897ms step_avg:87.18ms +step:321/1680 train_time:27984ms step_avg:87.18ms +step:322/1680 train_time:28071ms step_avg:87.18ms +step:323/1680 train_time:28158ms step_avg:87.18ms +step:324/1680 train_time:28246ms step_avg:87.18ms +step:325/1680 train_time:28333ms step_avg:87.18ms +step:326/1680 train_time:28420ms step_avg:87.18ms +step:327/1680 train_time:28507ms step_avg:87.18ms +step:328/1680 train_time:28595ms step_avg:87.18ms +step:329/1680 train_time:28682ms step_avg:87.18ms +step:330/1680 train_time:28769ms step_avg:87.18ms +step:331/1680 train_time:28857ms step_avg:87.18ms +step:332/1680 train_time:28944ms step_avg:87.18ms +step:333/1680 train_time:29031ms step_avg:87.18ms +step:334/1680 train_time:29118ms step_avg:87.18ms +step:335/1680 train_time:29205ms step_avg:87.18ms +step:336/1680 train_time:29293ms step_avg:87.18ms +step:337/1680 train_time:29380ms step_avg:87.18ms +step:338/1680 train_time:29467ms step_avg:87.18ms +step:339/1680 train_time:29555ms step_avg:87.18ms +step:340/1680 train_time:29641ms step_avg:87.18ms +step:341/1680 train_time:29728ms step_avg:87.18ms +step:342/1680 train_time:29815ms step_avg:87.18ms +step:343/1680 train_time:29902ms step_avg:87.18ms +step:344/1680 train_time:29989ms step_avg:87.18ms +step:345/1680 train_time:30077ms step_avg:87.18ms +step:346/1680 train_time:30164ms step_avg:87.18ms +step:347/1680 train_time:30252ms step_avg:87.18ms +step:348/1680 train_time:30340ms step_avg:87.18ms +step:349/1680 train_time:30426ms step_avg:87.18ms +step:350/1680 train_time:30514ms step_avg:87.18ms +step:351/1680 train_time:30601ms step_avg:87.18ms +step:352/1680 train_time:30688ms step_avg:87.18ms +step:353/1680 train_time:30775ms step_avg:87.18ms +step:354/1680 train_time:30862ms step_avg:87.18ms +step:355/1680 train_time:30949ms step_avg:87.18ms +step:356/1680 train_time:31037ms step_avg:87.18ms +step:357/1680 train_time:31124ms step_avg:87.18ms +step:358/1680 train_time:31211ms step_avg:87.18ms +step:359/1680 train_time:31298ms step_avg:87.18ms +step:360/1680 train_time:31385ms step_avg:87.18ms +step:361/1680 train_time:31472ms step_avg:87.18ms +step:362/1680 train_time:31559ms step_avg:87.18ms +step:363/1680 train_time:31646ms step_avg:87.18ms +step:364/1680 train_time:31734ms step_avg:87.18ms +step:365/1680 train_time:31821ms step_avg:87.18ms +step:366/1680 train_time:31908ms step_avg:87.18ms +step:367/1680 train_time:31995ms step_avg:87.18ms +step:368/1680 train_time:32082ms step_avg:87.18ms +step:369/1680 train_time:32169ms step_avg:87.18ms +step:370/1680 train_time:32256ms step_avg:87.18ms +step:371/1680 train_time:32343ms step_avg:87.18ms +step:372/1680 train_time:32431ms step_avg:87.18ms +step:373/1680 train_time:32518ms step_avg:87.18ms +step:374/1680 train_time:32605ms step_avg:87.18ms +step:375/1680 train_time:32693ms step_avg:87.18ms +step:375/1680 val_loss:3.8138 train_time:32781ms step_avg:87.42ms +step:376/1680 train_time:32802ms step_avg:87.24ms +step:377/1680 train_time:32870ms step_avg:87.19ms +step:378/1680 train_time:32961ms step_avg:87.20ms +step:379/1680 train_time:33049ms step_avg:87.20ms +step:380/1680 train_time:33137ms step_avg:87.20ms +step:381/1680 train_time:33223ms step_avg:87.20ms +step:382/1680 train_time:33309ms step_avg:87.20ms +step:383/1680 train_time:33396ms step_avg:87.20ms +step:384/1680 train_time:33483ms step_avg:87.20ms +step:385/1680 train_time:33570ms step_avg:87.20ms +step:386/1680 train_time:33657ms step_avg:87.19ms +step:387/1680 train_time:33744ms step_avg:87.19ms +step:388/1680 train_time:33833ms step_avg:87.20ms +step:389/1680 train_time:33921ms step_avg:87.20ms +step:390/1680 train_time:34009ms step_avg:87.20ms +step:391/1680 train_time:34096ms step_avg:87.20ms +step:392/1680 train_time:34184ms step_avg:87.20ms +step:393/1680 train_time:34270ms step_avg:87.20ms +step:394/1680 train_time:34357ms step_avg:87.20ms +step:395/1680 train_time:34444ms step_avg:87.20ms +step:396/1680 train_time:34531ms step_avg:87.20ms +step:397/1680 train_time:34617ms step_avg:87.20ms +step:398/1680 train_time:34705ms step_avg:87.20ms +step:399/1680 train_time:34792ms step_avg:87.20ms +step:400/1680 train_time:34879ms step_avg:87.20ms +step:401/1680 train_time:34967ms step_avg:87.20ms +step:402/1680 train_time:35055ms step_avg:87.20ms +step:403/1680 train_time:35143ms step_avg:87.20ms +step:404/1680 train_time:35230ms step_avg:87.20ms +step:405/1680 train_time:35317ms step_avg:87.20ms +step:406/1680 train_time:35404ms step_avg:87.20ms +step:407/1680 train_time:35490ms step_avg:87.20ms +step:408/1680 train_time:35576ms step_avg:87.20ms +step:409/1680 train_time:35663ms step_avg:87.20ms +step:410/1680 train_time:35750ms step_avg:87.20ms +step:411/1680 train_time:35838ms step_avg:87.20ms +step:412/1680 train_time:35925ms step_avg:87.20ms +step:413/1680 train_time:36012ms step_avg:87.20ms +step:414/1680 train_time:36101ms step_avg:87.20ms +step:415/1680 train_time:36188ms step_avg:87.20ms +step:416/1680 train_time:36275ms step_avg:87.20ms +step:417/1680 train_time:36362ms step_avg:87.20ms +step:418/1680 train_time:36448ms step_avg:87.20ms +step:419/1680 train_time:36535ms step_avg:87.20ms +step:420/1680 train_time:36622ms step_avg:87.20ms +step:421/1680 train_time:36709ms step_avg:87.19ms +step:422/1680 train_time:36796ms step_avg:87.19ms +step:423/1680 train_time:36883ms step_avg:87.19ms +step:424/1680 train_time:36971ms step_avg:87.19ms +step:425/1680 train_time:37059ms step_avg:87.20ms +step:426/1680 train_time:37147ms step_avg:87.20ms +step:427/1680 train_time:37234ms step_avg:87.20ms +step:428/1680 train_time:37321ms step_avg:87.20ms +step:429/1680 train_time:37408ms step_avg:87.20ms +step:430/1680 train_time:37495ms step_avg:87.20ms +step:431/1680 train_time:37582ms step_avg:87.20ms +step:432/1680 train_time:37669ms step_avg:87.20ms +step:433/1680 train_time:37755ms step_avg:87.19ms +step:434/1680 train_time:37843ms step_avg:87.20ms +step:435/1680 train_time:37930ms step_avg:87.20ms +step:436/1680 train_time:38018ms step_avg:87.20ms +step:437/1680 train_time:38105ms step_avg:87.20ms +step:438/1680 train_time:38193ms step_avg:87.20ms +step:439/1680 train_time:38280ms step_avg:87.20ms +step:440/1680 train_time:38367ms step_avg:87.20ms +step:441/1680 train_time:38454ms step_avg:87.20ms +step:442/1680 train_time:38541ms step_avg:87.20ms +step:443/1680 train_time:38629ms step_avg:87.20ms +step:444/1680 train_time:38716ms step_avg:87.20ms +step:445/1680 train_time:38804ms step_avg:87.20ms +step:446/1680 train_time:38891ms step_avg:87.20ms +step:447/1680 train_time:38979ms step_avg:87.20ms +step:448/1680 train_time:39066ms step_avg:87.20ms +step:449/1680 train_time:39154ms step_avg:87.20ms +step:450/1680 train_time:39241ms step_avg:87.20ms +step:451/1680 train_time:39328ms step_avg:87.20ms +step:452/1680 train_time:39415ms step_avg:87.20ms +step:453/1680 train_time:39502ms step_avg:87.20ms +step:454/1680 train_time:39589ms step_avg:87.20ms +step:455/1680 train_time:39676ms step_avg:87.20ms +step:456/1680 train_time:39764ms step_avg:87.20ms +step:457/1680 train_time:39851ms step_avg:87.20ms +step:458/1680 train_time:39938ms step_avg:87.20ms +step:459/1680 train_time:40025ms step_avg:87.20ms +step:460/1680 train_time:40112ms step_avg:87.20ms +step:461/1680 train_time:40201ms step_avg:87.20ms +step:462/1680 train_time:40288ms step_avg:87.20ms +step:463/1680 train_time:40375ms step_avg:87.20ms +step:464/1680 train_time:40462ms step_avg:87.20ms +step:465/1680 train_time:40549ms step_avg:87.20ms +step:466/1680 train_time:40637ms step_avg:87.20ms +step:467/1680 train_time:40724ms step_avg:87.20ms +step:468/1680 train_time:40811ms step_avg:87.20ms +step:469/1680 train_time:40899ms step_avg:87.21ms +step:470/1680 train_time:40986ms step_avg:87.20ms +step:471/1680 train_time:41073ms step_avg:87.20ms +step:472/1680 train_time:41162ms step_avg:87.21ms +step:473/1680 train_time:41248ms step_avg:87.21ms +step:474/1680 train_time:41336ms step_avg:87.21ms +step:475/1680 train_time:41422ms step_avg:87.20ms +step:476/1680 train_time:41509ms step_avg:87.20ms +step:477/1680 train_time:41596ms step_avg:87.20ms +step:478/1680 train_time:41683ms step_avg:87.20ms +step:479/1680 train_time:41770ms step_avg:87.20ms +step:480/1680 train_time:41858ms step_avg:87.20ms +step:481/1680 train_time:41945ms step_avg:87.20ms +step:482/1680 train_time:42032ms step_avg:87.20ms +step:483/1680 train_time:42120ms step_avg:87.21ms +step:484/1680 train_time:42207ms step_avg:87.20ms +step:485/1680 train_time:42295ms step_avg:87.21ms +step:486/1680 train_time:42382ms step_avg:87.21ms +step:487/1680 train_time:42469ms step_avg:87.21ms +step:488/1680 train_time:42557ms step_avg:87.21ms +step:489/1680 train_time:42643ms step_avg:87.20ms +step:490/1680 train_time:42730ms step_avg:87.20ms +step:491/1680 train_time:42817ms step_avg:87.20ms +step:492/1680 train_time:42904ms step_avg:87.20ms +step:493/1680 train_time:42991ms step_avg:87.20ms +step:494/1680 train_time:43079ms step_avg:87.20ms +step:495/1680 train_time:43165ms step_avg:87.20ms +step:496/1680 train_time:43253ms step_avg:87.20ms +step:497/1680 train_time:43341ms step_avg:87.20ms +step:498/1680 train_time:43428ms step_avg:87.20ms +step:499/1680 train_time:43515ms step_avg:87.21ms +step:500/1680 train_time:43602ms step_avg:87.20ms +step:500/1680 val_loss:3.7153 train_time:43691ms step_avg:87.38ms +step:501/1680 train_time:43710ms step_avg:87.24ms +step:502/1680 train_time:43781ms step_avg:87.21ms +step:503/1680 train_time:43873ms step_avg:87.22ms +step:504/1680 train_time:43960ms step_avg:87.22ms +step:505/1680 train_time:44047ms step_avg:87.22ms +step:506/1680 train_time:44133ms step_avg:87.22ms +step:507/1680 train_time:44219ms step_avg:87.22ms +step:508/1680 train_time:44305ms step_avg:87.22ms +step:509/1680 train_time:44391ms step_avg:87.21ms +step:510/1680 train_time:44478ms step_avg:87.21ms +step:511/1680 train_time:44564ms step_avg:87.21ms +step:512/1680 train_time:44652ms step_avg:87.21ms +step:513/1680 train_time:44740ms step_avg:87.21ms +step:514/1680 train_time:44829ms step_avg:87.22ms +step:515/1680 train_time:44918ms step_avg:87.22ms +step:516/1680 train_time:45006ms step_avg:87.22ms +step:517/1680 train_time:45093ms step_avg:87.22ms +step:518/1680 train_time:45181ms step_avg:87.22ms +step:519/1680 train_time:45267ms step_avg:87.22ms +step:520/1680 train_time:45354ms step_avg:87.22ms +step:521/1680 train_time:45440ms step_avg:87.22ms +step:522/1680 train_time:45526ms step_avg:87.21ms +step:523/1680 train_time:45612ms step_avg:87.21ms +step:524/1680 train_time:45701ms step_avg:87.22ms +step:525/1680 train_time:45789ms step_avg:87.22ms +step:526/1680 train_time:45878ms step_avg:87.22ms +step:527/1680 train_time:45965ms step_avg:87.22ms +step:528/1680 train_time:46053ms step_avg:87.22ms +step:529/1680 train_time:46140ms step_avg:87.22ms +step:530/1680 train_time:46226ms step_avg:87.22ms +step:531/1680 train_time:46314ms step_avg:87.22ms +step:532/1680 train_time:46401ms step_avg:87.22ms +step:533/1680 train_time:46487ms step_avg:87.22ms +step:534/1680 train_time:46574ms step_avg:87.22ms +step:535/1680 train_time:46662ms step_avg:87.22ms +step:536/1680 train_time:46750ms step_avg:87.22ms +step:537/1680 train_time:46839ms step_avg:87.22ms +step:538/1680 train_time:46927ms step_avg:87.22ms +step:539/1680 train_time:47014ms step_avg:87.22ms +step:540/1680 train_time:47102ms step_avg:87.23ms +step:541/1680 train_time:47190ms step_avg:87.23ms +step:542/1680 train_time:47278ms step_avg:87.23ms +step:543/1680 train_time:47364ms step_avg:87.23ms +step:544/1680 train_time:47451ms step_avg:87.23ms +step:545/1680 train_time:47537ms step_avg:87.22ms +step:546/1680 train_time:47624ms step_avg:87.22ms +step:547/1680 train_time:47712ms step_avg:87.22ms +step:548/1680 train_time:47801ms step_avg:87.23ms +step:549/1680 train_time:47889ms step_avg:87.23ms +step:550/1680 train_time:47978ms step_avg:87.23ms +step:551/1680 train_time:48066ms step_avg:87.23ms +step:552/1680 train_time:48155ms step_avg:87.24ms +step:553/1680 train_time:48244ms step_avg:87.24ms +step:554/1680 train_time:48332ms step_avg:87.24ms +step:555/1680 train_time:48420ms step_avg:87.24ms +step:556/1680 train_time:48508ms step_avg:87.24ms +step:557/1680 train_time:48597ms step_avg:87.25ms +step:558/1680 train_time:48685ms step_avg:87.25ms +step:559/1680 train_time:48774ms step_avg:87.25ms +step:560/1680 train_time:48863ms step_avg:87.26ms +step:561/1680 train_time:48952ms step_avg:87.26ms +step:562/1680 train_time:49041ms step_avg:87.26ms +step:563/1680 train_time:49129ms step_avg:87.26ms +step:564/1680 train_time:49216ms step_avg:87.26ms +step:565/1680 train_time:49304ms step_avg:87.26ms +step:566/1680 train_time:49393ms step_avg:87.27ms +step:567/1680 train_time:49481ms step_avg:87.27ms +step:568/1680 train_time:49570ms step_avg:87.27ms +step:569/1680 train_time:49658ms step_avg:87.27ms +step:570/1680 train_time:49746ms step_avg:87.27ms +step:571/1680 train_time:49835ms step_avg:87.28ms +step:572/1680 train_time:49924ms step_avg:87.28ms +step:573/1680 train_time:50012ms step_avg:87.28ms +step:574/1680 train_time:50101ms step_avg:87.28ms +step:575/1680 train_time:50190ms step_avg:87.29ms +step:576/1680 train_time:50278ms step_avg:87.29ms +step:577/1680 train_time:50367ms step_avg:87.29ms +step:578/1680 train_time:50455ms step_avg:87.29ms +step:579/1680 train_time:50543ms step_avg:87.29ms +step:580/1680 train_time:50632ms step_avg:87.30ms +step:581/1680 train_time:50721ms step_avg:87.30ms +step:582/1680 train_time:50809ms step_avg:87.30ms +step:583/1680 train_time:50898ms step_avg:87.30ms +step:584/1680 train_time:50986ms step_avg:87.30ms +step:585/1680 train_time:51075ms step_avg:87.31ms +step:586/1680 train_time:51163ms step_avg:87.31ms +step:587/1680 train_time:51251ms step_avg:87.31ms +step:588/1680 train_time:51340ms step_avg:87.31ms +step:589/1680 train_time:51428ms step_avg:87.31ms +step:590/1680 train_time:51516ms step_avg:87.32ms +step:591/1680 train_time:51605ms step_avg:87.32ms +step:592/1680 train_time:51694ms step_avg:87.32ms +step:593/1680 train_time:51782ms step_avg:87.32ms +step:594/1680 train_time:51871ms step_avg:87.32ms +step:595/1680 train_time:51960ms step_avg:87.33ms +step:596/1680 train_time:52048ms step_avg:87.33ms +step:597/1680 train_time:52136ms step_avg:87.33ms +step:598/1680 train_time:52225ms step_avg:87.33ms +step:599/1680 train_time:52313ms step_avg:87.33ms +step:600/1680 train_time:52401ms step_avg:87.34ms +step:601/1680 train_time:52490ms step_avg:87.34ms +step:602/1680 train_time:52579ms step_avg:87.34ms +step:603/1680 train_time:52667ms step_avg:87.34ms +step:604/1680 train_time:52754ms step_avg:87.34ms +step:605/1680 train_time:52843ms step_avg:87.34ms +step:606/1680 train_time:52931ms step_avg:87.35ms +step:607/1680 train_time:53020ms step_avg:87.35ms +step:608/1680 train_time:53108ms step_avg:87.35ms +step:609/1680 train_time:53197ms step_avg:87.35ms +step:610/1680 train_time:53285ms step_avg:87.35ms +step:611/1680 train_time:53374ms step_avg:87.36ms +step:612/1680 train_time:53463ms step_avg:87.36ms +step:613/1680 train_time:53551ms step_avg:87.36ms +step:614/1680 train_time:53639ms step_avg:87.36ms +step:615/1680 train_time:53727ms step_avg:87.36ms +step:616/1680 train_time:53815ms step_avg:87.36ms +step:617/1680 train_time:53904ms step_avg:87.36ms +step:618/1680 train_time:53993ms step_avg:87.37ms +step:619/1680 train_time:54081ms step_avg:87.37ms +step:620/1680 train_time:54169ms step_avg:87.37ms +step:621/1680 train_time:54258ms step_avg:87.37ms +step:622/1680 train_time:54346ms step_avg:87.37ms +step:623/1680 train_time:54435ms step_avg:87.38ms +step:624/1680 train_time:54523ms step_avg:87.38ms +step:625/1680 train_time:54611ms step_avg:87.38ms +step:625/1680 val_loss:3.6142 train_time:54702ms step_avg:87.52ms +step:626/1680 train_time:54721ms step_avg:87.41ms +step:627/1680 train_time:54792ms step_avg:87.39ms +step:628/1680 train_time:54882ms step_avg:87.39ms +step:629/1680 train_time:54972ms step_avg:87.40ms +step:630/1680 train_time:55060ms step_avg:87.40ms +step:631/1680 train_time:55147ms step_avg:87.40ms +step:632/1680 train_time:55235ms step_avg:87.40ms +step:633/1680 train_time:55321ms step_avg:87.40ms +step:634/1680 train_time:55408ms step_avg:87.39ms +step:635/1680 train_time:55496ms step_avg:87.40ms +step:636/1680 train_time:55583ms step_avg:87.40ms +step:637/1680 train_time:55676ms step_avg:87.40ms +step:638/1680 train_time:55767ms step_avg:87.41ms +step:639/1680 train_time:55856ms step_avg:87.41ms +step:640/1680 train_time:55944ms step_avg:87.41ms +step:641/1680 train_time:56034ms step_avg:87.42ms +step:642/1680 train_time:56122ms step_avg:87.42ms +step:643/1680 train_time:56209ms step_avg:87.42ms +step:644/1680 train_time:56297ms step_avg:87.42ms +step:645/1680 train_time:56384ms step_avg:87.42ms +step:646/1680 train_time:56473ms step_avg:87.42ms +step:647/1680 train_time:56560ms step_avg:87.42ms +step:648/1680 train_time:56649ms step_avg:87.42ms +step:649/1680 train_time:56737ms step_avg:87.42ms +step:650/1680 train_time:56826ms step_avg:87.43ms +step:651/1680 train_time:56915ms step_avg:87.43ms +step:652/1680 train_time:57004ms step_avg:87.43ms +step:653/1680 train_time:57093ms step_avg:87.43ms +step:654/1680 train_time:57181ms step_avg:87.43ms +step:655/1680 train_time:57268ms step_avg:87.43ms +step:656/1680 train_time:57356ms step_avg:87.43ms +step:657/1680 train_time:57444ms step_avg:87.43ms +step:658/1680 train_time:57532ms step_avg:87.43ms +step:659/1680 train_time:57621ms step_avg:87.44ms +step:660/1680 train_time:57709ms step_avg:87.44ms +step:661/1680 train_time:57798ms step_avg:87.44ms +step:662/1680 train_time:57886ms step_avg:87.44ms +step:663/1680 train_time:57975ms step_avg:87.44ms +step:664/1680 train_time:58063ms step_avg:87.44ms +step:665/1680 train_time:58151ms step_avg:87.45ms +step:666/1680 train_time:58240ms step_avg:87.45ms +step:667/1680 train_time:58328ms step_avg:87.45ms +step:668/1680 train_time:58416ms step_avg:87.45ms +step:669/1680 train_time:58504ms step_avg:87.45ms +step:670/1680 train_time:58593ms step_avg:87.45ms +step:671/1680 train_time:58681ms step_avg:87.45ms +step:672/1680 train_time:58769ms step_avg:87.45ms +step:673/1680 train_time:58858ms step_avg:87.46ms +step:674/1680 train_time:58947ms step_avg:87.46ms +step:675/1680 train_time:59035ms step_avg:87.46ms +step:676/1680 train_time:59124ms step_avg:87.46ms +step:677/1680 train_time:59212ms step_avg:87.46ms +step:678/1680 train_time:59300ms step_avg:87.46ms +step:679/1680 train_time:59388ms step_avg:87.46ms +step:680/1680 train_time:59476ms step_avg:87.46ms +step:681/1680 train_time:59564ms step_avg:87.47ms +step:682/1680 train_time:59652ms step_avg:87.47ms +step:683/1680 train_time:59740ms step_avg:87.47ms +step:684/1680 train_time:59829ms step_avg:87.47ms +step:685/1680 train_time:59918ms step_avg:87.47ms +step:686/1680 train_time:60006ms step_avg:87.47ms +step:687/1680 train_time:60094ms step_avg:87.47ms +step:688/1680 train_time:60183ms step_avg:87.48ms +step:689/1680 train_time:60272ms step_avg:87.48ms +step:690/1680 train_time:60359ms step_avg:87.48ms +step:691/1680 train_time:60448ms step_avg:87.48ms +step:692/1680 train_time:60537ms step_avg:87.48ms +step:693/1680 train_time:60625ms step_avg:87.48ms +step:694/1680 train_time:60714ms step_avg:87.48ms +step:695/1680 train_time:60802ms step_avg:87.49ms +step:696/1680 train_time:60891ms step_avg:87.49ms +step:697/1680 train_time:60980ms step_avg:87.49ms +step:698/1680 train_time:61069ms step_avg:87.49ms +step:699/1680 train_time:61157ms step_avg:87.49ms +step:700/1680 train_time:61245ms step_avg:87.49ms +step:701/1680 train_time:61334ms step_avg:87.49ms +step:702/1680 train_time:61422ms step_avg:87.50ms +step:703/1680 train_time:61510ms step_avg:87.50ms +step:704/1680 train_time:61598ms step_avg:87.50ms +step:705/1680 train_time:61686ms step_avg:87.50ms +step:706/1680 train_time:61774ms step_avg:87.50ms +step:707/1680 train_time:61862ms step_avg:87.50ms +step:708/1680 train_time:61952ms step_avg:87.50ms +step:709/1680 train_time:62041ms step_avg:87.51ms +step:710/1680 train_time:62130ms step_avg:87.51ms +step:711/1680 train_time:62218ms step_avg:87.51ms +step:712/1680 train_time:62306ms step_avg:87.51ms +step:713/1680 train_time:62394ms step_avg:87.51ms +step:714/1680 train_time:62482ms step_avg:87.51ms +step:715/1680 train_time:62570ms step_avg:87.51ms +step:716/1680 train_time:62658ms step_avg:87.51ms +step:717/1680 train_time:62746ms step_avg:87.51ms +step:718/1680 train_time:62835ms step_avg:87.51ms +step:719/1680 train_time:62924ms step_avg:87.52ms +step:720/1680 train_time:63012ms step_avg:87.52ms +step:721/1680 train_time:63101ms step_avg:87.52ms +step:722/1680 train_time:63189ms step_avg:87.52ms +step:723/1680 train_time:63278ms step_avg:87.52ms +step:724/1680 train_time:63366ms step_avg:87.52ms +step:725/1680 train_time:63454ms step_avg:87.52ms +step:726/1680 train_time:63543ms step_avg:87.52ms +step:727/1680 train_time:63631ms step_avg:87.53ms +step:728/1680 train_time:63720ms step_avg:87.53ms +step:729/1680 train_time:63809ms step_avg:87.53ms +step:730/1680 train_time:63897ms step_avg:87.53ms +step:731/1680 train_time:63985ms step_avg:87.53ms +step:732/1680 train_time:64074ms step_avg:87.53ms +step:733/1680 train_time:64162ms step_avg:87.53ms +step:734/1680 train_time:64251ms step_avg:87.53ms +step:735/1680 train_time:64339ms step_avg:87.54ms +step:736/1680 train_time:64427ms step_avg:87.54ms +step:737/1680 train_time:64516ms step_avg:87.54ms +step:738/1680 train_time:64604ms step_avg:87.54ms +step:739/1680 train_time:64692ms step_avg:87.54ms +step:740/1680 train_time:64781ms step_avg:87.54ms +step:741/1680 train_time:64870ms step_avg:87.54ms +step:742/1680 train_time:64958ms step_avg:87.54ms +step:743/1680 train_time:65047ms step_avg:87.55ms +step:744/1680 train_time:65135ms step_avg:87.55ms +step:745/1680 train_time:65224ms step_avg:87.55ms +step:746/1680 train_time:65313ms step_avg:87.55ms +step:747/1680 train_time:65401ms step_avg:87.55ms +step:748/1680 train_time:65489ms step_avg:87.55ms +step:749/1680 train_time:65578ms step_avg:87.55ms +step:750/1680 train_time:65666ms step_avg:87.55ms +step:750/1680 val_loss:3.5628 train_time:65756ms step_avg:87.67ms +step:751/1680 train_time:65774ms step_avg:87.58ms +step:752/1680 train_time:65847ms step_avg:87.56ms +step:753/1680 train_time:65939ms step_avg:87.57ms +step:754/1680 train_time:66029ms step_avg:87.57ms +step:755/1680 train_time:66117ms step_avg:87.57ms +step:756/1680 train_time:66204ms step_avg:87.57ms +step:757/1680 train_time:66292ms step_avg:87.57ms +step:758/1680 train_time:66379ms step_avg:87.57ms +step:759/1680 train_time:66466ms step_avg:87.57ms +step:760/1680 train_time:66554ms step_avg:87.57ms +step:761/1680 train_time:66641ms step_avg:87.57ms +step:762/1680 train_time:66730ms step_avg:87.57ms +step:763/1680 train_time:66821ms step_avg:87.58ms +step:764/1680 train_time:66911ms step_avg:87.58ms +step:765/1680 train_time:67000ms step_avg:87.58ms +step:766/1680 train_time:67089ms step_avg:87.58ms +step:767/1680 train_time:67177ms step_avg:87.58ms +step:768/1680 train_time:67266ms step_avg:87.59ms +step:769/1680 train_time:67353ms step_avg:87.59ms +step:770/1680 train_time:67441ms step_avg:87.59ms +step:771/1680 train_time:67528ms step_avg:87.58ms +step:772/1680 train_time:67616ms step_avg:87.59ms +step:773/1680 train_time:67705ms step_avg:87.59ms +step:774/1680 train_time:67794ms step_avg:87.59ms +step:775/1680 train_time:67884ms step_avg:87.59ms +step:776/1680 train_time:67973ms step_avg:87.59ms +step:777/1680 train_time:68062ms step_avg:87.60ms +step:778/1680 train_time:68152ms step_avg:87.60ms +step:779/1680 train_time:68240ms step_avg:87.60ms +step:780/1680 train_time:68328ms step_avg:87.60ms +step:781/1680 train_time:68416ms step_avg:87.60ms +step:782/1680 train_time:68504ms step_avg:87.60ms +step:783/1680 train_time:68592ms step_avg:87.60ms +step:784/1680 train_time:68681ms step_avg:87.60ms +step:785/1680 train_time:68769ms step_avg:87.60ms +step:786/1680 train_time:68858ms step_avg:87.61ms +step:787/1680 train_time:68946ms step_avg:87.61ms +step:788/1680 train_time:69035ms step_avg:87.61ms +step:789/1680 train_time:69124ms step_avg:87.61ms +step:790/1680 train_time:69213ms step_avg:87.61ms +step:791/1680 train_time:69300ms step_avg:87.61ms +step:792/1680 train_time:69388ms step_avg:87.61ms +step:793/1680 train_time:69476ms step_avg:87.61ms +step:794/1680 train_time:69564ms step_avg:87.61ms +step:795/1680 train_time:69652ms step_avg:87.61ms +step:796/1680 train_time:69741ms step_avg:87.61ms +step:797/1680 train_time:69829ms step_avg:87.62ms +step:798/1680 train_time:69917ms step_avg:87.62ms +step:799/1680 train_time:70006ms step_avg:87.62ms +step:800/1680 train_time:70094ms step_avg:87.62ms +step:801/1680 train_time:70183ms step_avg:87.62ms +step:802/1680 train_time:70271ms step_avg:87.62ms +step:803/1680 train_time:70358ms step_avg:87.62ms +step:804/1680 train_time:70447ms step_avg:87.62ms +step:805/1680 train_time:70535ms step_avg:87.62ms +step:806/1680 train_time:70624ms step_avg:87.62ms +step:807/1680 train_time:70712ms step_avg:87.62ms +step:808/1680 train_time:70800ms step_avg:87.62ms +step:809/1680 train_time:70890ms step_avg:87.63ms +step:810/1680 train_time:70978ms step_avg:87.63ms +step:811/1680 train_time:71066ms step_avg:87.63ms +step:812/1680 train_time:71154ms step_avg:87.63ms +step:813/1680 train_time:71242ms step_avg:87.63ms +step:814/1680 train_time:71331ms step_avg:87.63ms +step:815/1680 train_time:71419ms step_avg:87.63ms +step:816/1680 train_time:71506ms step_avg:87.63ms +step:817/1680 train_time:71594ms step_avg:87.63ms +step:818/1680 train_time:71682ms step_avg:87.63ms +step:819/1680 train_time:71771ms step_avg:87.63ms +step:820/1680 train_time:71859ms step_avg:87.63ms +step:821/1680 train_time:71948ms step_avg:87.63ms +step:822/1680 train_time:72036ms step_avg:87.64ms +step:823/1680 train_time:72125ms step_avg:87.64ms +step:824/1680 train_time:72213ms step_avg:87.64ms +step:825/1680 train_time:72302ms step_avg:87.64ms +step:826/1680 train_time:72390ms step_avg:87.64ms +step:827/1680 train_time:72478ms step_avg:87.64ms +step:828/1680 train_time:72566ms step_avg:87.64ms +step:829/1680 train_time:72655ms step_avg:87.64ms +step:830/1680 train_time:72743ms step_avg:87.64ms +step:831/1680 train_time:72831ms step_avg:87.64ms +step:832/1680 train_time:72920ms step_avg:87.64ms +step:833/1680 train_time:73009ms step_avg:87.65ms +step:834/1680 train_time:73098ms step_avg:87.65ms +step:835/1680 train_time:73186ms step_avg:87.65ms +step:836/1680 train_time:73274ms step_avg:87.65ms +step:837/1680 train_time:73363ms step_avg:87.65ms +step:838/1680 train_time:73451ms step_avg:87.65ms +step:839/1680 train_time:73540ms step_avg:87.65ms +step:840/1680 train_time:73629ms step_avg:87.65ms +step:841/1680 train_time:73718ms step_avg:87.65ms +step:842/1680 train_time:73805ms step_avg:87.65ms +step:843/1680 train_time:73893ms step_avg:87.66ms +step:844/1680 train_time:73982ms step_avg:87.66ms +step:845/1680 train_time:74070ms step_avg:87.66ms +step:846/1680 train_time:74158ms step_avg:87.66ms +step:847/1680 train_time:74246ms step_avg:87.66ms +step:848/1680 train_time:74334ms step_avg:87.66ms +step:849/1680 train_time:74423ms step_avg:87.66ms +step:850/1680 train_time:74511ms step_avg:87.66ms +step:851/1680 train_time:74599ms step_avg:87.66ms +step:852/1680 train_time:74688ms step_avg:87.66ms +step:853/1680 train_time:74777ms step_avg:87.66ms +step:854/1680 train_time:74866ms step_avg:87.66ms +step:855/1680 train_time:74954ms step_avg:87.67ms +step:856/1680 train_time:75043ms step_avg:87.67ms +step:857/1680 train_time:75131ms step_avg:87.67ms +step:858/1680 train_time:75219ms step_avg:87.67ms +step:859/1680 train_time:75308ms step_avg:87.67ms +step:860/1680 train_time:75396ms step_avg:87.67ms +step:861/1680 train_time:75485ms step_avg:87.67ms +step:862/1680 train_time:75573ms step_avg:87.67ms +step:863/1680 train_time:75662ms step_avg:87.67ms +step:864/1680 train_time:75750ms step_avg:87.67ms +step:865/1680 train_time:75838ms step_avg:87.67ms +step:866/1680 train_time:75926ms step_avg:87.67ms +step:867/1680 train_time:76015ms step_avg:87.68ms +step:868/1680 train_time:76104ms step_avg:87.68ms +step:869/1680 train_time:76192ms step_avg:87.68ms +step:870/1680 train_time:76280ms step_avg:87.68ms +step:871/1680 train_time:76368ms step_avg:87.68ms +step:872/1680 train_time:76457ms step_avg:87.68ms +step:873/1680 train_time:76545ms step_avg:87.68ms +step:874/1680 train_time:76633ms step_avg:87.68ms +step:875/1680 train_time:76721ms step_avg:87.68ms +step:875/1680 val_loss:3.5171 train_time:76811ms step_avg:87.78ms +step:876/1680 train_time:76831ms step_avg:87.71ms +step:877/1680 train_time:76902ms step_avg:87.69ms +step:878/1680 train_time:76993ms step_avg:87.69ms +step:879/1680 train_time:77086ms step_avg:87.70ms +step:880/1680 train_time:77174ms step_avg:87.70ms +step:881/1680 train_time:77262ms step_avg:87.70ms +step:882/1680 train_time:77349ms step_avg:87.70ms +step:883/1680 train_time:77436ms step_avg:87.70ms +step:884/1680 train_time:77523ms step_avg:87.70ms +step:885/1680 train_time:77611ms step_avg:87.70ms +step:886/1680 train_time:77698ms step_avg:87.70ms +step:887/1680 train_time:77786ms step_avg:87.70ms +step:888/1680 train_time:77876ms step_avg:87.70ms +step:889/1680 train_time:77966ms step_avg:87.70ms +step:890/1680 train_time:78056ms step_avg:87.70ms +step:891/1680 train_time:78146ms step_avg:87.71ms +step:892/1680 train_time:78234ms step_avg:87.71ms +step:893/1680 train_time:78322ms step_avg:87.71ms +step:894/1680 train_time:78410ms step_avg:87.71ms +step:895/1680 train_time:78498ms step_avg:87.71ms +step:896/1680 train_time:78585ms step_avg:87.71ms +step:897/1680 train_time:78673ms step_avg:87.71ms +step:898/1680 train_time:78761ms step_avg:87.71ms +step:899/1680 train_time:78849ms step_avg:87.71ms +step:900/1680 train_time:78939ms step_avg:87.71ms +step:901/1680 train_time:79028ms step_avg:87.71ms +step:902/1680 train_time:79117ms step_avg:87.71ms +step:903/1680 train_time:79206ms step_avg:87.71ms +step:904/1680 train_time:79294ms step_avg:87.71ms +step:905/1680 train_time:79382ms step_avg:87.72ms +step:906/1680 train_time:79470ms step_avg:87.71ms +step:907/1680 train_time:79557ms step_avg:87.71ms +step:908/1680 train_time:79645ms step_avg:87.72ms +step:909/1680 train_time:79734ms step_avg:87.72ms +step:910/1680 train_time:79822ms step_avg:87.72ms +step:911/1680 train_time:79911ms step_avg:87.72ms +step:912/1680 train_time:80000ms step_avg:87.72ms +step:913/1680 train_time:80089ms step_avg:87.72ms +step:914/1680 train_time:80178ms step_avg:87.72ms +step:915/1680 train_time:80266ms step_avg:87.72ms +step:916/1680 train_time:80354ms step_avg:87.72ms +step:917/1680 train_time:80443ms step_avg:87.72ms +step:918/1680 train_time:80530ms step_avg:87.72ms +step:919/1680 train_time:80619ms step_avg:87.72ms +step:920/1680 train_time:80706ms step_avg:87.72ms +step:921/1680 train_time:80794ms step_avg:87.72ms +step:922/1680 train_time:80882ms step_avg:87.72ms +step:923/1680 train_time:80972ms step_avg:87.73ms +step:924/1680 train_time:81060ms step_avg:87.73ms +step:925/1680 train_time:81150ms step_avg:87.73ms +step:926/1680 train_time:81239ms step_avg:87.73ms +step:927/1680 train_time:81328ms step_avg:87.73ms +step:928/1680 train_time:81417ms step_avg:87.73ms +step:929/1680 train_time:81506ms step_avg:87.74ms +step:930/1680 train_time:81593ms step_avg:87.73ms +step:931/1680 train_time:81681ms step_avg:87.73ms +step:932/1680 train_time:81769ms step_avg:87.74ms +step:933/1680 train_time:81858ms step_avg:87.74ms +step:934/1680 train_time:81946ms step_avg:87.74ms +step:935/1680 train_time:82035ms step_avg:87.74ms +step:936/1680 train_time:82124ms step_avg:87.74ms +step:937/1680 train_time:82212ms step_avg:87.74ms +step:938/1680 train_time:82300ms step_avg:87.74ms +step:939/1680 train_time:82388ms step_avg:87.74ms +step:940/1680 train_time:82477ms step_avg:87.74ms +step:941/1680 train_time:82565ms step_avg:87.74ms +step:942/1680 train_time:82653ms step_avg:87.74ms +step:943/1680 train_time:82741ms step_avg:87.74ms +step:944/1680 train_time:82829ms step_avg:87.74ms +step:945/1680 train_time:82918ms step_avg:87.74ms +step:946/1680 train_time:83007ms step_avg:87.75ms +step:947/1680 train_time:83096ms step_avg:87.75ms +step:948/1680 train_time:83184ms step_avg:87.75ms +step:949/1680 train_time:83273ms step_avg:87.75ms +step:950/1680 train_time:83362ms step_avg:87.75ms +step:951/1680 train_time:83450ms step_avg:87.75ms +step:952/1680 train_time:83539ms step_avg:87.75ms +step:953/1680 train_time:83627ms step_avg:87.75ms +step:954/1680 train_time:83716ms step_avg:87.75ms +step:955/1680 train_time:83804ms step_avg:87.75ms +step:956/1680 train_time:83892ms step_avg:87.75ms +step:957/1680 train_time:83980ms step_avg:87.75ms +step:958/1680 train_time:84069ms step_avg:87.75ms +step:959/1680 train_time:84157ms step_avg:87.76ms +step:960/1680 train_time:84245ms step_avg:87.76ms +step:961/1680 train_time:84334ms step_avg:87.76ms +step:962/1680 train_time:84423ms step_avg:87.76ms +step:963/1680 train_time:84511ms step_avg:87.76ms +step:964/1680 train_time:84599ms step_avg:87.76ms +step:965/1680 train_time:84688ms step_avg:87.76ms +step:966/1680 train_time:84776ms step_avg:87.76ms +step:967/1680 train_time:84865ms step_avg:87.76ms +step:968/1680 train_time:84953ms step_avg:87.76ms +step:969/1680 train_time:85042ms step_avg:87.76ms +step:970/1680 train_time:85130ms step_avg:87.76ms +step:971/1680 train_time:85219ms step_avg:87.76ms +step:972/1680 train_time:85308ms step_avg:87.76ms +step:973/1680 train_time:85396ms step_avg:87.77ms +step:974/1680 train_time:85484ms step_avg:87.77ms +step:975/1680 train_time:85573ms step_avg:87.77ms +step:976/1680 train_time:85660ms step_avg:87.77ms +step:977/1680 train_time:85748ms step_avg:87.77ms +step:978/1680 train_time:85837ms step_avg:87.77ms +step:979/1680 train_time:85925ms step_avg:87.77ms +step:980/1680 train_time:86013ms step_avg:87.77ms +step:981/1680 train_time:86101ms step_avg:87.77ms +step:982/1680 train_time:86190ms step_avg:87.77ms +step:983/1680 train_time:86278ms step_avg:87.77ms +step:984/1680 train_time:86366ms step_avg:87.77ms +step:985/1680 train_time:86454ms step_avg:87.77ms +step:986/1680 train_time:86543ms step_avg:87.77ms +step:987/1680 train_time:86631ms step_avg:87.77ms +step:988/1680 train_time:86720ms step_avg:87.77ms +step:989/1680 train_time:86808ms step_avg:87.77ms +step:990/1680 train_time:86897ms step_avg:87.77ms +step:991/1680 train_time:86985ms step_avg:87.77ms +step:992/1680 train_time:87073ms step_avg:87.77ms +step:993/1680 train_time:87161ms step_avg:87.78ms +step:994/1680 train_time:87249ms step_avg:87.78ms +step:995/1680 train_time:87337ms step_avg:87.78ms +step:996/1680 train_time:87426ms step_avg:87.78ms +step:997/1680 train_time:87514ms step_avg:87.78ms +step:998/1680 train_time:87603ms step_avg:87.78ms +step:999/1680 train_time:87691ms step_avg:87.78ms +step:1000/1680 train_time:87779ms step_avg:87.78ms +step:1000/1680 val_loss:3.4673 train_time:87868ms step_avg:87.87ms +step:1001/1680 train_time:87887ms step_avg:87.80ms +step:1002/1680 train_time:87960ms step_avg:87.78ms +step:1003/1680 train_time:88051ms step_avg:87.79ms +step:1004/1680 train_time:88139ms step_avg:87.79ms +step:1005/1680 train_time:88227ms step_avg:87.79ms +step:1006/1680 train_time:88314ms step_avg:87.79ms +step:1007/1680 train_time:88402ms step_avg:87.79ms +step:1008/1680 train_time:88490ms step_avg:87.79ms +step:1009/1680 train_time:88578ms step_avg:87.79ms +step:1010/1680 train_time:88665ms step_avg:87.79ms +step:1011/1680 train_time:88753ms step_avg:87.79ms +step:1012/1680 train_time:88842ms step_avg:87.79ms +step:1013/1680 train_time:88932ms step_avg:87.79ms +step:1014/1680 train_time:89022ms step_avg:87.79ms +step:1015/1680 train_time:89112ms step_avg:87.79ms +step:1016/1680 train_time:89200ms step_avg:87.80ms +step:1017/1680 train_time:89289ms step_avg:87.80ms +step:1018/1680 train_time:89376ms step_avg:87.80ms +step:1019/1680 train_time:89464ms step_avg:87.80ms +step:1020/1680 train_time:89552ms step_avg:87.80ms +step:1021/1680 train_time:89639ms step_avg:87.80ms +step:1022/1680 train_time:89728ms step_avg:87.80ms +step:1023/1680 train_time:89816ms step_avg:87.80ms +step:1024/1680 train_time:89906ms step_avg:87.80ms +step:1025/1680 train_time:89995ms step_avg:87.80ms +step:1026/1680 train_time:90084ms step_avg:87.80ms +step:1027/1680 train_time:90173ms step_avg:87.80ms +step:1028/1680 train_time:90262ms step_avg:87.80ms +step:1029/1680 train_time:90350ms step_avg:87.80ms +step:1030/1680 train_time:90438ms step_avg:87.80ms +step:1031/1680 train_time:90525ms step_avg:87.80ms +step:1032/1680 train_time:90613ms step_avg:87.80ms +step:1033/1680 train_time:90701ms step_avg:87.80ms +step:1034/1680 train_time:90790ms step_avg:87.80ms +step:1035/1680 train_time:90878ms step_avg:87.81ms +step:1036/1680 train_time:90968ms step_avg:87.81ms +step:1037/1680 train_time:91057ms step_avg:87.81ms +step:1038/1680 train_time:91145ms step_avg:87.81ms +step:1039/1680 train_time:91234ms step_avg:87.81ms +step:1040/1680 train_time:91322ms step_avg:87.81ms +step:1041/1680 train_time:91411ms step_avg:87.81ms +step:1042/1680 train_time:91498ms step_avg:87.81ms +step:1043/1680 train_time:91587ms step_avg:87.81ms +step:1044/1680 train_time:91674ms step_avg:87.81ms +step:1045/1680 train_time:91762ms step_avg:87.81ms +step:1046/1680 train_time:91852ms step_avg:87.81ms +step:1047/1680 train_time:91941ms step_avg:87.81ms +step:1048/1680 train_time:92031ms step_avg:87.82ms +step:1049/1680 train_time:92119ms step_avg:87.82ms +step:1050/1680 train_time:92208ms step_avg:87.82ms +step:1051/1680 train_time:92297ms step_avg:87.82ms +step:1052/1680 train_time:92385ms step_avg:87.82ms +step:1053/1680 train_time:92473ms step_avg:87.82ms +step:1054/1680 train_time:92561ms step_avg:87.82ms +step:1055/1680 train_time:92648ms step_avg:87.82ms +step:1056/1680 train_time:92736ms step_avg:87.82ms +step:1057/1680 train_time:92825ms step_avg:87.82ms +step:1058/1680 train_time:92913ms step_avg:87.82ms +step:1059/1680 train_time:93002ms step_avg:87.82ms +step:1060/1680 train_time:93090ms step_avg:87.82ms +step:1061/1680 train_time:93178ms step_avg:87.82ms +step:1062/1680 train_time:93267ms step_avg:87.82ms +step:1063/1680 train_time:93356ms step_avg:87.82ms +step:1064/1680 train_time:93445ms step_avg:87.82ms +step:1065/1680 train_time:93533ms step_avg:87.82ms +step:1066/1680 train_time:93621ms step_avg:87.83ms +step:1067/1680 train_time:93710ms step_avg:87.83ms +step:1068/1680 train_time:93798ms step_avg:87.83ms +step:1069/1680 train_time:93886ms step_avg:87.83ms +step:1070/1680 train_time:93975ms step_avg:87.83ms +step:1071/1680 train_time:94064ms step_avg:87.83ms +step:1072/1680 train_time:94154ms step_avg:87.83ms +step:1073/1680 train_time:94242ms step_avg:87.83ms +step:1074/1680 train_time:94330ms step_avg:87.83ms +step:1075/1680 train_time:94419ms step_avg:87.83ms +step:1076/1680 train_time:94507ms step_avg:87.83ms +step:1077/1680 train_time:94596ms step_avg:87.83ms +step:1078/1680 train_time:94684ms step_avg:87.83ms +step:1079/1680 train_time:94772ms step_avg:87.83ms +step:1080/1680 train_time:94861ms step_avg:87.83ms +step:1081/1680 train_time:94950ms step_avg:87.84ms +step:1082/1680 train_time:95039ms step_avg:87.84ms +step:1083/1680 train_time:95127ms step_avg:87.84ms +step:1084/1680 train_time:95216ms step_avg:87.84ms +step:1085/1680 train_time:95305ms step_avg:87.84ms +step:1086/1680 train_time:95394ms step_avg:87.84ms +step:1087/1680 train_time:95482ms step_avg:87.84ms +step:1088/1680 train_time:95571ms step_avg:87.84ms +step:1089/1680 train_time:95659ms step_avg:87.84ms +step:1090/1680 train_time:95747ms step_avg:87.84ms +step:1091/1680 train_time:95835ms step_avg:87.84ms +step:1092/1680 train_time:95924ms step_avg:87.84ms +step:1093/1680 train_time:96013ms step_avg:87.84ms +step:1094/1680 train_time:96101ms step_avg:87.84ms +step:1095/1680 train_time:96190ms step_avg:87.85ms +step:1096/1680 train_time:96280ms step_avg:87.85ms +step:1097/1680 train_time:96368ms step_avg:87.85ms +step:1098/1680 train_time:96457ms step_avg:87.85ms +step:1099/1680 train_time:96547ms step_avg:87.85ms +step:1100/1680 train_time:96635ms step_avg:87.85ms +step:1101/1680 train_time:96725ms step_avg:87.85ms +step:1102/1680 train_time:96814ms step_avg:87.85ms +step:1103/1680 train_time:96904ms step_avg:87.85ms +step:1104/1680 train_time:96993ms step_avg:87.86ms +step:1105/1680 train_time:97082ms step_avg:87.86ms +step:1106/1680 train_time:97172ms step_avg:87.86ms +step:1107/1680 train_time:97261ms step_avg:87.86ms +step:1108/1680 train_time:97351ms step_avg:87.86ms +step:1109/1680 train_time:97440ms step_avg:87.86ms +step:1110/1680 train_time:97529ms step_avg:87.86ms +step:1111/1680 train_time:97618ms step_avg:87.86ms +step:1112/1680 train_time:97707ms step_avg:87.87ms +step:1113/1680 train_time:97796ms step_avg:87.87ms +step:1114/1680 train_time:97885ms step_avg:87.87ms +step:1115/1680 train_time:97975ms step_avg:87.87ms +step:1116/1680 train_time:98065ms step_avg:87.87ms +step:1117/1680 train_time:98154ms step_avg:87.87ms +step:1118/1680 train_time:98243ms step_avg:87.87ms +step:1119/1680 train_time:98333ms step_avg:87.88ms +step:1120/1680 train_time:98422ms step_avg:87.88ms +step:1121/1680 train_time:98510ms step_avg:87.88ms +step:1122/1680 train_time:98601ms step_avg:87.88ms +step:1123/1680 train_time:98689ms step_avg:87.88ms +step:1124/1680 train_time:98778ms step_avg:87.88ms +step:1125/1680 train_time:98868ms step_avg:87.88ms +step:1125/1680 val_loss:3.4137 train_time:98959ms step_avg:87.96ms +step:1126/1680 train_time:98979ms step_avg:87.90ms +step:1127/1680 train_time:99049ms step_avg:87.89ms +step:1128/1680 train_time:99141ms step_avg:87.89ms +step:1129/1680 train_time:99233ms step_avg:87.89ms +step:1130/1680 train_time:99323ms step_avg:87.90ms +step:1131/1680 train_time:99411ms step_avg:87.90ms +step:1132/1680 train_time:99498ms step_avg:87.90ms +step:1133/1680 train_time:99586ms step_avg:87.90ms +step:1134/1680 train_time:99674ms step_avg:87.90ms +step:1135/1680 train_time:99763ms step_avg:87.90ms +step:1136/1680 train_time:99852ms step_avg:87.90ms +step:1137/1680 train_time:99944ms step_avg:87.90ms +step:1138/1680 train_time:100035ms step_avg:87.90ms +step:1139/1680 train_time:100125ms step_avg:87.91ms +step:1140/1680 train_time:100216ms step_avg:87.91ms +step:1141/1680 train_time:100305ms step_avg:87.91ms +step:1142/1680 train_time:100394ms step_avg:87.91ms +step:1143/1680 train_time:100483ms step_avg:87.91ms +step:1144/1680 train_time:100571ms step_avg:87.91ms +step:1145/1680 train_time:100659ms step_avg:87.91ms +step:1146/1680 train_time:100747ms step_avg:87.91ms +step:1147/1680 train_time:100836ms step_avg:87.91ms +step:1148/1680 train_time:100926ms step_avg:87.91ms +step:1149/1680 train_time:101015ms step_avg:87.92ms +step:1150/1680 train_time:101105ms step_avg:87.92ms +step:1151/1680 train_time:101195ms step_avg:87.92ms +step:1152/1680 train_time:101285ms step_avg:87.92ms +step:1153/1680 train_time:101373ms step_avg:87.92ms +step:1154/1680 train_time:101462ms step_avg:87.92ms +step:1155/1680 train_time:101551ms step_avg:87.92ms +step:1156/1680 train_time:101639ms step_avg:87.92ms +step:1157/1680 train_time:101728ms step_avg:87.92ms +step:1158/1680 train_time:101817ms step_avg:87.93ms +step:1159/1680 train_time:101906ms step_avg:87.93ms +step:1160/1680 train_time:101994ms step_avg:87.93ms +step:1161/1680 train_time:102084ms step_avg:87.93ms +step:1162/1680 train_time:102174ms step_avg:87.93ms +step:1163/1680 train_time:102264ms step_avg:87.93ms +step:1164/1680 train_time:102354ms step_avg:87.93ms +step:1165/1680 train_time:102443ms step_avg:87.93ms +step:1166/1680 train_time:102532ms step_avg:87.93ms +step:1167/1680 train_time:102620ms step_avg:87.94ms +step:1168/1680 train_time:102709ms step_avg:87.94ms +step:1169/1680 train_time:102798ms step_avg:87.94ms +step:1170/1680 train_time:102886ms step_avg:87.94ms +step:1171/1680 train_time:102975ms step_avg:87.94ms +step:1172/1680 train_time:103065ms step_avg:87.94ms +step:1173/1680 train_time:103155ms step_avg:87.94ms +step:1174/1680 train_time:103243ms step_avg:87.94ms +step:1175/1680 train_time:103333ms step_avg:87.94ms +step:1176/1680 train_time:103422ms step_avg:87.94ms +step:1177/1680 train_time:103512ms step_avg:87.95ms +step:1178/1680 train_time:103600ms step_avg:87.95ms +step:1179/1680 train_time:103689ms step_avg:87.95ms +step:1180/1680 train_time:103778ms step_avg:87.95ms +step:1181/1680 train_time:103867ms step_avg:87.95ms +step:1182/1680 train_time:103957ms step_avg:87.95ms +step:1183/1680 train_time:104046ms step_avg:87.95ms +step:1184/1680 train_time:104136ms step_avg:87.95ms +step:1185/1680 train_time:104225ms step_avg:87.95ms +step:1186/1680 train_time:104314ms step_avg:87.95ms +step:1187/1680 train_time:104404ms step_avg:87.96ms +step:1188/1680 train_time:104493ms step_avg:87.96ms +step:1189/1680 train_time:104581ms step_avg:87.96ms +step:1190/1680 train_time:104670ms step_avg:87.96ms +step:1191/1680 train_time:104759ms step_avg:87.96ms +step:1192/1680 train_time:104848ms step_avg:87.96ms +step:1193/1680 train_time:104936ms step_avg:87.96ms +step:1194/1680 train_time:105026ms step_avg:87.96ms +step:1195/1680 train_time:105115ms step_avg:87.96ms +step:1196/1680 train_time:105204ms step_avg:87.96ms +step:1197/1680 train_time:105293ms step_avg:87.96ms +step:1198/1680 train_time:105381ms step_avg:87.96ms +step:1199/1680 train_time:105471ms step_avg:87.97ms +step:1200/1680 train_time:105559ms step_avg:87.97ms +step:1201/1680 train_time:105648ms step_avg:87.97ms +step:1202/1680 train_time:105737ms step_avg:87.97ms +step:1203/1680 train_time:105826ms step_avg:87.97ms +step:1204/1680 train_time:105916ms step_avg:87.97ms +step:1205/1680 train_time:106004ms step_avg:87.97ms +step:1206/1680 train_time:106094ms step_avg:87.97ms +step:1207/1680 train_time:106183ms step_avg:87.97ms +step:1208/1680 train_time:106272ms step_avg:87.97ms +step:1209/1680 train_time:106361ms step_avg:87.97ms +step:1210/1680 train_time:106451ms step_avg:87.98ms +step:1211/1680 train_time:106539ms step_avg:87.98ms +step:1212/1680 train_time:106628ms step_avg:87.98ms +step:1213/1680 train_time:106718ms step_avg:87.98ms +step:1214/1680 train_time:106807ms step_avg:87.98ms +step:1215/1680 train_time:106896ms step_avg:87.98ms +step:1216/1680 train_time:106986ms step_avg:87.98ms +step:1217/1680 train_time:107075ms step_avg:87.98ms +step:1218/1680 train_time:107164ms step_avg:87.98ms +step:1219/1680 train_time:107253ms step_avg:87.98ms +step:1220/1680 train_time:107343ms step_avg:87.99ms +step:1221/1680 train_time:107432ms step_avg:87.99ms +step:1222/1680 train_time:107521ms step_avg:87.99ms +step:1223/1680 train_time:107610ms step_avg:87.99ms +step:1224/1680 train_time:107699ms step_avg:87.99ms +step:1225/1680 train_time:107788ms step_avg:87.99ms +step:1226/1680 train_time:107876ms step_avg:87.99ms +step:1227/1680 train_time:107965ms step_avg:87.99ms +step:1228/1680 train_time:108054ms step_avg:87.99ms +step:1229/1680 train_time:108143ms step_avg:87.99ms +step:1230/1680 train_time:108232ms step_avg:87.99ms +step:1231/1680 train_time:108322ms step_avg:87.99ms +step:1232/1680 train_time:108412ms step_avg:88.00ms +step:1233/1680 train_time:108501ms step_avg:88.00ms +step:1234/1680 train_time:108590ms step_avg:88.00ms +step:1235/1680 train_time:108679ms step_avg:88.00ms +step:1236/1680 train_time:108768ms step_avg:88.00ms +step:1237/1680 train_time:108858ms step_avg:88.00ms +step:1238/1680 train_time:108946ms step_avg:88.00ms +step:1239/1680 train_time:109035ms step_avg:88.00ms +step:1240/1680 train_time:109124ms step_avg:88.00ms +step:1241/1680 train_time:109213ms step_avg:88.00ms +step:1242/1680 train_time:109302ms step_avg:88.00ms +step:1243/1680 train_time:109391ms step_avg:88.01ms +step:1244/1680 train_time:109481ms step_avg:88.01ms +step:1245/1680 train_time:109570ms step_avg:88.01ms +step:1246/1680 train_time:109659ms step_avg:88.01ms +step:1247/1680 train_time:109748ms step_avg:88.01ms +step:1248/1680 train_time:109837ms step_avg:88.01ms +step:1249/1680 train_time:109926ms step_avg:88.01ms +step:1250/1680 train_time:110015ms step_avg:88.01ms +step:1250/1680 val_loss:3.3752 train_time:110105ms step_avg:88.08ms +step:1251/1680 train_time:110123ms step_avg:88.03ms +step:1252/1680 train_time:110196ms step_avg:88.02ms +step:1253/1680 train_time:110289ms step_avg:88.02ms +step:1254/1680 train_time:110378ms step_avg:88.02ms +step:1255/1680 train_time:110467ms step_avg:88.02ms +step:1256/1680 train_time:110554ms step_avg:88.02ms +step:1257/1680 train_time:110642ms step_avg:88.02ms +step:1258/1680 train_time:110730ms step_avg:88.02ms +step:1259/1680 train_time:110818ms step_avg:88.02ms +step:1260/1680 train_time:110906ms step_avg:88.02ms +step:1261/1680 train_time:110994ms step_avg:88.02ms +step:1262/1680 train_time:111084ms step_avg:88.02ms +step:1263/1680 train_time:111176ms step_avg:88.03ms +step:1264/1680 train_time:111267ms step_avg:88.03ms +step:1265/1680 train_time:111357ms step_avg:88.03ms +step:1266/1680 train_time:111447ms step_avg:88.03ms +step:1267/1680 train_time:111535ms step_avg:88.03ms +step:1268/1680 train_time:111623ms step_avg:88.03ms +step:1269/1680 train_time:111712ms step_avg:88.03ms +step:1270/1680 train_time:111801ms step_avg:88.03ms +step:1271/1680 train_time:111889ms step_avg:88.03ms +step:1272/1680 train_time:111977ms step_avg:88.03ms +step:1273/1680 train_time:112067ms step_avg:88.03ms +step:1274/1680 train_time:112157ms step_avg:88.04ms +step:1275/1680 train_time:112247ms step_avg:88.04ms +step:1276/1680 train_time:112338ms step_avg:88.04ms +step:1277/1680 train_time:112428ms step_avg:88.04ms +step:1278/1680 train_time:112516ms step_avg:88.04ms +step:1279/1680 train_time:112604ms step_avg:88.04ms +step:1280/1680 train_time:112693ms step_avg:88.04ms +step:1281/1680 train_time:112781ms step_avg:88.04ms +step:1282/1680 train_time:112869ms step_avg:88.04ms +step:1283/1680 train_time:112958ms step_avg:88.04ms +step:1284/1680 train_time:113046ms step_avg:88.04ms +step:1285/1680 train_time:113136ms step_avg:88.04ms +step:1286/1680 train_time:113226ms step_avg:88.04ms +step:1287/1680 train_time:113315ms step_avg:88.05ms +step:1288/1680 train_time:113404ms step_avg:88.05ms +step:1289/1680 train_time:113494ms step_avg:88.05ms +step:1290/1680 train_time:113583ms step_avg:88.05ms +step:1291/1680 train_time:113672ms step_avg:88.05ms +step:1292/1680 train_time:113762ms step_avg:88.05ms +step:1293/1680 train_time:113850ms step_avg:88.05ms +step:1294/1680 train_time:113938ms step_avg:88.05ms +step:1295/1680 train_time:114027ms step_avg:88.05ms +step:1296/1680 train_time:114115ms step_avg:88.05ms +step:1297/1680 train_time:114205ms step_avg:88.05ms +step:1298/1680 train_time:114294ms step_avg:88.05ms +step:1299/1680 train_time:114384ms step_avg:88.06ms +step:1300/1680 train_time:114474ms step_avg:88.06ms +step:1301/1680 train_time:114563ms step_avg:88.06ms +step:1302/1680 train_time:114652ms step_avg:88.06ms +step:1303/1680 train_time:114741ms step_avg:88.06ms +step:1304/1680 train_time:114830ms step_avg:88.06ms +step:1305/1680 train_time:114919ms step_avg:88.06ms +step:1306/1680 train_time:115008ms step_avg:88.06ms +step:1307/1680 train_time:115097ms step_avg:88.06ms +step:1308/1680 train_time:115187ms step_avg:88.06ms +step:1309/1680 train_time:115276ms step_avg:88.06ms +step:1310/1680 train_time:115366ms step_avg:88.07ms +step:1311/1680 train_time:115455ms step_avg:88.07ms +step:1312/1680 train_time:115544ms step_avg:88.07ms +step:1313/1680 train_time:115634ms step_avg:88.07ms +step:1314/1680 train_time:115723ms step_avg:88.07ms +step:1315/1680 train_time:115812ms step_avg:88.07ms +step:1316/1680 train_time:115902ms step_avg:88.07ms +step:1317/1680 train_time:115991ms step_avg:88.07ms +step:1318/1680 train_time:116081ms step_avg:88.07ms +step:1319/1680 train_time:116170ms step_avg:88.07ms +step:1320/1680 train_time:116260ms step_avg:88.08ms +step:1321/1680 train_time:116349ms step_avg:88.08ms +step:1322/1680 train_time:116438ms step_avg:88.08ms +step:1323/1680 train_time:116528ms step_avg:88.08ms +step:1324/1680 train_time:116617ms step_avg:88.08ms +step:1325/1680 train_time:116706ms step_avg:88.08ms +step:1326/1680 train_time:116795ms step_avg:88.08ms +step:1327/1680 train_time:116884ms step_avg:88.08ms +step:1328/1680 train_time:116974ms step_avg:88.08ms +step:1329/1680 train_time:117063ms step_avg:88.08ms +step:1330/1680 train_time:117153ms step_avg:88.08ms +step:1331/1680 train_time:117243ms step_avg:88.09ms +step:1332/1680 train_time:117333ms step_avg:88.09ms +step:1333/1680 train_time:117422ms step_avg:88.09ms +step:1334/1680 train_time:117511ms step_avg:88.09ms +step:1335/1680 train_time:117600ms step_avg:88.09ms +step:1336/1680 train_time:117689ms step_avg:88.09ms +step:1337/1680 train_time:117779ms step_avg:88.09ms +step:1338/1680 train_time:117869ms step_avg:88.09ms +step:1339/1680 train_time:117958ms step_avg:88.09ms +step:1340/1680 train_time:118047ms step_avg:88.09ms +step:1341/1680 train_time:118136ms step_avg:88.10ms +step:1342/1680 train_time:118225ms step_avg:88.10ms +step:1343/1680 train_time:118314ms step_avg:88.10ms +step:1344/1680 train_time:118403ms step_avg:88.10ms +step:1345/1680 train_time:118492ms step_avg:88.10ms +step:1346/1680 train_time:118581ms step_avg:88.10ms +step:1347/1680 train_time:118671ms step_avg:88.10ms +step:1348/1680 train_time:118760ms step_avg:88.10ms +step:1349/1680 train_time:118849ms step_avg:88.10ms +step:1350/1680 train_time:118938ms step_avg:88.10ms +step:1351/1680 train_time:119026ms step_avg:88.10ms +step:1352/1680 train_time:119115ms step_avg:88.10ms +step:1353/1680 train_time:119205ms step_avg:88.10ms +step:1354/1680 train_time:119294ms step_avg:88.11ms +step:1355/1680 train_time:119383ms step_avg:88.11ms +step:1356/1680 train_time:119472ms step_avg:88.11ms +step:1357/1680 train_time:119561ms step_avg:88.11ms +step:1358/1680 train_time:119649ms step_avg:88.11ms +step:1359/1680 train_time:119738ms step_avg:88.11ms +step:1360/1680 train_time:119827ms step_avg:88.11ms +step:1361/1680 train_time:119916ms step_avg:88.11ms +step:1362/1680 train_time:120005ms step_avg:88.11ms +step:1363/1680 train_time:120094ms step_avg:88.11ms +step:1364/1680 train_time:120184ms step_avg:88.11ms +step:1365/1680 train_time:120274ms step_avg:88.11ms +step:1366/1680 train_time:120363ms step_avg:88.11ms +step:1367/1680 train_time:120452ms step_avg:88.11ms +step:1368/1680 train_time:120542ms step_avg:88.12ms +step:1369/1680 train_time:120631ms step_avg:88.12ms +step:1370/1680 train_time:120721ms step_avg:88.12ms +step:1371/1680 train_time:120810ms step_avg:88.12ms +step:1372/1680 train_time:120899ms step_avg:88.12ms +step:1373/1680 train_time:120988ms step_avg:88.12ms +step:1374/1680 train_time:121077ms step_avg:88.12ms +step:1375/1680 train_time:121167ms step_avg:88.12ms +step:1375/1680 val_loss:3.3413 train_time:121258ms step_avg:88.19ms +step:1376/1680 train_time:121276ms step_avg:88.14ms +step:1377/1680 train_time:121350ms step_avg:88.13ms +step:1378/1680 train_time:121444ms step_avg:88.13ms +step:1379/1680 train_time:121534ms step_avg:88.13ms +step:1380/1680 train_time:121622ms step_avg:88.13ms +step:1381/1680 train_time:121710ms step_avg:88.13ms +step:1382/1680 train_time:121798ms step_avg:88.13ms +step:1383/1680 train_time:121886ms step_avg:88.13ms +step:1384/1680 train_time:121974ms step_avg:88.13ms +step:1385/1680 train_time:122063ms step_avg:88.13ms +step:1386/1680 train_time:122152ms step_avg:88.13ms +step:1387/1680 train_time:122242ms step_avg:88.13ms +step:1388/1680 train_time:122334ms step_avg:88.14ms +step:1389/1680 train_time:122426ms step_avg:88.14ms +step:1390/1680 train_time:122517ms step_avg:88.14ms +step:1391/1680 train_time:122607ms step_avg:88.14ms +step:1392/1680 train_time:122695ms step_avg:88.14ms +step:1393/1680 train_time:122783ms step_avg:88.14ms +step:1394/1680 train_time:122871ms step_avg:88.14ms +step:1395/1680 train_time:122959ms step_avg:88.14ms +step:1396/1680 train_time:123048ms step_avg:88.14ms +step:1397/1680 train_time:123136ms step_avg:88.14ms +step:1398/1680 train_time:123225ms step_avg:88.14ms +step:1399/1680 train_time:123315ms step_avg:88.15ms +step:1400/1680 train_time:123407ms step_avg:88.15ms +step:1401/1680 train_time:123496ms step_avg:88.15ms +step:1402/1680 train_time:123585ms step_avg:88.15ms +step:1403/1680 train_time:123675ms step_avg:88.15ms +step:1404/1680 train_time:123763ms step_avg:88.15ms +step:1405/1680 train_time:123852ms step_avg:88.15ms +step:1406/1680 train_time:123940ms step_avg:88.15ms +step:1407/1680 train_time:124028ms step_avg:88.15ms +step:1408/1680 train_time:124117ms step_avg:88.15ms +step:1409/1680 train_time:124206ms step_avg:88.15ms +step:1410/1680 train_time:124295ms step_avg:88.15ms +step:1411/1680 train_time:124385ms step_avg:88.15ms +step:1412/1680 train_time:124475ms step_avg:88.16ms +step:1413/1680 train_time:124565ms step_avg:88.16ms +step:1414/1680 train_time:124655ms step_avg:88.16ms +step:1415/1680 train_time:124744ms step_avg:88.16ms +step:1416/1680 train_time:124833ms step_avg:88.16ms +step:1417/1680 train_time:124923ms step_avg:88.16ms +step:1418/1680 train_time:125011ms step_avg:88.16ms +step:1419/1680 train_time:125100ms step_avg:88.16ms +step:1420/1680 train_time:125188ms step_avg:88.16ms +step:1421/1680 train_time:125277ms step_avg:88.16ms +step:1422/1680 train_time:125366ms step_avg:88.16ms +step:1423/1680 train_time:125456ms step_avg:88.16ms +step:1424/1680 train_time:125546ms step_avg:88.16ms +step:1425/1680 train_time:125635ms step_avg:88.16ms +step:1426/1680 train_time:125725ms step_avg:88.17ms +step:1427/1680 train_time:125814ms step_avg:88.17ms +step:1428/1680 train_time:125903ms step_avg:88.17ms +step:1429/1680 train_time:125993ms step_avg:88.17ms +step:1430/1680 train_time:126082ms step_avg:88.17ms +step:1431/1680 train_time:126171ms step_avg:88.17ms +step:1432/1680 train_time:126259ms step_avg:88.17ms +step:1433/1680 train_time:126349ms step_avg:88.17ms +step:1434/1680 train_time:126438ms step_avg:88.17ms +step:1435/1680 train_time:126529ms step_avg:88.17ms +step:1436/1680 train_time:126619ms step_avg:88.17ms +step:1437/1680 train_time:126710ms step_avg:88.18ms +step:1438/1680 train_time:126800ms step_avg:88.18ms +step:1439/1680 train_time:126889ms step_avg:88.18ms +step:1440/1680 train_time:126978ms step_avg:88.18ms +step:1441/1680 train_time:127068ms step_avg:88.18ms +step:1442/1680 train_time:127157ms step_avg:88.18ms +step:1443/1680 train_time:127245ms step_avg:88.18ms +step:1444/1680 train_time:127334ms step_avg:88.18ms +step:1445/1680 train_time:127424ms step_avg:88.18ms +step:1446/1680 train_time:127514ms step_avg:88.18ms +step:1447/1680 train_time:127603ms step_avg:88.18ms +step:1448/1680 train_time:127692ms step_avg:88.19ms +step:1449/1680 train_time:127781ms step_avg:88.19ms +step:1450/1680 train_time:127870ms step_avg:88.19ms +step:1451/1680 train_time:127959ms step_avg:88.19ms +step:1452/1680 train_time:128048ms step_avg:88.19ms +step:1453/1680 train_time:128137ms step_avg:88.19ms +step:1454/1680 train_time:128227ms step_avg:88.19ms +step:1455/1680 train_time:128315ms step_avg:88.19ms +step:1456/1680 train_time:128405ms step_avg:88.19ms +step:1457/1680 train_time:128494ms step_avg:88.19ms +step:1458/1680 train_time:128583ms step_avg:88.19ms +step:1459/1680 train_time:128672ms step_avg:88.19ms +step:1460/1680 train_time:128762ms step_avg:88.19ms +step:1461/1680 train_time:128851ms step_avg:88.19ms +step:1462/1680 train_time:128940ms step_avg:88.19ms +step:1463/1680 train_time:129029ms step_avg:88.19ms +step:1464/1680 train_time:129118ms step_avg:88.20ms +step:1465/1680 train_time:129207ms step_avg:88.20ms +step:1466/1680 train_time:129296ms step_avg:88.20ms +step:1467/1680 train_time:129384ms step_avg:88.20ms +step:1468/1680 train_time:129473ms step_avg:88.20ms +step:1469/1680 train_time:129562ms step_avg:88.20ms +step:1470/1680 train_time:129652ms step_avg:88.20ms +step:1471/1680 train_time:129741ms step_avg:88.20ms +step:1472/1680 train_time:129831ms step_avg:88.20ms +step:1473/1680 train_time:129920ms step_avg:88.20ms +step:1474/1680 train_time:130009ms step_avg:88.20ms +step:1475/1680 train_time:130098ms step_avg:88.20ms +step:1476/1680 train_time:130187ms step_avg:88.20ms +step:1477/1680 train_time:130276ms step_avg:88.20ms +step:1478/1680 train_time:130365ms step_avg:88.20ms +step:1479/1680 train_time:130454ms step_avg:88.20ms +step:1480/1680 train_time:130544ms step_avg:88.21ms +step:1481/1680 train_time:130633ms step_avg:88.21ms +step:1482/1680 train_time:130722ms step_avg:88.21ms +step:1483/1680 train_time:130811ms step_avg:88.21ms +step:1484/1680 train_time:130901ms step_avg:88.21ms +step:1485/1680 train_time:130990ms step_avg:88.21ms +step:1486/1680 train_time:131079ms step_avg:88.21ms +step:1487/1680 train_time:131168ms step_avg:88.21ms +step:1488/1680 train_time:131257ms step_avg:88.21ms +step:1489/1680 train_time:131346ms step_avg:88.21ms +step:1490/1680 train_time:131435ms step_avg:88.21ms +step:1491/1680 train_time:131525ms step_avg:88.21ms +step:1492/1680 train_time:131613ms step_avg:88.21ms +step:1493/1680 train_time:131703ms step_avg:88.21ms +step:1494/1680 train_time:131793ms step_avg:88.21ms +step:1495/1680 train_time:131882ms step_avg:88.22ms +step:1496/1680 train_time:131971ms step_avg:88.22ms +step:1497/1680 train_time:132060ms step_avg:88.22ms +step:1498/1680 train_time:132150ms step_avg:88.22ms +step:1499/1680 train_time:132240ms step_avg:88.22ms +step:1500/1680 train_time:132329ms step_avg:88.22ms +step:1500/1680 val_loss:3.3115 train_time:132420ms step_avg:88.28ms +step:1501/1680 train_time:132440ms step_avg:88.23ms +step:1502/1680 train_time:132512ms step_avg:88.22ms +step:1503/1680 train_time:132603ms step_avg:88.23ms +step:1504/1680 train_time:132693ms step_avg:88.23ms +step:1505/1680 train_time:132783ms step_avg:88.23ms +step:1506/1680 train_time:132872ms step_avg:88.23ms +step:1507/1680 train_time:132961ms step_avg:88.23ms +step:1508/1680 train_time:133049ms step_avg:88.23ms +step:1509/1680 train_time:133137ms step_avg:88.23ms +step:1510/1680 train_time:133226ms step_avg:88.23ms +step:1511/1680 train_time:133314ms step_avg:88.23ms +step:1512/1680 train_time:133404ms step_avg:88.23ms +step:1513/1680 train_time:133496ms step_avg:88.23ms +step:1514/1680 train_time:133586ms step_avg:88.23ms +step:1515/1680 train_time:133676ms step_avg:88.23ms +step:1516/1680 train_time:133766ms step_avg:88.24ms +step:1517/1680 train_time:133855ms step_avg:88.24ms +step:1518/1680 train_time:133943ms step_avg:88.24ms +step:1519/1680 train_time:134032ms step_avg:88.24ms +step:1520/1680 train_time:134120ms step_avg:88.24ms +step:1521/1680 train_time:134208ms step_avg:88.24ms +step:1522/1680 train_time:134297ms step_avg:88.24ms +step:1523/1680 train_time:134386ms step_avg:88.24ms +step:1524/1680 train_time:134476ms step_avg:88.24ms +step:1525/1680 train_time:134566ms step_avg:88.24ms +step:1526/1680 train_time:134656ms step_avg:88.24ms +step:1527/1680 train_time:134746ms step_avg:88.24ms +step:1528/1680 train_time:134835ms step_avg:88.24ms +step:1529/1680 train_time:134924ms step_avg:88.24ms +step:1530/1680 train_time:135013ms step_avg:88.24ms +step:1531/1680 train_time:135102ms step_avg:88.24ms +step:1532/1680 train_time:135190ms step_avg:88.24ms +step:1533/1680 train_time:135279ms step_avg:88.24ms +step:1534/1680 train_time:135368ms step_avg:88.25ms +step:1535/1680 train_time:135457ms step_avg:88.25ms +step:1536/1680 train_time:135548ms step_avg:88.25ms +step:1537/1680 train_time:135637ms step_avg:88.25ms +step:1538/1680 train_time:135727ms step_avg:88.25ms +step:1539/1680 train_time:135816ms step_avg:88.25ms +step:1540/1680 train_time:135905ms step_avg:88.25ms +step:1541/1680 train_time:135994ms step_avg:88.25ms +step:1542/1680 train_time:136084ms step_avg:88.25ms +step:1543/1680 train_time:136173ms step_avg:88.25ms +step:1544/1680 train_time:136261ms step_avg:88.25ms +step:1545/1680 train_time:136351ms step_avg:88.25ms +step:1546/1680 train_time:136441ms step_avg:88.25ms +step:1547/1680 train_time:136530ms step_avg:88.25ms +step:1548/1680 train_time:136619ms step_avg:88.26ms +step:1549/1680 train_time:136710ms step_avg:88.26ms +step:1550/1680 train_time:136799ms step_avg:88.26ms +step:1551/1680 train_time:136888ms step_avg:88.26ms +step:1552/1680 train_time:136976ms step_avg:88.26ms +step:1553/1680 train_time:137065ms step_avg:88.26ms +step:1554/1680 train_time:137154ms step_avg:88.26ms +step:1555/1680 train_time:137243ms step_avg:88.26ms +step:1556/1680 train_time:137333ms step_avg:88.26ms +step:1557/1680 train_time:137423ms step_avg:88.26ms +step:1558/1680 train_time:137512ms step_avg:88.26ms +step:1559/1680 train_time:137601ms step_avg:88.26ms +step:1560/1680 train_time:137691ms step_avg:88.26ms +step:1561/1680 train_time:137781ms step_avg:88.26ms +step:1562/1680 train_time:137870ms step_avg:88.27ms +step:1563/1680 train_time:137959ms step_avg:88.27ms +step:1564/1680 train_time:138048ms step_avg:88.27ms +step:1565/1680 train_time:138136ms step_avg:88.27ms +step:1566/1680 train_time:138225ms step_avg:88.27ms +step:1567/1680 train_time:138314ms step_avg:88.27ms +step:1568/1680 train_time:138403ms step_avg:88.27ms +step:1569/1680 train_time:138492ms step_avg:88.27ms +step:1570/1680 train_time:138582ms step_avg:88.27ms +step:1571/1680 train_time:138671ms step_avg:88.27ms +step:1572/1680 train_time:138760ms step_avg:88.27ms +step:1573/1680 train_time:138849ms step_avg:88.27ms +step:1574/1680 train_time:138938ms step_avg:88.27ms +step:1575/1680 train_time:139028ms step_avg:88.27ms +step:1576/1680 train_time:139117ms step_avg:88.27ms +step:1577/1680 train_time:139206ms step_avg:88.27ms +step:1578/1680 train_time:139294ms step_avg:88.27ms +step:1579/1680 train_time:139384ms step_avg:88.27ms +step:1580/1680 train_time:139473ms step_avg:88.27ms +step:1581/1680 train_time:139563ms step_avg:88.27ms +step:1582/1680 train_time:139652ms step_avg:88.28ms +step:1583/1680 train_time:139740ms step_avg:88.28ms +step:1584/1680 train_time:139830ms step_avg:88.28ms +step:1585/1680 train_time:139920ms step_avg:88.28ms +step:1586/1680 train_time:140009ms step_avg:88.28ms +step:1587/1680 train_time:140098ms step_avg:88.28ms +step:1588/1680 train_time:140186ms step_avg:88.28ms +step:1589/1680 train_time:140275ms step_avg:88.28ms +step:1590/1680 train_time:140365ms step_avg:88.28ms +step:1591/1680 train_time:140454ms step_avg:88.28ms +step:1592/1680 train_time:140544ms step_avg:88.28ms +step:1593/1680 train_time:140633ms step_avg:88.28ms +step:1594/1680 train_time:140723ms step_avg:88.28ms +step:1595/1680 train_time:140812ms step_avg:88.28ms +step:1596/1680 train_time:140901ms step_avg:88.28ms +step:1597/1680 train_time:140989ms step_avg:88.28ms +step:1598/1680 train_time:141078ms step_avg:88.28ms +step:1599/1680 train_time:141167ms step_avg:88.28ms +step:1600/1680 train_time:141256ms step_avg:88.28ms +step:1601/1680 train_time:141345ms step_avg:88.29ms +step:1602/1680 train_time:141434ms step_avg:88.29ms +step:1603/1680 train_time:141523ms step_avg:88.29ms +step:1604/1680 train_time:141612ms step_avg:88.29ms +step:1605/1680 train_time:141702ms step_avg:88.29ms +step:1606/1680 train_time:141791ms step_avg:88.29ms +step:1607/1680 train_time:141880ms step_avg:88.29ms +step:1608/1680 train_time:141969ms step_avg:88.29ms +step:1609/1680 train_time:142058ms step_avg:88.29ms +step:1610/1680 train_time:142148ms step_avg:88.29ms +step:1611/1680 train_time:142236ms step_avg:88.29ms +step:1612/1680 train_time:142326ms step_avg:88.29ms +step:1613/1680 train_time:142415ms step_avg:88.29ms +step:1614/1680 train_time:142504ms step_avg:88.29ms +step:1615/1680 train_time:142593ms step_avg:88.29ms +step:1616/1680 train_time:142683ms step_avg:88.29ms +step:1617/1680 train_time:142772ms step_avg:88.29ms +step:1618/1680 train_time:142863ms step_avg:88.30ms +step:1619/1680 train_time:142952ms step_avg:88.30ms +step:1620/1680 train_time:143041ms step_avg:88.30ms +step:1621/1680 train_time:143130ms step_avg:88.30ms +step:1622/1680 train_time:143220ms step_avg:88.30ms +step:1623/1680 train_time:143309ms step_avg:88.30ms +step:1624/1680 train_time:143399ms step_avg:88.30ms +step:1625/1680 train_time:143488ms step_avg:88.30ms +step:1625/1680 val_loss:3.2878 train_time:143579ms step_avg:88.36ms +step:1626/1680 train_time:143597ms step_avg:88.31ms +step:1627/1680 train_time:143672ms step_avg:88.30ms +step:1628/1680 train_time:143765ms step_avg:88.31ms +step:1629/1680 train_time:143854ms step_avg:88.31ms +step:1630/1680 train_time:143943ms step_avg:88.31ms +step:1631/1680 train_time:144032ms step_avg:88.31ms +step:1632/1680 train_time:144120ms step_avg:88.31ms +step:1633/1680 train_time:144209ms step_avg:88.31ms +step:1634/1680 train_time:144296ms step_avg:88.31ms +step:1635/1680 train_time:144385ms step_avg:88.31ms +step:1636/1680 train_time:144473ms step_avg:88.31ms +step:1637/1680 train_time:144566ms step_avg:88.31ms +step:1638/1680 train_time:144657ms step_avg:88.31ms +step:1639/1680 train_time:144748ms step_avg:88.31ms +step:1640/1680 train_time:144838ms step_avg:88.32ms +step:1641/1680 train_time:144928ms step_avg:88.32ms +step:1642/1680 train_time:145016ms step_avg:88.32ms +step:1643/1680 train_time:145105ms step_avg:88.32ms +step:1644/1680 train_time:145194ms step_avg:88.32ms +step:1645/1680 train_time:145282ms step_avg:88.32ms +step:1646/1680 train_time:145371ms step_avg:88.32ms +step:1647/1680 train_time:145459ms step_avg:88.32ms +step:1648/1680 train_time:145549ms step_avg:88.32ms +step:1649/1680 train_time:145640ms step_avg:88.32ms +step:1650/1680 train_time:145730ms step_avg:88.32ms +step:1651/1680 train_time:145819ms step_avg:88.32ms +step:1652/1680 train_time:145909ms step_avg:88.32ms +step:1653/1680 train_time:145999ms step_avg:88.32ms +step:1654/1680 train_time:146087ms step_avg:88.32ms +step:1655/1680 train_time:146176ms step_avg:88.32ms +step:1656/1680 train_time:146265ms step_avg:88.32ms +step:1657/1680 train_time:146354ms step_avg:88.32ms +step:1658/1680 train_time:146442ms step_avg:88.32ms +step:1659/1680 train_time:146532ms step_avg:88.33ms +step:1660/1680 train_time:146621ms step_avg:88.33ms +step:1661/1680 train_time:146711ms step_avg:88.33ms +step:1662/1680 train_time:146801ms step_avg:88.33ms +step:1663/1680 train_time:146891ms step_avg:88.33ms +step:1664/1680 train_time:146981ms step_avg:88.33ms +step:1665/1680 train_time:147069ms step_avg:88.33ms +step:1666/1680 train_time:147158ms step_avg:88.33ms +step:1667/1680 train_time:147246ms step_avg:88.33ms +step:1668/1680 train_time:147335ms step_avg:88.33ms +step:1669/1680 train_time:147424ms step_avg:88.33ms +step:1670/1680 train_time:147513ms step_avg:88.33ms +step:1671/1680 train_time:147602ms step_avg:88.33ms +step:1672/1680 train_time:147692ms step_avg:88.33ms +step:1673/1680 train_time:147782ms step_avg:88.33ms +step:1674/1680 train_time:147871ms step_avg:88.33ms +step:1675/1680 train_time:147960ms step_avg:88.33ms +step:1676/1680 train_time:148049ms step_avg:88.33ms +step:1677/1680 train_time:148138ms step_avg:88.34ms +step:1678/1680 train_time:148227ms step_avg:88.34ms +step:1679/1680 train_time:148316ms step_avg:88.34ms +step:1680/1680 train_time:148406ms step_avg:88.34ms +step:1680/1680 val_loss:3.2773 train_time:148497ms step_avg:88.39ms +peak memory allocated: 30760 MiB reserved: 46194 MiB diff --git a/records/092725_BF16CE/13de21d5-e0e9-4dab-b42d-ad13e73bc402.txt b/records/092725_BF16CE/13de21d5-e0e9-4dab-b42d-ad13e73bc402.txt new file mode 100644 index 000000000..6a1c9a66b --- /dev/null +++ b/records/092725_BF16CE/13de21d5-e0e9-4dab-b42d-ad13e73bc402.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:12:53 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 24C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 28C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 153204 C /usr/bin/python 0MiB | +| 0 N/A N/A 153205 C /usr/bin/python 0MiB | +| 0 N/A N/A 153206 C /usr/bin/python 0MiB | +| 0 N/A N/A 153207 C /usr/bin/python 0MiB | +| 0 N/A N/A 153208 C /usr/bin/python 0MiB | +| 0 N/A N/A 153209 C /usr/bin/python 0MiB | +| 0 N/A N/A 153210 C /usr/bin/python 0MiB | +| 0 N/A N/A 153211 C /usr/bin/python 0MiB | +| 1 N/A N/A 153205 C /usr/bin/python 0MiB | +| 2 N/A N/A 153206 C /usr/bin/python 0MiB | +| 3 N/A N/A 153207 C /usr/bin/python 0MiB | +| 4 N/A N/A 153208 C /usr/bin/python 0MiB | +| 5 N/A N/A 153209 C /usr/bin/python 0MiB | +| 6 N/A N/A 153210 C /usr/bin/python 0MiB | +| 7 N/A N/A 153211 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:150ms step_avg:149.83ms +step:2/1680 train_time:170ms step_avg:85.10ms +step:3/1680 train_time:233ms step_avg:77.80ms +step:4/1680 train_time:319ms step_avg:79.70ms +step:5/1680 train_time:405ms step_avg:80.94ms +step:6/1680 train_time:491ms step_avg:81.78ms +step:7/1680 train_time:577ms step_avg:82.44ms +step:8/1680 train_time:664ms step_avg:82.96ms +step:9/1680 train_time:750ms step_avg:83.35ms +step:10/1680 train_time:836ms step_avg:83.64ms +step:11/1680 train_time:923ms step_avg:83.87ms +step:12/1680 train_time:1011ms step_avg:84.22ms +step:13/1680 train_time:1101ms step_avg:84.71ms +step:14/1680 train_time:1192ms step_avg:85.13ms +step:15/1680 train_time:1280ms step_avg:85.33ms +step:16/1680 train_time:1368ms step_avg:85.49ms +step:17/1680 train_time:1454ms step_avg:85.54ms +step:18/1680 train_time:1541ms step_avg:85.63ms +step:19/1680 train_time:1628ms step_avg:85.70ms +step:20/1680 train_time:1715ms step_avg:85.76ms +step:21/1680 train_time:1802ms step_avg:85.80ms +step:22/1680 train_time:1889ms step_avg:85.88ms +step:23/1680 train_time:1976ms step_avg:85.92ms +step:24/1680 train_time:2064ms step_avg:86.01ms +step:25/1680 train_time:2154ms step_avg:86.14ms +step:26/1680 train_time:2242ms step_avg:86.24ms +step:27/1680 train_time:2331ms step_avg:86.34ms +step:28/1680 train_time:2419ms step_avg:86.38ms +step:29/1680 train_time:2506ms step_avg:86.40ms +step:30/1680 train_time:2593ms step_avg:86.42ms +step:31/1680 train_time:2680ms step_avg:86.46ms +step:32/1680 train_time:2767ms step_avg:86.46ms +step:33/1680 train_time:2853ms step_avg:86.46ms +step:34/1680 train_time:2941ms step_avg:86.50ms +step:35/1680 train_time:3030ms step_avg:86.56ms +step:36/1680 train_time:3118ms step_avg:86.61ms +step:37/1680 train_time:3206ms step_avg:86.66ms +step:38/1680 train_time:3294ms step_avg:86.70ms +step:39/1680 train_time:3382ms step_avg:86.73ms +step:40/1680 train_time:3469ms step_avg:86.74ms +step:41/1680 train_time:3556ms step_avg:86.74ms +step:42/1680 train_time:3643ms step_avg:86.74ms +step:43/1680 train_time:3730ms step_avg:86.75ms +step:44/1680 train_time:3817ms step_avg:86.75ms +step:45/1680 train_time:3904ms step_avg:86.75ms +step:46/1680 train_time:3991ms step_avg:86.77ms +step:47/1680 train_time:4080ms step_avg:86.80ms +step:48/1680 train_time:4168ms step_avg:86.83ms +step:49/1680 train_time:4256ms step_avg:86.85ms +step:50/1680 train_time:4344ms step_avg:86.87ms +step:51/1680 train_time:4431ms step_avg:86.88ms +step:52/1680 train_time:4518ms step_avg:86.88ms +step:53/1680 train_time:4605ms step_avg:86.90ms +step:54/1680 train_time:4693ms step_avg:86.90ms +step:55/1680 train_time:4779ms step_avg:86.90ms +step:56/1680 train_time:4866ms step_avg:86.89ms +step:57/1680 train_time:4953ms step_avg:86.90ms +step:58/1680 train_time:5041ms step_avg:86.91ms +step:59/1680 train_time:5129ms step_avg:86.94ms +step:60/1680 train_time:5217ms step_avg:86.95ms +step:61/1680 train_time:5305ms step_avg:86.96ms +step:62/1680 train_time:5393ms step_avg:86.98ms +step:63/1680 train_time:5480ms step_avg:86.99ms +step:64/1680 train_time:5567ms step_avg:86.98ms +step:65/1680 train_time:5654ms step_avg:86.99ms +step:66/1680 train_time:5742ms step_avg:87.00ms +step:67/1680 train_time:5829ms step_avg:87.00ms +step:68/1680 train_time:5916ms step_avg:87.00ms +step:69/1680 train_time:6003ms step_avg:87.00ms +step:70/1680 train_time:6091ms step_avg:87.01ms +step:71/1680 train_time:6179ms step_avg:87.03ms +step:72/1680 train_time:6267ms step_avg:87.04ms +step:73/1680 train_time:6355ms step_avg:87.05ms +step:74/1680 train_time:6442ms step_avg:87.06ms +step:75/1680 train_time:6529ms step_avg:87.06ms +step:76/1680 train_time:6616ms step_avg:87.06ms +step:77/1680 train_time:6703ms step_avg:87.06ms +step:78/1680 train_time:6791ms step_avg:87.06ms +step:79/1680 train_time:6878ms step_avg:87.07ms +step:80/1680 train_time:6966ms step_avg:87.07ms +step:81/1680 train_time:7053ms step_avg:87.08ms +step:82/1680 train_time:7141ms step_avg:87.08ms +step:83/1680 train_time:7228ms step_avg:87.09ms +step:84/1680 train_time:7315ms step_avg:87.09ms +step:85/1680 train_time:7403ms step_avg:87.09ms +step:86/1680 train_time:7490ms step_avg:87.09ms +step:87/1680 train_time:7577ms step_avg:87.09ms +step:88/1680 train_time:7665ms step_avg:87.10ms +step:89/1680 train_time:7752ms step_avg:87.10ms +step:90/1680 train_time:7840ms step_avg:87.11ms +step:91/1680 train_time:7927ms step_avg:87.11ms +step:92/1680 train_time:8014ms step_avg:87.11ms +step:93/1680 train_time:8101ms step_avg:87.11ms +step:94/1680 train_time:8189ms step_avg:87.12ms +step:95/1680 train_time:8276ms step_avg:87.12ms +step:96/1680 train_time:8364ms step_avg:87.12ms +step:97/1680 train_time:8452ms step_avg:87.13ms +step:98/1680 train_time:8539ms step_avg:87.13ms +step:99/1680 train_time:8626ms step_avg:87.13ms +step:100/1680 train_time:8713ms step_avg:87.13ms +step:101/1680 train_time:8800ms step_avg:87.13ms +step:102/1680 train_time:8888ms step_avg:87.14ms +step:103/1680 train_time:8974ms step_avg:87.13ms +step:104/1680 train_time:9062ms step_avg:87.14ms +step:105/1680 train_time:9151ms step_avg:87.15ms +step:106/1680 train_time:9238ms step_avg:87.15ms +step:107/1680 train_time:9326ms step_avg:87.16ms +step:108/1680 train_time:9413ms step_avg:87.15ms +step:109/1680 train_time:9500ms step_avg:87.16ms +step:110/1680 train_time:9588ms step_avg:87.17ms +step:111/1680 train_time:9675ms step_avg:87.16ms +step:112/1680 train_time:9762ms step_avg:87.16ms +step:113/1680 train_time:9849ms step_avg:87.16ms +step:114/1680 train_time:9936ms step_avg:87.16ms +step:115/1680 train_time:10023ms step_avg:87.16ms +step:116/1680 train_time:10110ms step_avg:87.16ms +step:117/1680 train_time:10197ms step_avg:87.16ms +step:118/1680 train_time:10285ms step_avg:87.16ms +step:119/1680 train_time:10373ms step_avg:87.17ms +step:120/1680 train_time:10460ms step_avg:87.17ms +step:121/1680 train_time:10548ms step_avg:87.17ms +step:122/1680 train_time:10635ms step_avg:87.17ms +step:123/1680 train_time:10722ms step_avg:87.17ms +step:124/1680 train_time:10809ms step_avg:87.17ms +step:125/1680 train_time:10896ms step_avg:87.17ms +step:125/1680 val_loss:4.3221 train_time:10985ms step_avg:87.88ms +step:126/1680 train_time:11004ms step_avg:87.34ms +step:127/1680 train_time:11075ms step_avg:87.21ms +step:128/1680 train_time:11172ms step_avg:87.28ms +step:129/1680 train_time:11263ms step_avg:87.31ms +step:130/1680 train_time:11350ms step_avg:87.31ms +step:131/1680 train_time:11437ms step_avg:87.30ms +step:132/1680 train_time:11523ms step_avg:87.29ms +step:133/1680 train_time:11609ms step_avg:87.29ms +step:134/1680 train_time:11695ms step_avg:87.28ms +step:135/1680 train_time:11781ms step_avg:87.27ms +step:136/1680 train_time:11867ms step_avg:87.26ms +step:137/1680 train_time:11955ms step_avg:87.26ms +step:138/1680 train_time:12043ms step_avg:87.27ms +step:139/1680 train_time:12132ms step_avg:87.28ms +step:140/1680 train_time:12222ms step_avg:87.30ms +step:141/1680 train_time:12310ms step_avg:87.30ms +step:142/1680 train_time:12397ms step_avg:87.30ms +step:143/1680 train_time:12484ms step_avg:87.30ms +step:144/1680 train_time:12571ms step_avg:87.30ms +step:145/1680 train_time:12657ms step_avg:87.29ms +step:146/1680 train_time:12744ms step_avg:87.29ms +step:147/1680 train_time:12830ms step_avg:87.28ms +step:148/1680 train_time:12917ms step_avg:87.27ms +step:149/1680 train_time:13004ms step_avg:87.27ms +step:150/1680 train_time:13092ms step_avg:87.28ms +step:151/1680 train_time:13181ms step_avg:87.29ms +step:152/1680 train_time:13269ms step_avg:87.30ms +step:153/1680 train_time:13357ms step_avg:87.30ms +step:154/1680 train_time:13445ms step_avg:87.31ms +step:155/1680 train_time:13532ms step_avg:87.30ms +step:156/1680 train_time:13620ms step_avg:87.30ms +step:157/1680 train_time:13707ms step_avg:87.30ms +step:158/1680 train_time:13793ms step_avg:87.30ms +step:159/1680 train_time:13880ms step_avg:87.30ms +step:160/1680 train_time:13967ms step_avg:87.29ms +step:161/1680 train_time:14055ms step_avg:87.30ms +step:162/1680 train_time:14142ms step_avg:87.30ms +step:163/1680 train_time:14230ms step_avg:87.30ms +step:164/1680 train_time:14318ms step_avg:87.30ms +step:165/1680 train_time:14405ms step_avg:87.30ms +step:166/1680 train_time:14492ms step_avg:87.30ms +step:167/1680 train_time:14579ms step_avg:87.30ms +step:168/1680 train_time:14666ms step_avg:87.30ms +step:169/1680 train_time:14753ms step_avg:87.29ms +step:170/1680 train_time:14840ms step_avg:87.29ms +step:171/1680 train_time:14927ms step_avg:87.29ms +step:172/1680 train_time:15014ms step_avg:87.29ms +step:173/1680 train_time:15101ms step_avg:87.29ms +step:174/1680 train_time:15189ms step_avg:87.29ms +step:175/1680 train_time:15276ms step_avg:87.29ms +step:176/1680 train_time:15364ms step_avg:87.30ms +step:177/1680 train_time:15452ms step_avg:87.30ms +step:178/1680 train_time:15539ms step_avg:87.30ms +step:179/1680 train_time:15626ms step_avg:87.30ms +step:180/1680 train_time:15713ms step_avg:87.29ms +step:181/1680 train_time:15800ms step_avg:87.29ms +step:182/1680 train_time:15887ms step_avg:87.29ms +step:183/1680 train_time:15973ms step_avg:87.29ms +step:184/1680 train_time:16060ms step_avg:87.28ms +step:185/1680 train_time:16148ms step_avg:87.29ms +step:186/1680 train_time:16235ms step_avg:87.28ms +step:187/1680 train_time:16322ms step_avg:87.28ms +step:188/1680 train_time:16409ms step_avg:87.28ms +step:189/1680 train_time:16497ms step_avg:87.29ms +step:190/1680 train_time:16585ms step_avg:87.29ms +step:191/1680 train_time:16672ms step_avg:87.29ms +step:192/1680 train_time:16760ms step_avg:87.29ms +step:193/1680 train_time:16847ms step_avg:87.29ms +step:194/1680 train_time:16934ms step_avg:87.29ms +step:195/1680 train_time:17021ms step_avg:87.29ms +step:196/1680 train_time:17108ms step_avg:87.28ms +step:197/1680 train_time:17195ms step_avg:87.29ms +step:198/1680 train_time:17282ms step_avg:87.28ms +step:199/1680 train_time:17369ms step_avg:87.28ms +step:200/1680 train_time:17457ms step_avg:87.28ms +step:201/1680 train_time:17544ms step_avg:87.29ms +step:202/1680 train_time:17631ms step_avg:87.28ms +step:203/1680 train_time:17719ms step_avg:87.29ms +step:204/1680 train_time:17807ms step_avg:87.29ms +step:205/1680 train_time:17894ms step_avg:87.29ms +step:206/1680 train_time:17981ms step_avg:87.29ms +step:207/1680 train_time:18068ms step_avg:87.28ms +step:208/1680 train_time:18155ms step_avg:87.28ms +step:209/1680 train_time:18242ms step_avg:87.28ms +step:210/1680 train_time:18330ms step_avg:87.28ms +step:211/1680 train_time:18417ms step_avg:87.28ms +step:212/1680 train_time:18505ms step_avg:87.29ms +step:213/1680 train_time:18593ms step_avg:87.29ms +step:214/1680 train_time:18681ms step_avg:87.29ms +step:215/1680 train_time:18768ms step_avg:87.29ms +step:216/1680 train_time:18854ms step_avg:87.29ms +step:217/1680 train_time:18942ms step_avg:87.29ms +step:218/1680 train_time:19029ms step_avg:87.29ms +step:219/1680 train_time:19116ms step_avg:87.29ms +step:220/1680 train_time:19204ms step_avg:87.29ms +step:221/1680 train_time:19291ms step_avg:87.29ms +step:222/1680 train_time:19378ms step_avg:87.29ms +step:223/1680 train_time:19465ms step_avg:87.29ms +step:224/1680 train_time:19553ms step_avg:87.29ms +step:225/1680 train_time:19641ms step_avg:87.29ms +step:226/1680 train_time:19728ms step_avg:87.29ms +step:227/1680 train_time:19815ms step_avg:87.29ms +step:228/1680 train_time:19902ms step_avg:87.29ms +step:229/1680 train_time:19989ms step_avg:87.29ms +step:230/1680 train_time:20077ms step_avg:87.29ms +step:231/1680 train_time:20164ms step_avg:87.29ms +step:232/1680 train_time:20251ms step_avg:87.29ms +step:233/1680 train_time:20338ms step_avg:87.29ms +step:234/1680 train_time:20425ms step_avg:87.29ms +step:235/1680 train_time:20512ms step_avg:87.28ms +step:236/1680 train_time:20599ms step_avg:87.29ms +step:237/1680 train_time:20687ms step_avg:87.29ms +step:238/1680 train_time:20774ms step_avg:87.29ms +step:239/1680 train_time:20861ms step_avg:87.28ms +step:240/1680 train_time:20948ms step_avg:87.28ms +step:241/1680 train_time:21035ms step_avg:87.28ms +step:242/1680 train_time:21121ms step_avg:87.28ms +step:243/1680 train_time:21208ms step_avg:87.28ms +step:244/1680 train_time:21295ms step_avg:87.28ms +step:245/1680 train_time:21382ms step_avg:87.27ms +step:246/1680 train_time:21469ms step_avg:87.27ms +step:247/1680 train_time:21557ms step_avg:87.27ms +step:248/1680 train_time:21645ms step_avg:87.28ms +step:249/1680 train_time:21732ms step_avg:87.28ms +step:250/1680 train_time:21819ms step_avg:87.28ms +step:250/1680 val_loss:3.9705 train_time:21908ms step_avg:87.63ms +step:251/1680 train_time:21927ms step_avg:87.36ms +step:252/1680 train_time:21998ms step_avg:87.29ms +step:253/1680 train_time:22090ms step_avg:87.31ms +step:254/1680 train_time:22179ms step_avg:87.32ms +step:255/1680 train_time:22266ms step_avg:87.32ms +step:256/1680 train_time:22352ms step_avg:87.31ms +step:257/1680 train_time:22440ms step_avg:87.31ms +step:258/1680 train_time:22526ms step_avg:87.31ms +step:259/1680 train_time:22612ms step_avg:87.30ms +step:260/1680 train_time:22698ms step_avg:87.30ms +step:261/1680 train_time:22784ms step_avg:87.29ms +step:262/1680 train_time:22870ms step_avg:87.29ms +step:263/1680 train_time:22958ms step_avg:87.29ms +step:264/1680 train_time:23048ms step_avg:87.30ms +step:265/1680 train_time:23136ms step_avg:87.31ms +step:266/1680 train_time:23226ms step_avg:87.31ms +step:267/1680 train_time:23313ms step_avg:87.32ms +step:268/1680 train_time:23401ms step_avg:87.32ms +step:269/1680 train_time:23488ms step_avg:87.32ms +step:270/1680 train_time:23575ms step_avg:87.31ms +step:271/1680 train_time:23662ms step_avg:87.31ms +step:272/1680 train_time:23748ms step_avg:87.31ms +step:273/1680 train_time:23835ms step_avg:87.31ms +step:274/1680 train_time:23923ms step_avg:87.31ms +step:275/1680 train_time:24010ms step_avg:87.31ms +step:276/1680 train_time:24099ms step_avg:87.31ms +step:277/1680 train_time:24186ms step_avg:87.32ms +step:278/1680 train_time:24274ms step_avg:87.32ms +step:279/1680 train_time:24362ms step_avg:87.32ms +step:280/1680 train_time:24449ms step_avg:87.32ms +step:281/1680 train_time:24536ms step_avg:87.32ms +step:282/1680 train_time:24623ms step_avg:87.32ms +step:283/1680 train_time:24709ms step_avg:87.31ms +step:284/1680 train_time:24796ms step_avg:87.31ms +step:285/1680 train_time:24883ms step_avg:87.31ms +step:286/1680 train_time:24970ms step_avg:87.31ms +step:287/1680 train_time:25058ms step_avg:87.31ms +step:288/1680 train_time:25145ms step_avg:87.31ms +step:289/1680 train_time:25233ms step_avg:87.31ms +step:290/1680 train_time:25320ms step_avg:87.31ms +step:291/1680 train_time:25407ms step_avg:87.31ms +step:292/1680 train_time:25494ms step_avg:87.31ms +step:293/1680 train_time:25581ms step_avg:87.31ms +step:294/1680 train_time:25668ms step_avg:87.31ms +step:295/1680 train_time:25754ms step_avg:87.30ms +step:296/1680 train_time:25842ms step_avg:87.30ms +step:297/1680 train_time:25928ms step_avg:87.30ms +step:298/1680 train_time:26016ms step_avg:87.30ms +step:299/1680 train_time:26104ms step_avg:87.30ms +step:300/1680 train_time:26191ms step_avg:87.30ms +step:301/1680 train_time:26279ms step_avg:87.31ms +step:302/1680 train_time:26366ms step_avg:87.31ms +step:303/1680 train_time:26454ms step_avg:87.31ms +step:304/1680 train_time:26541ms step_avg:87.31ms +step:305/1680 train_time:26628ms step_avg:87.30ms +step:306/1680 train_time:26715ms step_avg:87.30ms +step:307/1680 train_time:26802ms step_avg:87.30ms +step:308/1680 train_time:26889ms step_avg:87.30ms +step:309/1680 train_time:26976ms step_avg:87.30ms +step:310/1680 train_time:27063ms step_avg:87.30ms +step:311/1680 train_time:27151ms step_avg:87.30ms +step:312/1680 train_time:27239ms step_avg:87.31ms +step:313/1680 train_time:27327ms step_avg:87.31ms +step:314/1680 train_time:27414ms step_avg:87.30ms +step:315/1680 train_time:27501ms step_avg:87.30ms +step:316/1680 train_time:27587ms step_avg:87.30ms +step:317/1680 train_time:27675ms step_avg:87.30ms +step:318/1680 train_time:27762ms step_avg:87.30ms +step:319/1680 train_time:27849ms step_avg:87.30ms +step:320/1680 train_time:27936ms step_avg:87.30ms +step:321/1680 train_time:28024ms step_avg:87.30ms +step:322/1680 train_time:28111ms step_avg:87.30ms +step:323/1680 train_time:28199ms step_avg:87.30ms +step:324/1680 train_time:28287ms step_avg:87.30ms +step:325/1680 train_time:28375ms step_avg:87.31ms +step:326/1680 train_time:28462ms step_avg:87.31ms +step:327/1680 train_time:28550ms step_avg:87.31ms +step:328/1680 train_time:28637ms step_avg:87.31ms +step:329/1680 train_time:28724ms step_avg:87.31ms +step:330/1680 train_time:28812ms step_avg:87.31ms +step:331/1680 train_time:28899ms step_avg:87.31ms +step:332/1680 train_time:28986ms step_avg:87.31ms +step:333/1680 train_time:29073ms step_avg:87.31ms +step:334/1680 train_time:29160ms step_avg:87.31ms +step:335/1680 train_time:29248ms step_avg:87.31ms +step:336/1680 train_time:29335ms step_avg:87.31ms +step:337/1680 train_time:29423ms step_avg:87.31ms +step:338/1680 train_time:29510ms step_avg:87.31ms +step:339/1680 train_time:29597ms step_avg:87.31ms +step:340/1680 train_time:29685ms step_avg:87.31ms +step:341/1680 train_time:29771ms step_avg:87.31ms +step:342/1680 train_time:29858ms step_avg:87.30ms +step:343/1680 train_time:29945ms step_avg:87.30ms +step:344/1680 train_time:30033ms step_avg:87.30ms +step:345/1680 train_time:30120ms step_avg:87.30ms +step:346/1680 train_time:30208ms step_avg:87.30ms +step:347/1680 train_time:30295ms step_avg:87.31ms +step:348/1680 train_time:30383ms step_avg:87.31ms +step:349/1680 train_time:30471ms step_avg:87.31ms +step:350/1680 train_time:30558ms step_avg:87.31ms +step:351/1680 train_time:30646ms step_avg:87.31ms +step:352/1680 train_time:30733ms step_avg:87.31ms +step:353/1680 train_time:30819ms step_avg:87.31ms +step:354/1680 train_time:30906ms step_avg:87.31ms +step:355/1680 train_time:30993ms step_avg:87.31ms +step:356/1680 train_time:31080ms step_avg:87.30ms +step:357/1680 train_time:31167ms step_avg:87.30ms +step:358/1680 train_time:31254ms step_avg:87.30ms +step:359/1680 train_time:31343ms step_avg:87.31ms +step:360/1680 train_time:31429ms step_avg:87.30ms +step:361/1680 train_time:31517ms step_avg:87.30ms +step:362/1680 train_time:31605ms step_avg:87.31ms +step:363/1680 train_time:31692ms step_avg:87.31ms +step:364/1680 train_time:31780ms step_avg:87.31ms +step:365/1680 train_time:31867ms step_avg:87.31ms +step:366/1680 train_time:31955ms step_avg:87.31ms +step:367/1680 train_time:32042ms step_avg:87.31ms +step:368/1680 train_time:32129ms step_avg:87.31ms +step:369/1680 train_time:32218ms step_avg:87.31ms +step:370/1680 train_time:32305ms step_avg:87.31ms +step:371/1680 train_time:32392ms step_avg:87.31ms +step:372/1680 train_time:32479ms step_avg:87.31ms +step:373/1680 train_time:32566ms step_avg:87.31ms +step:374/1680 train_time:32653ms step_avg:87.31ms +step:375/1680 train_time:32739ms step_avg:87.30ms +step:375/1680 val_loss:3.8197 train_time:32828ms step_avg:87.54ms +step:376/1680 train_time:32846ms step_avg:87.36ms +step:377/1680 train_time:32918ms step_avg:87.31ms +step:378/1680 train_time:33010ms step_avg:87.33ms +step:379/1680 train_time:33099ms step_avg:87.33ms +step:380/1680 train_time:33188ms step_avg:87.34ms +step:381/1680 train_time:33275ms step_avg:87.34ms +step:382/1680 train_time:33361ms step_avg:87.33ms +step:383/1680 train_time:33447ms step_avg:87.33ms +step:384/1680 train_time:33534ms step_avg:87.33ms +step:385/1680 train_time:33621ms step_avg:87.33ms +step:386/1680 train_time:33707ms step_avg:87.32ms +step:387/1680 train_time:33795ms step_avg:87.32ms +step:388/1680 train_time:33882ms step_avg:87.33ms +step:389/1680 train_time:33971ms step_avg:87.33ms +step:390/1680 train_time:34060ms step_avg:87.33ms +step:391/1680 train_time:34148ms step_avg:87.33ms +step:392/1680 train_time:34235ms step_avg:87.33ms +step:393/1680 train_time:34322ms step_avg:87.33ms +step:394/1680 train_time:34409ms step_avg:87.33ms +step:395/1680 train_time:34496ms step_avg:87.33ms +step:396/1680 train_time:34583ms step_avg:87.33ms +step:397/1680 train_time:34670ms step_avg:87.33ms +step:398/1680 train_time:34757ms step_avg:87.33ms +step:399/1680 train_time:34844ms step_avg:87.33ms +step:400/1680 train_time:34931ms step_avg:87.33ms +step:401/1680 train_time:35020ms step_avg:87.33ms +step:402/1680 train_time:35107ms step_avg:87.33ms +step:403/1680 train_time:35194ms step_avg:87.33ms +step:404/1680 train_time:35281ms step_avg:87.33ms +step:405/1680 train_time:35368ms step_avg:87.33ms +step:406/1680 train_time:35455ms step_avg:87.33ms +step:407/1680 train_time:35542ms step_avg:87.33ms +step:408/1680 train_time:35629ms step_avg:87.33ms +step:409/1680 train_time:35716ms step_avg:87.33ms +step:410/1680 train_time:35803ms step_avg:87.32ms +step:411/1680 train_time:35891ms step_avg:87.33ms +step:412/1680 train_time:35980ms step_avg:87.33ms +step:413/1680 train_time:36067ms step_avg:87.33ms +step:414/1680 train_time:36155ms step_avg:87.33ms +step:415/1680 train_time:36242ms step_avg:87.33ms +step:416/1680 train_time:36329ms step_avg:87.33ms +step:417/1680 train_time:36417ms step_avg:87.33ms +step:418/1680 train_time:36504ms step_avg:87.33ms +step:419/1680 train_time:36591ms step_avg:87.33ms +step:420/1680 train_time:36679ms step_avg:87.33ms +step:421/1680 train_time:36766ms step_avg:87.33ms +step:422/1680 train_time:36854ms step_avg:87.33ms +step:423/1680 train_time:36941ms step_avg:87.33ms +step:424/1680 train_time:37029ms step_avg:87.33ms +step:425/1680 train_time:37117ms step_avg:87.33ms +step:426/1680 train_time:37204ms step_avg:87.33ms +step:427/1680 train_time:37291ms step_avg:87.33ms +step:428/1680 train_time:37378ms step_avg:87.33ms +step:429/1680 train_time:37465ms step_avg:87.33ms +step:430/1680 train_time:37552ms step_avg:87.33ms +step:431/1680 train_time:37639ms step_avg:87.33ms +step:432/1680 train_time:37725ms step_avg:87.33ms +step:433/1680 train_time:37813ms step_avg:87.33ms +step:434/1680 train_time:37900ms step_avg:87.33ms +step:435/1680 train_time:37988ms step_avg:87.33ms +step:436/1680 train_time:38075ms step_avg:87.33ms +step:437/1680 train_time:38162ms step_avg:87.33ms +step:438/1680 train_time:38249ms step_avg:87.33ms +step:439/1680 train_time:38337ms step_avg:87.33ms +step:440/1680 train_time:38424ms step_avg:87.33ms +step:441/1680 train_time:38511ms step_avg:87.33ms +step:442/1680 train_time:38598ms step_avg:87.33ms +step:443/1680 train_time:38685ms step_avg:87.32ms +step:444/1680 train_time:38773ms step_avg:87.33ms +step:445/1680 train_time:38860ms step_avg:87.33ms +step:446/1680 train_time:38947ms step_avg:87.32ms +step:447/1680 train_time:39034ms step_avg:87.32ms +step:448/1680 train_time:39122ms step_avg:87.32ms +step:449/1680 train_time:39209ms step_avg:87.32ms +step:450/1680 train_time:39296ms step_avg:87.32ms +step:451/1680 train_time:39383ms step_avg:87.32ms +step:452/1680 train_time:39470ms step_avg:87.32ms +step:453/1680 train_time:39558ms step_avg:87.32ms +step:454/1680 train_time:39644ms step_avg:87.32ms +step:455/1680 train_time:39732ms step_avg:87.32ms +step:456/1680 train_time:39820ms step_avg:87.32ms +step:457/1680 train_time:39908ms step_avg:87.33ms +step:458/1680 train_time:39995ms step_avg:87.33ms +step:459/1680 train_time:40083ms step_avg:87.33ms +step:460/1680 train_time:40170ms step_avg:87.33ms +step:461/1680 train_time:40257ms step_avg:87.33ms +step:462/1680 train_time:40344ms step_avg:87.33ms +step:463/1680 train_time:40432ms step_avg:87.33ms +step:464/1680 train_time:40519ms step_avg:87.33ms +step:465/1680 train_time:40606ms step_avg:87.32ms +step:466/1680 train_time:40693ms step_avg:87.32ms +step:467/1680 train_time:40780ms step_avg:87.32ms +step:468/1680 train_time:40867ms step_avg:87.32ms +step:469/1680 train_time:40955ms step_avg:87.32ms +step:470/1680 train_time:41042ms step_avg:87.32ms +step:471/1680 train_time:41129ms step_avg:87.32ms +step:472/1680 train_time:41217ms step_avg:87.32ms +step:473/1680 train_time:41304ms step_avg:87.32ms +step:474/1680 train_time:41392ms step_avg:87.33ms +step:475/1680 train_time:41480ms step_avg:87.33ms +step:476/1680 train_time:41566ms step_avg:87.32ms +step:477/1680 train_time:41654ms step_avg:87.32ms +step:478/1680 train_time:41741ms step_avg:87.32ms +step:479/1680 train_time:41828ms step_avg:87.32ms +step:480/1680 train_time:41915ms step_avg:87.32ms +step:481/1680 train_time:42002ms step_avg:87.32ms +step:482/1680 train_time:42090ms step_avg:87.32ms +step:483/1680 train_time:42177ms step_avg:87.32ms +step:484/1680 train_time:42264ms step_avg:87.32ms +step:485/1680 train_time:42352ms step_avg:87.32ms +step:486/1680 train_time:42439ms step_avg:87.32ms +step:487/1680 train_time:42525ms step_avg:87.32ms +step:488/1680 train_time:42613ms step_avg:87.32ms +step:489/1680 train_time:42701ms step_avg:87.32ms +step:490/1680 train_time:42788ms step_avg:87.32ms +step:491/1680 train_time:42875ms step_avg:87.32ms +step:492/1680 train_time:42962ms step_avg:87.32ms +step:493/1680 train_time:43049ms step_avg:87.32ms +step:494/1680 train_time:43137ms step_avg:87.32ms +step:495/1680 train_time:43224ms step_avg:87.32ms +step:496/1680 train_time:43311ms step_avg:87.32ms +step:497/1680 train_time:43399ms step_avg:87.32ms +step:498/1680 train_time:43486ms step_avg:87.32ms +step:499/1680 train_time:43574ms step_avg:87.32ms +step:500/1680 train_time:43661ms step_avg:87.32ms +step:500/1680 val_loss:3.7186 train_time:43750ms step_avg:87.50ms +step:501/1680 train_time:43768ms step_avg:87.36ms +step:502/1680 train_time:43838ms step_avg:87.33ms +step:503/1680 train_time:43932ms step_avg:87.34ms +step:504/1680 train_time:44024ms step_avg:87.35ms +step:505/1680 train_time:44110ms step_avg:87.35ms +step:506/1680 train_time:44198ms step_avg:87.35ms +step:507/1680 train_time:44283ms step_avg:87.34ms +step:508/1680 train_time:44370ms step_avg:87.34ms +step:509/1680 train_time:44456ms step_avg:87.34ms +step:510/1680 train_time:44542ms step_avg:87.34ms +step:511/1680 train_time:44628ms step_avg:87.33ms +step:512/1680 train_time:44716ms step_avg:87.34ms +step:513/1680 train_time:44805ms step_avg:87.34ms +step:514/1680 train_time:44894ms step_avg:87.34ms +step:515/1680 train_time:44982ms step_avg:87.34ms +step:516/1680 train_time:45070ms step_avg:87.35ms +step:517/1680 train_time:45158ms step_avg:87.35ms +step:518/1680 train_time:45245ms step_avg:87.34ms +step:519/1680 train_time:45331ms step_avg:87.34ms +step:520/1680 train_time:45417ms step_avg:87.34ms +step:521/1680 train_time:45504ms step_avg:87.34ms +step:522/1680 train_time:45590ms step_avg:87.34ms +step:523/1680 train_time:45677ms step_avg:87.34ms +step:524/1680 train_time:45765ms step_avg:87.34ms +step:525/1680 train_time:45855ms step_avg:87.34ms +step:526/1680 train_time:45943ms step_avg:87.34ms +step:527/1680 train_time:46032ms step_avg:87.35ms +step:528/1680 train_time:46119ms step_avg:87.35ms +step:529/1680 train_time:46206ms step_avg:87.35ms +step:530/1680 train_time:46293ms step_avg:87.35ms +step:531/1680 train_time:46380ms step_avg:87.34ms +step:532/1680 train_time:46467ms step_avg:87.34ms +step:533/1680 train_time:46553ms step_avg:87.34ms +step:534/1680 train_time:46640ms step_avg:87.34ms +step:535/1680 train_time:46727ms step_avg:87.34ms +step:536/1680 train_time:46815ms step_avg:87.34ms +step:537/1680 train_time:46903ms step_avg:87.34ms +step:538/1680 train_time:46991ms step_avg:87.34ms +step:539/1680 train_time:47079ms step_avg:87.34ms +step:540/1680 train_time:47166ms step_avg:87.34ms +step:541/1680 train_time:47254ms step_avg:87.35ms +step:542/1680 train_time:47341ms step_avg:87.34ms +step:543/1680 train_time:47428ms step_avg:87.34ms +step:544/1680 train_time:47514ms step_avg:87.34ms +step:545/1680 train_time:47601ms step_avg:87.34ms +step:546/1680 train_time:47688ms step_avg:87.34ms +step:547/1680 train_time:47775ms step_avg:87.34ms +step:548/1680 train_time:47863ms step_avg:87.34ms +step:549/1680 train_time:47953ms step_avg:87.35ms +step:550/1680 train_time:48041ms step_avg:87.35ms +step:551/1680 train_time:48130ms step_avg:87.35ms +step:552/1680 train_time:48218ms step_avg:87.35ms +step:553/1680 train_time:48307ms step_avg:87.35ms +step:554/1680 train_time:48395ms step_avg:87.36ms +step:555/1680 train_time:48483ms step_avg:87.36ms +step:556/1680 train_time:48571ms step_avg:87.36ms +step:557/1680 train_time:48660ms step_avg:87.36ms +step:558/1680 train_time:48749ms step_avg:87.36ms +step:559/1680 train_time:48838ms step_avg:87.37ms +step:560/1680 train_time:48927ms step_avg:87.37ms +step:561/1680 train_time:49016ms step_avg:87.37ms +step:562/1680 train_time:49105ms step_avg:87.38ms +step:563/1680 train_time:49193ms step_avg:87.38ms +step:564/1680 train_time:49281ms step_avg:87.38ms +step:565/1680 train_time:49369ms step_avg:87.38ms +step:566/1680 train_time:49457ms step_avg:87.38ms +step:567/1680 train_time:49545ms step_avg:87.38ms +step:568/1680 train_time:49634ms step_avg:87.38ms +step:569/1680 train_time:49723ms step_avg:87.39ms +step:570/1680 train_time:49811ms step_avg:87.39ms +step:571/1680 train_time:49899ms step_avg:87.39ms +step:572/1680 train_time:49989ms step_avg:87.39ms +step:573/1680 train_time:50077ms step_avg:87.39ms +step:574/1680 train_time:50166ms step_avg:87.40ms +step:575/1680 train_time:50255ms step_avg:87.40ms +step:576/1680 train_time:50342ms step_avg:87.40ms +step:577/1680 train_time:50431ms step_avg:87.40ms +step:578/1680 train_time:50519ms step_avg:87.40ms +step:579/1680 train_time:50608ms step_avg:87.41ms +step:580/1680 train_time:50696ms step_avg:87.41ms +step:581/1680 train_time:50785ms step_avg:87.41ms +step:582/1680 train_time:50874ms step_avg:87.41ms +step:583/1680 train_time:50962ms step_avg:87.41ms +step:584/1680 train_time:51050ms step_avg:87.41ms +step:585/1680 train_time:51138ms step_avg:87.42ms +step:586/1680 train_time:51227ms step_avg:87.42ms +step:587/1680 train_time:51315ms step_avg:87.42ms +step:588/1680 train_time:51404ms step_avg:87.42ms +step:589/1680 train_time:51492ms step_avg:87.42ms +step:590/1680 train_time:51580ms step_avg:87.42ms +step:591/1680 train_time:51668ms step_avg:87.43ms +step:592/1680 train_time:51757ms step_avg:87.43ms +step:593/1680 train_time:51845ms step_avg:87.43ms +step:594/1680 train_time:51934ms step_avg:87.43ms +step:595/1680 train_time:52022ms step_avg:87.43ms +step:596/1680 train_time:52111ms step_avg:87.44ms +step:597/1680 train_time:52199ms step_avg:87.44ms +step:598/1680 train_time:52288ms step_avg:87.44ms +step:599/1680 train_time:52376ms step_avg:87.44ms +step:600/1680 train_time:52464ms step_avg:87.44ms +step:601/1680 train_time:52553ms step_avg:87.44ms +step:602/1680 train_time:52641ms step_avg:87.44ms +step:603/1680 train_time:52729ms step_avg:87.44ms +step:604/1680 train_time:52818ms step_avg:87.45ms +step:605/1680 train_time:52906ms step_avg:87.45ms +step:606/1680 train_time:52995ms step_avg:87.45ms +step:607/1680 train_time:53083ms step_avg:87.45ms +step:608/1680 train_time:53172ms step_avg:87.45ms +step:609/1680 train_time:53260ms step_avg:87.46ms +step:610/1680 train_time:53348ms step_avg:87.46ms +step:611/1680 train_time:53436ms step_avg:87.46ms +step:612/1680 train_time:53525ms step_avg:87.46ms +step:613/1680 train_time:53614ms step_avg:87.46ms +step:614/1680 train_time:53702ms step_avg:87.46ms +step:615/1680 train_time:53790ms step_avg:87.46ms +step:616/1680 train_time:53878ms step_avg:87.46ms +step:617/1680 train_time:53966ms step_avg:87.47ms +step:618/1680 train_time:54054ms step_avg:87.47ms +step:619/1680 train_time:54143ms step_avg:87.47ms +step:620/1680 train_time:54232ms step_avg:87.47ms +step:621/1680 train_time:54320ms step_avg:87.47ms +step:622/1680 train_time:54409ms step_avg:87.47ms +step:623/1680 train_time:54497ms step_avg:87.47ms +step:624/1680 train_time:54585ms step_avg:87.48ms +step:625/1680 train_time:54673ms step_avg:87.48ms +step:625/1680 val_loss:3.6190 train_time:54763ms step_avg:87.62ms +step:626/1680 train_time:54783ms step_avg:87.51ms +step:627/1680 train_time:54851ms step_avg:87.48ms +step:628/1680 train_time:54940ms step_avg:87.48ms +step:629/1680 train_time:55032ms step_avg:87.49ms +step:630/1680 train_time:55121ms step_avg:87.49ms +step:631/1680 train_time:55207ms step_avg:87.49ms +step:632/1680 train_time:55295ms step_avg:87.49ms +step:633/1680 train_time:55382ms step_avg:87.49ms +step:634/1680 train_time:55469ms step_avg:87.49ms +step:635/1680 train_time:55557ms step_avg:87.49ms +step:636/1680 train_time:55645ms step_avg:87.49ms +step:637/1680 train_time:55737ms step_avg:87.50ms +step:638/1680 train_time:55827ms step_avg:87.50ms +step:639/1680 train_time:55916ms step_avg:87.51ms +step:640/1680 train_time:56005ms step_avg:87.51ms +step:641/1680 train_time:56092ms step_avg:87.51ms +step:642/1680 train_time:56180ms step_avg:87.51ms +step:643/1680 train_time:56268ms step_avg:87.51ms +step:644/1680 train_time:56356ms step_avg:87.51ms +step:645/1680 train_time:56443ms step_avg:87.51ms +step:646/1680 train_time:56529ms step_avg:87.51ms +step:647/1680 train_time:56618ms step_avg:87.51ms +step:648/1680 train_time:56709ms step_avg:87.51ms +step:649/1680 train_time:56799ms step_avg:87.52ms +step:650/1680 train_time:56888ms step_avg:87.52ms +step:651/1680 train_time:56977ms step_avg:87.52ms +step:652/1680 train_time:57066ms step_avg:87.52ms +step:653/1680 train_time:57154ms step_avg:87.53ms +step:654/1680 train_time:57242ms step_avg:87.53ms +step:655/1680 train_time:57330ms step_avg:87.53ms +step:656/1680 train_time:57417ms step_avg:87.53ms +step:657/1680 train_time:57505ms step_avg:87.53ms +step:658/1680 train_time:57593ms step_avg:87.53ms +step:659/1680 train_time:57681ms step_avg:87.53ms +step:660/1680 train_time:57771ms step_avg:87.53ms +step:661/1680 train_time:57861ms step_avg:87.53ms +step:662/1680 train_time:57950ms step_avg:87.54ms +step:663/1680 train_time:58040ms step_avg:87.54ms +step:664/1680 train_time:58128ms step_avg:87.54ms +step:665/1680 train_time:58217ms step_avg:87.54ms +step:666/1680 train_time:58305ms step_avg:87.54ms +step:667/1680 train_time:58392ms step_avg:87.54ms +step:668/1680 train_time:58480ms step_avg:87.55ms +step:669/1680 train_time:58568ms step_avg:87.55ms +step:670/1680 train_time:58657ms step_avg:87.55ms +step:671/1680 train_time:58745ms step_avg:87.55ms +step:672/1680 train_time:58834ms step_avg:87.55ms +step:673/1680 train_time:58922ms step_avg:87.55ms +step:674/1680 train_time:59010ms step_avg:87.55ms +step:675/1680 train_time:59098ms step_avg:87.55ms +step:676/1680 train_time:59187ms step_avg:87.55ms +step:677/1680 train_time:59275ms step_avg:87.56ms +step:678/1680 train_time:59364ms step_avg:87.56ms +step:679/1680 train_time:59451ms step_avg:87.56ms +step:680/1680 train_time:59540ms step_avg:87.56ms +step:681/1680 train_time:59628ms step_avg:87.56ms +step:682/1680 train_time:59716ms step_avg:87.56ms +step:683/1680 train_time:59805ms step_avg:87.56ms +step:684/1680 train_time:59893ms step_avg:87.56ms +step:685/1680 train_time:59982ms step_avg:87.56ms +step:686/1680 train_time:60070ms step_avg:87.57ms +step:687/1680 train_time:60159ms step_avg:87.57ms +step:688/1680 train_time:60247ms step_avg:87.57ms +step:689/1680 train_time:60336ms step_avg:87.57ms +step:690/1680 train_time:60424ms step_avg:87.57ms +step:691/1680 train_time:60512ms step_avg:87.57ms +step:692/1680 train_time:60601ms step_avg:87.57ms +step:693/1680 train_time:60689ms step_avg:87.57ms +step:694/1680 train_time:60778ms step_avg:87.58ms +step:695/1680 train_time:60867ms step_avg:87.58ms +step:696/1680 train_time:60955ms step_avg:87.58ms +step:697/1680 train_time:61044ms step_avg:87.58ms +step:698/1680 train_time:61132ms step_avg:87.58ms +step:699/1680 train_time:61221ms step_avg:87.58ms +step:700/1680 train_time:61309ms step_avg:87.58ms +step:701/1680 train_time:61399ms step_avg:87.59ms +step:702/1680 train_time:61487ms step_avg:87.59ms +step:703/1680 train_time:61575ms step_avg:87.59ms +step:704/1680 train_time:61663ms step_avg:87.59ms +step:705/1680 train_time:61750ms step_avg:87.59ms +step:706/1680 train_time:61839ms step_avg:87.59ms +step:707/1680 train_time:61927ms step_avg:87.59ms +step:708/1680 train_time:62016ms step_avg:87.59ms +step:709/1680 train_time:62106ms step_avg:87.60ms +step:710/1680 train_time:62193ms step_avg:87.60ms +step:711/1680 train_time:62282ms step_avg:87.60ms +step:712/1680 train_time:62370ms step_avg:87.60ms +step:713/1680 train_time:62459ms step_avg:87.60ms +step:714/1680 train_time:62547ms step_avg:87.60ms +step:715/1680 train_time:62635ms step_avg:87.60ms +step:716/1680 train_time:62724ms step_avg:87.60ms +step:717/1680 train_time:62812ms step_avg:87.60ms +step:718/1680 train_time:62901ms step_avg:87.61ms +step:719/1680 train_time:62989ms step_avg:87.61ms +step:720/1680 train_time:63078ms step_avg:87.61ms +step:721/1680 train_time:63166ms step_avg:87.61ms +step:722/1680 train_time:63254ms step_avg:87.61ms +step:723/1680 train_time:63343ms step_avg:87.61ms +step:724/1680 train_time:63431ms step_avg:87.61ms +step:725/1680 train_time:63520ms step_avg:87.61ms +step:726/1680 train_time:63609ms step_avg:87.62ms +step:727/1680 train_time:63698ms step_avg:87.62ms +step:728/1680 train_time:63786ms step_avg:87.62ms +step:729/1680 train_time:63874ms step_avg:87.62ms +step:730/1680 train_time:63963ms step_avg:87.62ms +step:731/1680 train_time:64051ms step_avg:87.62ms +step:732/1680 train_time:64140ms step_avg:87.62ms +step:733/1680 train_time:64228ms step_avg:87.62ms +step:734/1680 train_time:64316ms step_avg:87.62ms +step:735/1680 train_time:64405ms step_avg:87.63ms +step:736/1680 train_time:64493ms step_avg:87.63ms +step:737/1680 train_time:64582ms step_avg:87.63ms +step:738/1680 train_time:64670ms step_avg:87.63ms +step:739/1680 train_time:64759ms step_avg:87.63ms +step:740/1680 train_time:64847ms step_avg:87.63ms +step:741/1680 train_time:64936ms step_avg:87.63ms +step:742/1680 train_time:65024ms step_avg:87.63ms +step:743/1680 train_time:65113ms step_avg:87.63ms +step:744/1680 train_time:65201ms step_avg:87.64ms +step:745/1680 train_time:65290ms step_avg:87.64ms +step:746/1680 train_time:65379ms step_avg:87.64ms +step:747/1680 train_time:65467ms step_avg:87.64ms +step:748/1680 train_time:65556ms step_avg:87.64ms +step:749/1680 train_time:65645ms step_avg:87.64ms +step:750/1680 train_time:65733ms step_avg:87.64ms +step:750/1680 val_loss:3.5680 train_time:65823ms step_avg:87.76ms +step:751/1680 train_time:65841ms step_avg:87.67ms +step:752/1680 train_time:65917ms step_avg:87.66ms +step:753/1680 train_time:66009ms step_avg:87.66ms +step:754/1680 train_time:66098ms step_avg:87.66ms +step:755/1680 train_time:66186ms step_avg:87.66ms +step:756/1680 train_time:66273ms step_avg:87.66ms +step:757/1680 train_time:66361ms step_avg:87.66ms +step:758/1680 train_time:66448ms step_avg:87.66ms +step:759/1680 train_time:66536ms step_avg:87.66ms +step:760/1680 train_time:66624ms step_avg:87.66ms +step:761/1680 train_time:66712ms step_avg:87.66ms +step:762/1680 train_time:66801ms step_avg:87.67ms +step:763/1680 train_time:66891ms step_avg:87.67ms +step:764/1680 train_time:66980ms step_avg:87.67ms +step:765/1680 train_time:67070ms step_avg:87.67ms +step:766/1680 train_time:67158ms step_avg:87.67ms +step:767/1680 train_time:67247ms step_avg:87.68ms +step:768/1680 train_time:67335ms step_avg:87.68ms +step:769/1680 train_time:67422ms step_avg:87.68ms +step:770/1680 train_time:67511ms step_avg:87.68ms +step:771/1680 train_time:67598ms step_avg:87.68ms +step:772/1680 train_time:67686ms step_avg:87.68ms +step:773/1680 train_time:67774ms step_avg:87.68ms +step:774/1680 train_time:67864ms step_avg:87.68ms +step:775/1680 train_time:67954ms step_avg:87.68ms +step:776/1680 train_time:68043ms step_avg:87.68ms +step:777/1680 train_time:68132ms step_avg:87.69ms +step:778/1680 train_time:68220ms step_avg:87.69ms +step:779/1680 train_time:68309ms step_avg:87.69ms +step:780/1680 train_time:68397ms step_avg:87.69ms +step:781/1680 train_time:68485ms step_avg:87.69ms +step:782/1680 train_time:68573ms step_avg:87.69ms +step:783/1680 train_time:68661ms step_avg:87.69ms +step:784/1680 train_time:68749ms step_avg:87.69ms +step:785/1680 train_time:68838ms step_avg:87.69ms +step:786/1680 train_time:68927ms step_avg:87.69ms +step:787/1680 train_time:69017ms step_avg:87.70ms +step:788/1680 train_time:69107ms step_avg:87.70ms +step:789/1680 train_time:69195ms step_avg:87.70ms +step:790/1680 train_time:69284ms step_avg:87.70ms +step:791/1680 train_time:69372ms step_avg:87.70ms +step:792/1680 train_time:69460ms step_avg:87.70ms +step:793/1680 train_time:69548ms step_avg:87.70ms +step:794/1680 train_time:69636ms step_avg:87.70ms +step:795/1680 train_time:69724ms step_avg:87.70ms +step:796/1680 train_time:69813ms step_avg:87.71ms +step:797/1680 train_time:69902ms step_avg:87.71ms +step:798/1680 train_time:69991ms step_avg:87.71ms +step:799/1680 train_time:70079ms step_avg:87.71ms +step:800/1680 train_time:70169ms step_avg:87.71ms +step:801/1680 train_time:70257ms step_avg:87.71ms +step:802/1680 train_time:70345ms step_avg:87.71ms +step:803/1680 train_time:70434ms step_avg:87.71ms +step:804/1680 train_time:70522ms step_avg:87.71ms +step:805/1680 train_time:70610ms step_avg:87.71ms +step:806/1680 train_time:70699ms step_avg:87.72ms +step:807/1680 train_time:70787ms step_avg:87.72ms +step:808/1680 train_time:70875ms step_avg:87.72ms +step:809/1680 train_time:70963ms step_avg:87.72ms +step:810/1680 train_time:71053ms step_avg:87.72ms +step:811/1680 train_time:71142ms step_avg:87.72ms +step:812/1680 train_time:71231ms step_avg:87.72ms +step:813/1680 train_time:71320ms step_avg:87.72ms +step:814/1680 train_time:71408ms step_avg:87.72ms +step:815/1680 train_time:71496ms step_avg:87.73ms +step:816/1680 train_time:71585ms step_avg:87.73ms +step:817/1680 train_time:71673ms step_avg:87.73ms +step:818/1680 train_time:71761ms step_avg:87.73ms +step:819/1680 train_time:71849ms step_avg:87.73ms +step:820/1680 train_time:71937ms step_avg:87.73ms +step:821/1680 train_time:72026ms step_avg:87.73ms +step:822/1680 train_time:72114ms step_avg:87.73ms +step:823/1680 train_time:72203ms step_avg:87.73ms +step:824/1680 train_time:72291ms step_avg:87.73ms +step:825/1680 train_time:72379ms step_avg:87.73ms +step:826/1680 train_time:72467ms step_avg:87.73ms +step:827/1680 train_time:72556ms step_avg:87.73ms +step:828/1680 train_time:72644ms step_avg:87.73ms +step:829/1680 train_time:72732ms step_avg:87.73ms +step:830/1680 train_time:72820ms step_avg:87.73ms +step:831/1680 train_time:72908ms step_avg:87.74ms +step:832/1680 train_time:72997ms step_avg:87.74ms +step:833/1680 train_time:73086ms step_avg:87.74ms +step:834/1680 train_time:73174ms step_avg:87.74ms +step:835/1680 train_time:73262ms step_avg:87.74ms +step:836/1680 train_time:73350ms step_avg:87.74ms +step:837/1680 train_time:73439ms step_avg:87.74ms +step:838/1680 train_time:73528ms step_avg:87.74ms +step:839/1680 train_time:73616ms step_avg:87.74ms +step:840/1680 train_time:73703ms step_avg:87.74ms +step:841/1680 train_time:73792ms step_avg:87.74ms +step:842/1680 train_time:73880ms step_avg:87.74ms +step:843/1680 train_time:73969ms step_avg:87.74ms +step:844/1680 train_time:74058ms step_avg:87.75ms +step:845/1680 train_time:74146ms step_avg:87.75ms +step:846/1680 train_time:74235ms step_avg:87.75ms +step:847/1680 train_time:74324ms step_avg:87.75ms +step:848/1680 train_time:74413ms step_avg:87.75ms +step:849/1680 train_time:74500ms step_avg:87.75ms +step:850/1680 train_time:74588ms step_avg:87.75ms +step:851/1680 train_time:74676ms step_avg:87.75ms +step:852/1680 train_time:74765ms step_avg:87.75ms +step:853/1680 train_time:74853ms step_avg:87.75ms +step:854/1680 train_time:74942ms step_avg:87.75ms +step:855/1680 train_time:75031ms step_avg:87.76ms +step:856/1680 train_time:75119ms step_avg:87.76ms +step:857/1680 train_time:75208ms step_avg:87.76ms +step:858/1680 train_time:75296ms step_avg:87.76ms +step:859/1680 train_time:75385ms step_avg:87.76ms +step:860/1680 train_time:75474ms step_avg:87.76ms +step:861/1680 train_time:75561ms step_avg:87.76ms +step:862/1680 train_time:75649ms step_avg:87.76ms +step:863/1680 train_time:75738ms step_avg:87.76ms +step:864/1680 train_time:75827ms step_avg:87.76ms +step:865/1680 train_time:75915ms step_avg:87.76ms +step:866/1680 train_time:76004ms step_avg:87.76ms +step:867/1680 train_time:76092ms step_avg:87.77ms +step:868/1680 train_time:76181ms step_avg:87.77ms +step:869/1680 train_time:76269ms step_avg:87.77ms +step:870/1680 train_time:76358ms step_avg:87.77ms +step:871/1680 train_time:76447ms step_avg:87.77ms +step:872/1680 train_time:76536ms step_avg:87.77ms +step:873/1680 train_time:76624ms step_avg:87.77ms +step:874/1680 train_time:76713ms step_avg:87.77ms +step:875/1680 train_time:76801ms step_avg:87.77ms +step:875/1680 val_loss:3.5210 train_time:76890ms step_avg:87.87ms +step:876/1680 train_time:76910ms step_avg:87.80ms +step:877/1680 train_time:76982ms step_avg:87.78ms +step:878/1680 train_time:77077ms step_avg:87.79ms +step:879/1680 train_time:77167ms step_avg:87.79ms +step:880/1680 train_time:77254ms step_avg:87.79ms +step:881/1680 train_time:77342ms step_avg:87.79ms +step:882/1680 train_time:77429ms step_avg:87.79ms +step:883/1680 train_time:77516ms step_avg:87.79ms +step:884/1680 train_time:77603ms step_avg:87.79ms +step:885/1680 train_time:77691ms step_avg:87.79ms +step:886/1680 train_time:77779ms step_avg:87.79ms +step:887/1680 train_time:77869ms step_avg:87.79ms +step:888/1680 train_time:77960ms step_avg:87.79ms +step:889/1680 train_time:78049ms step_avg:87.79ms +step:890/1680 train_time:78139ms step_avg:87.80ms +step:891/1680 train_time:78227ms step_avg:87.80ms +step:892/1680 train_time:78315ms step_avg:87.80ms +step:893/1680 train_time:78403ms step_avg:87.80ms +step:894/1680 train_time:78491ms step_avg:87.80ms +step:895/1680 train_time:78578ms step_avg:87.80ms +step:896/1680 train_time:78665ms step_avg:87.80ms +step:897/1680 train_time:78754ms step_avg:87.80ms +step:898/1680 train_time:78841ms step_avg:87.80ms +step:899/1680 train_time:78932ms step_avg:87.80ms +step:900/1680 train_time:79021ms step_avg:87.80ms +step:901/1680 train_time:79111ms step_avg:87.80ms +step:902/1680 train_time:79200ms step_avg:87.80ms +step:903/1680 train_time:79288ms step_avg:87.81ms +step:904/1680 train_time:79376ms step_avg:87.81ms +step:905/1680 train_time:79464ms step_avg:87.81ms +step:906/1680 train_time:79551ms step_avg:87.81ms +step:907/1680 train_time:79639ms step_avg:87.81ms +step:908/1680 train_time:79728ms step_avg:87.81ms +step:909/1680 train_time:79816ms step_avg:87.81ms +step:910/1680 train_time:79904ms step_avg:87.81ms +step:911/1680 train_time:79995ms step_avg:87.81ms +step:912/1680 train_time:80084ms step_avg:87.81ms +step:913/1680 train_time:80173ms step_avg:87.81ms +step:914/1680 train_time:80262ms step_avg:87.81ms +step:915/1680 train_time:80351ms step_avg:87.82ms +step:916/1680 train_time:80439ms step_avg:87.82ms +step:917/1680 train_time:80527ms step_avg:87.82ms +step:918/1680 train_time:80614ms step_avg:87.82ms +step:919/1680 train_time:80703ms step_avg:87.82ms +step:920/1680 train_time:80791ms step_avg:87.82ms +step:921/1680 train_time:80878ms step_avg:87.82ms +step:922/1680 train_time:80968ms step_avg:87.82ms +step:923/1680 train_time:81056ms step_avg:87.82ms +step:924/1680 train_time:81145ms step_avg:87.82ms +step:925/1680 train_time:81235ms step_avg:87.82ms +step:926/1680 train_time:81323ms step_avg:87.82ms +step:927/1680 train_time:81411ms step_avg:87.82ms +step:928/1680 train_time:81499ms step_avg:87.82ms +step:929/1680 train_time:81588ms step_avg:87.82ms +step:930/1680 train_time:81676ms step_avg:87.82ms +step:931/1680 train_time:81764ms step_avg:87.82ms +step:932/1680 train_time:81852ms step_avg:87.82ms +step:933/1680 train_time:81940ms step_avg:87.82ms +step:934/1680 train_time:82028ms step_avg:87.82ms +step:935/1680 train_time:82117ms step_avg:87.83ms +step:936/1680 train_time:82206ms step_avg:87.83ms +step:937/1680 train_time:82295ms step_avg:87.83ms +step:938/1680 train_time:82384ms step_avg:87.83ms +step:939/1680 train_time:82473ms step_avg:87.83ms +step:940/1680 train_time:82562ms step_avg:87.83ms +step:941/1680 train_time:82650ms step_avg:87.83ms +step:942/1680 train_time:82738ms step_avg:87.83ms +step:943/1680 train_time:82826ms step_avg:87.83ms +step:944/1680 train_time:82915ms step_avg:87.83ms +step:945/1680 train_time:83004ms step_avg:87.83ms +step:946/1680 train_time:83092ms step_avg:87.84ms +step:947/1680 train_time:83181ms step_avg:87.84ms +step:948/1680 train_time:83270ms step_avg:87.84ms +step:949/1680 train_time:83359ms step_avg:87.84ms +step:950/1680 train_time:83447ms step_avg:87.84ms +step:951/1680 train_time:83535ms step_avg:87.84ms +step:952/1680 train_time:83623ms step_avg:87.84ms +step:953/1680 train_time:83711ms step_avg:87.84ms +step:954/1680 train_time:83800ms step_avg:87.84ms +step:955/1680 train_time:83889ms step_avg:87.84ms +step:956/1680 train_time:83977ms step_avg:87.84ms +step:957/1680 train_time:84065ms step_avg:87.84ms +step:958/1680 train_time:84154ms step_avg:87.84ms +step:959/1680 train_time:84242ms step_avg:87.84ms +step:960/1680 train_time:84331ms step_avg:87.84ms +step:961/1680 train_time:84419ms step_avg:87.85ms +step:962/1680 train_time:84508ms step_avg:87.85ms +step:963/1680 train_time:84596ms step_avg:87.85ms +step:964/1680 train_time:84684ms step_avg:87.85ms +step:965/1680 train_time:84773ms step_avg:87.85ms +step:966/1680 train_time:84862ms step_avg:87.85ms +step:967/1680 train_time:84951ms step_avg:87.85ms +step:968/1680 train_time:85039ms step_avg:87.85ms +step:969/1680 train_time:85127ms step_avg:87.85ms +step:970/1680 train_time:85215ms step_avg:87.85ms +step:971/1680 train_time:85304ms step_avg:87.85ms +step:972/1680 train_time:85393ms step_avg:87.85ms +step:973/1680 train_time:85482ms step_avg:87.85ms +step:974/1680 train_time:85571ms step_avg:87.86ms +step:975/1680 train_time:85659ms step_avg:87.86ms +step:976/1680 train_time:85748ms step_avg:87.86ms +step:977/1680 train_time:85836ms step_avg:87.86ms +step:978/1680 train_time:85925ms step_avg:87.86ms +step:979/1680 train_time:86014ms step_avg:87.86ms +step:980/1680 train_time:86102ms step_avg:87.86ms +step:981/1680 train_time:86191ms step_avg:87.86ms +step:982/1680 train_time:86279ms step_avg:87.86ms +step:983/1680 train_time:86367ms step_avg:87.86ms +step:984/1680 train_time:86456ms step_avg:87.86ms +step:985/1680 train_time:86544ms step_avg:87.86ms +step:986/1680 train_time:86633ms step_avg:87.86ms +step:987/1680 train_time:86721ms step_avg:87.86ms +step:988/1680 train_time:86810ms step_avg:87.86ms +step:989/1680 train_time:86900ms step_avg:87.87ms +step:990/1680 train_time:86988ms step_avg:87.87ms +step:991/1680 train_time:87077ms step_avg:87.87ms +step:992/1680 train_time:87164ms step_avg:87.87ms +step:993/1680 train_time:87253ms step_avg:87.87ms +step:994/1680 train_time:87341ms step_avg:87.87ms +step:995/1680 train_time:87430ms step_avg:87.87ms +step:996/1680 train_time:87518ms step_avg:87.87ms +step:997/1680 train_time:87607ms step_avg:87.87ms +step:998/1680 train_time:87695ms step_avg:87.87ms +step:999/1680 train_time:87784ms step_avg:87.87ms +step:1000/1680 train_time:87873ms step_avg:87.87ms +step:1000/1680 val_loss:3.4706 train_time:87964ms step_avg:87.96ms +step:1001/1680 train_time:87982ms step_avg:87.89ms +step:1002/1680 train_time:88055ms step_avg:87.88ms +step:1003/1680 train_time:88150ms step_avg:87.89ms +step:1004/1680 train_time:88240ms step_avg:87.89ms +step:1005/1680 train_time:88328ms step_avg:87.89ms +step:1006/1680 train_time:88415ms step_avg:87.89ms +step:1007/1680 train_time:88503ms step_avg:87.89ms +step:1008/1680 train_time:88590ms step_avg:87.89ms +step:1009/1680 train_time:88678ms step_avg:87.89ms +step:1010/1680 train_time:88766ms step_avg:87.89ms +step:1011/1680 train_time:88854ms step_avg:87.89ms +step:1012/1680 train_time:88943ms step_avg:87.89ms +step:1013/1680 train_time:89033ms step_avg:87.89ms +step:1014/1680 train_time:89124ms step_avg:87.89ms +step:1015/1680 train_time:89213ms step_avg:87.89ms +step:1016/1680 train_time:89302ms step_avg:87.90ms +step:1017/1680 train_time:89390ms step_avg:87.90ms +step:1018/1680 train_time:89478ms step_avg:87.90ms +step:1019/1680 train_time:89566ms step_avg:87.90ms +step:1020/1680 train_time:89654ms step_avg:87.90ms +step:1021/1680 train_time:89743ms step_avg:87.90ms +step:1022/1680 train_time:89830ms step_avg:87.90ms +step:1023/1680 train_time:89920ms step_avg:87.90ms +step:1024/1680 train_time:90009ms step_avg:87.90ms +step:1025/1680 train_time:90099ms step_avg:87.90ms +step:1026/1680 train_time:90188ms step_avg:87.90ms +step:1027/1680 train_time:90276ms step_avg:87.90ms +step:1028/1680 train_time:90364ms step_avg:87.90ms +step:1029/1680 train_time:90452ms step_avg:87.90ms +step:1030/1680 train_time:90541ms step_avg:87.90ms +step:1031/1680 train_time:90628ms step_avg:87.90ms +step:1032/1680 train_time:90717ms step_avg:87.90ms +step:1033/1680 train_time:90805ms step_avg:87.90ms +step:1034/1680 train_time:90894ms step_avg:87.90ms +step:1035/1680 train_time:90983ms step_avg:87.91ms +step:1036/1680 train_time:91071ms step_avg:87.91ms +step:1037/1680 train_time:91160ms step_avg:87.91ms +step:1038/1680 train_time:91249ms step_avg:87.91ms +step:1039/1680 train_time:91338ms step_avg:87.91ms +step:1040/1680 train_time:91426ms step_avg:87.91ms +step:1041/1680 train_time:91514ms step_avg:87.91ms +step:1042/1680 train_time:91602ms step_avg:87.91ms +step:1043/1680 train_time:91690ms step_avg:87.91ms +step:1044/1680 train_time:91779ms step_avg:87.91ms +step:1045/1680 train_time:91867ms step_avg:87.91ms +step:1046/1680 train_time:91957ms step_avg:87.91ms +step:1047/1680 train_time:92046ms step_avg:87.91ms +step:1048/1680 train_time:92135ms step_avg:87.92ms +step:1049/1680 train_time:92224ms step_avg:87.92ms +step:1050/1680 train_time:92313ms step_avg:87.92ms +step:1051/1680 train_time:92401ms step_avg:87.92ms +step:1052/1680 train_time:92489ms step_avg:87.92ms +step:1053/1680 train_time:92577ms step_avg:87.92ms +step:1054/1680 train_time:92665ms step_avg:87.92ms +step:1055/1680 train_time:92752ms step_avg:87.92ms +step:1056/1680 train_time:92840ms step_avg:87.92ms +step:1057/1680 train_time:92930ms step_avg:87.92ms +step:1058/1680 train_time:93019ms step_avg:87.92ms +step:1059/1680 train_time:93108ms step_avg:87.92ms +step:1060/1680 train_time:93197ms step_avg:87.92ms +step:1061/1680 train_time:93286ms step_avg:87.92ms +step:1062/1680 train_time:93374ms step_avg:87.92ms +step:1063/1680 train_time:93463ms step_avg:87.92ms +step:1064/1680 train_time:93551ms step_avg:87.92ms +step:1065/1680 train_time:93640ms step_avg:87.92ms +step:1066/1680 train_time:93729ms step_avg:87.93ms +step:1067/1680 train_time:93817ms step_avg:87.93ms +step:1068/1680 train_time:93906ms step_avg:87.93ms +step:1069/1680 train_time:93995ms step_avg:87.93ms +step:1070/1680 train_time:94084ms step_avg:87.93ms +step:1071/1680 train_time:94172ms step_avg:87.93ms +step:1072/1680 train_time:94261ms step_avg:87.93ms +step:1073/1680 train_time:94350ms step_avg:87.93ms +step:1074/1680 train_time:94438ms step_avg:87.93ms +step:1075/1680 train_time:94526ms step_avg:87.93ms +step:1076/1680 train_time:94614ms step_avg:87.93ms +step:1077/1680 train_time:94703ms step_avg:87.93ms +step:1078/1680 train_time:94792ms step_avg:87.93ms +step:1079/1680 train_time:94879ms step_avg:87.93ms +step:1080/1680 train_time:94967ms step_avg:87.93ms +step:1081/1680 train_time:95057ms step_avg:87.93ms +step:1082/1680 train_time:95146ms step_avg:87.94ms +step:1083/1680 train_time:95235ms step_avg:87.94ms +step:1084/1680 train_time:95323ms step_avg:87.94ms +step:1085/1680 train_time:95413ms step_avg:87.94ms +step:1086/1680 train_time:95501ms step_avg:87.94ms +step:1087/1680 train_time:95589ms step_avg:87.94ms +step:1088/1680 train_time:95676ms step_avg:87.94ms +step:1089/1680 train_time:95765ms step_avg:87.94ms +step:1090/1680 train_time:95853ms step_avg:87.94ms +step:1091/1680 train_time:95941ms step_avg:87.94ms +step:1092/1680 train_time:96029ms step_avg:87.94ms +step:1093/1680 train_time:96119ms step_avg:87.94ms +step:1094/1680 train_time:96207ms step_avg:87.94ms +step:1095/1680 train_time:96296ms step_avg:87.94ms +step:1096/1680 train_time:96384ms step_avg:87.94ms +step:1097/1680 train_time:96474ms step_avg:87.94ms +step:1098/1680 train_time:96563ms step_avg:87.94ms +step:1099/1680 train_time:96652ms step_avg:87.95ms +step:1100/1680 train_time:96742ms step_avg:87.95ms +step:1101/1680 train_time:96830ms step_avg:87.95ms +step:1102/1680 train_time:96920ms step_avg:87.95ms +step:1103/1680 train_time:97009ms step_avg:87.95ms +step:1104/1680 train_time:97099ms step_avg:87.95ms +step:1105/1680 train_time:97188ms step_avg:87.95ms +step:1106/1680 train_time:97277ms step_avg:87.95ms +step:1107/1680 train_time:97366ms step_avg:87.95ms +step:1108/1680 train_time:97456ms step_avg:87.96ms +step:1109/1680 train_time:97546ms step_avg:87.96ms +step:1110/1680 train_time:97635ms step_avg:87.96ms +step:1111/1680 train_time:97724ms step_avg:87.96ms +step:1112/1680 train_time:97814ms step_avg:87.96ms +step:1113/1680 train_time:97904ms step_avg:87.96ms +step:1114/1680 train_time:97992ms step_avg:87.96ms +step:1115/1680 train_time:98081ms step_avg:87.96ms +step:1116/1680 train_time:98169ms step_avg:87.97ms +step:1117/1680 train_time:98260ms step_avg:87.97ms +step:1118/1680 train_time:98349ms step_avg:87.97ms +step:1119/1680 train_time:98438ms step_avg:87.97ms +step:1120/1680 train_time:98528ms step_avg:87.97ms +step:1121/1680 train_time:98618ms step_avg:87.97ms +step:1122/1680 train_time:98707ms step_avg:87.97ms +step:1123/1680 train_time:98796ms step_avg:87.97ms +step:1124/1680 train_time:98884ms step_avg:87.98ms +step:1125/1680 train_time:98973ms step_avg:87.98ms +step:1125/1680 val_loss:3.4164 train_time:99064ms step_avg:88.06ms +step:1126/1680 train_time:99083ms step_avg:88.00ms +step:1127/1680 train_time:99154ms step_avg:87.98ms +step:1128/1680 train_time:99244ms step_avg:87.98ms +step:1129/1680 train_time:99336ms step_avg:87.99ms +step:1130/1680 train_time:99427ms step_avg:87.99ms +step:1131/1680 train_time:99514ms step_avg:87.99ms +step:1132/1680 train_time:99602ms step_avg:87.99ms +step:1133/1680 train_time:99690ms step_avg:87.99ms +step:1134/1680 train_time:99779ms step_avg:87.99ms +step:1135/1680 train_time:99869ms step_avg:87.99ms +step:1136/1680 train_time:99959ms step_avg:87.99ms +step:1137/1680 train_time:100049ms step_avg:87.99ms +step:1138/1680 train_time:100141ms step_avg:88.00ms +step:1139/1680 train_time:100231ms step_avg:88.00ms +step:1140/1680 train_time:100322ms step_avg:88.00ms +step:1141/1680 train_time:100411ms step_avg:88.00ms +step:1142/1680 train_time:100501ms step_avg:88.00ms +step:1143/1680 train_time:100589ms step_avg:88.00ms +step:1144/1680 train_time:100677ms step_avg:88.00ms +step:1145/1680 train_time:100766ms step_avg:88.01ms +step:1146/1680 train_time:100855ms step_avg:88.01ms +step:1147/1680 train_time:100943ms step_avg:88.01ms +step:1148/1680 train_time:101033ms step_avg:88.01ms +step:1149/1680 train_time:101123ms step_avg:88.01ms +step:1150/1680 train_time:101212ms step_avg:88.01ms +step:1151/1680 train_time:101302ms step_avg:88.01ms +step:1152/1680 train_time:101391ms step_avg:88.01ms +step:1153/1680 train_time:101481ms step_avg:88.01ms +step:1154/1680 train_time:101570ms step_avg:88.02ms +step:1155/1680 train_time:101659ms step_avg:88.02ms +step:1156/1680 train_time:101748ms step_avg:88.02ms +step:1157/1680 train_time:101837ms step_avg:88.02ms +step:1158/1680 train_time:101926ms step_avg:88.02ms +step:1159/1680 train_time:102015ms step_avg:88.02ms +step:1160/1680 train_time:102106ms step_avg:88.02ms +step:1161/1680 train_time:102197ms step_avg:88.03ms +step:1162/1680 train_time:102286ms step_avg:88.03ms +step:1163/1680 train_time:102377ms step_avg:88.03ms +step:1164/1680 train_time:102466ms step_avg:88.03ms +step:1165/1680 train_time:102556ms step_avg:88.03ms +step:1166/1680 train_time:102644ms step_avg:88.03ms +step:1167/1680 train_time:102733ms step_avg:88.03ms +step:1168/1680 train_time:102821ms step_avg:88.03ms +step:1169/1680 train_time:102910ms step_avg:88.03ms +step:1170/1680 train_time:102999ms step_avg:88.03ms +step:1171/1680 train_time:103089ms step_avg:88.03ms +step:1172/1680 train_time:103180ms step_avg:88.04ms +step:1173/1680 train_time:103270ms step_avg:88.04ms +step:1174/1680 train_time:103359ms step_avg:88.04ms +step:1175/1680 train_time:103449ms step_avg:88.04ms +step:1176/1680 train_time:103537ms step_avg:88.04ms +step:1177/1680 train_time:103626ms step_avg:88.04ms +step:1178/1680 train_time:103715ms step_avg:88.04ms +step:1179/1680 train_time:103804ms step_avg:88.04ms +step:1180/1680 train_time:103893ms step_avg:88.04ms +step:1181/1680 train_time:103982ms step_avg:88.05ms +step:1182/1680 train_time:104071ms step_avg:88.05ms +step:1183/1680 train_time:104160ms step_avg:88.05ms +step:1184/1680 train_time:104249ms step_avg:88.05ms +step:1185/1680 train_time:104338ms step_avg:88.05ms +step:1186/1680 train_time:104427ms step_avg:88.05ms +step:1187/1680 train_time:104516ms step_avg:88.05ms +step:1188/1680 train_time:104605ms step_avg:88.05ms +step:1189/1680 train_time:104695ms step_avg:88.05ms +step:1190/1680 train_time:104784ms step_avg:88.05ms +step:1191/1680 train_time:104873ms step_avg:88.05ms +step:1192/1680 train_time:104962ms step_avg:88.06ms +step:1193/1680 train_time:105051ms step_avg:88.06ms +step:1194/1680 train_time:105140ms step_avg:88.06ms +step:1195/1680 train_time:105229ms step_avg:88.06ms +step:1196/1680 train_time:105319ms step_avg:88.06ms +step:1197/1680 train_time:105407ms step_avg:88.06ms +step:1198/1680 train_time:105497ms step_avg:88.06ms +step:1199/1680 train_time:105587ms step_avg:88.06ms +step:1200/1680 train_time:105677ms step_avg:88.06ms +step:1201/1680 train_time:105766ms step_avg:88.06ms +step:1202/1680 train_time:105856ms step_avg:88.07ms +step:1203/1680 train_time:105945ms step_avg:88.07ms +step:1204/1680 train_time:106035ms step_avg:88.07ms +step:1205/1680 train_time:106123ms step_avg:88.07ms +step:1206/1680 train_time:106213ms step_avg:88.07ms +step:1207/1680 train_time:106302ms step_avg:88.07ms +step:1208/1680 train_time:106391ms step_avg:88.07ms +step:1209/1680 train_time:106480ms step_avg:88.07ms +step:1210/1680 train_time:106570ms step_avg:88.07ms +step:1211/1680 train_time:106659ms step_avg:88.08ms +step:1212/1680 train_time:106749ms step_avg:88.08ms +step:1213/1680 train_time:106837ms step_avg:88.08ms +step:1214/1680 train_time:106927ms step_avg:88.08ms +step:1215/1680 train_time:107016ms step_avg:88.08ms +step:1216/1680 train_time:107105ms step_avg:88.08ms +step:1217/1680 train_time:107194ms step_avg:88.08ms +step:1218/1680 train_time:107283ms step_avg:88.08ms +step:1219/1680 train_time:107371ms step_avg:88.08ms +step:1220/1680 train_time:107460ms step_avg:88.08ms +step:1221/1680 train_time:107549ms step_avg:88.08ms +step:1222/1680 train_time:107639ms step_avg:88.08ms +step:1223/1680 train_time:107728ms step_avg:88.08ms +step:1224/1680 train_time:107817ms step_avg:88.09ms +step:1225/1680 train_time:107906ms step_avg:88.09ms +step:1226/1680 train_time:107997ms step_avg:88.09ms +step:1227/1680 train_time:108085ms step_avg:88.09ms +step:1228/1680 train_time:108175ms step_avg:88.09ms +step:1229/1680 train_time:108263ms step_avg:88.09ms +step:1230/1680 train_time:108353ms step_avg:88.09ms +step:1231/1680 train_time:108442ms step_avg:88.09ms +step:1232/1680 train_time:108531ms step_avg:88.09ms +step:1233/1680 train_time:108620ms step_avg:88.09ms +step:1234/1680 train_time:108710ms step_avg:88.10ms +step:1235/1680 train_time:108800ms step_avg:88.10ms +step:1236/1680 train_time:108889ms step_avg:88.10ms +step:1237/1680 train_time:108979ms step_avg:88.10ms +step:1238/1680 train_time:109069ms step_avg:88.10ms +step:1239/1680 train_time:109160ms step_avg:88.10ms +step:1240/1680 train_time:109249ms step_avg:88.10ms +step:1241/1680 train_time:109339ms step_avg:88.11ms +step:1242/1680 train_time:109427ms step_avg:88.11ms +step:1243/1680 train_time:109517ms step_avg:88.11ms +step:1244/1680 train_time:109606ms step_avg:88.11ms +step:1245/1680 train_time:109695ms step_avg:88.11ms +step:1246/1680 train_time:109784ms step_avg:88.11ms +step:1247/1680 train_time:109873ms step_avg:88.11ms +step:1248/1680 train_time:109963ms step_avg:88.11ms +step:1249/1680 train_time:110052ms step_avg:88.11ms +step:1250/1680 train_time:110142ms step_avg:88.11ms +step:1250/1680 val_loss:3.3775 train_time:110232ms step_avg:88.19ms +step:1251/1680 train_time:110250ms step_avg:88.13ms +step:1252/1680 train_time:110326ms step_avg:88.12ms +step:1253/1680 train_time:110420ms step_avg:88.12ms +step:1254/1680 train_time:110510ms step_avg:88.13ms +step:1255/1680 train_time:110599ms step_avg:88.13ms +step:1256/1680 train_time:110688ms step_avg:88.13ms +step:1257/1680 train_time:110776ms step_avg:88.13ms +step:1258/1680 train_time:110863ms step_avg:88.13ms +step:1259/1680 train_time:110951ms step_avg:88.13ms +step:1260/1680 train_time:111040ms step_avg:88.13ms +step:1261/1680 train_time:111129ms step_avg:88.13ms +step:1262/1680 train_time:111220ms step_avg:88.13ms +step:1263/1680 train_time:111311ms step_avg:88.13ms +step:1264/1680 train_time:111403ms step_avg:88.14ms +step:1265/1680 train_time:111493ms step_avg:88.14ms +step:1266/1680 train_time:111582ms step_avg:88.14ms +step:1267/1680 train_time:111671ms step_avg:88.14ms +step:1268/1680 train_time:111759ms step_avg:88.14ms +step:1269/1680 train_time:111847ms step_avg:88.14ms +step:1270/1680 train_time:111936ms step_avg:88.14ms +step:1271/1680 train_time:112023ms step_avg:88.14ms +step:1272/1680 train_time:112112ms step_avg:88.14ms +step:1273/1680 train_time:112202ms step_avg:88.14ms +step:1274/1680 train_time:112293ms step_avg:88.14ms +step:1275/1680 train_time:112384ms step_avg:88.14ms +step:1276/1680 train_time:112474ms step_avg:88.15ms +step:1277/1680 train_time:112564ms step_avg:88.15ms +step:1278/1680 train_time:112652ms step_avg:88.15ms +step:1279/1680 train_time:112741ms step_avg:88.15ms +step:1280/1680 train_time:112830ms step_avg:88.15ms +step:1281/1680 train_time:112918ms step_avg:88.15ms +step:1282/1680 train_time:113008ms step_avg:88.15ms +step:1283/1680 train_time:113097ms step_avg:88.15ms +step:1284/1680 train_time:113186ms step_avg:88.15ms +step:1285/1680 train_time:113276ms step_avg:88.15ms +step:1286/1680 train_time:113366ms step_avg:88.15ms +step:1287/1680 train_time:113457ms step_avg:88.16ms +step:1288/1680 train_time:113546ms step_avg:88.16ms +step:1289/1680 train_time:113636ms step_avg:88.16ms +step:1290/1680 train_time:113725ms step_avg:88.16ms +step:1291/1680 train_time:113814ms step_avg:88.16ms +step:1292/1680 train_time:113902ms step_avg:88.16ms +step:1293/1680 train_time:113991ms step_avg:88.16ms +step:1294/1680 train_time:114080ms step_avg:88.16ms +step:1295/1680 train_time:114170ms step_avg:88.16ms +step:1296/1680 train_time:114260ms step_avg:88.16ms +step:1297/1680 train_time:114350ms step_avg:88.17ms +step:1298/1680 train_time:114439ms step_avg:88.17ms +step:1299/1680 train_time:114529ms step_avg:88.17ms +step:1300/1680 train_time:114617ms step_avg:88.17ms +step:1301/1680 train_time:114707ms step_avg:88.17ms +step:1302/1680 train_time:114798ms step_avg:88.17ms +step:1303/1680 train_time:114887ms step_avg:88.17ms +step:1304/1680 train_time:114976ms step_avg:88.17ms +step:1305/1680 train_time:115064ms step_avg:88.17ms +step:1306/1680 train_time:115154ms step_avg:88.17ms +step:1307/1680 train_time:115242ms step_avg:88.17ms +step:1308/1680 train_time:115332ms step_avg:88.17ms +step:1309/1680 train_time:115421ms step_avg:88.18ms +step:1310/1680 train_time:115512ms step_avg:88.18ms +step:1311/1680 train_time:115601ms step_avg:88.18ms +step:1312/1680 train_time:115690ms step_avg:88.18ms +step:1313/1680 train_time:115779ms step_avg:88.18ms +step:1314/1680 train_time:115868ms step_avg:88.18ms +step:1315/1680 train_time:115957ms step_avg:88.18ms +step:1316/1680 train_time:116046ms step_avg:88.18ms +step:1317/1680 train_time:116135ms step_avg:88.18ms +step:1318/1680 train_time:116225ms step_avg:88.18ms +step:1319/1680 train_time:116314ms step_avg:88.18ms +step:1320/1680 train_time:116404ms step_avg:88.18ms +step:1321/1680 train_time:116494ms step_avg:88.19ms +step:1322/1680 train_time:116583ms step_avg:88.19ms +step:1323/1680 train_time:116672ms step_avg:88.19ms +step:1324/1680 train_time:116761ms step_avg:88.19ms +step:1325/1680 train_time:116850ms step_avg:88.19ms +step:1326/1680 train_time:116939ms step_avg:88.19ms +step:1327/1680 train_time:117028ms step_avg:88.19ms +step:1328/1680 train_time:117118ms step_avg:88.19ms +step:1329/1680 train_time:117207ms step_avg:88.19ms +step:1330/1680 train_time:117297ms step_avg:88.19ms +step:1331/1680 train_time:117387ms step_avg:88.19ms +step:1332/1680 train_time:117476ms step_avg:88.20ms +step:1333/1680 train_time:117565ms step_avg:88.20ms +step:1334/1680 train_time:117655ms step_avg:88.20ms +step:1335/1680 train_time:117744ms step_avg:88.20ms +step:1336/1680 train_time:117833ms step_avg:88.20ms +step:1337/1680 train_time:117923ms step_avg:88.20ms +step:1338/1680 train_time:118012ms step_avg:88.20ms +step:1339/1680 train_time:118102ms step_avg:88.20ms +step:1340/1680 train_time:118192ms step_avg:88.20ms +step:1341/1680 train_time:118281ms step_avg:88.20ms +step:1342/1680 train_time:118370ms step_avg:88.20ms +step:1343/1680 train_time:118460ms step_avg:88.21ms +step:1344/1680 train_time:118549ms step_avg:88.21ms +step:1345/1680 train_time:118638ms step_avg:88.21ms +step:1346/1680 train_time:118728ms step_avg:88.21ms +step:1347/1680 train_time:118818ms step_avg:88.21ms +step:1348/1680 train_time:118907ms step_avg:88.21ms +step:1349/1680 train_time:118997ms step_avg:88.21ms +step:1350/1680 train_time:119087ms step_avg:88.21ms +step:1351/1680 train_time:119177ms step_avg:88.21ms +step:1352/1680 train_time:119267ms step_avg:88.22ms +step:1353/1680 train_time:119358ms step_avg:88.22ms +step:1354/1680 train_time:119447ms step_avg:88.22ms +step:1355/1680 train_time:119536ms step_avg:88.22ms +step:1356/1680 train_time:119625ms step_avg:88.22ms +step:1357/1680 train_time:119715ms step_avg:88.22ms +step:1358/1680 train_time:119805ms step_avg:88.22ms +step:1359/1680 train_time:119893ms step_avg:88.22ms +step:1360/1680 train_time:119982ms step_avg:88.22ms +step:1361/1680 train_time:120071ms step_avg:88.22ms +step:1362/1680 train_time:120161ms step_avg:88.22ms +step:1363/1680 train_time:120250ms step_avg:88.22ms +step:1364/1680 train_time:120340ms step_avg:88.23ms +step:1365/1680 train_time:120428ms step_avg:88.23ms +step:1366/1680 train_time:120517ms step_avg:88.23ms +step:1367/1680 train_time:120607ms step_avg:88.23ms +step:1368/1680 train_time:120697ms step_avg:88.23ms +step:1369/1680 train_time:120787ms step_avg:88.23ms +step:1370/1680 train_time:120876ms step_avg:88.23ms +step:1371/1680 train_time:120965ms step_avg:88.23ms +step:1372/1680 train_time:121054ms step_avg:88.23ms +step:1373/1680 train_time:121144ms step_avg:88.23ms +step:1374/1680 train_time:121233ms step_avg:88.23ms +step:1375/1680 train_time:121323ms step_avg:88.23ms +step:1375/1680 val_loss:3.3429 train_time:121413ms step_avg:88.30ms +step:1376/1680 train_time:121431ms step_avg:88.25ms +step:1377/1680 train_time:121504ms step_avg:88.24ms +step:1378/1680 train_time:121594ms step_avg:88.24ms +step:1379/1680 train_time:121683ms step_avg:88.24ms +step:1380/1680 train_time:121771ms step_avg:88.24ms +step:1381/1680 train_time:121860ms step_avg:88.24ms +step:1382/1680 train_time:121948ms step_avg:88.24ms +step:1383/1680 train_time:122037ms step_avg:88.24ms +step:1384/1680 train_time:122126ms step_avg:88.24ms +step:1385/1680 train_time:122214ms step_avg:88.24ms +step:1386/1680 train_time:122304ms step_avg:88.24ms +step:1387/1680 train_time:122394ms step_avg:88.24ms +step:1388/1680 train_time:122485ms step_avg:88.25ms +step:1389/1680 train_time:122576ms step_avg:88.25ms +step:1390/1680 train_time:122666ms step_avg:88.25ms +step:1391/1680 train_time:122755ms step_avg:88.25ms +step:1392/1680 train_time:122844ms step_avg:88.25ms +step:1393/1680 train_time:122933ms step_avg:88.25ms +step:1394/1680 train_time:123021ms step_avg:88.25ms +step:1395/1680 train_time:123110ms step_avg:88.25ms +step:1396/1680 train_time:123198ms step_avg:88.25ms +step:1397/1680 train_time:123287ms step_avg:88.25ms +step:1398/1680 train_time:123378ms step_avg:88.25ms +step:1399/1680 train_time:123468ms step_avg:88.25ms +step:1400/1680 train_time:123558ms step_avg:88.26ms +step:1401/1680 train_time:123648ms step_avg:88.26ms +step:1402/1680 train_time:123737ms step_avg:88.26ms +step:1403/1680 train_time:123827ms step_avg:88.26ms +step:1404/1680 train_time:123916ms step_avg:88.26ms +step:1405/1680 train_time:124004ms step_avg:88.26ms +step:1406/1680 train_time:124092ms step_avg:88.26ms +step:1407/1680 train_time:124181ms step_avg:88.26ms +step:1408/1680 train_time:124270ms step_avg:88.26ms +step:1409/1680 train_time:124360ms step_avg:88.26ms +step:1410/1680 train_time:124451ms step_avg:88.26ms +step:1411/1680 train_time:124540ms step_avg:88.26ms +step:1412/1680 train_time:124631ms step_avg:88.27ms +step:1413/1680 train_time:124720ms step_avg:88.27ms +step:1414/1680 train_time:124809ms step_avg:88.27ms +step:1415/1680 train_time:124898ms step_avg:88.27ms +step:1416/1680 train_time:124986ms step_avg:88.27ms +step:1417/1680 train_time:125075ms step_avg:88.27ms +step:1418/1680 train_time:125164ms step_avg:88.27ms +step:1419/1680 train_time:125253ms step_avg:88.27ms +step:1420/1680 train_time:125342ms step_avg:88.27ms +step:1421/1680 train_time:125431ms step_avg:88.27ms +step:1422/1680 train_time:125521ms step_avg:88.27ms +step:1423/1680 train_time:125611ms step_avg:88.27ms +step:1424/1680 train_time:125700ms step_avg:88.27ms +step:1425/1680 train_time:125790ms step_avg:88.27ms +step:1426/1680 train_time:125879ms step_avg:88.27ms +step:1427/1680 train_time:125969ms step_avg:88.28ms +step:1428/1680 train_time:126058ms step_avg:88.28ms +step:1429/1680 train_time:126147ms step_avg:88.28ms +step:1430/1680 train_time:126236ms step_avg:88.28ms +step:1431/1680 train_time:126325ms step_avg:88.28ms +step:1432/1680 train_time:126414ms step_avg:88.28ms +step:1433/1680 train_time:126504ms step_avg:88.28ms +step:1434/1680 train_time:126593ms step_avg:88.28ms +step:1435/1680 train_time:126682ms step_avg:88.28ms +step:1436/1680 train_time:126772ms step_avg:88.28ms +step:1437/1680 train_time:126862ms step_avg:88.28ms +step:1438/1680 train_time:126951ms step_avg:88.28ms +step:1439/1680 train_time:127040ms step_avg:88.28ms +step:1440/1680 train_time:127129ms step_avg:88.28ms +step:1441/1680 train_time:127218ms step_avg:88.28ms +step:1442/1680 train_time:127306ms step_avg:88.28ms +step:1443/1680 train_time:127396ms step_avg:88.29ms +step:1444/1680 train_time:127485ms step_avg:88.29ms +step:1445/1680 train_time:127574ms step_avg:88.29ms +step:1446/1680 train_time:127664ms step_avg:88.29ms +step:1447/1680 train_time:127753ms step_avg:88.29ms +step:1448/1680 train_time:127843ms step_avg:88.29ms +step:1449/1680 train_time:127933ms step_avg:88.29ms +step:1450/1680 train_time:128023ms step_avg:88.29ms +step:1451/1680 train_time:128112ms step_avg:88.29ms +step:1452/1680 train_time:128201ms step_avg:88.29ms +step:1453/1680 train_time:128290ms step_avg:88.29ms +step:1454/1680 train_time:128379ms step_avg:88.29ms +step:1455/1680 train_time:128469ms step_avg:88.29ms +step:1456/1680 train_time:128559ms step_avg:88.30ms +step:1457/1680 train_time:128648ms step_avg:88.30ms +step:1458/1680 train_time:128738ms step_avg:88.30ms +step:1459/1680 train_time:128828ms step_avg:88.30ms +step:1460/1680 train_time:128917ms step_avg:88.30ms +step:1461/1680 train_time:129007ms step_avg:88.30ms +step:1462/1680 train_time:129096ms step_avg:88.30ms +step:1463/1680 train_time:129185ms step_avg:88.30ms +step:1464/1680 train_time:129274ms step_avg:88.30ms +step:1465/1680 train_time:129363ms step_avg:88.30ms +step:1466/1680 train_time:129452ms step_avg:88.30ms +step:1467/1680 train_time:129541ms step_avg:88.30ms +step:1468/1680 train_time:129631ms step_avg:88.30ms +step:1469/1680 train_time:129721ms step_avg:88.31ms +step:1470/1680 train_time:129809ms step_avg:88.31ms +step:1471/1680 train_time:129898ms step_avg:88.31ms +step:1472/1680 train_time:129988ms step_avg:88.31ms +step:1473/1680 train_time:130077ms step_avg:88.31ms +step:1474/1680 train_time:130167ms step_avg:88.31ms +step:1475/1680 train_time:130256ms step_avg:88.31ms +step:1476/1680 train_time:130346ms step_avg:88.31ms +step:1477/1680 train_time:130435ms step_avg:88.31ms +step:1478/1680 train_time:130524ms step_avg:88.31ms +step:1479/1680 train_time:130613ms step_avg:88.31ms +step:1480/1680 train_time:130703ms step_avg:88.31ms +step:1481/1680 train_time:130792ms step_avg:88.31ms +step:1482/1680 train_time:130881ms step_avg:88.31ms +step:1483/1680 train_time:130971ms step_avg:88.31ms +step:1484/1680 train_time:131059ms step_avg:88.31ms +step:1485/1680 train_time:131150ms step_avg:88.32ms +step:1486/1680 train_time:131240ms step_avg:88.32ms +step:1487/1680 train_time:131329ms step_avg:88.32ms +step:1488/1680 train_time:131417ms step_avg:88.32ms +step:1489/1680 train_time:131507ms step_avg:88.32ms +step:1490/1680 train_time:131595ms step_avg:88.32ms +step:1491/1680 train_time:131685ms step_avg:88.32ms +step:1492/1680 train_time:131774ms step_avg:88.32ms +step:1493/1680 train_time:131864ms step_avg:88.32ms +step:1494/1680 train_time:131953ms step_avg:88.32ms +step:1495/1680 train_time:132042ms step_avg:88.32ms +step:1496/1680 train_time:132131ms step_avg:88.32ms +step:1497/1680 train_time:132220ms step_avg:88.32ms +step:1498/1680 train_time:132310ms step_avg:88.32ms +step:1499/1680 train_time:132400ms step_avg:88.33ms +step:1500/1680 train_time:132490ms step_avg:88.33ms +step:1500/1680 val_loss:3.3130 train_time:132580ms step_avg:88.39ms +step:1501/1680 train_time:132598ms step_avg:88.34ms +step:1502/1680 train_time:132673ms step_avg:88.33ms +step:1503/1680 train_time:132766ms step_avg:88.33ms +step:1504/1680 train_time:132857ms step_avg:88.34ms +step:1505/1680 train_time:132946ms step_avg:88.34ms +step:1506/1680 train_time:133034ms step_avg:88.34ms +step:1507/1680 train_time:133122ms step_avg:88.34ms +step:1508/1680 train_time:133211ms step_avg:88.34ms +step:1509/1680 train_time:133299ms step_avg:88.34ms +step:1510/1680 train_time:133388ms step_avg:88.34ms +step:1511/1680 train_time:133477ms step_avg:88.34ms +step:1512/1680 train_time:133568ms step_avg:88.34ms +step:1513/1680 train_time:133658ms step_avg:88.34ms +step:1514/1680 train_time:133750ms step_avg:88.34ms +step:1515/1680 train_time:133841ms step_avg:88.34ms +step:1516/1680 train_time:133930ms step_avg:88.34ms +step:1517/1680 train_time:134018ms step_avg:88.34ms +step:1518/1680 train_time:134108ms step_avg:88.35ms +step:1519/1680 train_time:134197ms step_avg:88.35ms +step:1520/1680 train_time:134286ms step_avg:88.35ms +step:1521/1680 train_time:134374ms step_avg:88.35ms +step:1522/1680 train_time:134463ms step_avg:88.35ms +step:1523/1680 train_time:134552ms step_avg:88.35ms +step:1524/1680 train_time:134642ms step_avg:88.35ms +step:1525/1680 train_time:134733ms step_avg:88.35ms +step:1526/1680 train_time:134823ms step_avg:88.35ms +step:1527/1680 train_time:134913ms step_avg:88.35ms +step:1528/1680 train_time:135003ms step_avg:88.35ms +step:1529/1680 train_time:135092ms step_avg:88.35ms +step:1530/1680 train_time:135181ms step_avg:88.35ms +step:1531/1680 train_time:135270ms step_avg:88.35ms +step:1532/1680 train_time:135359ms step_avg:88.35ms +step:1533/1680 train_time:135448ms step_avg:88.35ms +step:1534/1680 train_time:135537ms step_avg:88.36ms +step:1535/1680 train_time:135627ms step_avg:88.36ms +step:1536/1680 train_time:135716ms step_avg:88.36ms +step:1537/1680 train_time:135806ms step_avg:88.36ms +step:1538/1680 train_time:135896ms step_avg:88.36ms +step:1539/1680 train_time:135985ms step_avg:88.36ms +step:1540/1680 train_time:136074ms step_avg:88.36ms +step:1541/1680 train_time:136162ms step_avg:88.36ms +step:1542/1680 train_time:136251ms step_avg:88.36ms +step:1543/1680 train_time:136340ms step_avg:88.36ms +step:1544/1680 train_time:136429ms step_avg:88.36ms +step:1545/1680 train_time:136519ms step_avg:88.36ms +step:1546/1680 train_time:136609ms step_avg:88.36ms +step:1547/1680 train_time:136697ms step_avg:88.36ms +step:1548/1680 train_time:136787ms step_avg:88.36ms +step:1549/1680 train_time:136876ms step_avg:88.36ms +step:1550/1680 train_time:136966ms step_avg:88.37ms +step:1551/1680 train_time:137055ms step_avg:88.37ms +step:1552/1680 train_time:137144ms step_avg:88.37ms +step:1553/1680 train_time:137233ms step_avg:88.37ms +step:1554/1680 train_time:137321ms step_avg:88.37ms +step:1555/1680 train_time:137411ms step_avg:88.37ms +step:1556/1680 train_time:137500ms step_avg:88.37ms +step:1557/1680 train_time:137590ms step_avg:88.37ms +step:1558/1680 train_time:137680ms step_avg:88.37ms +step:1559/1680 train_time:137772ms step_avg:88.37ms +step:1560/1680 train_time:137863ms step_avg:88.37ms +step:1561/1680 train_time:137952ms step_avg:88.37ms +step:1562/1680 train_time:138041ms step_avg:88.37ms +step:1563/1680 train_time:138131ms step_avg:88.38ms +step:1564/1680 train_time:138219ms step_avg:88.38ms +step:1565/1680 train_time:138308ms step_avg:88.38ms +step:1566/1680 train_time:138397ms step_avg:88.38ms +step:1567/1680 train_time:138487ms step_avg:88.38ms +step:1568/1680 train_time:138576ms step_avg:88.38ms +step:1569/1680 train_time:138665ms step_avg:88.38ms +step:1570/1680 train_time:138755ms step_avg:88.38ms +step:1571/1680 train_time:138844ms step_avg:88.38ms +step:1572/1680 train_time:138932ms step_avg:88.38ms +step:1573/1680 train_time:139022ms step_avg:88.38ms +step:1574/1680 train_time:139112ms step_avg:88.38ms +step:1575/1680 train_time:139201ms step_avg:88.38ms +step:1576/1680 train_time:139290ms step_avg:88.38ms +step:1577/1680 train_time:139380ms step_avg:88.38ms +step:1578/1680 train_time:139470ms step_avg:88.38ms +step:1579/1680 train_time:139559ms step_avg:88.38ms +step:1580/1680 train_time:139648ms step_avg:88.38ms +step:1581/1680 train_time:139737ms step_avg:88.39ms +step:1582/1680 train_time:139826ms step_avg:88.39ms +step:1583/1680 train_time:139915ms step_avg:88.39ms +step:1584/1680 train_time:140005ms step_avg:88.39ms +step:1585/1680 train_time:140094ms step_avg:88.39ms +step:1586/1680 train_time:140184ms step_avg:88.39ms +step:1587/1680 train_time:140273ms step_avg:88.39ms +step:1588/1680 train_time:140363ms step_avg:88.39ms +step:1589/1680 train_time:140453ms step_avg:88.39ms +step:1590/1680 train_time:140543ms step_avg:88.39ms +step:1591/1680 train_time:140633ms step_avg:88.39ms +step:1592/1680 train_time:140721ms step_avg:88.39ms +step:1593/1680 train_time:140812ms step_avg:88.39ms +step:1594/1680 train_time:140901ms step_avg:88.39ms +step:1595/1680 train_time:140992ms step_avg:88.40ms +step:1596/1680 train_time:141082ms step_avg:88.40ms +step:1597/1680 train_time:141172ms step_avg:88.40ms +step:1598/1680 train_time:141262ms step_avg:88.40ms +step:1599/1680 train_time:141350ms step_avg:88.40ms +step:1600/1680 train_time:141439ms step_avg:88.40ms +step:1601/1680 train_time:141528ms step_avg:88.40ms +step:1602/1680 train_time:141616ms step_avg:88.40ms +step:1603/1680 train_time:141705ms step_avg:88.40ms +step:1604/1680 train_time:141795ms step_avg:88.40ms +step:1605/1680 train_time:141884ms step_avg:88.40ms +step:1606/1680 train_time:141974ms step_avg:88.40ms +step:1607/1680 train_time:142064ms step_avg:88.40ms +step:1608/1680 train_time:142153ms step_avg:88.40ms +step:1609/1680 train_time:142243ms step_avg:88.40ms +step:1610/1680 train_time:142333ms step_avg:88.41ms +step:1611/1680 train_time:142422ms step_avg:88.41ms +step:1612/1680 train_time:142512ms step_avg:88.41ms +step:1613/1680 train_time:142601ms step_avg:88.41ms +step:1614/1680 train_time:142691ms step_avg:88.41ms +step:1615/1680 train_time:142780ms step_avg:88.41ms +step:1616/1680 train_time:142870ms step_avg:88.41ms +step:1617/1680 train_time:142959ms step_avg:88.41ms +step:1618/1680 train_time:143048ms step_avg:88.41ms +step:1619/1680 train_time:143137ms step_avg:88.41ms +step:1620/1680 train_time:143227ms step_avg:88.41ms +step:1621/1680 train_time:143316ms step_avg:88.41ms +step:1622/1680 train_time:143406ms step_avg:88.41ms +step:1623/1680 train_time:143496ms step_avg:88.41ms +step:1624/1680 train_time:143586ms step_avg:88.42ms +step:1625/1680 train_time:143676ms step_avg:88.42ms +step:1625/1680 val_loss:3.2891 train_time:143767ms step_avg:88.47ms +step:1626/1680 train_time:143786ms step_avg:88.43ms +step:1627/1680 train_time:143859ms step_avg:88.42ms +step:1628/1680 train_time:143952ms step_avg:88.42ms +step:1629/1680 train_time:144041ms step_avg:88.42ms +step:1630/1680 train_time:144129ms step_avg:88.42ms +step:1631/1680 train_time:144217ms step_avg:88.42ms +step:1632/1680 train_time:144306ms step_avg:88.42ms +step:1633/1680 train_time:144395ms step_avg:88.42ms +step:1634/1680 train_time:144483ms step_avg:88.42ms +step:1635/1680 train_time:144573ms step_avg:88.42ms +step:1636/1680 train_time:144662ms step_avg:88.42ms +step:1637/1680 train_time:144752ms step_avg:88.43ms +step:1638/1680 train_time:144843ms step_avg:88.43ms +step:1639/1680 train_time:144933ms step_avg:88.43ms +step:1640/1680 train_time:145024ms step_avg:88.43ms +step:1641/1680 train_time:145115ms step_avg:88.43ms +step:1642/1680 train_time:145205ms step_avg:88.43ms +step:1643/1680 train_time:145294ms step_avg:88.43ms +step:1644/1680 train_time:145383ms step_avg:88.43ms +step:1645/1680 train_time:145471ms step_avg:88.43ms +step:1646/1680 train_time:145561ms step_avg:88.43ms +step:1647/1680 train_time:145649ms step_avg:88.43ms +step:1648/1680 train_time:145739ms step_avg:88.43ms +step:1649/1680 train_time:145830ms step_avg:88.44ms +step:1650/1680 train_time:145919ms step_avg:88.44ms +step:1651/1680 train_time:146009ms step_avg:88.44ms +step:1652/1680 train_time:146099ms step_avg:88.44ms +step:1653/1680 train_time:146188ms step_avg:88.44ms +step:1654/1680 train_time:146277ms step_avg:88.44ms +step:1655/1680 train_time:146366ms step_avg:88.44ms +step:1656/1680 train_time:146455ms step_avg:88.44ms +step:1657/1680 train_time:146544ms step_avg:88.44ms +step:1658/1680 train_time:146633ms step_avg:88.44ms +step:1659/1680 train_time:146724ms step_avg:88.44ms +step:1660/1680 train_time:146814ms step_avg:88.44ms +step:1661/1680 train_time:146905ms step_avg:88.44ms +step:1662/1680 train_time:146995ms step_avg:88.44ms +step:1663/1680 train_time:147085ms step_avg:88.45ms +step:1664/1680 train_time:147174ms step_avg:88.45ms +step:1665/1680 train_time:147265ms step_avg:88.45ms +step:1666/1680 train_time:147354ms step_avg:88.45ms +step:1667/1680 train_time:147443ms step_avg:88.45ms +step:1668/1680 train_time:147532ms step_avg:88.45ms +step:1669/1680 train_time:147622ms step_avg:88.45ms +step:1670/1680 train_time:147711ms step_avg:88.45ms +step:1671/1680 train_time:147802ms step_avg:88.45ms +step:1672/1680 train_time:147892ms step_avg:88.45ms +step:1673/1680 train_time:147982ms step_avg:88.45ms +step:1674/1680 train_time:148072ms step_avg:88.45ms +step:1675/1680 train_time:148162ms step_avg:88.45ms +step:1676/1680 train_time:148251ms step_avg:88.46ms +step:1677/1680 train_time:148340ms step_avg:88.46ms +step:1678/1680 train_time:148429ms step_avg:88.46ms +step:1679/1680 train_time:148518ms step_avg:88.46ms +step:1680/1680 train_time:148608ms step_avg:88.46ms +step:1680/1680 val_loss:3.2782 train_time:148698ms step_avg:88.51ms +peak memory allocated: 30760 MiB reserved: 45974 MiB diff --git a/records/092725_BF16CE/1711f2cf-76af-46e0-b2df-47bd0d6bec0c.txt b/records/092725_BF16CE/1711f2cf-76af-46e0-b2df-47bd0d6bec0c.txt new file mode 100644 index 000000000..dde98eac7 --- /dev/null +++ b/records/092725_BF16CE/1711f2cf-76af-46e0-b2df-47bd0d6bec0c.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:46:24 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 23C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 26C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 25C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 162824 C /usr/bin/python 0MiB | +| 0 N/A N/A 162825 C /usr/bin/python 0MiB | +| 0 N/A N/A 162826 C /usr/bin/python 0MiB | +| 0 N/A N/A 162827 C /usr/bin/python 0MiB | +| 0 N/A N/A 162828 C /usr/bin/python 0MiB | +| 0 N/A N/A 162829 C /usr/bin/python 0MiB | +| 0 N/A N/A 162830 C /usr/bin/python 0MiB | +| 0 N/A N/A 162831 C /usr/bin/python 0MiB | +| 1 N/A N/A 162825 C /usr/bin/python 0MiB | +| 2 N/A N/A 162826 C /usr/bin/python 0MiB | +| 3 N/A N/A 162827 C /usr/bin/python 0MiB | +| 4 N/A N/A 162828 C /usr/bin/python 0MiB | +| 5 N/A N/A 162829 C /usr/bin/python 0MiB | +| 6 N/A N/A 162830 C /usr/bin/python 0MiB | +| 7 N/A N/A 162831 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:144ms step_avg:144.30ms +step:2/1680 train_time:164ms step_avg:82.21ms +step:3/1680 train_time:228ms step_avg:76.15ms +step:4/1680 train_time:313ms step_avg:78.37ms +step:5/1680 train_time:399ms step_avg:79.85ms +step:6/1680 train_time:485ms step_avg:80.87ms +step:7/1680 train_time:572ms step_avg:81.76ms +step:8/1680 train_time:658ms step_avg:82.27ms +step:9/1680 train_time:744ms step_avg:82.70ms +step:10/1680 train_time:830ms step_avg:83.03ms +step:11/1680 train_time:916ms step_avg:83.32ms +step:12/1680 train_time:1004ms step_avg:83.68ms +step:13/1680 train_time:1094ms step_avg:84.18ms +step:14/1680 train_time:1185ms step_avg:84.62ms +step:15/1680 train_time:1273ms step_avg:84.84ms +step:16/1680 train_time:1359ms step_avg:84.96ms +step:17/1680 train_time:1446ms step_avg:85.07ms +step:18/1680 train_time:1533ms step_avg:85.14ms +step:19/1680 train_time:1619ms step_avg:85.23ms +step:20/1680 train_time:1706ms step_avg:85.29ms +step:21/1680 train_time:1792ms step_avg:85.35ms +step:22/1680 train_time:1879ms step_avg:85.40ms +step:23/1680 train_time:1965ms step_avg:85.45ms +step:24/1680 train_time:2054ms step_avg:85.60ms +step:25/1680 train_time:2143ms step_avg:85.71ms +step:26/1680 train_time:2231ms step_avg:85.82ms +step:27/1680 train_time:2319ms step_avg:85.90ms +step:28/1680 train_time:2407ms step_avg:85.97ms +step:29/1680 train_time:2494ms step_avg:86.01ms +step:30/1680 train_time:2581ms step_avg:86.04ms +step:31/1680 train_time:2668ms step_avg:86.05ms +step:32/1680 train_time:2754ms step_avg:86.07ms +step:33/1680 train_time:2841ms step_avg:86.10ms +step:34/1680 train_time:2928ms step_avg:86.12ms +step:35/1680 train_time:3016ms step_avg:86.17ms +step:36/1680 train_time:3104ms step_avg:86.21ms +step:37/1680 train_time:3193ms step_avg:86.30ms +step:38/1680 train_time:3280ms step_avg:86.32ms +step:39/1680 train_time:3367ms step_avg:86.34ms +step:40/1680 train_time:3456ms step_avg:86.39ms +step:41/1680 train_time:3543ms step_avg:86.41ms +step:42/1680 train_time:3630ms step_avg:86.42ms +step:43/1680 train_time:3716ms step_avg:86.43ms +step:44/1680 train_time:3803ms step_avg:86.43ms +step:45/1680 train_time:3890ms step_avg:86.44ms +step:46/1680 train_time:3977ms step_avg:86.46ms +step:47/1680 train_time:4065ms step_avg:86.49ms +step:48/1680 train_time:4153ms step_avg:86.53ms +step:49/1680 train_time:4241ms step_avg:86.56ms +step:50/1680 train_time:4328ms step_avg:86.56ms +step:51/1680 train_time:4416ms step_avg:86.58ms +step:52/1680 train_time:4503ms step_avg:86.59ms +step:53/1680 train_time:4589ms step_avg:86.59ms +step:54/1680 train_time:4676ms step_avg:86.59ms +step:55/1680 train_time:4763ms step_avg:86.60ms +step:56/1680 train_time:4850ms step_avg:86.61ms +step:57/1680 train_time:4937ms step_avg:86.62ms +step:58/1680 train_time:5024ms step_avg:86.63ms +step:59/1680 train_time:5111ms step_avg:86.64ms +step:60/1680 train_time:5199ms step_avg:86.65ms +step:61/1680 train_time:5287ms step_avg:86.68ms +step:62/1680 train_time:5375ms step_avg:86.70ms +step:63/1680 train_time:5463ms step_avg:86.71ms +step:64/1680 train_time:5550ms step_avg:86.73ms +step:65/1680 train_time:5638ms step_avg:86.73ms +step:66/1680 train_time:5725ms step_avg:86.74ms +step:67/1680 train_time:5811ms step_avg:86.73ms +step:68/1680 train_time:5899ms step_avg:86.75ms +step:69/1680 train_time:5986ms step_avg:86.75ms +step:70/1680 train_time:6073ms step_avg:86.75ms +step:71/1680 train_time:6160ms step_avg:86.75ms +step:72/1680 train_time:6247ms step_avg:86.77ms +step:73/1680 train_time:6335ms step_avg:86.77ms +step:74/1680 train_time:6422ms step_avg:86.78ms +step:75/1680 train_time:6510ms step_avg:86.80ms +step:76/1680 train_time:6597ms step_avg:86.81ms +step:77/1680 train_time:6684ms step_avg:86.80ms +step:78/1680 train_time:6771ms step_avg:86.81ms +step:79/1680 train_time:6857ms step_avg:86.80ms +step:80/1680 train_time:6945ms step_avg:86.81ms +step:81/1680 train_time:7032ms step_avg:86.81ms +step:82/1680 train_time:7119ms step_avg:86.82ms +step:83/1680 train_time:7207ms step_avg:86.83ms +step:84/1680 train_time:7294ms step_avg:86.84ms +step:85/1680 train_time:7381ms step_avg:86.84ms +step:86/1680 train_time:7468ms step_avg:86.84ms +step:87/1680 train_time:7556ms step_avg:86.85ms +step:88/1680 train_time:7643ms step_avg:86.85ms +step:89/1680 train_time:7730ms step_avg:86.85ms +step:90/1680 train_time:7817ms step_avg:86.85ms +step:91/1680 train_time:7903ms step_avg:86.85ms +step:92/1680 train_time:7990ms step_avg:86.85ms +step:93/1680 train_time:8077ms step_avg:86.85ms +step:94/1680 train_time:8164ms step_avg:86.85ms +step:95/1680 train_time:8250ms step_avg:86.84ms +step:96/1680 train_time:8338ms step_avg:86.85ms +step:97/1680 train_time:8425ms step_avg:86.86ms +step:98/1680 train_time:8513ms step_avg:86.87ms +step:99/1680 train_time:8600ms step_avg:86.87ms +step:100/1680 train_time:8687ms step_avg:86.87ms +step:101/1680 train_time:8774ms step_avg:86.88ms +step:102/1680 train_time:8861ms step_avg:86.87ms +step:103/1680 train_time:8949ms step_avg:86.89ms +step:104/1680 train_time:9036ms step_avg:86.89ms +step:105/1680 train_time:9123ms step_avg:86.88ms +step:106/1680 train_time:9210ms step_avg:86.89ms +step:107/1680 train_time:9298ms step_avg:86.90ms +step:108/1680 train_time:9385ms step_avg:86.90ms +step:109/1680 train_time:9472ms step_avg:86.90ms +step:110/1680 train_time:9559ms step_avg:86.90ms +step:111/1680 train_time:9646ms step_avg:86.90ms +step:112/1680 train_time:9734ms step_avg:86.91ms +step:113/1680 train_time:9821ms step_avg:86.91ms +step:114/1680 train_time:9908ms step_avg:86.91ms +step:115/1680 train_time:9995ms step_avg:86.91ms +step:116/1680 train_time:10081ms step_avg:86.91ms +step:117/1680 train_time:10168ms step_avg:86.91ms +step:118/1680 train_time:10256ms step_avg:86.91ms +step:119/1680 train_time:10342ms step_avg:86.91ms +step:120/1680 train_time:10430ms step_avg:86.91ms +step:121/1680 train_time:10517ms step_avg:86.92ms +step:122/1680 train_time:10604ms step_avg:86.92ms +step:123/1680 train_time:10691ms step_avg:86.92ms +step:124/1680 train_time:10778ms step_avg:86.92ms +step:125/1680 train_time:10866ms step_avg:86.93ms +step:125/1680 val_loss:4.3230 train_time:10955ms step_avg:87.64ms +step:126/1680 train_time:10974ms step_avg:87.10ms +step:127/1680 train_time:11044ms step_avg:86.96ms +step:128/1680 train_time:11140ms step_avg:87.03ms +step:129/1680 train_time:11232ms step_avg:87.07ms +step:130/1680 train_time:11319ms step_avg:87.07ms +step:131/1680 train_time:11405ms step_avg:87.06ms +step:132/1680 train_time:11492ms step_avg:87.06ms +step:133/1680 train_time:11578ms step_avg:87.05ms +step:134/1680 train_time:11664ms step_avg:87.04ms +step:135/1680 train_time:11750ms step_avg:87.03ms +step:136/1680 train_time:11835ms step_avg:87.02ms +step:137/1680 train_time:11921ms step_avg:87.01ms +step:138/1680 train_time:12007ms step_avg:87.01ms +step:139/1680 train_time:12096ms step_avg:87.02ms +step:140/1680 train_time:12186ms step_avg:87.04ms +step:141/1680 train_time:12274ms step_avg:87.05ms +step:142/1680 train_time:12362ms step_avg:87.05ms +step:143/1680 train_time:12448ms step_avg:87.05ms +step:144/1680 train_time:12535ms step_avg:87.05ms +step:145/1680 train_time:12622ms step_avg:87.05ms +step:146/1680 train_time:12707ms step_avg:87.04ms +step:147/1680 train_time:12794ms step_avg:87.03ms +step:148/1680 train_time:12881ms step_avg:87.03ms +step:149/1680 train_time:12967ms step_avg:87.02ms +step:150/1680 train_time:13055ms step_avg:87.03ms +step:151/1680 train_time:13143ms step_avg:87.04ms +step:152/1680 train_time:13231ms step_avg:87.05ms +step:153/1680 train_time:13318ms step_avg:87.05ms +step:154/1680 train_time:13406ms step_avg:87.05ms +step:155/1680 train_time:13493ms step_avg:87.05ms +step:156/1680 train_time:13580ms step_avg:87.05ms +step:157/1680 train_time:13666ms step_avg:87.05ms +step:158/1680 train_time:13754ms step_avg:87.05ms +step:159/1680 train_time:13840ms step_avg:87.04ms +step:160/1680 train_time:13926ms step_avg:87.04ms +step:161/1680 train_time:14013ms step_avg:87.04ms +step:162/1680 train_time:14101ms step_avg:87.04ms +step:163/1680 train_time:14189ms step_avg:87.05ms +step:164/1680 train_time:14277ms step_avg:87.05ms +step:165/1680 train_time:14364ms step_avg:87.05ms +step:166/1680 train_time:14452ms step_avg:87.06ms +step:167/1680 train_time:14538ms step_avg:87.06ms +step:168/1680 train_time:14625ms step_avg:87.05ms +step:169/1680 train_time:14712ms step_avg:87.05ms +step:170/1680 train_time:14799ms step_avg:87.05ms +step:171/1680 train_time:14886ms step_avg:87.05ms +step:172/1680 train_time:14973ms step_avg:87.05ms +step:173/1680 train_time:15060ms step_avg:87.05ms +step:174/1680 train_time:15148ms step_avg:87.06ms +step:175/1680 train_time:15236ms step_avg:87.06ms +step:176/1680 train_time:15324ms step_avg:87.07ms +step:177/1680 train_time:15412ms step_avg:87.07ms +step:178/1680 train_time:15499ms step_avg:87.07ms +step:179/1680 train_time:15586ms step_avg:87.07ms +step:180/1680 train_time:15673ms step_avg:87.07ms +step:181/1680 train_time:15759ms step_avg:87.07ms +step:182/1680 train_time:15846ms step_avg:87.07ms +step:183/1680 train_time:15933ms step_avg:87.07ms +step:184/1680 train_time:16019ms step_avg:87.06ms +step:185/1680 train_time:16107ms step_avg:87.07ms +step:186/1680 train_time:16195ms step_avg:87.07ms +step:187/1680 train_time:16282ms step_avg:87.07ms +step:188/1680 train_time:16369ms step_avg:87.07ms +step:189/1680 train_time:16456ms step_avg:87.07ms +step:190/1680 train_time:16543ms step_avg:87.07ms +step:191/1680 train_time:16631ms step_avg:87.07ms +step:192/1680 train_time:16717ms step_avg:87.07ms +step:193/1680 train_time:16804ms step_avg:87.07ms +step:194/1680 train_time:16890ms step_avg:87.06ms +step:195/1680 train_time:16977ms step_avg:87.06ms +step:196/1680 train_time:17064ms step_avg:87.06ms +step:197/1680 train_time:17152ms step_avg:87.07ms +step:198/1680 train_time:17239ms step_avg:87.07ms +step:199/1680 train_time:17326ms step_avg:87.07ms +step:200/1680 train_time:17414ms step_avg:87.07ms +step:201/1680 train_time:17500ms step_avg:87.06ms +step:202/1680 train_time:17587ms step_avg:87.07ms +step:203/1680 train_time:17675ms step_avg:87.07ms +step:204/1680 train_time:17761ms step_avg:87.07ms +step:205/1680 train_time:17848ms step_avg:87.06ms +step:206/1680 train_time:17935ms step_avg:87.06ms +step:207/1680 train_time:18022ms step_avg:87.06ms +step:208/1680 train_time:18110ms step_avg:87.07ms +step:209/1680 train_time:18197ms step_avg:87.07ms +step:210/1680 train_time:18284ms step_avg:87.07ms +step:211/1680 train_time:18372ms step_avg:87.07ms +step:212/1680 train_time:18458ms step_avg:87.07ms +step:213/1680 train_time:18546ms step_avg:87.07ms +step:214/1680 train_time:18633ms step_avg:87.07ms +step:215/1680 train_time:18719ms step_avg:87.07ms +step:216/1680 train_time:18807ms step_avg:87.07ms +step:217/1680 train_time:18894ms step_avg:87.07ms +step:218/1680 train_time:18980ms step_avg:87.07ms +step:219/1680 train_time:19067ms step_avg:87.06ms +step:220/1680 train_time:19154ms step_avg:87.07ms +step:221/1680 train_time:19242ms step_avg:87.07ms +step:222/1680 train_time:19329ms step_avg:87.07ms +step:223/1680 train_time:19415ms step_avg:87.06ms +step:224/1680 train_time:19502ms step_avg:87.06ms +step:225/1680 train_time:19589ms step_avg:87.06ms +step:226/1680 train_time:19676ms step_avg:87.06ms +step:227/1680 train_time:19763ms step_avg:87.06ms +step:228/1680 train_time:19850ms step_avg:87.06ms +step:229/1680 train_time:19937ms step_avg:87.06ms +step:230/1680 train_time:20024ms step_avg:87.06ms +step:231/1680 train_time:20112ms step_avg:87.06ms +step:232/1680 train_time:20199ms step_avg:87.06ms +step:233/1680 train_time:20285ms step_avg:87.06ms +step:234/1680 train_time:20373ms step_avg:87.06ms +step:235/1680 train_time:20460ms step_avg:87.06ms +step:236/1680 train_time:20546ms step_avg:87.06ms +step:237/1680 train_time:20633ms step_avg:87.06ms +step:238/1680 train_time:20721ms step_avg:87.06ms +step:239/1680 train_time:20808ms step_avg:87.06ms +step:240/1680 train_time:20895ms step_avg:87.06ms +step:241/1680 train_time:20982ms step_avg:87.06ms +step:242/1680 train_time:21069ms step_avg:87.06ms +step:243/1680 train_time:21156ms step_avg:87.06ms +step:244/1680 train_time:21243ms step_avg:87.06ms +step:245/1680 train_time:21331ms step_avg:87.06ms +step:246/1680 train_time:21417ms step_avg:87.06ms +step:247/1680 train_time:21504ms step_avg:87.06ms +step:248/1680 train_time:21592ms step_avg:87.06ms +step:249/1680 train_time:21679ms step_avg:87.06ms +step:250/1680 train_time:21767ms step_avg:87.07ms +step:250/1680 val_loss:3.9722 train_time:21856ms step_avg:87.42ms +step:251/1680 train_time:21877ms step_avg:87.16ms +step:252/1680 train_time:21945ms step_avg:87.08ms +step:253/1680 train_time:22038ms step_avg:87.11ms +step:254/1680 train_time:22126ms step_avg:87.11ms +step:255/1680 train_time:22213ms step_avg:87.11ms +step:256/1680 train_time:22300ms step_avg:87.11ms +step:257/1680 train_time:22386ms step_avg:87.10ms +step:258/1680 train_time:22472ms step_avg:87.10ms +step:259/1680 train_time:22558ms step_avg:87.10ms +step:260/1680 train_time:22644ms step_avg:87.09ms +step:261/1680 train_time:22730ms step_avg:87.09ms +step:262/1680 train_time:22818ms step_avg:87.09ms +step:263/1680 train_time:22906ms step_avg:87.10ms +step:264/1680 train_time:22996ms step_avg:87.11ms +step:265/1680 train_time:23085ms step_avg:87.11ms +step:266/1680 train_time:23173ms step_avg:87.12ms +step:267/1680 train_time:23260ms step_avg:87.12ms +step:268/1680 train_time:23346ms step_avg:87.11ms +step:269/1680 train_time:23433ms step_avg:87.11ms +step:270/1680 train_time:23519ms step_avg:87.11ms +step:271/1680 train_time:23605ms step_avg:87.10ms +step:272/1680 train_time:23691ms step_avg:87.10ms +step:273/1680 train_time:23778ms step_avg:87.10ms +step:274/1680 train_time:23865ms step_avg:87.10ms +step:275/1680 train_time:23953ms step_avg:87.10ms +step:276/1680 train_time:24042ms step_avg:87.11ms +step:277/1680 train_time:24131ms step_avg:87.12ms +step:278/1680 train_time:24218ms step_avg:87.11ms +step:279/1680 train_time:24305ms step_avg:87.12ms +step:280/1680 train_time:24392ms step_avg:87.11ms +step:281/1680 train_time:24479ms step_avg:87.11ms +step:282/1680 train_time:24565ms step_avg:87.11ms +step:283/1680 train_time:24652ms step_avg:87.11ms +step:284/1680 train_time:24738ms step_avg:87.11ms +step:285/1680 train_time:24825ms step_avg:87.11ms +step:286/1680 train_time:24913ms step_avg:87.11ms +step:287/1680 train_time:25000ms step_avg:87.11ms +step:288/1680 train_time:25089ms step_avg:87.11ms +step:289/1680 train_time:25176ms step_avg:87.12ms +step:290/1680 train_time:25264ms step_avg:87.12ms +step:291/1680 train_time:25350ms step_avg:87.11ms +step:292/1680 train_time:25437ms step_avg:87.11ms +step:293/1680 train_time:25523ms step_avg:87.11ms +step:294/1680 train_time:25610ms step_avg:87.11ms +step:295/1680 train_time:25696ms step_avg:87.11ms +step:296/1680 train_time:25783ms step_avg:87.10ms +step:297/1680 train_time:25870ms step_avg:87.11ms +step:298/1680 train_time:25958ms step_avg:87.11ms +step:299/1680 train_time:26046ms step_avg:87.11ms +step:300/1680 train_time:26133ms step_avg:87.11ms +step:301/1680 train_time:26220ms step_avg:87.11ms +step:302/1680 train_time:26307ms step_avg:87.11ms +step:303/1680 train_time:26394ms step_avg:87.11ms +step:304/1680 train_time:26481ms step_avg:87.11ms +step:305/1680 train_time:26569ms step_avg:87.11ms +step:306/1680 train_time:26655ms step_avg:87.11ms +step:307/1680 train_time:26742ms step_avg:87.11ms +step:308/1680 train_time:26829ms step_avg:87.11ms +step:309/1680 train_time:26917ms step_avg:87.11ms +step:310/1680 train_time:27004ms step_avg:87.11ms +step:311/1680 train_time:27092ms step_avg:87.11ms +step:312/1680 train_time:27179ms step_avg:87.11ms +step:313/1680 train_time:27266ms step_avg:87.11ms +step:314/1680 train_time:27353ms step_avg:87.11ms +step:315/1680 train_time:27440ms step_avg:87.11ms +step:316/1680 train_time:27526ms step_avg:87.11ms +step:317/1680 train_time:27613ms step_avg:87.11ms +step:318/1680 train_time:27700ms step_avg:87.11ms +step:319/1680 train_time:27788ms step_avg:87.11ms +step:320/1680 train_time:27874ms step_avg:87.11ms +step:321/1680 train_time:27961ms step_avg:87.11ms +step:322/1680 train_time:28048ms step_avg:87.11ms +step:323/1680 train_time:28136ms step_avg:87.11ms +step:324/1680 train_time:28223ms step_avg:87.11ms +step:325/1680 train_time:28311ms step_avg:87.11ms +step:326/1680 train_time:28398ms step_avg:87.11ms +step:327/1680 train_time:28485ms step_avg:87.11ms +step:328/1680 train_time:28572ms step_avg:87.11ms +step:329/1680 train_time:28659ms step_avg:87.11ms +step:330/1680 train_time:28746ms step_avg:87.11ms +step:331/1680 train_time:28832ms step_avg:87.11ms +step:332/1680 train_time:28920ms step_avg:87.11ms +step:333/1680 train_time:29007ms step_avg:87.11ms +step:334/1680 train_time:29094ms step_avg:87.11ms +step:335/1680 train_time:29182ms step_avg:87.11ms +step:336/1680 train_time:29269ms step_avg:87.11ms +step:337/1680 train_time:29356ms step_avg:87.11ms +step:338/1680 train_time:29442ms step_avg:87.11ms +step:339/1680 train_time:29529ms step_avg:87.11ms +step:340/1680 train_time:29616ms step_avg:87.11ms +step:341/1680 train_time:29703ms step_avg:87.11ms +step:342/1680 train_time:29790ms step_avg:87.11ms +step:343/1680 train_time:29877ms step_avg:87.11ms +step:344/1680 train_time:29964ms step_avg:87.10ms +step:345/1680 train_time:30051ms step_avg:87.11ms +step:346/1680 train_time:30138ms step_avg:87.10ms +step:347/1680 train_time:30225ms step_avg:87.10ms +step:348/1680 train_time:30313ms step_avg:87.11ms +step:349/1680 train_time:30400ms step_avg:87.11ms +step:350/1680 train_time:30487ms step_avg:87.11ms +step:351/1680 train_time:30574ms step_avg:87.10ms +step:352/1680 train_time:30661ms step_avg:87.10ms +step:353/1680 train_time:30748ms step_avg:87.11ms +step:354/1680 train_time:30836ms step_avg:87.11ms +step:355/1680 train_time:30922ms step_avg:87.10ms +step:356/1680 train_time:31010ms step_avg:87.11ms +step:357/1680 train_time:31097ms step_avg:87.11ms +step:358/1680 train_time:31184ms step_avg:87.11ms +step:359/1680 train_time:31272ms step_avg:87.11ms +step:360/1680 train_time:31359ms step_avg:87.11ms +step:361/1680 train_time:31446ms step_avg:87.11ms +step:362/1680 train_time:31533ms step_avg:87.11ms +step:363/1680 train_time:31620ms step_avg:87.11ms +step:364/1680 train_time:31708ms step_avg:87.11ms +step:365/1680 train_time:31795ms step_avg:87.11ms +step:366/1680 train_time:31882ms step_avg:87.11ms +step:367/1680 train_time:31970ms step_avg:87.11ms +step:368/1680 train_time:32057ms step_avg:87.11ms +step:369/1680 train_time:32144ms step_avg:87.11ms +step:370/1680 train_time:32232ms step_avg:87.11ms +step:371/1680 train_time:32318ms step_avg:87.11ms +step:372/1680 train_time:32404ms step_avg:87.11ms +step:373/1680 train_time:32491ms step_avg:87.11ms +step:374/1680 train_time:32577ms step_avg:87.11ms +step:375/1680 train_time:32665ms step_avg:87.11ms +step:375/1680 val_loss:3.8220 train_time:32753ms step_avg:87.34ms +step:376/1680 train_time:32772ms step_avg:87.16ms +step:377/1680 train_time:32841ms step_avg:87.11ms +step:378/1680 train_time:32930ms step_avg:87.12ms +step:379/1680 train_time:33016ms step_avg:87.11ms +step:380/1680 train_time:33103ms step_avg:87.11ms +step:381/1680 train_time:33189ms step_avg:87.11ms +step:382/1680 train_time:33274ms step_avg:87.11ms +step:383/1680 train_time:33362ms step_avg:87.11ms +step:384/1680 train_time:33448ms step_avg:87.10ms +step:385/1680 train_time:33534ms step_avg:87.10ms +step:386/1680 train_time:33621ms step_avg:87.10ms +step:387/1680 train_time:33709ms step_avg:87.10ms +step:388/1680 train_time:33798ms step_avg:87.11ms +step:389/1680 train_time:33886ms step_avg:87.11ms +step:390/1680 train_time:33973ms step_avg:87.11ms +step:391/1680 train_time:34062ms step_avg:87.11ms +step:392/1680 train_time:34148ms step_avg:87.11ms +step:393/1680 train_time:34235ms step_avg:87.11ms +step:394/1680 train_time:34322ms step_avg:87.11ms +step:395/1680 train_time:34408ms step_avg:87.11ms +step:396/1680 train_time:34494ms step_avg:87.11ms +step:397/1680 train_time:34581ms step_avg:87.11ms +step:398/1680 train_time:34668ms step_avg:87.10ms +step:399/1680 train_time:34756ms step_avg:87.11ms +step:400/1680 train_time:34845ms step_avg:87.11ms +step:401/1680 train_time:34933ms step_avg:87.11ms +step:402/1680 train_time:35021ms step_avg:87.12ms +step:403/1680 train_time:35109ms step_avg:87.12ms +step:404/1680 train_time:35195ms step_avg:87.12ms +step:405/1680 train_time:35282ms step_avg:87.12ms +step:406/1680 train_time:35369ms step_avg:87.12ms +step:407/1680 train_time:35455ms step_avg:87.11ms +step:408/1680 train_time:35542ms step_avg:87.11ms +step:409/1680 train_time:35629ms step_avg:87.11ms +step:410/1680 train_time:35717ms step_avg:87.11ms +step:411/1680 train_time:35804ms step_avg:87.11ms +step:412/1680 train_time:35892ms step_avg:87.12ms +step:413/1680 train_time:35979ms step_avg:87.12ms +step:414/1680 train_time:36066ms step_avg:87.12ms +step:415/1680 train_time:36154ms step_avg:87.12ms +step:416/1680 train_time:36240ms step_avg:87.12ms +step:417/1680 train_time:36327ms step_avg:87.12ms +step:418/1680 train_time:36414ms step_avg:87.11ms +step:419/1680 train_time:36501ms step_avg:87.11ms +step:420/1680 train_time:36587ms step_avg:87.11ms +step:421/1680 train_time:36675ms step_avg:87.11ms +step:422/1680 train_time:36763ms step_avg:87.11ms +step:423/1680 train_time:36850ms step_avg:87.12ms +step:424/1680 train_time:36938ms step_avg:87.12ms +step:425/1680 train_time:37025ms step_avg:87.12ms +step:426/1680 train_time:37112ms step_avg:87.12ms +step:427/1680 train_time:37199ms step_avg:87.12ms +step:428/1680 train_time:37285ms step_avg:87.12ms +step:429/1680 train_time:37372ms step_avg:87.11ms +step:430/1680 train_time:37460ms step_avg:87.12ms +step:431/1680 train_time:37546ms step_avg:87.11ms +step:432/1680 train_time:37634ms step_avg:87.11ms +step:433/1680 train_time:37721ms step_avg:87.11ms +step:434/1680 train_time:37808ms step_avg:87.11ms +step:435/1680 train_time:37896ms step_avg:87.12ms +step:436/1680 train_time:37983ms step_avg:87.12ms +step:437/1680 train_time:38070ms step_avg:87.12ms +step:438/1680 train_time:38157ms step_avg:87.12ms +step:439/1680 train_time:38244ms step_avg:87.12ms +step:440/1680 train_time:38330ms step_avg:87.11ms +step:441/1680 train_time:38417ms step_avg:87.11ms +step:442/1680 train_time:38504ms step_avg:87.11ms +step:443/1680 train_time:38590ms step_avg:87.11ms +step:444/1680 train_time:38677ms step_avg:87.11ms +step:445/1680 train_time:38764ms step_avg:87.11ms +step:446/1680 train_time:38851ms step_avg:87.11ms +step:447/1680 train_time:38939ms step_avg:87.11ms +step:448/1680 train_time:39026ms step_avg:87.11ms +step:449/1680 train_time:39113ms step_avg:87.11ms +step:450/1680 train_time:39200ms step_avg:87.11ms +step:451/1680 train_time:39286ms step_avg:87.11ms +step:452/1680 train_time:39374ms step_avg:87.11ms +step:453/1680 train_time:39461ms step_avg:87.11ms +step:454/1680 train_time:39547ms step_avg:87.11ms +step:455/1680 train_time:39635ms step_avg:87.11ms +step:456/1680 train_time:39723ms step_avg:87.11ms +step:457/1680 train_time:39810ms step_avg:87.11ms +step:458/1680 train_time:39897ms step_avg:87.11ms +step:459/1680 train_time:39984ms step_avg:87.11ms +step:460/1680 train_time:40071ms step_avg:87.11ms +step:461/1680 train_time:40158ms step_avg:87.11ms +step:462/1680 train_time:40245ms step_avg:87.11ms +step:463/1680 train_time:40332ms step_avg:87.11ms +step:464/1680 train_time:40420ms step_avg:87.11ms +step:465/1680 train_time:40506ms step_avg:87.11ms +step:466/1680 train_time:40594ms step_avg:87.11ms +step:467/1680 train_time:40681ms step_avg:87.11ms +step:468/1680 train_time:40768ms step_avg:87.11ms +step:469/1680 train_time:40855ms step_avg:87.11ms +step:470/1680 train_time:40943ms step_avg:87.11ms +step:471/1680 train_time:41029ms step_avg:87.11ms +step:472/1680 train_time:41116ms step_avg:87.11ms +step:473/1680 train_time:41203ms step_avg:87.11ms +step:474/1680 train_time:41290ms step_avg:87.11ms +step:475/1680 train_time:41377ms step_avg:87.11ms +step:476/1680 train_time:41464ms step_avg:87.11ms +step:477/1680 train_time:41551ms step_avg:87.11ms +step:478/1680 train_time:41638ms step_avg:87.11ms +step:479/1680 train_time:41724ms step_avg:87.11ms +step:480/1680 train_time:41811ms step_avg:87.11ms +step:481/1680 train_time:41898ms step_avg:87.11ms +step:482/1680 train_time:41985ms step_avg:87.11ms +step:483/1680 train_time:42072ms step_avg:87.11ms +step:484/1680 train_time:42159ms step_avg:87.11ms +step:485/1680 train_time:42247ms step_avg:87.11ms +step:486/1680 train_time:42334ms step_avg:87.11ms +step:487/1680 train_time:42421ms step_avg:87.11ms +step:488/1680 train_time:42508ms step_avg:87.11ms +step:489/1680 train_time:42596ms step_avg:87.11ms +step:490/1680 train_time:42683ms step_avg:87.11ms +step:491/1680 train_time:42769ms step_avg:87.11ms +step:492/1680 train_time:42857ms step_avg:87.11ms +step:493/1680 train_time:42944ms step_avg:87.11ms +step:494/1680 train_time:43030ms step_avg:87.11ms +step:495/1680 train_time:43118ms step_avg:87.11ms +step:496/1680 train_time:43205ms step_avg:87.11ms +step:497/1680 train_time:43292ms step_avg:87.11ms +step:498/1680 train_time:43379ms step_avg:87.11ms +step:499/1680 train_time:43466ms step_avg:87.11ms +step:500/1680 train_time:43554ms step_avg:87.11ms +step:500/1680 val_loss:3.7208 train_time:43643ms step_avg:87.29ms +step:501/1680 train_time:43662ms step_avg:87.15ms +step:502/1680 train_time:43732ms step_avg:87.12ms +step:503/1680 train_time:43825ms step_avg:87.13ms +step:504/1680 train_time:43915ms step_avg:87.13ms +step:505/1680 train_time:44001ms step_avg:87.13ms +step:506/1680 train_time:44088ms step_avg:87.13ms +step:507/1680 train_time:44175ms step_avg:87.13ms +step:508/1680 train_time:44261ms step_avg:87.13ms +step:509/1680 train_time:44347ms step_avg:87.13ms +step:510/1680 train_time:44433ms step_avg:87.12ms +step:511/1680 train_time:44519ms step_avg:87.12ms +step:512/1680 train_time:44606ms step_avg:87.12ms +step:513/1680 train_time:44694ms step_avg:87.12ms +step:514/1680 train_time:44783ms step_avg:87.13ms +step:515/1680 train_time:44871ms step_avg:87.13ms +step:516/1680 train_time:44959ms step_avg:87.13ms +step:517/1680 train_time:45046ms step_avg:87.13ms +step:518/1680 train_time:45132ms step_avg:87.13ms +step:519/1680 train_time:45219ms step_avg:87.13ms +step:520/1680 train_time:45306ms step_avg:87.13ms +step:521/1680 train_time:45393ms step_avg:87.13ms +step:522/1680 train_time:45480ms step_avg:87.13ms +step:523/1680 train_time:45566ms step_avg:87.12ms +step:524/1680 train_time:45654ms step_avg:87.13ms +step:525/1680 train_time:45742ms step_avg:87.13ms +step:526/1680 train_time:45830ms step_avg:87.13ms +step:527/1680 train_time:45918ms step_avg:87.13ms +step:528/1680 train_time:46005ms step_avg:87.13ms +step:529/1680 train_time:46094ms step_avg:87.13ms +step:530/1680 train_time:46180ms step_avg:87.13ms +step:531/1680 train_time:46267ms step_avg:87.13ms +step:532/1680 train_time:46354ms step_avg:87.13ms +step:533/1680 train_time:46440ms step_avg:87.13ms +step:534/1680 train_time:46527ms step_avg:87.13ms +step:535/1680 train_time:46613ms step_avg:87.13ms +step:536/1680 train_time:46701ms step_avg:87.13ms +step:537/1680 train_time:46788ms step_avg:87.13ms +step:538/1680 train_time:46876ms step_avg:87.13ms +step:539/1680 train_time:46963ms step_avg:87.13ms +step:540/1680 train_time:47051ms step_avg:87.13ms +step:541/1680 train_time:47139ms step_avg:87.13ms +step:542/1680 train_time:47226ms step_avg:87.13ms +step:543/1680 train_time:47313ms step_avg:87.13ms +step:544/1680 train_time:47399ms step_avg:87.13ms +step:545/1680 train_time:47486ms step_avg:87.13ms +step:546/1680 train_time:47572ms step_avg:87.13ms +step:547/1680 train_time:47659ms step_avg:87.13ms +step:548/1680 train_time:47746ms step_avg:87.13ms +step:549/1680 train_time:47835ms step_avg:87.13ms +step:550/1680 train_time:47923ms step_avg:87.13ms +step:551/1680 train_time:48012ms step_avg:87.14ms +step:552/1680 train_time:48101ms step_avg:87.14ms +step:553/1680 train_time:48190ms step_avg:87.14ms +step:554/1680 train_time:48278ms step_avg:87.14ms +step:555/1680 train_time:48366ms step_avg:87.15ms +step:556/1680 train_time:48454ms step_avg:87.15ms +step:557/1680 train_time:48543ms step_avg:87.15ms +step:558/1680 train_time:48630ms step_avg:87.15ms +step:559/1680 train_time:48718ms step_avg:87.15ms +step:560/1680 train_time:48806ms step_avg:87.15ms +step:561/1680 train_time:48895ms step_avg:87.16ms +step:562/1680 train_time:48984ms step_avg:87.16ms +step:563/1680 train_time:49073ms step_avg:87.16ms +step:564/1680 train_time:49162ms step_avg:87.17ms +step:565/1680 train_time:49250ms step_avg:87.17ms +step:566/1680 train_time:49338ms step_avg:87.17ms +step:567/1680 train_time:49426ms step_avg:87.17ms +step:568/1680 train_time:49514ms step_avg:87.17ms +step:569/1680 train_time:49602ms step_avg:87.17ms +step:570/1680 train_time:49691ms step_avg:87.18ms +step:571/1680 train_time:49779ms step_avg:87.18ms +step:572/1680 train_time:49866ms step_avg:87.18ms +step:573/1680 train_time:49955ms step_avg:87.18ms +step:574/1680 train_time:50043ms step_avg:87.18ms +step:575/1680 train_time:50132ms step_avg:87.19ms +step:576/1680 train_time:50220ms step_avg:87.19ms +step:577/1680 train_time:50309ms step_avg:87.19ms +step:578/1680 train_time:50397ms step_avg:87.19ms +step:579/1680 train_time:50486ms step_avg:87.20ms +step:580/1680 train_time:50575ms step_avg:87.20ms +step:581/1680 train_time:50662ms step_avg:87.20ms +step:582/1680 train_time:50750ms step_avg:87.20ms +step:583/1680 train_time:50838ms step_avg:87.20ms +step:584/1680 train_time:50926ms step_avg:87.20ms +step:585/1680 train_time:51015ms step_avg:87.21ms +step:586/1680 train_time:51104ms step_avg:87.21ms +step:587/1680 train_time:51193ms step_avg:87.21ms +step:588/1680 train_time:51282ms step_avg:87.21ms +step:589/1680 train_time:51370ms step_avg:87.22ms +step:590/1680 train_time:51459ms step_avg:87.22ms +step:591/1680 train_time:51547ms step_avg:87.22ms +step:592/1680 train_time:51635ms step_avg:87.22ms +step:593/1680 train_time:51723ms step_avg:87.22ms +step:594/1680 train_time:51810ms step_avg:87.22ms +step:595/1680 train_time:51898ms step_avg:87.22ms +step:596/1680 train_time:51986ms step_avg:87.23ms +step:597/1680 train_time:52075ms step_avg:87.23ms +step:598/1680 train_time:52163ms step_avg:87.23ms +step:599/1680 train_time:52252ms step_avg:87.23ms +step:600/1680 train_time:52341ms step_avg:87.23ms +step:601/1680 train_time:52429ms step_avg:87.24ms +step:602/1680 train_time:52517ms step_avg:87.24ms +step:603/1680 train_time:52606ms step_avg:87.24ms +step:604/1680 train_time:52694ms step_avg:87.24ms +step:605/1680 train_time:52782ms step_avg:87.24ms +step:606/1680 train_time:52871ms step_avg:87.25ms +step:607/1680 train_time:52959ms step_avg:87.25ms +step:608/1680 train_time:53047ms step_avg:87.25ms +step:609/1680 train_time:53135ms step_avg:87.25ms +step:610/1680 train_time:53223ms step_avg:87.25ms +step:611/1680 train_time:53312ms step_avg:87.25ms +step:612/1680 train_time:53400ms step_avg:87.26ms +step:613/1680 train_time:53489ms step_avg:87.26ms +step:614/1680 train_time:53577ms step_avg:87.26ms +step:615/1680 train_time:53666ms step_avg:87.26ms +step:616/1680 train_time:53754ms step_avg:87.26ms +step:617/1680 train_time:53843ms step_avg:87.26ms +step:618/1680 train_time:53931ms step_avg:87.27ms +step:619/1680 train_time:54020ms step_avg:87.27ms +step:620/1680 train_time:54108ms step_avg:87.27ms +step:621/1680 train_time:54196ms step_avg:87.27ms +step:622/1680 train_time:54284ms step_avg:87.27ms +step:623/1680 train_time:54372ms step_avg:87.27ms +step:624/1680 train_time:54461ms step_avg:87.28ms +step:625/1680 train_time:54549ms step_avg:87.28ms +step:625/1680 val_loss:3.6197 train_time:54639ms step_avg:87.42ms +step:626/1680 train_time:54665ms step_avg:87.32ms +step:627/1680 train_time:54730ms step_avg:87.29ms +step:628/1680 train_time:54817ms step_avg:87.29ms +step:629/1680 train_time:54908ms step_avg:87.29ms +step:630/1680 train_time:54994ms step_avg:87.29ms +step:631/1680 train_time:55081ms step_avg:87.29ms +step:632/1680 train_time:55168ms step_avg:87.29ms +step:633/1680 train_time:55255ms step_avg:87.29ms +step:634/1680 train_time:55342ms step_avg:87.29ms +step:635/1680 train_time:55429ms step_avg:87.29ms +step:636/1680 train_time:55517ms step_avg:87.29ms +step:637/1680 train_time:55609ms step_avg:87.30ms +step:638/1680 train_time:55701ms step_avg:87.31ms +step:639/1680 train_time:55791ms step_avg:87.31ms +step:640/1680 train_time:55879ms step_avg:87.31ms +step:641/1680 train_time:55967ms step_avg:87.31ms +step:642/1680 train_time:56055ms step_avg:87.31ms +step:643/1680 train_time:56142ms step_avg:87.31ms +step:644/1680 train_time:56229ms step_avg:87.31ms +step:645/1680 train_time:56317ms step_avg:87.31ms +step:646/1680 train_time:56404ms step_avg:87.31ms +step:647/1680 train_time:56492ms step_avg:87.31ms +step:648/1680 train_time:56582ms step_avg:87.32ms +step:649/1680 train_time:56672ms step_avg:87.32ms +step:650/1680 train_time:56760ms step_avg:87.32ms +step:651/1680 train_time:56849ms step_avg:87.33ms +step:652/1680 train_time:56937ms step_avg:87.33ms +step:653/1680 train_time:57026ms step_avg:87.33ms +step:654/1680 train_time:57113ms step_avg:87.33ms +step:655/1680 train_time:57201ms step_avg:87.33ms +step:656/1680 train_time:57288ms step_avg:87.33ms +step:657/1680 train_time:57376ms step_avg:87.33ms +step:658/1680 train_time:57464ms step_avg:87.33ms +step:659/1680 train_time:57552ms step_avg:87.33ms +step:660/1680 train_time:57641ms step_avg:87.34ms +step:661/1680 train_time:57730ms step_avg:87.34ms +step:662/1680 train_time:57819ms step_avg:87.34ms +step:663/1680 train_time:57908ms step_avg:87.34ms +step:664/1680 train_time:57996ms step_avg:87.34ms +step:665/1680 train_time:58083ms step_avg:87.34ms +step:666/1680 train_time:58171ms step_avg:87.34ms +step:667/1680 train_time:58259ms step_avg:87.35ms +step:668/1680 train_time:58347ms step_avg:87.35ms +step:669/1680 train_time:58434ms step_avg:87.35ms +step:670/1680 train_time:58523ms step_avg:87.35ms +step:671/1680 train_time:58612ms step_avg:87.35ms +step:672/1680 train_time:58701ms step_avg:87.35ms +step:673/1680 train_time:58789ms step_avg:87.35ms +step:674/1680 train_time:58878ms step_avg:87.36ms +step:675/1680 train_time:58966ms step_avg:87.36ms +step:676/1680 train_time:59054ms step_avg:87.36ms +step:677/1680 train_time:59143ms step_avg:87.36ms +step:678/1680 train_time:59231ms step_avg:87.36ms +step:679/1680 train_time:59318ms step_avg:87.36ms +step:680/1680 train_time:59407ms step_avg:87.36ms +step:681/1680 train_time:59495ms step_avg:87.36ms +step:682/1680 train_time:59583ms step_avg:87.36ms +step:683/1680 train_time:59671ms step_avg:87.37ms +step:684/1680 train_time:59759ms step_avg:87.37ms +step:685/1680 train_time:59847ms step_avg:87.37ms +step:686/1680 train_time:59936ms step_avg:87.37ms +step:687/1680 train_time:60023ms step_avg:87.37ms +step:688/1680 train_time:60112ms step_avg:87.37ms +step:689/1680 train_time:60200ms step_avg:87.37ms +step:690/1680 train_time:60288ms step_avg:87.37ms +step:691/1680 train_time:60375ms step_avg:87.37ms +step:692/1680 train_time:60463ms step_avg:87.37ms +step:693/1680 train_time:60551ms step_avg:87.38ms +step:694/1680 train_time:60639ms step_avg:87.38ms +step:695/1680 train_time:60728ms step_avg:87.38ms +step:696/1680 train_time:60817ms step_avg:87.38ms +step:697/1680 train_time:60905ms step_avg:87.38ms +step:698/1680 train_time:60993ms step_avg:87.38ms +step:699/1680 train_time:61082ms step_avg:87.38ms +step:700/1680 train_time:61171ms step_avg:87.39ms +step:701/1680 train_time:61259ms step_avg:87.39ms +step:702/1680 train_time:61347ms step_avg:87.39ms +step:703/1680 train_time:61435ms step_avg:87.39ms +step:704/1680 train_time:61523ms step_avg:87.39ms +step:705/1680 train_time:61611ms step_avg:87.39ms +step:706/1680 train_time:61700ms step_avg:87.39ms +step:707/1680 train_time:61788ms step_avg:87.39ms +step:708/1680 train_time:61876ms step_avg:87.40ms +step:709/1680 train_time:61965ms step_avg:87.40ms +step:710/1680 train_time:62054ms step_avg:87.40ms +step:711/1680 train_time:62142ms step_avg:87.40ms +step:712/1680 train_time:62230ms step_avg:87.40ms +step:713/1680 train_time:62319ms step_avg:87.40ms +step:714/1680 train_time:62406ms step_avg:87.40ms +step:715/1680 train_time:62495ms step_avg:87.41ms +step:716/1680 train_time:62583ms step_avg:87.41ms +step:717/1680 train_time:62672ms step_avg:87.41ms +step:718/1680 train_time:62760ms step_avg:87.41ms +step:719/1680 train_time:62849ms step_avg:87.41ms +step:720/1680 train_time:62937ms step_avg:87.41ms +step:721/1680 train_time:63025ms step_avg:87.41ms +step:722/1680 train_time:63113ms step_avg:87.41ms +step:723/1680 train_time:63202ms step_avg:87.42ms +step:724/1680 train_time:63290ms step_avg:87.42ms +step:725/1680 train_time:63379ms step_avg:87.42ms +step:726/1680 train_time:63466ms step_avg:87.42ms +step:727/1680 train_time:63555ms step_avg:87.42ms +step:728/1680 train_time:63643ms step_avg:87.42ms +step:729/1680 train_time:63730ms step_avg:87.42ms +step:730/1680 train_time:63819ms step_avg:87.42ms +step:731/1680 train_time:63907ms step_avg:87.42ms +step:732/1680 train_time:63995ms step_avg:87.43ms +step:733/1680 train_time:64084ms step_avg:87.43ms +step:734/1680 train_time:64173ms step_avg:87.43ms +step:735/1680 train_time:64261ms step_avg:87.43ms +step:736/1680 train_time:64349ms step_avg:87.43ms +step:737/1680 train_time:64437ms step_avg:87.43ms +step:738/1680 train_time:64525ms step_avg:87.43ms +step:739/1680 train_time:64613ms step_avg:87.43ms +step:740/1680 train_time:64702ms step_avg:87.43ms +step:741/1680 train_time:64790ms step_avg:87.44ms +step:742/1680 train_time:64879ms step_avg:87.44ms +step:743/1680 train_time:64967ms step_avg:87.44ms +step:744/1680 train_time:65055ms step_avg:87.44ms +step:745/1680 train_time:65144ms step_avg:87.44ms +step:746/1680 train_time:65232ms step_avg:87.44ms +step:747/1680 train_time:65321ms step_avg:87.44ms +step:748/1680 train_time:65409ms step_avg:87.45ms +step:749/1680 train_time:65497ms step_avg:87.45ms +step:750/1680 train_time:65585ms step_avg:87.45ms +step:750/1680 val_loss:3.5671 train_time:65675ms step_avg:87.57ms +step:751/1680 train_time:65697ms step_avg:87.48ms +step:752/1680 train_time:65766ms step_avg:87.45ms +step:753/1680 train_time:65857ms step_avg:87.46ms +step:754/1680 train_time:65946ms step_avg:87.46ms +step:755/1680 train_time:66035ms step_avg:87.46ms +step:756/1680 train_time:66123ms step_avg:87.46ms +step:757/1680 train_time:66209ms step_avg:87.46ms +step:758/1680 train_time:66296ms step_avg:87.46ms +step:759/1680 train_time:66383ms step_avg:87.46ms +step:760/1680 train_time:66472ms step_avg:87.46ms +step:761/1680 train_time:66559ms step_avg:87.46ms +step:762/1680 train_time:66648ms step_avg:87.46ms +step:763/1680 train_time:66737ms step_avg:87.47ms +step:764/1680 train_time:66827ms step_avg:87.47ms +step:765/1680 train_time:66917ms step_avg:87.47ms +step:766/1680 train_time:67006ms step_avg:87.47ms +step:767/1680 train_time:67094ms step_avg:87.48ms +step:768/1680 train_time:67181ms step_avg:87.48ms +step:769/1680 train_time:67268ms step_avg:87.48ms +step:770/1680 train_time:67356ms step_avg:87.48ms +step:771/1680 train_time:67444ms step_avg:87.48ms +step:772/1680 train_time:67532ms step_avg:87.48ms +step:773/1680 train_time:67621ms step_avg:87.48ms +step:774/1680 train_time:67709ms step_avg:87.48ms +step:775/1680 train_time:67798ms step_avg:87.48ms +step:776/1680 train_time:67887ms step_avg:87.48ms +step:777/1680 train_time:67976ms step_avg:87.48ms +step:778/1680 train_time:68064ms step_avg:87.49ms +step:779/1680 train_time:68153ms step_avg:87.49ms +step:780/1680 train_time:68240ms step_avg:87.49ms +step:781/1680 train_time:68328ms step_avg:87.49ms +step:782/1680 train_time:68416ms step_avg:87.49ms +step:783/1680 train_time:68503ms step_avg:87.49ms +step:784/1680 train_time:68591ms step_avg:87.49ms +step:785/1680 train_time:68680ms step_avg:87.49ms +step:786/1680 train_time:68769ms step_avg:87.49ms +step:787/1680 train_time:68858ms step_avg:87.49ms +step:788/1680 train_time:68946ms step_avg:87.50ms +step:789/1680 train_time:69036ms step_avg:87.50ms +step:790/1680 train_time:69124ms step_avg:87.50ms +step:791/1680 train_time:69212ms step_avg:87.50ms +step:792/1680 train_time:69299ms step_avg:87.50ms +step:793/1680 train_time:69388ms step_avg:87.50ms +step:794/1680 train_time:69475ms step_avg:87.50ms +step:795/1680 train_time:69564ms step_avg:87.50ms +step:796/1680 train_time:69652ms step_avg:87.50ms +step:797/1680 train_time:69740ms step_avg:87.50ms +step:798/1680 train_time:69829ms step_avg:87.50ms +step:799/1680 train_time:69917ms step_avg:87.51ms +step:800/1680 train_time:70006ms step_avg:87.51ms +step:801/1680 train_time:70094ms step_avg:87.51ms +step:802/1680 train_time:70183ms step_avg:87.51ms +step:803/1680 train_time:70271ms step_avg:87.51ms +step:804/1680 train_time:70359ms step_avg:87.51ms +step:805/1680 train_time:70447ms step_avg:87.51ms +step:806/1680 train_time:70535ms step_avg:87.51ms +step:807/1680 train_time:70624ms step_avg:87.51ms +step:808/1680 train_time:70712ms step_avg:87.52ms +step:809/1680 train_time:70800ms step_avg:87.52ms +step:810/1680 train_time:70889ms step_avg:87.52ms +step:811/1680 train_time:70978ms step_avg:87.52ms +step:812/1680 train_time:71066ms step_avg:87.52ms +step:813/1680 train_time:71154ms step_avg:87.52ms +step:814/1680 train_time:71243ms step_avg:87.52ms +step:815/1680 train_time:71331ms step_avg:87.52ms +step:816/1680 train_time:71419ms step_avg:87.52ms +step:817/1680 train_time:71508ms step_avg:87.52ms +step:818/1680 train_time:71595ms step_avg:87.52ms +step:819/1680 train_time:71683ms step_avg:87.53ms +step:820/1680 train_time:71771ms step_avg:87.53ms +step:821/1680 train_time:71860ms step_avg:87.53ms +step:822/1680 train_time:71948ms step_avg:87.53ms +step:823/1680 train_time:72036ms step_avg:87.53ms +step:824/1680 train_time:72124ms step_avg:87.53ms +step:825/1680 train_time:72213ms step_avg:87.53ms +step:826/1680 train_time:72301ms step_avg:87.53ms +step:827/1680 train_time:72389ms step_avg:87.53ms +step:828/1680 train_time:72477ms step_avg:87.53ms +step:829/1680 train_time:72565ms step_avg:87.53ms +step:830/1680 train_time:72654ms step_avg:87.53ms +step:831/1680 train_time:72742ms step_avg:87.54ms +step:832/1680 train_time:72830ms step_avg:87.54ms +step:833/1680 train_time:72918ms step_avg:87.54ms +step:834/1680 train_time:73006ms step_avg:87.54ms +step:835/1680 train_time:73094ms step_avg:87.54ms +step:836/1680 train_time:73183ms step_avg:87.54ms +step:837/1680 train_time:73271ms step_avg:87.54ms +step:838/1680 train_time:73359ms step_avg:87.54ms +step:839/1680 train_time:73447ms step_avg:87.54ms +step:840/1680 train_time:73536ms step_avg:87.54ms +step:841/1680 train_time:73625ms step_avg:87.54ms +step:842/1680 train_time:73713ms step_avg:87.54ms +step:843/1680 train_time:73801ms step_avg:87.55ms +step:844/1680 train_time:73889ms step_avg:87.55ms +step:845/1680 train_time:73977ms step_avg:87.55ms +step:846/1680 train_time:74066ms step_avg:87.55ms +step:847/1680 train_time:74155ms step_avg:87.55ms +step:848/1680 train_time:74243ms step_avg:87.55ms +step:849/1680 train_time:74332ms step_avg:87.55ms +step:850/1680 train_time:74419ms step_avg:87.55ms +step:851/1680 train_time:74507ms step_avg:87.55ms +step:852/1680 train_time:74595ms step_avg:87.55ms +step:853/1680 train_time:74683ms step_avg:87.55ms +step:854/1680 train_time:74772ms step_avg:87.55ms +step:855/1680 train_time:74860ms step_avg:87.56ms +step:856/1680 train_time:74948ms step_avg:87.56ms +step:857/1680 train_time:75037ms step_avg:87.56ms +step:858/1680 train_time:75126ms step_avg:87.56ms +step:859/1680 train_time:75216ms step_avg:87.56ms +step:860/1680 train_time:75304ms step_avg:87.56ms +step:861/1680 train_time:75393ms step_avg:87.56ms +step:862/1680 train_time:75480ms step_avg:87.56ms +step:863/1680 train_time:75568ms step_avg:87.56ms +step:864/1680 train_time:75656ms step_avg:87.56ms +step:865/1680 train_time:75745ms step_avg:87.57ms +step:866/1680 train_time:75833ms step_avg:87.57ms +step:867/1680 train_time:75921ms step_avg:87.57ms +step:868/1680 train_time:76009ms step_avg:87.57ms +step:869/1680 train_time:76097ms step_avg:87.57ms +step:870/1680 train_time:76186ms step_avg:87.57ms +step:871/1680 train_time:76274ms step_avg:87.57ms +step:872/1680 train_time:76363ms step_avg:87.57ms +step:873/1680 train_time:76450ms step_avg:87.57ms +step:874/1680 train_time:76538ms step_avg:87.57ms +step:875/1680 train_time:76627ms step_avg:87.57ms +step:875/1680 val_loss:3.5195 train_time:76716ms step_avg:87.68ms +step:876/1680 train_time:76737ms step_avg:87.60ms +step:877/1680 train_time:76806ms step_avg:87.58ms +step:878/1680 train_time:76900ms step_avg:87.59ms +step:879/1680 train_time:76990ms step_avg:87.59ms +step:880/1680 train_time:77077ms step_avg:87.59ms +step:881/1680 train_time:77164ms step_avg:87.59ms +step:882/1680 train_time:77251ms step_avg:87.59ms +step:883/1680 train_time:77338ms step_avg:87.59ms +step:884/1680 train_time:77425ms step_avg:87.59ms +step:885/1680 train_time:77513ms step_avg:87.58ms +step:886/1680 train_time:77600ms step_avg:87.59ms +step:887/1680 train_time:77690ms step_avg:87.59ms +step:888/1680 train_time:77780ms step_avg:87.59ms +step:889/1680 train_time:77870ms step_avg:87.59ms +step:890/1680 train_time:77960ms step_avg:87.60ms +step:891/1680 train_time:78048ms step_avg:87.60ms +step:892/1680 train_time:78136ms step_avg:87.60ms +step:893/1680 train_time:78224ms step_avg:87.60ms +step:894/1680 train_time:78311ms step_avg:87.60ms +step:895/1680 train_time:78399ms step_avg:87.60ms +step:896/1680 train_time:78486ms step_avg:87.60ms +step:897/1680 train_time:78573ms step_avg:87.60ms +step:898/1680 train_time:78662ms step_avg:87.60ms +step:899/1680 train_time:78750ms step_avg:87.60ms +step:900/1680 train_time:78840ms step_avg:87.60ms +step:901/1680 train_time:78928ms step_avg:87.60ms +step:902/1680 train_time:79017ms step_avg:87.60ms +step:903/1680 train_time:79106ms step_avg:87.60ms +step:904/1680 train_time:79194ms step_avg:87.60ms +step:905/1680 train_time:79283ms step_avg:87.61ms +step:906/1680 train_time:79371ms step_avg:87.61ms +step:907/1680 train_time:79459ms step_avg:87.61ms +step:908/1680 train_time:79547ms step_avg:87.61ms +step:909/1680 train_time:79634ms step_avg:87.61ms +step:910/1680 train_time:79723ms step_avg:87.61ms +step:911/1680 train_time:79812ms step_avg:87.61ms +step:912/1680 train_time:79902ms step_avg:87.61ms +step:913/1680 train_time:79991ms step_avg:87.61ms +step:914/1680 train_time:80080ms step_avg:87.61ms +step:915/1680 train_time:80167ms step_avg:87.61ms +step:916/1680 train_time:80255ms step_avg:87.62ms +step:917/1680 train_time:80344ms step_avg:87.62ms +step:918/1680 train_time:80431ms step_avg:87.62ms +step:919/1680 train_time:80519ms step_avg:87.62ms +step:920/1680 train_time:80606ms step_avg:87.62ms +step:921/1680 train_time:80695ms step_avg:87.62ms +step:922/1680 train_time:80783ms step_avg:87.62ms +step:923/1680 train_time:80872ms step_avg:87.62ms +step:924/1680 train_time:80961ms step_avg:87.62ms +step:925/1680 train_time:81049ms step_avg:87.62ms +step:926/1680 train_time:81137ms step_avg:87.62ms +step:927/1680 train_time:81225ms step_avg:87.62ms +step:928/1680 train_time:81313ms step_avg:87.62ms +step:929/1680 train_time:81401ms step_avg:87.62ms +step:930/1680 train_time:81489ms step_avg:87.62ms +step:931/1680 train_time:81577ms step_avg:87.62ms +step:932/1680 train_time:81665ms step_avg:87.62ms +step:933/1680 train_time:81753ms step_avg:87.62ms +step:934/1680 train_time:81841ms step_avg:87.62ms +step:935/1680 train_time:81929ms step_avg:87.62ms +step:936/1680 train_time:82019ms step_avg:87.63ms +step:937/1680 train_time:82107ms step_avg:87.63ms +step:938/1680 train_time:82195ms step_avg:87.63ms +step:939/1680 train_time:82283ms step_avg:87.63ms +step:940/1680 train_time:82371ms step_avg:87.63ms +step:941/1680 train_time:82459ms step_avg:87.63ms +step:942/1680 train_time:82547ms step_avg:87.63ms +step:943/1680 train_time:82635ms step_avg:87.63ms +step:944/1680 train_time:82723ms step_avg:87.63ms +step:945/1680 train_time:82811ms step_avg:87.63ms +step:946/1680 train_time:82899ms step_avg:87.63ms +step:947/1680 train_time:82988ms step_avg:87.63ms +step:948/1680 train_time:83076ms step_avg:87.63ms +step:949/1680 train_time:83164ms step_avg:87.63ms +step:950/1680 train_time:83253ms step_avg:87.63ms +step:951/1680 train_time:83341ms step_avg:87.64ms +step:952/1680 train_time:83430ms step_avg:87.64ms +step:953/1680 train_time:83518ms step_avg:87.64ms +step:954/1680 train_time:83606ms step_avg:87.64ms +step:955/1680 train_time:83694ms step_avg:87.64ms +step:956/1680 train_time:83783ms step_avg:87.64ms +step:957/1680 train_time:83872ms step_avg:87.64ms +step:958/1680 train_time:83960ms step_avg:87.64ms +step:959/1680 train_time:84048ms step_avg:87.64ms +step:960/1680 train_time:84137ms step_avg:87.64ms +step:961/1680 train_time:84226ms step_avg:87.64ms +step:962/1680 train_time:84314ms step_avg:87.64ms +step:963/1680 train_time:84403ms step_avg:87.65ms +step:964/1680 train_time:84491ms step_avg:87.65ms +step:965/1680 train_time:84579ms step_avg:87.65ms +step:966/1680 train_time:84667ms step_avg:87.65ms +step:967/1680 train_time:84756ms step_avg:87.65ms +step:968/1680 train_time:84844ms step_avg:87.65ms +step:969/1680 train_time:84933ms step_avg:87.65ms +step:970/1680 train_time:85021ms step_avg:87.65ms +step:971/1680 train_time:85109ms step_avg:87.65ms +step:972/1680 train_time:85198ms step_avg:87.65ms +step:973/1680 train_time:85286ms step_avg:87.65ms +step:974/1680 train_time:85374ms step_avg:87.65ms +step:975/1680 train_time:85462ms step_avg:87.65ms +step:976/1680 train_time:85550ms step_avg:87.65ms +step:977/1680 train_time:85638ms step_avg:87.65ms +step:978/1680 train_time:85726ms step_avg:87.65ms +step:979/1680 train_time:85814ms step_avg:87.65ms +step:980/1680 train_time:85902ms step_avg:87.66ms +step:981/1680 train_time:85991ms step_avg:87.66ms +step:982/1680 train_time:86080ms step_avg:87.66ms +step:983/1680 train_time:86168ms step_avg:87.66ms +step:984/1680 train_time:86256ms step_avg:87.66ms +step:985/1680 train_time:86345ms step_avg:87.66ms +step:986/1680 train_time:86433ms step_avg:87.66ms +step:987/1680 train_time:86521ms step_avg:87.66ms +step:988/1680 train_time:86608ms step_avg:87.66ms +step:989/1680 train_time:86697ms step_avg:87.66ms +step:990/1680 train_time:86785ms step_avg:87.66ms +step:991/1680 train_time:86873ms step_avg:87.66ms +step:992/1680 train_time:86962ms step_avg:87.66ms +step:993/1680 train_time:87050ms step_avg:87.66ms +step:994/1680 train_time:87139ms step_avg:87.66ms +step:995/1680 train_time:87227ms step_avg:87.67ms +step:996/1680 train_time:87315ms step_avg:87.67ms +step:997/1680 train_time:87404ms step_avg:87.67ms +step:998/1680 train_time:87492ms step_avg:87.67ms +step:999/1680 train_time:87581ms step_avg:87.67ms +step:1000/1680 train_time:87668ms step_avg:87.67ms +step:1000/1680 val_loss:3.4690 train_time:87758ms step_avg:87.76ms +step:1001/1680 train_time:87778ms step_avg:87.69ms +step:1002/1680 train_time:87850ms step_avg:87.67ms +step:1003/1680 train_time:87944ms step_avg:87.68ms +step:1004/1680 train_time:88032ms step_avg:87.68ms +step:1005/1680 train_time:88120ms step_avg:87.68ms +step:1006/1680 train_time:88207ms step_avg:87.68ms +step:1007/1680 train_time:88295ms step_avg:87.68ms +step:1008/1680 train_time:88382ms step_avg:87.68ms +step:1009/1680 train_time:88469ms step_avg:87.68ms +step:1010/1680 train_time:88557ms step_avg:87.68ms +step:1011/1680 train_time:88644ms step_avg:87.68ms +step:1012/1680 train_time:88733ms step_avg:87.68ms +step:1013/1680 train_time:88823ms step_avg:87.68ms +step:1014/1680 train_time:88913ms step_avg:87.69ms +step:1015/1680 train_time:89002ms step_avg:87.69ms +step:1016/1680 train_time:89090ms step_avg:87.69ms +step:1017/1680 train_time:89179ms step_avg:87.69ms +step:1018/1680 train_time:89266ms step_avg:87.69ms +step:1019/1680 train_time:89353ms step_avg:87.69ms +step:1020/1680 train_time:89441ms step_avg:87.69ms +step:1021/1680 train_time:89529ms step_avg:87.69ms +step:1022/1680 train_time:89616ms step_avg:87.69ms +step:1023/1680 train_time:89706ms step_avg:87.69ms +step:1024/1680 train_time:89795ms step_avg:87.69ms +step:1025/1680 train_time:89884ms step_avg:87.69ms +step:1026/1680 train_time:89973ms step_avg:87.69ms +step:1027/1680 train_time:90061ms step_avg:87.69ms +step:1028/1680 train_time:90150ms step_avg:87.69ms +step:1029/1680 train_time:90238ms step_avg:87.70ms +step:1030/1680 train_time:90326ms step_avg:87.70ms +step:1031/1680 train_time:90413ms step_avg:87.69ms +step:1032/1680 train_time:90501ms step_avg:87.69ms +step:1033/1680 train_time:90589ms step_avg:87.69ms +step:1034/1680 train_time:90676ms step_avg:87.69ms +step:1035/1680 train_time:90765ms step_avg:87.70ms +step:1036/1680 train_time:90853ms step_avg:87.70ms +step:1037/1680 train_time:90943ms step_avg:87.70ms +step:1038/1680 train_time:91032ms step_avg:87.70ms +step:1039/1680 train_time:91120ms step_avg:87.70ms +step:1040/1680 train_time:91209ms step_avg:87.70ms +step:1041/1680 train_time:91297ms step_avg:87.70ms +step:1042/1680 train_time:91384ms step_avg:87.70ms +step:1043/1680 train_time:91472ms step_avg:87.70ms +step:1044/1680 train_time:91559ms step_avg:87.70ms +step:1045/1680 train_time:91647ms step_avg:87.70ms +step:1046/1680 train_time:91735ms step_avg:87.70ms +step:1047/1680 train_time:91824ms step_avg:87.70ms +step:1048/1680 train_time:91911ms step_avg:87.70ms +step:1049/1680 train_time:92001ms step_avg:87.70ms +step:1050/1680 train_time:92089ms step_avg:87.70ms +step:1051/1680 train_time:92178ms step_avg:87.71ms +step:1052/1680 train_time:92267ms step_avg:87.71ms +step:1053/1680 train_time:92355ms step_avg:87.71ms +step:1054/1680 train_time:92443ms step_avg:87.71ms +step:1055/1680 train_time:92531ms step_avg:87.71ms +step:1056/1680 train_time:92618ms step_avg:87.71ms +step:1057/1680 train_time:92706ms step_avg:87.71ms +step:1058/1680 train_time:92795ms step_avg:87.71ms +step:1059/1680 train_time:92883ms step_avg:87.71ms +step:1060/1680 train_time:92971ms step_avg:87.71ms +step:1061/1680 train_time:93060ms step_avg:87.71ms +step:1062/1680 train_time:93149ms step_avg:87.71ms +step:1063/1680 train_time:93236ms step_avg:87.71ms +step:1064/1680 train_time:93325ms step_avg:87.71ms +step:1065/1680 train_time:93413ms step_avg:87.71ms +step:1066/1680 train_time:93501ms step_avg:87.71ms +step:1067/1680 train_time:93589ms step_avg:87.71ms +step:1068/1680 train_time:93677ms step_avg:87.71ms +step:1069/1680 train_time:93765ms step_avg:87.71ms +step:1070/1680 train_time:93853ms step_avg:87.71ms +step:1071/1680 train_time:93942ms step_avg:87.71ms +step:1072/1680 train_time:94031ms step_avg:87.72ms +step:1073/1680 train_time:94120ms step_avg:87.72ms +step:1074/1680 train_time:94208ms step_avg:87.72ms +step:1075/1680 train_time:94295ms step_avg:87.72ms +step:1076/1680 train_time:94384ms step_avg:87.72ms +step:1077/1680 train_time:94472ms step_avg:87.72ms +step:1078/1680 train_time:94561ms step_avg:87.72ms +step:1079/1680 train_time:94649ms step_avg:87.72ms +step:1080/1680 train_time:94737ms step_avg:87.72ms +step:1081/1680 train_time:94825ms step_avg:87.72ms +step:1082/1680 train_time:94913ms step_avg:87.72ms +step:1083/1680 train_time:95002ms step_avg:87.72ms +step:1084/1680 train_time:95090ms step_avg:87.72ms +step:1085/1680 train_time:95179ms step_avg:87.72ms +step:1086/1680 train_time:95267ms step_avg:87.72ms +step:1087/1680 train_time:95356ms step_avg:87.72ms +step:1088/1680 train_time:95444ms step_avg:87.72ms +step:1089/1680 train_time:95532ms step_avg:87.72ms +step:1090/1680 train_time:95620ms step_avg:87.72ms +step:1091/1680 train_time:95708ms step_avg:87.72ms +step:1092/1680 train_time:95796ms step_avg:87.73ms +step:1093/1680 train_time:95885ms step_avg:87.73ms +step:1094/1680 train_time:95973ms step_avg:87.73ms +step:1095/1680 train_time:96062ms step_avg:87.73ms +step:1096/1680 train_time:96150ms step_avg:87.73ms +step:1097/1680 train_time:96239ms step_avg:87.73ms +step:1098/1680 train_time:96328ms step_avg:87.73ms +step:1099/1680 train_time:96417ms step_avg:87.73ms +step:1100/1680 train_time:96505ms step_avg:87.73ms +step:1101/1680 train_time:96594ms step_avg:87.73ms +step:1102/1680 train_time:96682ms step_avg:87.73ms +step:1103/1680 train_time:96771ms step_avg:87.73ms +step:1104/1680 train_time:96860ms step_avg:87.74ms +step:1105/1680 train_time:96949ms step_avg:87.74ms +step:1106/1680 train_time:97038ms step_avg:87.74ms +step:1107/1680 train_time:97127ms step_avg:87.74ms +step:1108/1680 train_time:97215ms step_avg:87.74ms +step:1109/1680 train_time:97304ms step_avg:87.74ms +step:1110/1680 train_time:97393ms step_avg:87.74ms +step:1111/1680 train_time:97482ms step_avg:87.74ms +step:1112/1680 train_time:97571ms step_avg:87.74ms +step:1113/1680 train_time:97660ms step_avg:87.74ms +step:1114/1680 train_time:97748ms step_avg:87.75ms +step:1115/1680 train_time:97837ms step_avg:87.75ms +step:1116/1680 train_time:97927ms step_avg:87.75ms +step:1117/1680 train_time:98016ms step_avg:87.75ms +step:1118/1680 train_time:98104ms step_avg:87.75ms +step:1119/1680 train_time:98193ms step_avg:87.75ms +step:1120/1680 train_time:98282ms step_avg:87.75ms +step:1121/1680 train_time:98371ms step_avg:87.75ms +step:1122/1680 train_time:98459ms step_avg:87.75ms +step:1123/1680 train_time:98548ms step_avg:87.75ms +step:1124/1680 train_time:98637ms step_avg:87.76ms +step:1125/1680 train_time:98727ms step_avg:87.76ms +step:1125/1680 val_loss:3.4152 train_time:98817ms step_avg:87.84ms +step:1126/1680 train_time:98837ms step_avg:87.78ms +step:1127/1680 train_time:98906ms step_avg:87.76ms +step:1128/1680 train_time:98998ms step_avg:87.76ms +step:1129/1680 train_time:99091ms step_avg:87.77ms +step:1130/1680 train_time:99178ms step_avg:87.77ms +step:1131/1680 train_time:99266ms step_avg:87.77ms +step:1132/1680 train_time:99355ms step_avg:87.77ms +step:1133/1680 train_time:99442ms step_avg:87.77ms +step:1134/1680 train_time:99530ms step_avg:87.77ms +step:1135/1680 train_time:99618ms step_avg:87.77ms +step:1136/1680 train_time:99708ms step_avg:87.77ms +step:1137/1680 train_time:99798ms step_avg:87.77ms +step:1138/1680 train_time:99889ms step_avg:87.78ms +step:1139/1680 train_time:99980ms step_avg:87.78ms +step:1140/1680 train_time:100071ms step_avg:87.78ms +step:1141/1680 train_time:100161ms step_avg:87.78ms +step:1142/1680 train_time:100248ms step_avg:87.78ms +step:1143/1680 train_time:100337ms step_avg:87.78ms +step:1144/1680 train_time:100426ms step_avg:87.78ms +step:1145/1680 train_time:100514ms step_avg:87.78ms +step:1146/1680 train_time:100602ms step_avg:87.79ms +step:1147/1680 train_time:100691ms step_avg:87.79ms +step:1148/1680 train_time:100780ms step_avg:87.79ms +step:1149/1680 train_time:100870ms step_avg:87.79ms +step:1150/1680 train_time:100960ms step_avg:87.79ms +step:1151/1680 train_time:101050ms step_avg:87.79ms +step:1152/1680 train_time:101140ms step_avg:87.79ms +step:1153/1680 train_time:101228ms step_avg:87.80ms +step:1154/1680 train_time:101317ms step_avg:87.80ms +step:1155/1680 train_time:101406ms step_avg:87.80ms +step:1156/1680 train_time:101495ms step_avg:87.80ms +step:1157/1680 train_time:101583ms step_avg:87.80ms +step:1158/1680 train_time:101671ms step_avg:87.80ms +step:1159/1680 train_time:101760ms step_avg:87.80ms +step:1160/1680 train_time:101849ms step_avg:87.80ms +step:1161/1680 train_time:101939ms step_avg:87.80ms +step:1162/1680 train_time:102029ms step_avg:87.80ms +step:1163/1680 train_time:102118ms step_avg:87.81ms +step:1164/1680 train_time:102208ms step_avg:87.81ms +step:1165/1680 train_time:102297ms step_avg:87.81ms +step:1166/1680 train_time:102386ms step_avg:87.81ms +step:1167/1680 train_time:102474ms step_avg:87.81ms +step:1168/1680 train_time:102562ms step_avg:87.81ms +step:1169/1680 train_time:102651ms step_avg:87.81ms +step:1170/1680 train_time:102739ms step_avg:87.81ms +step:1171/1680 train_time:102828ms step_avg:87.81ms +step:1172/1680 train_time:102918ms step_avg:87.81ms +step:1173/1680 train_time:103008ms step_avg:87.82ms +step:1174/1680 train_time:103098ms step_avg:87.82ms +step:1175/1680 train_time:103188ms step_avg:87.82ms +step:1176/1680 train_time:103276ms step_avg:87.82ms +step:1177/1680 train_time:103364ms step_avg:87.82ms +step:1178/1680 train_time:103452ms step_avg:87.82ms +step:1179/1680 train_time:103542ms step_avg:87.82ms +step:1180/1680 train_time:103630ms step_avg:87.82ms +step:1181/1680 train_time:103719ms step_avg:87.82ms +step:1182/1680 train_time:103807ms step_avg:87.82ms +step:1183/1680 train_time:103896ms step_avg:87.82ms +step:1184/1680 train_time:103984ms step_avg:87.82ms +step:1185/1680 train_time:104073ms step_avg:87.83ms +step:1186/1680 train_time:104163ms step_avg:87.83ms +step:1187/1680 train_time:104252ms step_avg:87.83ms +step:1188/1680 train_time:104341ms step_avg:87.83ms +step:1189/1680 train_time:104430ms step_avg:87.83ms +step:1190/1680 train_time:104519ms step_avg:87.83ms +step:1191/1680 train_time:104608ms step_avg:87.83ms +step:1192/1680 train_time:104697ms step_avg:87.83ms +step:1193/1680 train_time:104787ms step_avg:87.84ms +step:1194/1680 train_time:104876ms step_avg:87.84ms +step:1195/1680 train_time:104965ms step_avg:87.84ms +step:1196/1680 train_time:105054ms step_avg:87.84ms +step:1197/1680 train_time:105143ms step_avg:87.84ms +step:1198/1680 train_time:105232ms step_avg:87.84ms +step:1199/1680 train_time:105321ms step_avg:87.84ms +step:1200/1680 train_time:105410ms step_avg:87.84ms +step:1201/1680 train_time:105499ms step_avg:87.84ms +step:1202/1680 train_time:105588ms step_avg:87.84ms +step:1203/1680 train_time:105676ms step_avg:87.84ms +step:1204/1680 train_time:105765ms step_avg:87.85ms +step:1205/1680 train_time:105854ms step_avg:87.85ms +step:1206/1680 train_time:105944ms step_avg:87.85ms +step:1207/1680 train_time:106033ms step_avg:87.85ms +step:1208/1680 train_time:106122ms step_avg:87.85ms +step:1209/1680 train_time:106211ms step_avg:87.85ms +step:1210/1680 train_time:106300ms step_avg:87.85ms +step:1211/1680 train_time:106389ms step_avg:87.85ms +step:1212/1680 train_time:106478ms step_avg:87.85ms +step:1213/1680 train_time:106567ms step_avg:87.85ms +step:1214/1680 train_time:106655ms step_avg:87.85ms +step:1215/1680 train_time:106743ms step_avg:87.85ms +step:1216/1680 train_time:106833ms step_avg:87.86ms +step:1217/1680 train_time:106922ms step_avg:87.86ms +step:1218/1680 train_time:107011ms step_avg:87.86ms +step:1219/1680 train_time:107102ms step_avg:87.86ms +step:1220/1680 train_time:107192ms step_avg:87.86ms +step:1221/1680 train_time:107281ms step_avg:87.86ms +step:1222/1680 train_time:107370ms step_avg:87.86ms +step:1223/1680 train_time:107459ms step_avg:87.87ms +step:1224/1680 train_time:107548ms step_avg:87.87ms +step:1225/1680 train_time:107637ms step_avg:87.87ms +step:1226/1680 train_time:107726ms step_avg:87.87ms +step:1227/1680 train_time:107815ms step_avg:87.87ms +step:1228/1680 train_time:107904ms step_avg:87.87ms +step:1229/1680 train_time:107992ms step_avg:87.87ms +step:1230/1680 train_time:108081ms step_avg:87.87ms +step:1231/1680 train_time:108170ms step_avg:87.87ms +step:1232/1680 train_time:108259ms step_avg:87.87ms +step:1233/1680 train_time:108347ms step_avg:87.87ms +step:1234/1680 train_time:108436ms step_avg:87.87ms +step:1235/1680 train_time:108525ms step_avg:87.87ms +step:1236/1680 train_time:108614ms step_avg:87.88ms +step:1237/1680 train_time:108703ms step_avg:87.88ms +step:1238/1680 train_time:108793ms step_avg:87.88ms +step:1239/1680 train_time:108883ms step_avg:87.88ms +step:1240/1680 train_time:108972ms step_avg:87.88ms +step:1241/1680 train_time:109060ms step_avg:87.88ms +step:1242/1680 train_time:109149ms step_avg:87.88ms +step:1243/1680 train_time:109237ms step_avg:87.88ms +step:1244/1680 train_time:109327ms step_avg:87.88ms +step:1245/1680 train_time:109415ms step_avg:87.88ms +step:1246/1680 train_time:109504ms step_avg:87.88ms +step:1247/1680 train_time:109593ms step_avg:87.89ms +step:1248/1680 train_time:109682ms step_avg:87.89ms +step:1249/1680 train_time:109770ms step_avg:87.89ms +step:1250/1680 train_time:109859ms step_avg:87.89ms +step:1250/1680 val_loss:3.3769 train_time:109949ms step_avg:87.96ms +step:1251/1680 train_time:109968ms step_avg:87.90ms +step:1252/1680 train_time:110045ms step_avg:87.90ms +step:1253/1680 train_time:110138ms step_avg:87.90ms +step:1254/1680 train_time:110227ms step_avg:87.90ms +step:1255/1680 train_time:110315ms step_avg:87.90ms +step:1256/1680 train_time:110402ms step_avg:87.90ms +step:1257/1680 train_time:110490ms step_avg:87.90ms +step:1258/1680 train_time:110579ms step_avg:87.90ms +step:1259/1680 train_time:110667ms step_avg:87.90ms +step:1260/1680 train_time:110755ms step_avg:87.90ms +step:1261/1680 train_time:110843ms step_avg:87.90ms +step:1262/1680 train_time:110934ms step_avg:87.90ms +step:1263/1680 train_time:111027ms step_avg:87.91ms +step:1264/1680 train_time:111117ms step_avg:87.91ms +step:1265/1680 train_time:111207ms step_avg:87.91ms +step:1266/1680 train_time:111296ms step_avg:87.91ms +step:1267/1680 train_time:111384ms step_avg:87.91ms +step:1268/1680 train_time:111472ms step_avg:87.91ms +step:1269/1680 train_time:111560ms step_avg:87.91ms +step:1270/1680 train_time:111648ms step_avg:87.91ms +step:1271/1680 train_time:111736ms step_avg:87.91ms +step:1272/1680 train_time:111824ms step_avg:87.91ms +step:1273/1680 train_time:111914ms step_avg:87.91ms +step:1274/1680 train_time:112005ms step_avg:87.92ms +step:1275/1680 train_time:112094ms step_avg:87.92ms +step:1276/1680 train_time:112184ms step_avg:87.92ms +step:1277/1680 train_time:112273ms step_avg:87.92ms +step:1278/1680 train_time:112361ms step_avg:87.92ms +step:1279/1680 train_time:112450ms step_avg:87.92ms +step:1280/1680 train_time:112539ms step_avg:87.92ms +step:1281/1680 train_time:112627ms step_avg:87.92ms +step:1282/1680 train_time:112715ms step_avg:87.92ms +step:1283/1680 train_time:112804ms step_avg:87.92ms +step:1284/1680 train_time:112893ms step_avg:87.92ms +step:1285/1680 train_time:112982ms step_avg:87.92ms +step:1286/1680 train_time:113072ms step_avg:87.93ms +step:1287/1680 train_time:113163ms step_avg:87.93ms +step:1288/1680 train_time:113251ms step_avg:87.93ms +step:1289/1680 train_time:113340ms step_avg:87.93ms +step:1290/1680 train_time:113429ms step_avg:87.93ms +step:1291/1680 train_time:113518ms step_avg:87.93ms +step:1292/1680 train_time:113607ms step_avg:87.93ms +step:1293/1680 train_time:113696ms step_avg:87.93ms +step:1294/1680 train_time:113784ms step_avg:87.93ms +step:1295/1680 train_time:113873ms step_avg:87.93ms +step:1296/1680 train_time:113961ms step_avg:87.93ms +step:1297/1680 train_time:114051ms step_avg:87.93ms +step:1298/1680 train_time:114141ms step_avg:87.94ms +step:1299/1680 train_time:114230ms step_avg:87.94ms +step:1300/1680 train_time:114319ms step_avg:87.94ms +step:1301/1680 train_time:114408ms step_avg:87.94ms +step:1302/1680 train_time:114496ms step_avg:87.94ms +step:1303/1680 train_time:114585ms step_avg:87.94ms +step:1304/1680 train_time:114673ms step_avg:87.94ms +step:1305/1680 train_time:114762ms step_avg:87.94ms +step:1306/1680 train_time:114851ms step_avg:87.94ms +step:1307/1680 train_time:114940ms step_avg:87.94ms +step:1308/1680 train_time:115029ms step_avg:87.94ms +step:1309/1680 train_time:115119ms step_avg:87.94ms +step:1310/1680 train_time:115208ms step_avg:87.95ms +step:1311/1680 train_time:115297ms step_avg:87.95ms +step:1312/1680 train_time:115386ms step_avg:87.95ms +step:1313/1680 train_time:115475ms step_avg:87.95ms +step:1314/1680 train_time:115564ms step_avg:87.95ms +step:1315/1680 train_time:115652ms step_avg:87.95ms +step:1316/1680 train_time:115742ms step_avg:87.95ms +step:1317/1680 train_time:115830ms step_avg:87.95ms +step:1318/1680 train_time:115919ms step_avg:87.95ms +step:1319/1680 train_time:116008ms step_avg:87.95ms +step:1320/1680 train_time:116097ms step_avg:87.95ms +step:1321/1680 train_time:116187ms step_avg:87.95ms +step:1322/1680 train_time:116276ms step_avg:87.95ms +step:1323/1680 train_time:116366ms step_avg:87.96ms +step:1324/1680 train_time:116455ms step_avg:87.96ms +step:1325/1680 train_time:116544ms step_avg:87.96ms +step:1326/1680 train_time:116633ms step_avg:87.96ms +step:1327/1680 train_time:116721ms step_avg:87.96ms +step:1328/1680 train_time:116810ms step_avg:87.96ms +step:1329/1680 train_time:116899ms step_avg:87.96ms +step:1330/1680 train_time:116989ms step_avg:87.96ms +step:1331/1680 train_time:117079ms step_avg:87.96ms +step:1332/1680 train_time:117169ms step_avg:87.97ms +step:1333/1680 train_time:117258ms step_avg:87.97ms +step:1334/1680 train_time:117348ms step_avg:87.97ms +step:1335/1680 train_time:117437ms step_avg:87.97ms +step:1336/1680 train_time:117528ms step_avg:87.97ms +step:1337/1680 train_time:117616ms step_avg:87.97ms +step:1338/1680 train_time:117704ms step_avg:87.97ms +step:1339/1680 train_time:117793ms step_avg:87.97ms +step:1340/1680 train_time:117881ms step_avg:87.97ms +step:1341/1680 train_time:117970ms step_avg:87.97ms +step:1342/1680 train_time:118059ms step_avg:87.97ms +step:1343/1680 train_time:118149ms step_avg:87.97ms +step:1344/1680 train_time:118238ms step_avg:87.97ms +step:1345/1680 train_time:118328ms step_avg:87.98ms +step:1346/1680 train_time:118417ms step_avg:87.98ms +step:1347/1680 train_time:118507ms step_avg:87.98ms +step:1348/1680 train_time:118595ms step_avg:87.98ms +step:1349/1680 train_time:118684ms step_avg:87.98ms +step:1350/1680 train_time:118774ms step_avg:87.98ms +step:1351/1680 train_time:118862ms step_avg:87.98ms +step:1352/1680 train_time:118951ms step_avg:87.98ms +step:1353/1680 train_time:119042ms step_avg:87.98ms +step:1354/1680 train_time:119130ms step_avg:87.98ms +step:1355/1680 train_time:119220ms step_avg:87.98ms +step:1356/1680 train_time:119308ms step_avg:87.99ms +step:1357/1680 train_time:119397ms step_avg:87.99ms +step:1358/1680 train_time:119487ms step_avg:87.99ms +step:1359/1680 train_time:119576ms step_avg:87.99ms +step:1360/1680 train_time:119665ms step_avg:87.99ms +step:1361/1680 train_time:119753ms step_avg:87.99ms +step:1362/1680 train_time:119842ms step_avg:87.99ms +step:1363/1680 train_time:119930ms step_avg:87.99ms +step:1364/1680 train_time:120019ms step_avg:87.99ms +step:1365/1680 train_time:120108ms step_avg:87.99ms +step:1366/1680 train_time:120197ms step_avg:87.99ms +step:1367/1680 train_time:120285ms step_avg:87.99ms +step:1368/1680 train_time:120373ms step_avg:87.99ms +step:1369/1680 train_time:120463ms step_avg:87.99ms +step:1370/1680 train_time:120552ms step_avg:87.99ms +step:1371/1680 train_time:120641ms step_avg:88.00ms +step:1372/1680 train_time:120731ms step_avg:88.00ms +step:1373/1680 train_time:120820ms step_avg:88.00ms +step:1374/1680 train_time:120910ms step_avg:88.00ms +step:1375/1680 train_time:120999ms step_avg:88.00ms +step:1375/1680 val_loss:3.3420 train_time:121089ms step_avg:88.06ms +step:1376/1680 train_time:121108ms step_avg:88.01ms +step:1377/1680 train_time:121179ms step_avg:88.00ms +step:1378/1680 train_time:121271ms step_avg:88.01ms +step:1379/1680 train_time:121361ms step_avg:88.01ms +step:1380/1680 train_time:121449ms step_avg:88.01ms +step:1381/1680 train_time:121537ms step_avg:88.01ms +step:1382/1680 train_time:121625ms step_avg:88.01ms +step:1383/1680 train_time:121713ms step_avg:88.01ms +step:1384/1680 train_time:121801ms step_avg:88.01ms +step:1385/1680 train_time:121891ms step_avg:88.01ms +step:1386/1680 train_time:121979ms step_avg:88.01ms +step:1387/1680 train_time:122070ms step_avg:88.01ms +step:1388/1680 train_time:122160ms step_avg:88.01ms +step:1389/1680 train_time:122251ms step_avg:88.01ms +step:1390/1680 train_time:122340ms step_avg:88.01ms +step:1391/1680 train_time:122428ms step_avg:88.01ms +step:1392/1680 train_time:122517ms step_avg:88.01ms +step:1393/1680 train_time:122606ms step_avg:88.02ms +step:1394/1680 train_time:122694ms step_avg:88.02ms +step:1395/1680 train_time:122782ms step_avg:88.02ms +step:1396/1680 train_time:122870ms step_avg:88.02ms +step:1397/1680 train_time:122959ms step_avg:88.02ms +step:1398/1680 train_time:123048ms step_avg:88.02ms +step:1399/1680 train_time:123137ms step_avg:88.02ms +step:1400/1680 train_time:123227ms step_avg:88.02ms +step:1401/1680 train_time:123317ms step_avg:88.02ms +step:1402/1680 train_time:123406ms step_avg:88.02ms +step:1403/1680 train_time:123495ms step_avg:88.02ms +step:1404/1680 train_time:123584ms step_avg:88.02ms +step:1405/1680 train_time:123672ms step_avg:88.02ms +step:1406/1680 train_time:123760ms step_avg:88.02ms +step:1407/1680 train_time:123848ms step_avg:88.02ms +step:1408/1680 train_time:123937ms step_avg:88.02ms +step:1409/1680 train_time:124026ms step_avg:88.02ms +step:1410/1680 train_time:124116ms step_avg:88.03ms +step:1411/1680 train_time:124206ms step_avg:88.03ms +step:1412/1680 train_time:124296ms step_avg:88.03ms +step:1413/1680 train_time:124387ms step_avg:88.03ms +step:1414/1680 train_time:124475ms step_avg:88.03ms +step:1415/1680 train_time:124564ms step_avg:88.03ms +step:1416/1680 train_time:124652ms step_avg:88.03ms +step:1417/1680 train_time:124740ms step_avg:88.03ms +step:1418/1680 train_time:124829ms step_avg:88.03ms +step:1419/1680 train_time:124917ms step_avg:88.03ms +step:1420/1680 train_time:125006ms step_avg:88.03ms +step:1421/1680 train_time:125095ms step_avg:88.03ms +step:1422/1680 train_time:125184ms step_avg:88.03ms +step:1423/1680 train_time:125274ms step_avg:88.04ms +step:1424/1680 train_time:125365ms step_avg:88.04ms +step:1425/1680 train_time:125454ms step_avg:88.04ms +step:1426/1680 train_time:125543ms step_avg:88.04ms +step:1427/1680 train_time:125632ms step_avg:88.04ms +step:1428/1680 train_time:125720ms step_avg:88.04ms +step:1429/1680 train_time:125809ms step_avg:88.04ms +step:1430/1680 train_time:125899ms step_avg:88.04ms +step:1431/1680 train_time:125988ms step_avg:88.04ms +step:1432/1680 train_time:126076ms step_avg:88.04ms +step:1433/1680 train_time:126165ms step_avg:88.04ms +step:1434/1680 train_time:126255ms step_avg:88.04ms +step:1435/1680 train_time:126345ms step_avg:88.05ms +step:1436/1680 train_time:126434ms step_avg:88.05ms +step:1437/1680 train_time:126524ms step_avg:88.05ms +step:1438/1680 train_time:126613ms step_avg:88.05ms +step:1439/1680 train_time:126702ms step_avg:88.05ms +step:1440/1680 train_time:126790ms step_avg:88.05ms +step:1441/1680 train_time:126879ms step_avg:88.05ms +step:1442/1680 train_time:126968ms step_avg:88.05ms +step:1443/1680 train_time:127057ms step_avg:88.05ms +step:1444/1680 train_time:127146ms step_avg:88.05ms +step:1445/1680 train_time:127234ms step_avg:88.05ms +step:1446/1680 train_time:127324ms step_avg:88.05ms +step:1447/1680 train_time:127413ms step_avg:88.05ms +step:1448/1680 train_time:127502ms step_avg:88.05ms +step:1449/1680 train_time:127591ms step_avg:88.05ms +step:1450/1680 train_time:127679ms step_avg:88.05ms +step:1451/1680 train_time:127769ms step_avg:88.06ms +step:1452/1680 train_time:127858ms step_avg:88.06ms +step:1453/1680 train_time:127946ms step_avg:88.06ms +step:1454/1680 train_time:128034ms step_avg:88.06ms +step:1455/1680 train_time:128124ms step_avg:88.06ms +step:1456/1680 train_time:128213ms step_avg:88.06ms +step:1457/1680 train_time:128302ms step_avg:88.06ms +step:1458/1680 train_time:128392ms step_avg:88.06ms +step:1459/1680 train_time:128481ms step_avg:88.06ms +step:1460/1680 train_time:128571ms step_avg:88.06ms +step:1461/1680 train_time:128661ms step_avg:88.06ms +step:1462/1680 train_time:128749ms step_avg:88.06ms +step:1463/1680 train_time:128838ms step_avg:88.06ms +step:1464/1680 train_time:128927ms step_avg:88.06ms +step:1465/1680 train_time:129016ms step_avg:88.07ms +step:1466/1680 train_time:129105ms step_avg:88.07ms +step:1467/1680 train_time:129194ms step_avg:88.07ms +step:1468/1680 train_time:129282ms step_avg:88.07ms +step:1469/1680 train_time:129373ms step_avg:88.07ms +step:1470/1680 train_time:129462ms step_avg:88.07ms +step:1471/1680 train_time:129551ms step_avg:88.07ms +step:1472/1680 train_time:129640ms step_avg:88.07ms +step:1473/1680 train_time:129729ms step_avg:88.07ms +step:1474/1680 train_time:129818ms step_avg:88.07ms +step:1475/1680 train_time:129907ms step_avg:88.07ms +step:1476/1680 train_time:129996ms step_avg:88.07ms +step:1477/1680 train_time:130085ms step_avg:88.07ms +step:1478/1680 train_time:130175ms step_avg:88.07ms +step:1479/1680 train_time:130265ms step_avg:88.08ms +step:1480/1680 train_time:130354ms step_avg:88.08ms +step:1481/1680 train_time:130443ms step_avg:88.08ms +step:1482/1680 train_time:130532ms step_avg:88.08ms +step:1483/1680 train_time:130621ms step_avg:88.08ms +step:1484/1680 train_time:130710ms step_avg:88.08ms +step:1485/1680 train_time:130799ms step_avg:88.08ms +step:1486/1680 train_time:130888ms step_avg:88.08ms +step:1487/1680 train_time:130976ms step_avg:88.08ms +step:1488/1680 train_time:131065ms step_avg:88.08ms +step:1489/1680 train_time:131154ms step_avg:88.08ms +step:1490/1680 train_time:131243ms step_avg:88.08ms +step:1491/1680 train_time:131332ms step_avg:88.08ms +step:1492/1680 train_time:131422ms step_avg:88.08ms +step:1493/1680 train_time:131512ms step_avg:88.09ms +step:1494/1680 train_time:131600ms step_avg:88.09ms +step:1495/1680 train_time:131690ms step_avg:88.09ms +step:1496/1680 train_time:131779ms step_avg:88.09ms +step:1497/1680 train_time:131868ms step_avg:88.09ms +step:1498/1680 train_time:131956ms step_avg:88.09ms +step:1499/1680 train_time:132045ms step_avg:88.09ms +step:1500/1680 train_time:132133ms step_avg:88.09ms +step:1500/1680 val_loss:3.3122 train_time:132224ms step_avg:88.15ms +step:1501/1680 train_time:132243ms step_avg:88.10ms +step:1502/1680 train_time:132315ms step_avg:88.09ms +step:1503/1680 train_time:132409ms step_avg:88.10ms +step:1504/1680 train_time:132499ms step_avg:88.10ms +step:1505/1680 train_time:132587ms step_avg:88.10ms +step:1506/1680 train_time:132676ms step_avg:88.10ms +step:1507/1680 train_time:132764ms step_avg:88.10ms +step:1508/1680 train_time:132854ms step_avg:88.10ms +step:1509/1680 train_time:132941ms step_avg:88.10ms +step:1510/1680 train_time:133029ms step_avg:88.10ms +step:1511/1680 train_time:133117ms step_avg:88.10ms +step:1512/1680 train_time:133206ms step_avg:88.10ms +step:1513/1680 train_time:133296ms step_avg:88.10ms +step:1514/1680 train_time:133388ms step_avg:88.10ms +step:1515/1680 train_time:133478ms step_avg:88.10ms +step:1516/1680 train_time:133567ms step_avg:88.11ms +step:1517/1680 train_time:133657ms step_avg:88.11ms +step:1518/1680 train_time:133746ms step_avg:88.11ms +step:1519/1680 train_time:133835ms step_avg:88.11ms +step:1520/1680 train_time:133924ms step_avg:88.11ms +step:1521/1680 train_time:134012ms step_avg:88.11ms +step:1522/1680 train_time:134100ms step_avg:88.11ms +step:1523/1680 train_time:134188ms step_avg:88.11ms +step:1524/1680 train_time:134278ms step_avg:88.11ms +step:1525/1680 train_time:134367ms step_avg:88.11ms +step:1526/1680 train_time:134457ms step_avg:88.11ms +step:1527/1680 train_time:134547ms step_avg:88.11ms +step:1528/1680 train_time:134636ms step_avg:88.11ms +step:1529/1680 train_time:134726ms step_avg:88.11ms +step:1530/1680 train_time:134815ms step_avg:88.11ms +step:1531/1680 train_time:134904ms step_avg:88.11ms +step:1532/1680 train_time:134993ms step_avg:88.12ms +step:1533/1680 train_time:135081ms step_avg:88.12ms +step:1534/1680 train_time:135169ms step_avg:88.12ms +step:1535/1680 train_time:135259ms step_avg:88.12ms +step:1536/1680 train_time:135349ms step_avg:88.12ms +step:1537/1680 train_time:135439ms step_avg:88.12ms +step:1538/1680 train_time:135528ms step_avg:88.12ms +step:1539/1680 train_time:135618ms step_avg:88.12ms +step:1540/1680 train_time:135707ms step_avg:88.12ms +step:1541/1680 train_time:135796ms step_avg:88.12ms +step:1542/1680 train_time:135885ms step_avg:88.12ms +step:1543/1680 train_time:135973ms step_avg:88.12ms +step:1544/1680 train_time:136062ms step_avg:88.12ms +step:1545/1680 train_time:136151ms step_avg:88.12ms +step:1546/1680 train_time:136239ms step_avg:88.12ms +step:1547/1680 train_time:136329ms step_avg:88.12ms +step:1548/1680 train_time:136418ms step_avg:88.13ms +step:1549/1680 train_time:136506ms step_avg:88.13ms +step:1550/1680 train_time:136596ms step_avg:88.13ms +step:1551/1680 train_time:136685ms step_avg:88.13ms +step:1552/1680 train_time:136773ms step_avg:88.13ms +step:1553/1680 train_time:136862ms step_avg:88.13ms +step:1554/1680 train_time:136951ms step_avg:88.13ms +step:1555/1680 train_time:137041ms step_avg:88.13ms +step:1556/1680 train_time:137130ms step_avg:88.13ms +step:1557/1680 train_time:137218ms step_avg:88.13ms +step:1558/1680 train_time:137308ms step_avg:88.13ms +step:1559/1680 train_time:137397ms step_avg:88.13ms +step:1560/1680 train_time:137486ms step_avg:88.13ms +step:1561/1680 train_time:137575ms step_avg:88.13ms +step:1562/1680 train_time:137664ms step_avg:88.13ms +step:1563/1680 train_time:137754ms step_avg:88.13ms +step:1564/1680 train_time:137842ms step_avg:88.13ms +step:1565/1680 train_time:137931ms step_avg:88.13ms +step:1566/1680 train_time:138020ms step_avg:88.14ms +step:1567/1680 train_time:138110ms step_avg:88.14ms +step:1568/1680 train_time:138198ms step_avg:88.14ms +step:1569/1680 train_time:138286ms step_avg:88.14ms +step:1570/1680 train_time:138375ms step_avg:88.14ms +step:1571/1680 train_time:138464ms step_avg:88.14ms +step:1572/1680 train_time:138553ms step_avg:88.14ms +step:1573/1680 train_time:138642ms step_avg:88.14ms +step:1574/1680 train_time:138731ms step_avg:88.14ms +step:1575/1680 train_time:138819ms step_avg:88.14ms +step:1576/1680 train_time:138908ms step_avg:88.14ms +step:1577/1680 train_time:138997ms step_avg:88.14ms +step:1578/1680 train_time:139087ms step_avg:88.14ms +step:1579/1680 train_time:139176ms step_avg:88.14ms +step:1580/1680 train_time:139265ms step_avg:88.14ms +step:1581/1680 train_time:139354ms step_avg:88.14ms +step:1582/1680 train_time:139443ms step_avg:88.14ms +step:1583/1680 train_time:139531ms step_avg:88.14ms +step:1584/1680 train_time:139621ms step_avg:88.14ms +step:1585/1680 train_time:139710ms step_avg:88.14ms +step:1586/1680 train_time:139798ms step_avg:88.15ms +step:1587/1680 train_time:139888ms step_avg:88.15ms +step:1588/1680 train_time:139976ms step_avg:88.15ms +step:1589/1680 train_time:140065ms step_avg:88.15ms +step:1590/1680 train_time:140154ms step_avg:88.15ms +step:1591/1680 train_time:140244ms step_avg:88.15ms +step:1592/1680 train_time:140332ms step_avg:88.15ms +step:1593/1680 train_time:140421ms step_avg:88.15ms +step:1594/1680 train_time:140510ms step_avg:88.15ms +step:1595/1680 train_time:140599ms step_avg:88.15ms +step:1596/1680 train_time:140688ms step_avg:88.15ms +step:1597/1680 train_time:140777ms step_avg:88.15ms +step:1598/1680 train_time:140866ms step_avg:88.15ms +step:1599/1680 train_time:140955ms step_avg:88.15ms +step:1600/1680 train_time:141044ms step_avg:88.15ms +step:1601/1680 train_time:141133ms step_avg:88.15ms +step:1602/1680 train_time:141221ms step_avg:88.15ms +step:1603/1680 train_time:141310ms step_avg:88.15ms +step:1604/1680 train_time:141399ms step_avg:88.15ms +step:1605/1680 train_time:141487ms step_avg:88.15ms +step:1606/1680 train_time:141577ms step_avg:88.15ms +step:1607/1680 train_time:141665ms step_avg:88.16ms +step:1608/1680 train_time:141755ms step_avg:88.16ms +step:1609/1680 train_time:141844ms step_avg:88.16ms +step:1610/1680 train_time:141933ms step_avg:88.16ms +step:1611/1680 train_time:142022ms step_avg:88.16ms +step:1612/1680 train_time:142111ms step_avg:88.16ms +step:1613/1680 train_time:142200ms step_avg:88.16ms +step:1614/1680 train_time:142290ms step_avg:88.16ms +step:1615/1680 train_time:142378ms step_avg:88.16ms +step:1616/1680 train_time:142467ms step_avg:88.16ms +step:1617/1680 train_time:142556ms step_avg:88.16ms +step:1618/1680 train_time:142645ms step_avg:88.16ms +step:1619/1680 train_time:142735ms step_avg:88.16ms +step:1620/1680 train_time:142824ms step_avg:88.16ms +step:1621/1680 train_time:142915ms step_avg:88.16ms +step:1622/1680 train_time:143004ms step_avg:88.17ms +step:1623/1680 train_time:143094ms step_avg:88.17ms +step:1624/1680 train_time:143182ms step_avg:88.17ms +step:1625/1680 train_time:143272ms step_avg:88.17ms +step:1625/1680 val_loss:3.2882 train_time:143362ms step_avg:88.22ms +step:1626/1680 train_time:143380ms step_avg:88.18ms +step:1627/1680 train_time:143453ms step_avg:88.17ms +step:1628/1680 train_time:143545ms step_avg:88.17ms +step:1629/1680 train_time:143634ms step_avg:88.17ms +step:1630/1680 train_time:143722ms step_avg:88.17ms +step:1631/1680 train_time:143811ms step_avg:88.17ms +step:1632/1680 train_time:143898ms step_avg:88.17ms +step:1633/1680 train_time:143986ms step_avg:88.17ms +step:1634/1680 train_time:144074ms step_avg:88.17ms +step:1635/1680 train_time:144162ms step_avg:88.17ms +step:1636/1680 train_time:144251ms step_avg:88.17ms +step:1637/1680 train_time:144341ms step_avg:88.17ms +step:1638/1680 train_time:144433ms step_avg:88.18ms +step:1639/1680 train_time:144524ms step_avg:88.18ms +step:1640/1680 train_time:144613ms step_avg:88.18ms +step:1641/1680 train_time:144702ms step_avg:88.18ms +step:1642/1680 train_time:144791ms step_avg:88.18ms +step:1643/1680 train_time:144879ms step_avg:88.18ms +step:1644/1680 train_time:144967ms step_avg:88.18ms +step:1645/1680 train_time:145056ms step_avg:88.18ms +step:1646/1680 train_time:145144ms step_avg:88.18ms +step:1647/1680 train_time:145232ms step_avg:88.18ms +step:1648/1680 train_time:145321ms step_avg:88.18ms +step:1649/1680 train_time:145411ms step_avg:88.18ms +step:1650/1680 train_time:145501ms step_avg:88.18ms +step:1651/1680 train_time:145591ms step_avg:88.18ms +step:1652/1680 train_time:145680ms step_avg:88.18ms +step:1653/1680 train_time:145769ms step_avg:88.18ms +step:1654/1680 train_time:145858ms step_avg:88.18ms +step:1655/1680 train_time:145946ms step_avg:88.18ms +step:1656/1680 train_time:146034ms step_avg:88.18ms +step:1657/1680 train_time:146122ms step_avg:88.18ms +step:1658/1680 train_time:146211ms step_avg:88.19ms +step:1659/1680 train_time:146300ms step_avg:88.19ms +step:1660/1680 train_time:146390ms step_avg:88.19ms +step:1661/1680 train_time:146480ms step_avg:88.19ms +step:1662/1680 train_time:146569ms step_avg:88.19ms +step:1663/1680 train_time:146659ms step_avg:88.19ms +step:1664/1680 train_time:146748ms step_avg:88.19ms +step:1665/1680 train_time:146837ms step_avg:88.19ms +step:1666/1680 train_time:146926ms step_avg:88.19ms +step:1667/1680 train_time:147015ms step_avg:88.19ms +step:1668/1680 train_time:147103ms step_avg:88.19ms +step:1669/1680 train_time:147193ms step_avg:88.19ms +step:1670/1680 train_time:147281ms step_avg:88.19ms +step:1671/1680 train_time:147371ms step_avg:88.19ms +step:1672/1680 train_time:147460ms step_avg:88.19ms +step:1673/1680 train_time:147550ms step_avg:88.19ms +step:1674/1680 train_time:147639ms step_avg:88.20ms +step:1675/1680 train_time:147729ms step_avg:88.20ms +step:1676/1680 train_time:147817ms step_avg:88.20ms +step:1677/1680 train_time:147906ms step_avg:88.20ms +step:1678/1680 train_time:147995ms step_avg:88.20ms +step:1679/1680 train_time:148084ms step_avg:88.20ms +step:1680/1680 train_time:148172ms step_avg:88.20ms +step:1680/1680 val_loss:3.2774 train_time:148262ms step_avg:88.25ms +peak memory allocated: 30760 MiB reserved: 46214 MiB diff --git a/records/092725_BF16CE/27974127-6559-494e-9941-2d88325c2e52.txt b/records/092725_BF16CE/27974127-6559-494e-9941-2d88325c2e52.txt new file mode 100644 index 000000000..be6854e19 --- /dev/null +++ b/records/092725_BF16CE/27974127-6559-494e-9941-2d88325c2e52.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:11:36 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 170107 C /usr/bin/python 0MiB | +| 0 N/A N/A 170108 C /usr/bin/python 0MiB | +| 0 N/A N/A 170109 C /usr/bin/python 0MiB | +| 0 N/A N/A 170110 C /usr/bin/python 0MiB | +| 0 N/A N/A 170111 C /usr/bin/python 0MiB | +| 0 N/A N/A 170112 C /usr/bin/python 0MiB | +| 0 N/A N/A 170113 C /usr/bin/python 0MiB | +| 0 N/A N/A 170114 C /usr/bin/python 0MiB | +| 1 N/A N/A 170108 C /usr/bin/python 0MiB | +| 2 N/A N/A 170109 C /usr/bin/python 0MiB | +| 3 N/A N/A 170110 C /usr/bin/python 0MiB | +| 4 N/A N/A 170111 C /usr/bin/python 0MiB | +| 5 N/A N/A 170112 C /usr/bin/python 0MiB | +| 6 N/A N/A 170113 C /usr/bin/python 0MiB | +| 7 N/A N/A 170114 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.17ms +step:1/1680 train_time:145ms step_avg:144.63ms +step:2/1680 train_time:166ms step_avg:82.77ms +step:3/1680 train_time:228ms step_avg:76.13ms +step:4/1680 train_time:314ms step_avg:78.39ms +step:5/1680 train_time:399ms step_avg:79.86ms +step:6/1680 train_time:485ms step_avg:80.85ms +step:7/1680 train_time:572ms step_avg:81.70ms +step:8/1680 train_time:658ms step_avg:82.26ms +step:9/1680 train_time:744ms step_avg:82.67ms +step:10/1680 train_time:830ms step_avg:83.03ms +step:11/1680 train_time:916ms step_avg:83.30ms +step:12/1680 train_time:1003ms step_avg:83.61ms +step:13/1680 train_time:1094ms step_avg:84.17ms +step:14/1680 train_time:1185ms step_avg:84.61ms +step:15/1680 train_time:1272ms step_avg:84.83ms +step:16/1680 train_time:1360ms step_avg:84.99ms +step:17/1680 train_time:1446ms step_avg:85.08ms +step:18/1680 train_time:1533ms step_avg:85.16ms +step:19/1680 train_time:1620ms step_avg:85.26ms +step:20/1680 train_time:1706ms step_avg:85.31ms +step:21/1680 train_time:1792ms step_avg:85.35ms +step:22/1680 train_time:1879ms step_avg:85.41ms +step:23/1680 train_time:1966ms step_avg:85.49ms +step:24/1680 train_time:2055ms step_avg:85.62ms +step:25/1680 train_time:2144ms step_avg:85.76ms +step:26/1680 train_time:2233ms step_avg:85.87ms +step:27/1680 train_time:2321ms step_avg:85.95ms +step:28/1680 train_time:2408ms step_avg:86.01ms +step:29/1680 train_time:2495ms step_avg:86.04ms +step:30/1680 train_time:2582ms step_avg:86.08ms +step:31/1680 train_time:2670ms step_avg:86.12ms +step:32/1680 train_time:2757ms step_avg:86.14ms +step:33/1680 train_time:2843ms step_avg:86.16ms +step:34/1680 train_time:2930ms step_avg:86.18ms +step:35/1680 train_time:3017ms step_avg:86.20ms +step:36/1680 train_time:3105ms step_avg:86.24ms +step:37/1680 train_time:3192ms step_avg:86.28ms +step:38/1680 train_time:3282ms step_avg:86.37ms +step:39/1680 train_time:3370ms step_avg:86.40ms +step:40/1680 train_time:3457ms step_avg:86.42ms +step:41/1680 train_time:3543ms step_avg:86.42ms +step:42/1680 train_time:3631ms step_avg:86.46ms +step:43/1680 train_time:3718ms step_avg:86.48ms +step:44/1680 train_time:3805ms step_avg:86.48ms +step:45/1680 train_time:3892ms step_avg:86.48ms +step:46/1680 train_time:3979ms step_avg:86.50ms +step:47/1680 train_time:4066ms step_avg:86.52ms +step:48/1680 train_time:4153ms step_avg:86.53ms +step:49/1680 train_time:4242ms step_avg:86.57ms +step:50/1680 train_time:4330ms step_avg:86.61ms +step:51/1680 train_time:4418ms step_avg:86.63ms +step:52/1680 train_time:4506ms step_avg:86.65ms +step:53/1680 train_time:4593ms step_avg:86.66ms +step:54/1680 train_time:4680ms step_avg:86.66ms +step:55/1680 train_time:4767ms step_avg:86.67ms +step:56/1680 train_time:4854ms step_avg:86.68ms +step:57/1680 train_time:4941ms step_avg:86.68ms +step:58/1680 train_time:5028ms step_avg:86.69ms +step:59/1680 train_time:5116ms step_avg:86.70ms +step:60/1680 train_time:5203ms step_avg:86.72ms +step:61/1680 train_time:5291ms step_avg:86.74ms +step:62/1680 train_time:5379ms step_avg:86.76ms +step:63/1680 train_time:5466ms step_avg:86.77ms +step:64/1680 train_time:5554ms step_avg:86.78ms +step:65/1680 train_time:5642ms step_avg:86.79ms +step:66/1680 train_time:5728ms step_avg:86.79ms +step:67/1680 train_time:5815ms step_avg:86.80ms +step:68/1680 train_time:5902ms step_avg:86.79ms +step:69/1680 train_time:5989ms step_avg:86.80ms +step:70/1680 train_time:6076ms step_avg:86.80ms +step:71/1680 train_time:6164ms step_avg:86.82ms +step:72/1680 train_time:6252ms step_avg:86.83ms +step:73/1680 train_time:6339ms step_avg:86.84ms +step:74/1680 train_time:6426ms step_avg:86.84ms +step:75/1680 train_time:6515ms step_avg:86.86ms +step:76/1680 train_time:6602ms step_avg:86.87ms +step:77/1680 train_time:6688ms step_avg:86.86ms +step:78/1680 train_time:6776ms step_avg:86.87ms +step:79/1680 train_time:6862ms step_avg:86.87ms +step:80/1680 train_time:6950ms step_avg:86.87ms +step:81/1680 train_time:7036ms step_avg:86.87ms +step:82/1680 train_time:7124ms step_avg:86.87ms +step:83/1680 train_time:7211ms step_avg:86.88ms +step:84/1680 train_time:7298ms step_avg:86.88ms +step:85/1680 train_time:7385ms step_avg:86.89ms +step:86/1680 train_time:7473ms step_avg:86.90ms +step:87/1680 train_time:7561ms step_avg:86.90ms +step:88/1680 train_time:7648ms step_avg:86.91ms +step:89/1680 train_time:7736ms step_avg:86.92ms +step:90/1680 train_time:7822ms step_avg:86.92ms +step:91/1680 train_time:7910ms step_avg:86.92ms +step:92/1680 train_time:7997ms step_avg:86.93ms +step:93/1680 train_time:8084ms step_avg:86.93ms +step:94/1680 train_time:8171ms step_avg:86.93ms +step:95/1680 train_time:8258ms step_avg:86.93ms +step:96/1680 train_time:8345ms step_avg:86.92ms +step:97/1680 train_time:8432ms step_avg:86.93ms +step:98/1680 train_time:8519ms step_avg:86.93ms +step:99/1680 train_time:8607ms step_avg:86.94ms +step:100/1680 train_time:8694ms step_avg:86.94ms +step:101/1680 train_time:8781ms step_avg:86.94ms +step:102/1680 train_time:8868ms step_avg:86.94ms +step:103/1680 train_time:8955ms step_avg:86.94ms +step:104/1680 train_time:9042ms step_avg:86.94ms +step:105/1680 train_time:9129ms step_avg:86.94ms +step:106/1680 train_time:9215ms step_avg:86.94ms +step:107/1680 train_time:9302ms step_avg:86.94ms +step:108/1680 train_time:9389ms step_avg:86.94ms +step:109/1680 train_time:9476ms step_avg:86.94ms +step:110/1680 train_time:9565ms step_avg:86.95ms +step:111/1680 train_time:9652ms step_avg:86.95ms +step:112/1680 train_time:9739ms step_avg:86.95ms +step:113/1680 train_time:9825ms step_avg:86.95ms +step:114/1680 train_time:9912ms step_avg:86.95ms +step:115/1680 train_time:10000ms step_avg:86.96ms +step:116/1680 train_time:10087ms step_avg:86.95ms +step:117/1680 train_time:10174ms step_avg:86.96ms +step:118/1680 train_time:10261ms step_avg:86.95ms +step:119/1680 train_time:10348ms step_avg:86.96ms +step:120/1680 train_time:10435ms step_avg:86.96ms +step:121/1680 train_time:10522ms step_avg:86.96ms +step:122/1680 train_time:10609ms step_avg:86.96ms +step:123/1680 train_time:10696ms step_avg:86.96ms +step:124/1680 train_time:10784ms step_avg:86.97ms +step:125/1680 train_time:10871ms step_avg:86.97ms +step:125/1680 val_loss:4.3049 train_time:10960ms step_avg:87.68ms +step:126/1680 train_time:10984ms step_avg:87.17ms +step:127/1680 train_time:11050ms step_avg:87.00ms +step:128/1680 train_time:11145ms step_avg:87.07ms +step:129/1680 train_time:11239ms step_avg:87.12ms +step:130/1680 train_time:11327ms step_avg:87.13ms +step:131/1680 train_time:11414ms step_avg:87.13ms +step:132/1680 train_time:11500ms step_avg:87.12ms +step:133/1680 train_time:11586ms step_avg:87.11ms +step:134/1680 train_time:11672ms step_avg:87.11ms +step:135/1680 train_time:11758ms step_avg:87.10ms +step:136/1680 train_time:11844ms step_avg:87.09ms +step:137/1680 train_time:11930ms step_avg:87.08ms +step:138/1680 train_time:12017ms step_avg:87.08ms +step:139/1680 train_time:12105ms step_avg:87.09ms +step:140/1680 train_time:12194ms step_avg:87.10ms +step:141/1680 train_time:12283ms step_avg:87.12ms +step:142/1680 train_time:12371ms step_avg:87.12ms +step:143/1680 train_time:12459ms step_avg:87.12ms +step:144/1680 train_time:12547ms step_avg:87.13ms +step:145/1680 train_time:12633ms step_avg:87.12ms +step:146/1680 train_time:12720ms step_avg:87.12ms +step:147/1680 train_time:12806ms step_avg:87.12ms +step:148/1680 train_time:12892ms step_avg:87.11ms +step:149/1680 train_time:12979ms step_avg:87.11ms +step:150/1680 train_time:13066ms step_avg:87.11ms +step:151/1680 train_time:13154ms step_avg:87.11ms +step:152/1680 train_time:13242ms step_avg:87.12ms +step:153/1680 train_time:13330ms step_avg:87.12ms +step:154/1680 train_time:13417ms step_avg:87.12ms +step:155/1680 train_time:13504ms step_avg:87.13ms +step:156/1680 train_time:13591ms step_avg:87.12ms +step:157/1680 train_time:13677ms step_avg:87.12ms +step:158/1680 train_time:13765ms step_avg:87.12ms +step:159/1680 train_time:13851ms step_avg:87.11ms +step:160/1680 train_time:13938ms step_avg:87.11ms +step:161/1680 train_time:14025ms step_avg:87.11ms +step:162/1680 train_time:14112ms step_avg:87.11ms +step:163/1680 train_time:14199ms step_avg:87.11ms +step:164/1680 train_time:14287ms step_avg:87.11ms +step:165/1680 train_time:14374ms step_avg:87.12ms +step:166/1680 train_time:14461ms step_avg:87.12ms +step:167/1680 train_time:14548ms step_avg:87.12ms +step:168/1680 train_time:14635ms step_avg:87.11ms +step:169/1680 train_time:14721ms step_avg:87.11ms +step:170/1680 train_time:14808ms step_avg:87.10ms +step:171/1680 train_time:14895ms step_avg:87.10ms +step:172/1680 train_time:14982ms step_avg:87.10ms +step:173/1680 train_time:15068ms step_avg:87.10ms +step:174/1680 train_time:15155ms step_avg:87.10ms +step:175/1680 train_time:15244ms step_avg:87.11ms +step:176/1680 train_time:15331ms step_avg:87.11ms +step:177/1680 train_time:15418ms step_avg:87.11ms +step:178/1680 train_time:15506ms step_avg:87.11ms +step:179/1680 train_time:15593ms step_avg:87.11ms +step:180/1680 train_time:15680ms step_avg:87.11ms +step:181/1680 train_time:15766ms step_avg:87.11ms +step:182/1680 train_time:15853ms step_avg:87.10ms +step:183/1680 train_time:15940ms step_avg:87.10ms +step:184/1680 train_time:16026ms step_avg:87.10ms +step:185/1680 train_time:16114ms step_avg:87.10ms +step:186/1680 train_time:16201ms step_avg:87.10ms +step:187/1680 train_time:16289ms step_avg:87.11ms +step:188/1680 train_time:16376ms step_avg:87.11ms +step:189/1680 train_time:16463ms step_avg:87.11ms +step:190/1680 train_time:16550ms step_avg:87.10ms +step:191/1680 train_time:16637ms step_avg:87.10ms +step:192/1680 train_time:16724ms step_avg:87.10ms +step:193/1680 train_time:16811ms step_avg:87.10ms +step:194/1680 train_time:16898ms step_avg:87.10ms +step:195/1680 train_time:16984ms step_avg:87.10ms +step:196/1680 train_time:17071ms step_avg:87.10ms +step:197/1680 train_time:17159ms step_avg:87.10ms +step:198/1680 train_time:17246ms step_avg:87.10ms +step:199/1680 train_time:17334ms step_avg:87.10ms +step:200/1680 train_time:17421ms step_avg:87.10ms +step:201/1680 train_time:17508ms step_avg:87.10ms +step:202/1680 train_time:17595ms step_avg:87.10ms +step:203/1680 train_time:17683ms step_avg:87.11ms +step:204/1680 train_time:17769ms step_avg:87.10ms +step:205/1680 train_time:17856ms step_avg:87.10ms +step:206/1680 train_time:17943ms step_avg:87.10ms +step:207/1680 train_time:18030ms step_avg:87.10ms +step:208/1680 train_time:18116ms step_avg:87.09ms +step:209/1680 train_time:18203ms step_avg:87.09ms +step:210/1680 train_time:18289ms step_avg:87.09ms +step:211/1680 train_time:18377ms step_avg:87.09ms +step:212/1680 train_time:18464ms step_avg:87.10ms +step:213/1680 train_time:18551ms step_avg:87.09ms +step:214/1680 train_time:18638ms step_avg:87.09ms +step:215/1680 train_time:18725ms step_avg:87.09ms +step:216/1680 train_time:18813ms step_avg:87.10ms +step:217/1680 train_time:18900ms step_avg:87.10ms +step:218/1680 train_time:18986ms step_avg:87.09ms +step:219/1680 train_time:19073ms step_avg:87.09ms +step:220/1680 train_time:19159ms step_avg:87.09ms +step:221/1680 train_time:19246ms step_avg:87.09ms +step:222/1680 train_time:19333ms step_avg:87.09ms +step:223/1680 train_time:19420ms step_avg:87.09ms +step:224/1680 train_time:19507ms step_avg:87.09ms +step:225/1680 train_time:19594ms step_avg:87.08ms +step:226/1680 train_time:19681ms step_avg:87.09ms +step:227/1680 train_time:19768ms step_avg:87.09ms +step:228/1680 train_time:19856ms step_avg:87.09ms +step:229/1680 train_time:19943ms step_avg:87.09ms +step:230/1680 train_time:20030ms step_avg:87.09ms +step:231/1680 train_time:20117ms step_avg:87.09ms +step:232/1680 train_time:20205ms step_avg:87.09ms +step:233/1680 train_time:20291ms step_avg:87.09ms +step:234/1680 train_time:20379ms step_avg:87.09ms +step:235/1680 train_time:20466ms step_avg:87.09ms +step:236/1680 train_time:20553ms step_avg:87.09ms +step:237/1680 train_time:20640ms step_avg:87.09ms +step:238/1680 train_time:20727ms step_avg:87.09ms +step:239/1680 train_time:20814ms step_avg:87.09ms +step:240/1680 train_time:20901ms step_avg:87.09ms +step:241/1680 train_time:20988ms step_avg:87.09ms +step:242/1680 train_time:21075ms step_avg:87.09ms +step:243/1680 train_time:21162ms step_avg:87.09ms +step:244/1680 train_time:21248ms step_avg:87.08ms +step:245/1680 train_time:21336ms step_avg:87.08ms +step:246/1680 train_time:21423ms step_avg:87.08ms +step:247/1680 train_time:21509ms step_avg:87.08ms +step:248/1680 train_time:21596ms step_avg:87.08ms +step:249/1680 train_time:21684ms step_avg:87.08ms +step:250/1680 train_time:21771ms step_avg:87.08ms +step:250/1680 val_loss:3.9700 train_time:21860ms step_avg:87.44ms +step:251/1680 train_time:21879ms step_avg:87.17ms +step:252/1680 train_time:21948ms step_avg:87.10ms +step:253/1680 train_time:22039ms step_avg:87.11ms +step:254/1680 train_time:22127ms step_avg:87.12ms +step:255/1680 train_time:22214ms step_avg:87.11ms +step:256/1680 train_time:22300ms step_avg:87.11ms +step:257/1680 train_time:22386ms step_avg:87.10ms +step:258/1680 train_time:22472ms step_avg:87.10ms +step:259/1680 train_time:22558ms step_avg:87.10ms +step:260/1680 train_time:22644ms step_avg:87.09ms +step:261/1680 train_time:22730ms step_avg:87.09ms +step:262/1680 train_time:22817ms step_avg:87.09ms +step:263/1680 train_time:22905ms step_avg:87.09ms +step:264/1680 train_time:22994ms step_avg:87.10ms +step:265/1680 train_time:23083ms step_avg:87.11ms +step:266/1680 train_time:23170ms step_avg:87.11ms +step:267/1680 train_time:23257ms step_avg:87.10ms +step:268/1680 train_time:23343ms step_avg:87.10ms +step:269/1680 train_time:23430ms step_avg:87.10ms +step:270/1680 train_time:23517ms step_avg:87.10ms +step:271/1680 train_time:23603ms step_avg:87.09ms +step:272/1680 train_time:23688ms step_avg:87.09ms +step:273/1680 train_time:23775ms step_avg:87.09ms +step:274/1680 train_time:23863ms step_avg:87.09ms +step:275/1680 train_time:23950ms step_avg:87.09ms +step:276/1680 train_time:24039ms step_avg:87.10ms +step:277/1680 train_time:24127ms step_avg:87.10ms +step:278/1680 train_time:24214ms step_avg:87.10ms +step:279/1680 train_time:24300ms step_avg:87.10ms +step:280/1680 train_time:24387ms step_avg:87.10ms +step:281/1680 train_time:24473ms step_avg:87.09ms +step:282/1680 train_time:24560ms step_avg:87.09ms +step:283/1680 train_time:24646ms step_avg:87.09ms +step:284/1680 train_time:24733ms step_avg:87.09ms +step:285/1680 train_time:24819ms step_avg:87.08ms +step:286/1680 train_time:24907ms step_avg:87.09ms +step:287/1680 train_time:24995ms step_avg:87.09ms +step:288/1680 train_time:25082ms step_avg:87.09ms +step:289/1680 train_time:25170ms step_avg:87.09ms +step:290/1680 train_time:25257ms step_avg:87.09ms +step:291/1680 train_time:25344ms step_avg:87.09ms +step:292/1680 train_time:25431ms step_avg:87.09ms +step:293/1680 train_time:25517ms step_avg:87.09ms +step:294/1680 train_time:25605ms step_avg:87.09ms +step:295/1680 train_time:25691ms step_avg:87.09ms +step:296/1680 train_time:25778ms step_avg:87.09ms +step:297/1680 train_time:25865ms step_avg:87.09ms +step:298/1680 train_time:25952ms step_avg:87.09ms +step:299/1680 train_time:26040ms step_avg:87.09ms +step:300/1680 train_time:26127ms step_avg:87.09ms +step:301/1680 train_time:26213ms step_avg:87.09ms +step:302/1680 train_time:26300ms step_avg:87.09ms +step:303/1680 train_time:26388ms step_avg:87.09ms +step:304/1680 train_time:26474ms step_avg:87.09ms +step:305/1680 train_time:26561ms step_avg:87.09ms +step:306/1680 train_time:26648ms step_avg:87.09ms +step:307/1680 train_time:26735ms step_avg:87.09ms +step:308/1680 train_time:26822ms step_avg:87.08ms +step:309/1680 train_time:26909ms step_avg:87.08ms +step:310/1680 train_time:26996ms step_avg:87.08ms +step:311/1680 train_time:27083ms step_avg:87.08ms +step:312/1680 train_time:27170ms step_avg:87.08ms +step:313/1680 train_time:27257ms step_avg:87.08ms +step:314/1680 train_time:27344ms step_avg:87.08ms +step:315/1680 train_time:27431ms step_avg:87.08ms +step:316/1680 train_time:27518ms step_avg:87.08ms +step:317/1680 train_time:27605ms step_avg:87.08ms +step:318/1680 train_time:27691ms step_avg:87.08ms +step:319/1680 train_time:27778ms step_avg:87.08ms +step:320/1680 train_time:27865ms step_avg:87.08ms +step:321/1680 train_time:27951ms step_avg:87.08ms +step:322/1680 train_time:28039ms step_avg:87.08ms +step:323/1680 train_time:28125ms step_avg:87.08ms +step:324/1680 train_time:28213ms step_avg:87.08ms +step:325/1680 train_time:28300ms step_avg:87.08ms +step:326/1680 train_time:28387ms step_avg:87.08ms +step:327/1680 train_time:28473ms step_avg:87.07ms +step:328/1680 train_time:28561ms step_avg:87.08ms +step:329/1680 train_time:28648ms step_avg:87.07ms +step:330/1680 train_time:28734ms step_avg:87.07ms +step:331/1680 train_time:28821ms step_avg:87.07ms +step:332/1680 train_time:28909ms step_avg:87.08ms +step:333/1680 train_time:28996ms step_avg:87.07ms +step:334/1680 train_time:29083ms step_avg:87.07ms +step:335/1680 train_time:29170ms step_avg:87.07ms +step:336/1680 train_time:29257ms step_avg:87.08ms +step:337/1680 train_time:29344ms step_avg:87.08ms +step:338/1680 train_time:29431ms step_avg:87.07ms +step:339/1680 train_time:29518ms step_avg:87.07ms +step:340/1680 train_time:29605ms step_avg:87.07ms +step:341/1680 train_time:29692ms step_avg:87.07ms +step:342/1680 train_time:29779ms step_avg:87.07ms +step:343/1680 train_time:29867ms step_avg:87.08ms +step:344/1680 train_time:29953ms step_avg:87.07ms +step:345/1680 train_time:30041ms step_avg:87.07ms +step:346/1680 train_time:30128ms step_avg:87.07ms +step:347/1680 train_time:30215ms step_avg:87.07ms +step:348/1680 train_time:30302ms step_avg:87.07ms +step:349/1680 train_time:30389ms step_avg:87.07ms +step:350/1680 train_time:30475ms step_avg:87.07ms +step:351/1680 train_time:30563ms step_avg:87.07ms +step:352/1680 train_time:30650ms step_avg:87.07ms +step:353/1680 train_time:30737ms step_avg:87.07ms +step:354/1680 train_time:30824ms step_avg:87.07ms +step:355/1680 train_time:30911ms step_avg:87.07ms +step:356/1680 train_time:30998ms step_avg:87.07ms +step:357/1680 train_time:31086ms step_avg:87.07ms +step:358/1680 train_time:31172ms step_avg:87.07ms +step:359/1680 train_time:31259ms step_avg:87.07ms +step:360/1680 train_time:31346ms step_avg:87.07ms +step:361/1680 train_time:31433ms step_avg:87.07ms +step:362/1680 train_time:31520ms step_avg:87.07ms +step:363/1680 train_time:31607ms step_avg:87.07ms +step:364/1680 train_time:31694ms step_avg:87.07ms +step:365/1680 train_time:31781ms step_avg:87.07ms +step:366/1680 train_time:31869ms step_avg:87.07ms +step:367/1680 train_time:31955ms step_avg:87.07ms +step:368/1680 train_time:32043ms step_avg:87.07ms +step:369/1680 train_time:32130ms step_avg:87.07ms +step:370/1680 train_time:32217ms step_avg:87.07ms +step:371/1680 train_time:32304ms step_avg:87.07ms +step:372/1680 train_time:32391ms step_avg:87.07ms +step:373/1680 train_time:32478ms step_avg:87.07ms +step:374/1680 train_time:32565ms step_avg:87.07ms +step:375/1680 train_time:32652ms step_avg:87.07ms +step:375/1680 val_loss:3.8196 train_time:32740ms step_avg:87.31ms +step:376/1680 train_time:32760ms step_avg:87.13ms +step:377/1680 train_time:32831ms step_avg:87.08ms +step:378/1680 train_time:32919ms step_avg:87.09ms +step:379/1680 train_time:33007ms step_avg:87.09ms +step:380/1680 train_time:33093ms step_avg:87.09ms +step:381/1680 train_time:33179ms step_avg:87.09ms +step:382/1680 train_time:33266ms step_avg:87.08ms +step:383/1680 train_time:33351ms step_avg:87.08ms +step:384/1680 train_time:33437ms step_avg:87.08ms +step:385/1680 train_time:33524ms step_avg:87.07ms +step:386/1680 train_time:33609ms step_avg:87.07ms +step:387/1680 train_time:33698ms step_avg:87.07ms +step:388/1680 train_time:33786ms step_avg:87.08ms +step:389/1680 train_time:33875ms step_avg:87.08ms +step:390/1680 train_time:33963ms step_avg:87.08ms +step:391/1680 train_time:34049ms step_avg:87.08ms +step:392/1680 train_time:34136ms step_avg:87.08ms +step:393/1680 train_time:34222ms step_avg:87.08ms +step:394/1680 train_time:34310ms step_avg:87.08ms +step:395/1680 train_time:34396ms step_avg:87.08ms +step:396/1680 train_time:34482ms step_avg:87.08ms +step:397/1680 train_time:34570ms step_avg:87.08ms +step:398/1680 train_time:34656ms step_avg:87.08ms +step:399/1680 train_time:34745ms step_avg:87.08ms +step:400/1680 train_time:34833ms step_avg:87.08ms +step:401/1680 train_time:34920ms step_avg:87.08ms +step:402/1680 train_time:35007ms step_avg:87.08ms +step:403/1680 train_time:35094ms step_avg:87.08ms +step:404/1680 train_time:35181ms step_avg:87.08ms +step:405/1680 train_time:35268ms step_avg:87.08ms +step:406/1680 train_time:35355ms step_avg:87.08ms +step:407/1680 train_time:35441ms step_avg:87.08ms +step:408/1680 train_time:35528ms step_avg:87.08ms +step:409/1680 train_time:35614ms step_avg:87.08ms +step:410/1680 train_time:35702ms step_avg:87.08ms +step:411/1680 train_time:35789ms step_avg:87.08ms +step:412/1680 train_time:35877ms step_avg:87.08ms +step:413/1680 train_time:35964ms step_avg:87.08ms +step:414/1680 train_time:36051ms step_avg:87.08ms +step:415/1680 train_time:36138ms step_avg:87.08ms +step:416/1680 train_time:36225ms step_avg:87.08ms +step:417/1680 train_time:36312ms step_avg:87.08ms +step:418/1680 train_time:36398ms step_avg:87.08ms +step:419/1680 train_time:36485ms step_avg:87.08ms +step:420/1680 train_time:36571ms step_avg:87.07ms +step:421/1680 train_time:36658ms step_avg:87.07ms +step:422/1680 train_time:36746ms step_avg:87.07ms +step:423/1680 train_time:36834ms step_avg:87.08ms +step:424/1680 train_time:36921ms step_avg:87.08ms +step:425/1680 train_time:37009ms step_avg:87.08ms +step:426/1680 train_time:37096ms step_avg:87.08ms +step:427/1680 train_time:37182ms step_avg:87.08ms +step:428/1680 train_time:37269ms step_avg:87.08ms +step:429/1680 train_time:37356ms step_avg:87.08ms +step:430/1680 train_time:37443ms step_avg:87.08ms +step:431/1680 train_time:37531ms step_avg:87.08ms +step:432/1680 train_time:37618ms step_avg:87.08ms +step:433/1680 train_time:37704ms step_avg:87.08ms +step:434/1680 train_time:37791ms step_avg:87.08ms +step:435/1680 train_time:37878ms step_avg:87.08ms +step:436/1680 train_time:37966ms step_avg:87.08ms +step:437/1680 train_time:38054ms step_avg:87.08ms +step:438/1680 train_time:38141ms step_avg:87.08ms +step:439/1680 train_time:38228ms step_avg:87.08ms +step:440/1680 train_time:38314ms step_avg:87.08ms +step:441/1680 train_time:38401ms step_avg:87.08ms +step:442/1680 train_time:38488ms step_avg:87.08ms +step:443/1680 train_time:38575ms step_avg:87.08ms +step:444/1680 train_time:38662ms step_avg:87.08ms +step:445/1680 train_time:38749ms step_avg:87.08ms +step:446/1680 train_time:38835ms step_avg:87.07ms +step:447/1680 train_time:38923ms step_avg:87.08ms +step:448/1680 train_time:39011ms step_avg:87.08ms +step:449/1680 train_time:39098ms step_avg:87.08ms +step:450/1680 train_time:39185ms step_avg:87.08ms +step:451/1680 train_time:39271ms step_avg:87.08ms +step:452/1680 train_time:39358ms step_avg:87.08ms +step:453/1680 train_time:39446ms step_avg:87.08ms +step:454/1680 train_time:39533ms step_avg:87.08ms +step:455/1680 train_time:39620ms step_avg:87.08ms +step:456/1680 train_time:39707ms step_avg:87.08ms +step:457/1680 train_time:39794ms step_avg:87.08ms +step:458/1680 train_time:39881ms step_avg:87.08ms +step:459/1680 train_time:39969ms step_avg:87.08ms +step:460/1680 train_time:40056ms step_avg:87.08ms +step:461/1680 train_time:40144ms step_avg:87.08ms +step:462/1680 train_time:40231ms step_avg:87.08ms +step:463/1680 train_time:40318ms step_avg:87.08ms +step:464/1680 train_time:40405ms step_avg:87.08ms +step:465/1680 train_time:40492ms step_avg:87.08ms +step:466/1680 train_time:40579ms step_avg:87.08ms +step:467/1680 train_time:40666ms step_avg:87.08ms +step:468/1680 train_time:40753ms step_avg:87.08ms +step:469/1680 train_time:40840ms step_avg:87.08ms +step:470/1680 train_time:40928ms step_avg:87.08ms +step:471/1680 train_time:41014ms step_avg:87.08ms +step:472/1680 train_time:41101ms step_avg:87.08ms +step:473/1680 train_time:41188ms step_avg:87.08ms +step:474/1680 train_time:41275ms step_avg:87.08ms +step:475/1680 train_time:41362ms step_avg:87.08ms +step:476/1680 train_time:41450ms step_avg:87.08ms +step:477/1680 train_time:41537ms step_avg:87.08ms +step:478/1680 train_time:41623ms step_avg:87.08ms +step:479/1680 train_time:41710ms step_avg:87.08ms +step:480/1680 train_time:41797ms step_avg:87.08ms +step:481/1680 train_time:41884ms step_avg:87.08ms +step:482/1680 train_time:41971ms step_avg:87.08ms +step:483/1680 train_time:42058ms step_avg:87.08ms +step:484/1680 train_time:42146ms step_avg:87.08ms +step:485/1680 train_time:42233ms step_avg:87.08ms +step:486/1680 train_time:42320ms step_avg:87.08ms +step:487/1680 train_time:42407ms step_avg:87.08ms +step:488/1680 train_time:42494ms step_avg:87.08ms +step:489/1680 train_time:42580ms step_avg:87.08ms +step:490/1680 train_time:42668ms step_avg:87.08ms +step:491/1680 train_time:42754ms step_avg:87.08ms +step:492/1680 train_time:42841ms step_avg:87.08ms +step:493/1680 train_time:42929ms step_avg:87.08ms +step:494/1680 train_time:43015ms step_avg:87.07ms +step:495/1680 train_time:43102ms step_avg:87.07ms +step:496/1680 train_time:43189ms step_avg:87.08ms +step:497/1680 train_time:43276ms step_avg:87.07ms +step:498/1680 train_time:43363ms step_avg:87.07ms +step:499/1680 train_time:43450ms step_avg:87.08ms +step:500/1680 train_time:43537ms step_avg:87.07ms +step:500/1680 val_loss:3.7184 train_time:43626ms step_avg:87.25ms +step:501/1680 train_time:43645ms step_avg:87.12ms +step:502/1680 train_time:43715ms step_avg:87.08ms +step:503/1680 train_time:43805ms step_avg:87.09ms +step:504/1680 train_time:43892ms step_avg:87.09ms +step:505/1680 train_time:43979ms step_avg:87.09ms +step:506/1680 train_time:44065ms step_avg:87.08ms +step:507/1680 train_time:44151ms step_avg:87.08ms +step:508/1680 train_time:44237ms step_avg:87.08ms +step:509/1680 train_time:44323ms step_avg:87.08ms +step:510/1680 train_time:44410ms step_avg:87.08ms +step:511/1680 train_time:44497ms step_avg:87.08ms +step:512/1680 train_time:44585ms step_avg:87.08ms +step:513/1680 train_time:44674ms step_avg:87.08ms +step:514/1680 train_time:44762ms step_avg:87.09ms +step:515/1680 train_time:44850ms step_avg:87.09ms +step:516/1680 train_time:44937ms step_avg:87.09ms +step:517/1680 train_time:45023ms step_avg:87.09ms +step:518/1680 train_time:45109ms step_avg:87.08ms +step:519/1680 train_time:45195ms step_avg:87.08ms +step:520/1680 train_time:45282ms step_avg:87.08ms +step:521/1680 train_time:45369ms step_avg:87.08ms +step:522/1680 train_time:45455ms step_avg:87.08ms +step:523/1680 train_time:45542ms step_avg:87.08ms +step:524/1680 train_time:45630ms step_avg:87.08ms +step:525/1680 train_time:45718ms step_avg:87.08ms +step:526/1680 train_time:45806ms step_avg:87.08ms +step:527/1680 train_time:45894ms step_avg:87.09ms +step:528/1680 train_time:45981ms step_avg:87.09ms +step:529/1680 train_time:46068ms step_avg:87.08ms +step:530/1680 train_time:46154ms step_avg:87.08ms +step:531/1680 train_time:46241ms step_avg:87.08ms +step:532/1680 train_time:46327ms step_avg:87.08ms +step:533/1680 train_time:46414ms step_avg:87.08ms +step:534/1680 train_time:46501ms step_avg:87.08ms +step:535/1680 train_time:46588ms step_avg:87.08ms +step:536/1680 train_time:46675ms step_avg:87.08ms +step:537/1680 train_time:46763ms step_avg:87.08ms +step:538/1680 train_time:46851ms step_avg:87.08ms +step:539/1680 train_time:46938ms step_avg:87.08ms +step:540/1680 train_time:47025ms step_avg:87.08ms +step:541/1680 train_time:47113ms step_avg:87.09ms +step:542/1680 train_time:47200ms step_avg:87.08ms +step:543/1680 train_time:47287ms step_avg:87.08ms +step:544/1680 train_time:47373ms step_avg:87.08ms +step:545/1680 train_time:47460ms step_avg:87.08ms +step:546/1680 train_time:47547ms step_avg:87.08ms +step:547/1680 train_time:47634ms step_avg:87.08ms +step:548/1680 train_time:47721ms step_avg:87.08ms +step:549/1680 train_time:47810ms step_avg:87.08ms +step:550/1680 train_time:47898ms step_avg:87.09ms +step:551/1680 train_time:47987ms step_avg:87.09ms +step:552/1680 train_time:48074ms step_avg:87.09ms +step:553/1680 train_time:48162ms step_avg:87.09ms +step:554/1680 train_time:48250ms step_avg:87.09ms +step:555/1680 train_time:48338ms step_avg:87.10ms +step:556/1680 train_time:48426ms step_avg:87.10ms +step:557/1680 train_time:48514ms step_avg:87.10ms +step:558/1680 train_time:48603ms step_avg:87.10ms +step:559/1680 train_time:48691ms step_avg:87.10ms +step:560/1680 train_time:48779ms step_avg:87.11ms +step:561/1680 train_time:48868ms step_avg:87.11ms +step:562/1680 train_time:48956ms step_avg:87.11ms +step:563/1680 train_time:49044ms step_avg:87.11ms +step:564/1680 train_time:49132ms step_avg:87.11ms +step:565/1680 train_time:49220ms step_avg:87.11ms +step:566/1680 train_time:49308ms step_avg:87.12ms +step:567/1680 train_time:49396ms step_avg:87.12ms +step:568/1680 train_time:49484ms step_avg:87.12ms +step:569/1680 train_time:49572ms step_avg:87.12ms +step:570/1680 train_time:49661ms step_avg:87.12ms +step:571/1680 train_time:49748ms step_avg:87.13ms +step:572/1680 train_time:49837ms step_avg:87.13ms +step:573/1680 train_time:49926ms step_avg:87.13ms +step:574/1680 train_time:50014ms step_avg:87.13ms +step:575/1680 train_time:50103ms step_avg:87.14ms +step:576/1680 train_time:50192ms step_avg:87.14ms +step:577/1680 train_time:50280ms step_avg:87.14ms +step:578/1680 train_time:50368ms step_avg:87.14ms +step:579/1680 train_time:50456ms step_avg:87.14ms +step:580/1680 train_time:50544ms step_avg:87.14ms +step:581/1680 train_time:50633ms step_avg:87.15ms +step:582/1680 train_time:50721ms step_avg:87.15ms +step:583/1680 train_time:50809ms step_avg:87.15ms +step:584/1680 train_time:50898ms step_avg:87.15ms +step:585/1680 train_time:50986ms step_avg:87.16ms +step:586/1680 train_time:51074ms step_avg:87.16ms +step:587/1680 train_time:51162ms step_avg:87.16ms +step:588/1680 train_time:51250ms step_avg:87.16ms +step:589/1680 train_time:51338ms step_avg:87.16ms +step:590/1680 train_time:51426ms step_avg:87.16ms +step:591/1680 train_time:51514ms step_avg:87.16ms +step:592/1680 train_time:51601ms step_avg:87.16ms +step:593/1680 train_time:51690ms step_avg:87.17ms +step:594/1680 train_time:51777ms step_avg:87.17ms +step:595/1680 train_time:51867ms step_avg:87.17ms +step:596/1680 train_time:51955ms step_avg:87.17ms +step:597/1680 train_time:52043ms step_avg:87.17ms +step:598/1680 train_time:52132ms step_avg:87.18ms +step:599/1680 train_time:52220ms step_avg:87.18ms +step:600/1680 train_time:52308ms step_avg:87.18ms +step:601/1680 train_time:52396ms step_avg:87.18ms +step:602/1680 train_time:52484ms step_avg:87.18ms +step:603/1680 train_time:52571ms step_avg:87.18ms +step:604/1680 train_time:52660ms step_avg:87.18ms +step:605/1680 train_time:52748ms step_avg:87.19ms +step:606/1680 train_time:52836ms step_avg:87.19ms +step:607/1680 train_time:52924ms step_avg:87.19ms +step:608/1680 train_time:53013ms step_avg:87.19ms +step:609/1680 train_time:53101ms step_avg:87.19ms +step:610/1680 train_time:53191ms step_avg:87.20ms +step:611/1680 train_time:53278ms step_avg:87.20ms +step:612/1680 train_time:53366ms step_avg:87.20ms +step:613/1680 train_time:53454ms step_avg:87.20ms +step:614/1680 train_time:53542ms step_avg:87.20ms +step:615/1680 train_time:53630ms step_avg:87.20ms +step:616/1680 train_time:53718ms step_avg:87.21ms +step:617/1680 train_time:53807ms step_avg:87.21ms +step:618/1680 train_time:53895ms step_avg:87.21ms +step:619/1680 train_time:53983ms step_avg:87.21ms +step:620/1680 train_time:54072ms step_avg:87.21ms +step:621/1680 train_time:54160ms step_avg:87.21ms +step:622/1680 train_time:54248ms step_avg:87.22ms +step:623/1680 train_time:54336ms step_avg:87.22ms +step:624/1680 train_time:54424ms step_avg:87.22ms +step:625/1680 train_time:54511ms step_avg:87.22ms +step:625/1680 val_loss:3.6180 train_time:54601ms step_avg:87.36ms +step:626/1680 train_time:54626ms step_avg:87.26ms +step:627/1680 train_time:54692ms step_avg:87.23ms +step:628/1680 train_time:54781ms step_avg:87.23ms +step:629/1680 train_time:54872ms step_avg:87.24ms +step:630/1680 train_time:54962ms step_avg:87.24ms +step:631/1680 train_time:55049ms step_avg:87.24ms +step:632/1680 train_time:55135ms step_avg:87.24ms +step:633/1680 train_time:55223ms step_avg:87.24ms +step:634/1680 train_time:55310ms step_avg:87.24ms +step:635/1680 train_time:55397ms step_avg:87.24ms +step:636/1680 train_time:55484ms step_avg:87.24ms +step:637/1680 train_time:55575ms step_avg:87.25ms +step:638/1680 train_time:55665ms step_avg:87.25ms +step:639/1680 train_time:55753ms step_avg:87.25ms +step:640/1680 train_time:55842ms step_avg:87.25ms +step:641/1680 train_time:55933ms step_avg:87.26ms +step:642/1680 train_time:56021ms step_avg:87.26ms +step:643/1680 train_time:56109ms step_avg:87.26ms +step:644/1680 train_time:56196ms step_avg:87.26ms +step:645/1680 train_time:56284ms step_avg:87.26ms +step:646/1680 train_time:56371ms step_avg:87.26ms +step:647/1680 train_time:56459ms step_avg:87.26ms +step:648/1680 train_time:56547ms step_avg:87.26ms +step:649/1680 train_time:56636ms step_avg:87.27ms +step:650/1680 train_time:56724ms step_avg:87.27ms +step:651/1680 train_time:56812ms step_avg:87.27ms +step:652/1680 train_time:56902ms step_avg:87.27ms +step:653/1680 train_time:56991ms step_avg:87.28ms +step:654/1680 train_time:57079ms step_avg:87.28ms +step:655/1680 train_time:57167ms step_avg:87.28ms +step:656/1680 train_time:57254ms step_avg:87.28ms +step:657/1680 train_time:57341ms step_avg:87.28ms +step:658/1680 train_time:57429ms step_avg:87.28ms +step:659/1680 train_time:57517ms step_avg:87.28ms +step:660/1680 train_time:57606ms step_avg:87.28ms +step:661/1680 train_time:57694ms step_avg:87.28ms +step:662/1680 train_time:57783ms step_avg:87.29ms +step:663/1680 train_time:57872ms step_avg:87.29ms +step:664/1680 train_time:57961ms step_avg:87.29ms +step:665/1680 train_time:58049ms step_avg:87.29ms +step:666/1680 train_time:58136ms step_avg:87.29ms +step:667/1680 train_time:58224ms step_avg:87.29ms +step:668/1680 train_time:58311ms step_avg:87.29ms +step:669/1680 train_time:58399ms step_avg:87.29ms +step:670/1680 train_time:58488ms step_avg:87.30ms +step:671/1680 train_time:58575ms step_avg:87.30ms +step:672/1680 train_time:58663ms step_avg:87.30ms +step:673/1680 train_time:58751ms step_avg:87.30ms +step:674/1680 train_time:58839ms step_avg:87.30ms +step:675/1680 train_time:58928ms step_avg:87.30ms +step:676/1680 train_time:59016ms step_avg:87.30ms +step:677/1680 train_time:59105ms step_avg:87.30ms +step:678/1680 train_time:59193ms step_avg:87.31ms +step:679/1680 train_time:59281ms step_avg:87.31ms +step:680/1680 train_time:59369ms step_avg:87.31ms +step:681/1680 train_time:59456ms step_avg:87.31ms +step:682/1680 train_time:59545ms step_avg:87.31ms +step:683/1680 train_time:59633ms step_avg:87.31ms +step:684/1680 train_time:59721ms step_avg:87.31ms +step:685/1680 train_time:59810ms step_avg:87.31ms +step:686/1680 train_time:59898ms step_avg:87.32ms +step:687/1680 train_time:59987ms step_avg:87.32ms +step:688/1680 train_time:60076ms step_avg:87.32ms +step:689/1680 train_time:60164ms step_avg:87.32ms +step:690/1680 train_time:60252ms step_avg:87.32ms +step:691/1680 train_time:60340ms step_avg:87.32ms +step:692/1680 train_time:60428ms step_avg:87.32ms +step:693/1680 train_time:60516ms step_avg:87.32ms +step:694/1680 train_time:60605ms step_avg:87.33ms +step:695/1680 train_time:60692ms step_avg:87.33ms +step:696/1680 train_time:60781ms step_avg:87.33ms +step:697/1680 train_time:60870ms step_avg:87.33ms +step:698/1680 train_time:60959ms step_avg:87.33ms +step:699/1680 train_time:61048ms step_avg:87.34ms +step:700/1680 train_time:61135ms step_avg:87.34ms +step:701/1680 train_time:61224ms step_avg:87.34ms +step:702/1680 train_time:61311ms step_avg:87.34ms +step:703/1680 train_time:61399ms step_avg:87.34ms +step:704/1680 train_time:61487ms step_avg:87.34ms +step:705/1680 train_time:61574ms step_avg:87.34ms +step:706/1680 train_time:61662ms step_avg:87.34ms +step:707/1680 train_time:61751ms step_avg:87.34ms +step:708/1680 train_time:61839ms step_avg:87.34ms +step:709/1680 train_time:61928ms step_avg:87.35ms +step:710/1680 train_time:62017ms step_avg:87.35ms +step:711/1680 train_time:62105ms step_avg:87.35ms +step:712/1680 train_time:62193ms step_avg:87.35ms +step:713/1680 train_time:62281ms step_avg:87.35ms +step:714/1680 train_time:62369ms step_avg:87.35ms +step:715/1680 train_time:62457ms step_avg:87.35ms +step:716/1680 train_time:62544ms step_avg:87.35ms +step:717/1680 train_time:62632ms step_avg:87.35ms +step:718/1680 train_time:62720ms step_avg:87.35ms +step:719/1680 train_time:62808ms step_avg:87.35ms +step:720/1680 train_time:62897ms step_avg:87.36ms +step:721/1680 train_time:62985ms step_avg:87.36ms +step:722/1680 train_time:63073ms step_avg:87.36ms +step:723/1680 train_time:63162ms step_avg:87.36ms +step:724/1680 train_time:63250ms step_avg:87.36ms +step:725/1680 train_time:63338ms step_avg:87.36ms +step:726/1680 train_time:63426ms step_avg:87.36ms +step:727/1680 train_time:63514ms step_avg:87.36ms +step:728/1680 train_time:63602ms step_avg:87.37ms +step:729/1680 train_time:63691ms step_avg:87.37ms +step:730/1680 train_time:63779ms step_avg:87.37ms +step:731/1680 train_time:63868ms step_avg:87.37ms +step:732/1680 train_time:63956ms step_avg:87.37ms +step:733/1680 train_time:64044ms step_avg:87.37ms +step:734/1680 train_time:64133ms step_avg:87.37ms +step:735/1680 train_time:64220ms step_avg:87.37ms +step:736/1680 train_time:64309ms step_avg:87.38ms +step:737/1680 train_time:64397ms step_avg:87.38ms +step:738/1680 train_time:64485ms step_avg:87.38ms +step:739/1680 train_time:64573ms step_avg:87.38ms +step:740/1680 train_time:64661ms step_avg:87.38ms +step:741/1680 train_time:64749ms step_avg:87.38ms +step:742/1680 train_time:64837ms step_avg:87.38ms +step:743/1680 train_time:64926ms step_avg:87.38ms +step:744/1680 train_time:65013ms step_avg:87.38ms +step:745/1680 train_time:65102ms step_avg:87.38ms +step:746/1680 train_time:65190ms step_avg:87.39ms +step:747/1680 train_time:65278ms step_avg:87.39ms +step:748/1680 train_time:65366ms step_avg:87.39ms +step:749/1680 train_time:65454ms step_avg:87.39ms +step:750/1680 train_time:65542ms step_avg:87.39ms +step:750/1680 val_loss:3.5654 train_time:65632ms step_avg:87.51ms +step:751/1680 train_time:65653ms step_avg:87.42ms +step:752/1680 train_time:65722ms step_avg:87.40ms +step:753/1680 train_time:65811ms step_avg:87.40ms +step:754/1680 train_time:65899ms step_avg:87.40ms +step:755/1680 train_time:65986ms step_avg:87.40ms +step:756/1680 train_time:66074ms step_avg:87.40ms +step:757/1680 train_time:66162ms step_avg:87.40ms +step:758/1680 train_time:66249ms step_avg:87.40ms +step:759/1680 train_time:66336ms step_avg:87.40ms +step:760/1680 train_time:66425ms step_avg:87.40ms +step:761/1680 train_time:66513ms step_avg:87.40ms +step:762/1680 train_time:66603ms step_avg:87.41ms +step:763/1680 train_time:66693ms step_avg:87.41ms +step:764/1680 train_time:66783ms step_avg:87.41ms +step:765/1680 train_time:66871ms step_avg:87.41ms +step:766/1680 train_time:66958ms step_avg:87.41ms +step:767/1680 train_time:67047ms step_avg:87.41ms +step:768/1680 train_time:67135ms step_avg:87.42ms +step:769/1680 train_time:67222ms step_avg:87.42ms +step:770/1680 train_time:67310ms step_avg:87.42ms +step:771/1680 train_time:67398ms step_avg:87.42ms +step:772/1680 train_time:67485ms step_avg:87.42ms +step:773/1680 train_time:67574ms step_avg:87.42ms +step:774/1680 train_time:67664ms step_avg:87.42ms +step:775/1680 train_time:67754ms step_avg:87.42ms +step:776/1680 train_time:67842ms step_avg:87.43ms +step:777/1680 train_time:67930ms step_avg:87.43ms +step:778/1680 train_time:68018ms step_avg:87.43ms +step:779/1680 train_time:68106ms step_avg:87.43ms +step:780/1680 train_time:68194ms step_avg:87.43ms +step:781/1680 train_time:68282ms step_avg:87.43ms +step:782/1680 train_time:68370ms step_avg:87.43ms +step:783/1680 train_time:68459ms step_avg:87.43ms +step:784/1680 train_time:68548ms step_avg:87.43ms +step:785/1680 train_time:68637ms step_avg:87.44ms +step:786/1680 train_time:68726ms step_avg:87.44ms +step:787/1680 train_time:68815ms step_avg:87.44ms +step:788/1680 train_time:68903ms step_avg:87.44ms +step:789/1680 train_time:68992ms step_avg:87.44ms +step:790/1680 train_time:69080ms step_avg:87.44ms +step:791/1680 train_time:69168ms step_avg:87.44ms +step:792/1680 train_time:69255ms step_avg:87.44ms +step:793/1680 train_time:69343ms step_avg:87.44ms +step:794/1680 train_time:69432ms step_avg:87.45ms +step:795/1680 train_time:69519ms step_avg:87.45ms +step:796/1680 train_time:69608ms step_avg:87.45ms +step:797/1680 train_time:69696ms step_avg:87.45ms +step:798/1680 train_time:69785ms step_avg:87.45ms +step:799/1680 train_time:69873ms step_avg:87.45ms +step:800/1680 train_time:69961ms step_avg:87.45ms +step:801/1680 train_time:70049ms step_avg:87.45ms +step:802/1680 train_time:70137ms step_avg:87.45ms +step:803/1680 train_time:70225ms step_avg:87.45ms +step:804/1680 train_time:70313ms step_avg:87.45ms +step:805/1680 train_time:70400ms step_avg:87.45ms +step:806/1680 train_time:70488ms step_avg:87.45ms +step:807/1680 train_time:70576ms step_avg:87.46ms +step:808/1680 train_time:70665ms step_avg:87.46ms +step:809/1680 train_time:70753ms step_avg:87.46ms +step:810/1680 train_time:70842ms step_avg:87.46ms +step:811/1680 train_time:70930ms step_avg:87.46ms +step:812/1680 train_time:71018ms step_avg:87.46ms +step:813/1680 train_time:71107ms step_avg:87.46ms +step:814/1680 train_time:71195ms step_avg:87.46ms +step:815/1680 train_time:71283ms step_avg:87.46ms +step:816/1680 train_time:71371ms step_avg:87.46ms +step:817/1680 train_time:71459ms step_avg:87.46ms +step:818/1680 train_time:71547ms step_avg:87.47ms +step:819/1680 train_time:71636ms step_avg:87.47ms +step:820/1680 train_time:71724ms step_avg:87.47ms +step:821/1680 train_time:71812ms step_avg:87.47ms +step:822/1680 train_time:71900ms step_avg:87.47ms +step:823/1680 train_time:71988ms step_avg:87.47ms +step:824/1680 train_time:72076ms step_avg:87.47ms +step:825/1680 train_time:72165ms step_avg:87.47ms +step:826/1680 train_time:72253ms step_avg:87.47ms +step:827/1680 train_time:72341ms step_avg:87.47ms +step:828/1680 train_time:72429ms step_avg:87.47ms +step:829/1680 train_time:72516ms step_avg:87.47ms +step:830/1680 train_time:72605ms step_avg:87.48ms +step:831/1680 train_time:72693ms step_avg:87.48ms +step:832/1680 train_time:72781ms step_avg:87.48ms +step:833/1680 train_time:72869ms step_avg:87.48ms +step:834/1680 train_time:72957ms step_avg:87.48ms +step:835/1680 train_time:73045ms step_avg:87.48ms +step:836/1680 train_time:73134ms step_avg:87.48ms +step:837/1680 train_time:73222ms step_avg:87.48ms +step:838/1680 train_time:73310ms step_avg:87.48ms +step:839/1680 train_time:73398ms step_avg:87.48ms +step:840/1680 train_time:73486ms step_avg:87.48ms +step:841/1680 train_time:73575ms step_avg:87.48ms +step:842/1680 train_time:73663ms step_avg:87.49ms +step:843/1680 train_time:73753ms step_avg:87.49ms +step:844/1680 train_time:73842ms step_avg:87.49ms +step:845/1680 train_time:73930ms step_avg:87.49ms +step:846/1680 train_time:74018ms step_avg:87.49ms +step:847/1680 train_time:74105ms step_avg:87.49ms +step:848/1680 train_time:74195ms step_avg:87.49ms +step:849/1680 train_time:74284ms step_avg:87.50ms +step:850/1680 train_time:74371ms step_avg:87.50ms +step:851/1680 train_time:74459ms step_avg:87.50ms +step:852/1680 train_time:74547ms step_avg:87.50ms +step:853/1680 train_time:74636ms step_avg:87.50ms +step:854/1680 train_time:74725ms step_avg:87.50ms +step:855/1680 train_time:74813ms step_avg:87.50ms +step:856/1680 train_time:74902ms step_avg:87.50ms +step:857/1680 train_time:74990ms step_avg:87.50ms +step:858/1680 train_time:75078ms step_avg:87.50ms +step:859/1680 train_time:75166ms step_avg:87.50ms +step:860/1680 train_time:75255ms step_avg:87.51ms +step:861/1680 train_time:75344ms step_avg:87.51ms +step:862/1680 train_time:75432ms step_avg:87.51ms +step:863/1680 train_time:75520ms step_avg:87.51ms +step:864/1680 train_time:75608ms step_avg:87.51ms +step:865/1680 train_time:75696ms step_avg:87.51ms +step:866/1680 train_time:75785ms step_avg:87.51ms +step:867/1680 train_time:75873ms step_avg:87.51ms +step:868/1680 train_time:75961ms step_avg:87.51ms +step:869/1680 train_time:76049ms step_avg:87.51ms +step:870/1680 train_time:76137ms step_avg:87.51ms +step:871/1680 train_time:76225ms step_avg:87.51ms +step:872/1680 train_time:76313ms step_avg:87.52ms +step:873/1680 train_time:76401ms step_avg:87.52ms +step:874/1680 train_time:76490ms step_avg:87.52ms +step:875/1680 train_time:76577ms step_avg:87.52ms +step:875/1680 val_loss:3.5189 train_time:76666ms step_avg:87.62ms +step:876/1680 train_time:76687ms step_avg:87.54ms +step:877/1680 train_time:76759ms step_avg:87.52ms +step:878/1680 train_time:76848ms step_avg:87.53ms +step:879/1680 train_time:76936ms step_avg:87.53ms +step:880/1680 train_time:77023ms step_avg:87.53ms +step:881/1680 train_time:77110ms step_avg:87.53ms +step:882/1680 train_time:77197ms step_avg:87.53ms +step:883/1680 train_time:77285ms step_avg:87.53ms +step:884/1680 train_time:77372ms step_avg:87.53ms +step:885/1680 train_time:77460ms step_avg:87.53ms +step:886/1680 train_time:77548ms step_avg:87.53ms +step:887/1680 train_time:77639ms step_avg:87.53ms +step:888/1680 train_time:77729ms step_avg:87.53ms +step:889/1680 train_time:77818ms step_avg:87.53ms +step:890/1680 train_time:77907ms step_avg:87.54ms +step:891/1680 train_time:77995ms step_avg:87.54ms +step:892/1680 train_time:78083ms step_avg:87.54ms +step:893/1680 train_time:78170ms step_avg:87.54ms +step:894/1680 train_time:78258ms step_avg:87.54ms +step:895/1680 train_time:78344ms step_avg:87.54ms +step:896/1680 train_time:78431ms step_avg:87.53ms +step:897/1680 train_time:78519ms step_avg:87.54ms +step:898/1680 train_time:78608ms step_avg:87.54ms +step:899/1680 train_time:78697ms step_avg:87.54ms +step:900/1680 train_time:78786ms step_avg:87.54ms +step:901/1680 train_time:78876ms step_avg:87.54ms +step:902/1680 train_time:78964ms step_avg:87.54ms +step:903/1680 train_time:79053ms step_avg:87.54ms +step:904/1680 train_time:79140ms step_avg:87.54ms +step:905/1680 train_time:79227ms step_avg:87.54ms +step:906/1680 train_time:79315ms step_avg:87.54ms +step:907/1680 train_time:79402ms step_avg:87.54ms +step:908/1680 train_time:79490ms step_avg:87.54ms +step:909/1680 train_time:79579ms step_avg:87.55ms +step:910/1680 train_time:79668ms step_avg:87.55ms +step:911/1680 train_time:79756ms step_avg:87.55ms +step:912/1680 train_time:79845ms step_avg:87.55ms +step:913/1680 train_time:79934ms step_avg:87.55ms +step:914/1680 train_time:80022ms step_avg:87.55ms +step:915/1680 train_time:80109ms step_avg:87.55ms +step:916/1680 train_time:80197ms step_avg:87.55ms +step:917/1680 train_time:80285ms step_avg:87.55ms +step:918/1680 train_time:80373ms step_avg:87.55ms +step:919/1680 train_time:80461ms step_avg:87.55ms +step:920/1680 train_time:80549ms step_avg:87.55ms +step:921/1680 train_time:80637ms step_avg:87.55ms +step:922/1680 train_time:80725ms step_avg:87.55ms +step:923/1680 train_time:80815ms step_avg:87.56ms +step:924/1680 train_time:80903ms step_avg:87.56ms +step:925/1680 train_time:80991ms step_avg:87.56ms +step:926/1680 train_time:81080ms step_avg:87.56ms +step:927/1680 train_time:81167ms step_avg:87.56ms +step:928/1680 train_time:81256ms step_avg:87.56ms +step:929/1680 train_time:81344ms step_avg:87.56ms +step:930/1680 train_time:81432ms step_avg:87.56ms +step:931/1680 train_time:81519ms step_avg:87.56ms +step:932/1680 train_time:81607ms step_avg:87.56ms +step:933/1680 train_time:81696ms step_avg:87.56ms +step:934/1680 train_time:81784ms step_avg:87.56ms +step:935/1680 train_time:81874ms step_avg:87.57ms +step:936/1680 train_time:81962ms step_avg:87.57ms +step:937/1680 train_time:82049ms step_avg:87.57ms +step:938/1680 train_time:82137ms step_avg:87.57ms +step:939/1680 train_time:82225ms step_avg:87.57ms +step:940/1680 train_time:82314ms step_avg:87.57ms +step:941/1680 train_time:82402ms step_avg:87.57ms +step:942/1680 train_time:82491ms step_avg:87.57ms +step:943/1680 train_time:82579ms step_avg:87.57ms +step:944/1680 train_time:82667ms step_avg:87.57ms +step:945/1680 train_time:82756ms step_avg:87.57ms +step:946/1680 train_time:82844ms step_avg:87.57ms +step:947/1680 train_time:82932ms step_avg:87.57ms +step:948/1680 train_time:83020ms step_avg:87.57ms +step:949/1680 train_time:83108ms step_avg:87.57ms +step:950/1680 train_time:83195ms step_avg:87.57ms +step:951/1680 train_time:83283ms step_avg:87.57ms +step:952/1680 train_time:83372ms step_avg:87.58ms +step:953/1680 train_time:83460ms step_avg:87.58ms +step:954/1680 train_time:83548ms step_avg:87.58ms +step:955/1680 train_time:83636ms step_avg:87.58ms +step:956/1680 train_time:83724ms step_avg:87.58ms +step:957/1680 train_time:83813ms step_avg:87.58ms +step:958/1680 train_time:83901ms step_avg:87.58ms +step:959/1680 train_time:83990ms step_avg:87.58ms +step:960/1680 train_time:84078ms step_avg:87.58ms +step:961/1680 train_time:84166ms step_avg:87.58ms +step:962/1680 train_time:84254ms step_avg:87.58ms +step:963/1680 train_time:84341ms step_avg:87.58ms +step:964/1680 train_time:84429ms step_avg:87.58ms +step:965/1680 train_time:84518ms step_avg:87.58ms +step:966/1680 train_time:84606ms step_avg:87.58ms +step:967/1680 train_time:84695ms step_avg:87.59ms +step:968/1680 train_time:84783ms step_avg:87.59ms +step:969/1680 train_time:84872ms step_avg:87.59ms +step:970/1680 train_time:84960ms step_avg:87.59ms +step:971/1680 train_time:85048ms step_avg:87.59ms +step:972/1680 train_time:85137ms step_avg:87.59ms +step:973/1680 train_time:85224ms step_avg:87.59ms +step:974/1680 train_time:85313ms step_avg:87.59ms +step:975/1680 train_time:85401ms step_avg:87.59ms +step:976/1680 train_time:85488ms step_avg:87.59ms +step:977/1680 train_time:85577ms step_avg:87.59ms +step:978/1680 train_time:85666ms step_avg:87.59ms +step:979/1680 train_time:85754ms step_avg:87.59ms +step:980/1680 train_time:85842ms step_avg:87.59ms +step:981/1680 train_time:85931ms step_avg:87.60ms +step:982/1680 train_time:86019ms step_avg:87.60ms +step:983/1680 train_time:86107ms step_avg:87.60ms +step:984/1680 train_time:86196ms step_avg:87.60ms +step:985/1680 train_time:86284ms step_avg:87.60ms +step:986/1680 train_time:86371ms step_avg:87.60ms +step:987/1680 train_time:86459ms step_avg:87.60ms +step:988/1680 train_time:86548ms step_avg:87.60ms +step:989/1680 train_time:86636ms step_avg:87.60ms +step:990/1680 train_time:86726ms step_avg:87.60ms +step:991/1680 train_time:86814ms step_avg:87.60ms +step:992/1680 train_time:86902ms step_avg:87.60ms +step:993/1680 train_time:86991ms step_avg:87.60ms +step:994/1680 train_time:87079ms step_avg:87.60ms +step:995/1680 train_time:87167ms step_avg:87.61ms +step:996/1680 train_time:87256ms step_avg:87.61ms +step:997/1680 train_time:87344ms step_avg:87.61ms +step:998/1680 train_time:87433ms step_avg:87.61ms +step:999/1680 train_time:87520ms step_avg:87.61ms +step:1000/1680 train_time:87608ms step_avg:87.61ms +step:1000/1680 val_loss:3.4696 train_time:87698ms step_avg:87.70ms +step:1001/1680 train_time:87718ms step_avg:87.63ms +step:1002/1680 train_time:87790ms step_avg:87.61ms +step:1003/1680 train_time:87882ms step_avg:87.62ms +step:1004/1680 train_time:87971ms step_avg:87.62ms +step:1005/1680 train_time:88059ms step_avg:87.62ms +step:1006/1680 train_time:88147ms step_avg:87.62ms +step:1007/1680 train_time:88234ms step_avg:87.62ms +step:1008/1680 train_time:88322ms step_avg:87.62ms +step:1009/1680 train_time:88409ms step_avg:87.62ms +step:1010/1680 train_time:88497ms step_avg:87.62ms +step:1011/1680 train_time:88584ms step_avg:87.62ms +step:1012/1680 train_time:88674ms step_avg:87.62ms +step:1013/1680 train_time:88764ms step_avg:87.62ms +step:1014/1680 train_time:88853ms step_avg:87.63ms +step:1015/1680 train_time:88943ms step_avg:87.63ms +step:1016/1680 train_time:89031ms step_avg:87.63ms +step:1017/1680 train_time:89120ms step_avg:87.63ms +step:1018/1680 train_time:89207ms step_avg:87.63ms +step:1019/1680 train_time:89294ms step_avg:87.63ms +step:1020/1680 train_time:89382ms step_avg:87.63ms +step:1021/1680 train_time:89469ms step_avg:87.63ms +step:1022/1680 train_time:89557ms step_avg:87.63ms +step:1023/1680 train_time:89646ms step_avg:87.63ms +step:1024/1680 train_time:89735ms step_avg:87.63ms +step:1025/1680 train_time:89825ms step_avg:87.63ms +step:1026/1680 train_time:89913ms step_avg:87.63ms +step:1027/1680 train_time:90001ms step_avg:87.64ms +step:1028/1680 train_time:90090ms step_avg:87.64ms +step:1029/1680 train_time:90178ms step_avg:87.64ms +step:1030/1680 train_time:90266ms step_avg:87.64ms +step:1031/1680 train_time:90353ms step_avg:87.64ms +step:1032/1680 train_time:90441ms step_avg:87.64ms +step:1033/1680 train_time:90528ms step_avg:87.64ms +step:1034/1680 train_time:90616ms step_avg:87.64ms +step:1035/1680 train_time:90704ms step_avg:87.64ms +step:1036/1680 train_time:90792ms step_avg:87.64ms +step:1037/1680 train_time:90881ms step_avg:87.64ms +step:1038/1680 train_time:90970ms step_avg:87.64ms +step:1039/1680 train_time:91058ms step_avg:87.64ms +step:1040/1680 train_time:91147ms step_avg:87.64ms +step:1041/1680 train_time:91234ms step_avg:87.64ms +step:1042/1680 train_time:91322ms step_avg:87.64ms +step:1043/1680 train_time:91410ms step_avg:87.64ms +step:1044/1680 train_time:91497ms step_avg:87.64ms +step:1045/1680 train_time:91585ms step_avg:87.64ms +step:1046/1680 train_time:91673ms step_avg:87.64ms +step:1047/1680 train_time:91761ms step_avg:87.64ms +step:1048/1680 train_time:91850ms step_avg:87.64ms +step:1049/1680 train_time:91938ms step_avg:87.64ms +step:1050/1680 train_time:92027ms step_avg:87.64ms +step:1051/1680 train_time:92116ms step_avg:87.65ms +step:1052/1680 train_time:92204ms step_avg:87.65ms +step:1053/1680 train_time:92292ms step_avg:87.65ms +step:1054/1680 train_time:92380ms step_avg:87.65ms +step:1055/1680 train_time:92468ms step_avg:87.65ms +step:1056/1680 train_time:92556ms step_avg:87.65ms +step:1057/1680 train_time:92645ms step_avg:87.65ms +step:1058/1680 train_time:92733ms step_avg:87.65ms +step:1059/1680 train_time:92821ms step_avg:87.65ms +step:1060/1680 train_time:92909ms step_avg:87.65ms +step:1061/1680 train_time:92997ms step_avg:87.65ms +step:1062/1680 train_time:93085ms step_avg:87.65ms +step:1063/1680 train_time:93174ms step_avg:87.65ms +step:1064/1680 train_time:93262ms step_avg:87.65ms +step:1065/1680 train_time:93350ms step_avg:87.65ms +step:1066/1680 train_time:93439ms step_avg:87.65ms +step:1067/1680 train_time:93527ms step_avg:87.65ms +step:1068/1680 train_time:93616ms step_avg:87.66ms +step:1069/1680 train_time:93705ms step_avg:87.66ms +step:1070/1680 train_time:93792ms step_avg:87.66ms +step:1071/1680 train_time:93881ms step_avg:87.66ms +step:1072/1680 train_time:93969ms step_avg:87.66ms +step:1073/1680 train_time:94057ms step_avg:87.66ms +step:1074/1680 train_time:94145ms step_avg:87.66ms +step:1075/1680 train_time:94233ms step_avg:87.66ms +step:1076/1680 train_time:94322ms step_avg:87.66ms +step:1077/1680 train_time:94410ms step_avg:87.66ms +step:1078/1680 train_time:94498ms step_avg:87.66ms +step:1079/1680 train_time:94586ms step_avg:87.66ms +step:1080/1680 train_time:94674ms step_avg:87.66ms +step:1081/1680 train_time:94762ms step_avg:87.66ms +step:1082/1680 train_time:94851ms step_avg:87.66ms +step:1083/1680 train_time:94939ms step_avg:87.66ms +step:1084/1680 train_time:95028ms step_avg:87.66ms +step:1085/1680 train_time:95117ms step_avg:87.67ms +step:1086/1680 train_time:95205ms step_avg:87.67ms +step:1087/1680 train_time:95293ms step_avg:87.67ms +step:1088/1680 train_time:95382ms step_avg:87.67ms +step:1089/1680 train_time:95470ms step_avg:87.67ms +step:1090/1680 train_time:95558ms step_avg:87.67ms +step:1091/1680 train_time:95646ms step_avg:87.67ms +step:1092/1680 train_time:95734ms step_avg:87.67ms +step:1093/1680 train_time:95822ms step_avg:87.67ms +step:1094/1680 train_time:95910ms step_avg:87.67ms +step:1095/1680 train_time:95998ms step_avg:87.67ms +step:1096/1680 train_time:96087ms step_avg:87.67ms +step:1097/1680 train_time:96177ms step_avg:87.67ms +step:1098/1680 train_time:96266ms step_avg:87.67ms +step:1099/1680 train_time:96354ms step_avg:87.67ms +step:1100/1680 train_time:96443ms step_avg:87.68ms +step:1101/1680 train_time:96531ms step_avg:87.68ms +step:1102/1680 train_time:96621ms step_avg:87.68ms +step:1103/1680 train_time:96709ms step_avg:87.68ms +step:1104/1680 train_time:96799ms step_avg:87.68ms +step:1105/1680 train_time:96888ms step_avg:87.68ms +step:1106/1680 train_time:96978ms step_avg:87.68ms +step:1107/1680 train_time:97067ms step_avg:87.68ms +step:1108/1680 train_time:97156ms step_avg:87.69ms +step:1109/1680 train_time:97245ms step_avg:87.69ms +step:1110/1680 train_time:97333ms step_avg:87.69ms +step:1111/1680 train_time:97422ms step_avg:87.69ms +step:1112/1680 train_time:97511ms step_avg:87.69ms +step:1113/1680 train_time:97600ms step_avg:87.69ms +step:1114/1680 train_time:97688ms step_avg:87.69ms +step:1115/1680 train_time:97777ms step_avg:87.69ms +step:1116/1680 train_time:97867ms step_avg:87.69ms +step:1117/1680 train_time:97957ms step_avg:87.70ms +step:1118/1680 train_time:98045ms step_avg:87.70ms +step:1119/1680 train_time:98135ms step_avg:87.70ms +step:1120/1680 train_time:98223ms step_avg:87.70ms +step:1121/1680 train_time:98312ms step_avg:87.70ms +step:1122/1680 train_time:98401ms step_avg:87.70ms +step:1123/1680 train_time:98490ms step_avg:87.70ms +step:1124/1680 train_time:98578ms step_avg:87.70ms +step:1125/1680 train_time:98667ms step_avg:87.70ms +step:1125/1680 val_loss:3.4157 train_time:98757ms step_avg:87.78ms +step:1126/1680 train_time:98777ms step_avg:87.72ms +step:1127/1680 train_time:98845ms step_avg:87.71ms +step:1128/1680 train_time:98936ms step_avg:87.71ms +step:1129/1680 train_time:99026ms step_avg:87.71ms +step:1130/1680 train_time:99115ms step_avg:87.71ms +step:1131/1680 train_time:99203ms step_avg:87.71ms +step:1132/1680 train_time:99291ms step_avg:87.71ms +step:1133/1680 train_time:99379ms step_avg:87.71ms +step:1134/1680 train_time:99467ms step_avg:87.71ms +step:1135/1680 train_time:99557ms step_avg:87.71ms +step:1136/1680 train_time:99646ms step_avg:87.72ms +step:1137/1680 train_time:99740ms step_avg:87.72ms +step:1138/1680 train_time:99830ms step_avg:87.72ms +step:1139/1680 train_time:99919ms step_avg:87.73ms +step:1140/1680 train_time:100009ms step_avg:87.73ms +step:1141/1680 train_time:100098ms step_avg:87.73ms +step:1142/1680 train_time:100187ms step_avg:87.73ms +step:1143/1680 train_time:100276ms step_avg:87.73ms +step:1144/1680 train_time:100365ms step_avg:87.73ms +step:1145/1680 train_time:100454ms step_avg:87.73ms +step:1146/1680 train_time:100541ms step_avg:87.73ms +step:1147/1680 train_time:100630ms step_avg:87.73ms +step:1148/1680 train_time:100720ms step_avg:87.74ms +step:1149/1680 train_time:100809ms step_avg:87.74ms +step:1150/1680 train_time:100899ms step_avg:87.74ms +step:1151/1680 train_time:100989ms step_avg:87.74ms +step:1152/1680 train_time:101077ms step_avg:87.74ms +step:1153/1680 train_time:101166ms step_avg:87.74ms +step:1154/1680 train_time:101255ms step_avg:87.74ms +step:1155/1680 train_time:101343ms step_avg:87.74ms +step:1156/1680 train_time:101432ms step_avg:87.74ms +step:1157/1680 train_time:101520ms step_avg:87.74ms +step:1158/1680 train_time:101609ms step_avg:87.75ms +step:1159/1680 train_time:101699ms step_avg:87.75ms +step:1160/1680 train_time:101789ms step_avg:87.75ms +step:1161/1680 train_time:101878ms step_avg:87.75ms +step:1162/1680 train_time:101967ms step_avg:87.75ms +step:1163/1680 train_time:102057ms step_avg:87.75ms +step:1164/1680 train_time:102146ms step_avg:87.75ms +step:1165/1680 train_time:102235ms step_avg:87.76ms +step:1166/1680 train_time:102323ms step_avg:87.76ms +step:1167/1680 train_time:102412ms step_avg:87.76ms +step:1168/1680 train_time:102500ms step_avg:87.76ms +step:1169/1680 train_time:102589ms step_avg:87.76ms +step:1170/1680 train_time:102678ms step_avg:87.76ms +step:1171/1680 train_time:102767ms step_avg:87.76ms +step:1172/1680 train_time:102856ms step_avg:87.76ms +step:1173/1680 train_time:102944ms step_avg:87.76ms +step:1174/1680 train_time:103033ms step_avg:87.76ms +step:1175/1680 train_time:103122ms step_avg:87.76ms +step:1176/1680 train_time:103211ms step_avg:87.76ms +step:1177/1680 train_time:103300ms step_avg:87.77ms +step:1178/1680 train_time:103389ms step_avg:87.77ms +step:1179/1680 train_time:103477ms step_avg:87.77ms +step:1180/1680 train_time:103566ms step_avg:87.77ms +step:1181/1680 train_time:103655ms step_avg:87.77ms +step:1182/1680 train_time:103744ms step_avg:87.77ms +step:1183/1680 train_time:103834ms step_avg:87.77ms +step:1184/1680 train_time:103922ms step_avg:87.77ms +step:1185/1680 train_time:104011ms step_avg:87.77ms +step:1186/1680 train_time:104101ms step_avg:87.77ms +step:1187/1680 train_time:104189ms step_avg:87.77ms +step:1188/1680 train_time:104278ms step_avg:87.78ms +step:1189/1680 train_time:104367ms step_avg:87.78ms +step:1190/1680 train_time:104456ms step_avg:87.78ms +step:1191/1680 train_time:104545ms step_avg:87.78ms +step:1192/1680 train_time:104636ms step_avg:87.78ms +step:1193/1680 train_time:104725ms step_avg:87.78ms +step:1194/1680 train_time:104814ms step_avg:87.78ms +step:1195/1680 train_time:104902ms step_avg:87.78ms +step:1196/1680 train_time:104992ms step_avg:87.79ms +step:1197/1680 train_time:105080ms step_avg:87.79ms +step:1198/1680 train_time:105169ms step_avg:87.79ms +step:1199/1680 train_time:105258ms step_avg:87.79ms +step:1200/1680 train_time:105346ms step_avg:87.79ms +step:1201/1680 train_time:105434ms step_avg:87.79ms +step:1202/1680 train_time:105524ms step_avg:87.79ms +step:1203/1680 train_time:105612ms step_avg:87.79ms +step:1204/1680 train_time:105701ms step_avg:87.79ms +step:1205/1680 train_time:105790ms step_avg:87.79ms +step:1206/1680 train_time:105879ms step_avg:87.79ms +step:1207/1680 train_time:105968ms step_avg:87.79ms +step:1208/1680 train_time:106058ms step_avg:87.80ms +step:1209/1680 train_time:106147ms step_avg:87.80ms +step:1210/1680 train_time:106235ms step_avg:87.80ms +step:1211/1680 train_time:106324ms step_avg:87.80ms +step:1212/1680 train_time:106413ms step_avg:87.80ms +step:1213/1680 train_time:106502ms step_avg:87.80ms +step:1214/1680 train_time:106591ms step_avg:87.80ms +step:1215/1680 train_time:106680ms step_avg:87.80ms +step:1216/1680 train_time:106770ms step_avg:87.80ms +step:1217/1680 train_time:106858ms step_avg:87.80ms +step:1218/1680 train_time:106947ms step_avg:87.81ms +step:1219/1680 train_time:107037ms step_avg:87.81ms +step:1220/1680 train_time:107126ms step_avg:87.81ms +step:1221/1680 train_time:107215ms step_avg:87.81ms +step:1222/1680 train_time:107304ms step_avg:87.81ms +step:1223/1680 train_time:107392ms step_avg:87.81ms +step:1224/1680 train_time:107481ms step_avg:87.81ms +step:1225/1680 train_time:107570ms step_avg:87.81ms +step:1226/1680 train_time:107659ms step_avg:87.81ms +step:1227/1680 train_time:107748ms step_avg:87.81ms +step:1228/1680 train_time:107837ms step_avg:87.82ms +step:1229/1680 train_time:107926ms step_avg:87.82ms +step:1230/1680 train_time:108015ms step_avg:87.82ms +step:1231/1680 train_time:108104ms step_avg:87.82ms +step:1232/1680 train_time:108194ms step_avg:87.82ms +step:1233/1680 train_time:108282ms step_avg:87.82ms +step:1234/1680 train_time:108372ms step_avg:87.82ms +step:1235/1680 train_time:108460ms step_avg:87.82ms +step:1236/1680 train_time:108550ms step_avg:87.82ms +step:1237/1680 train_time:108639ms step_avg:87.82ms +step:1238/1680 train_time:108727ms step_avg:87.82ms +step:1239/1680 train_time:108816ms step_avg:87.83ms +step:1240/1680 train_time:108905ms step_avg:87.83ms +step:1241/1680 train_time:108994ms step_avg:87.83ms +step:1242/1680 train_time:109083ms step_avg:87.83ms +step:1243/1680 train_time:109172ms step_avg:87.83ms +step:1244/1680 train_time:109261ms step_avg:87.83ms +step:1245/1680 train_time:109350ms step_avg:87.83ms +step:1246/1680 train_time:109439ms step_avg:87.83ms +step:1247/1680 train_time:109528ms step_avg:87.83ms +step:1248/1680 train_time:109617ms step_avg:87.83ms +step:1249/1680 train_time:109705ms step_avg:87.83ms +step:1250/1680 train_time:109794ms step_avg:87.84ms +step:1250/1680 val_loss:3.3777 train_time:109884ms step_avg:87.91ms +step:1251/1680 train_time:109904ms step_avg:87.85ms +step:1252/1680 train_time:109977ms step_avg:87.84ms +step:1253/1680 train_time:110069ms step_avg:87.84ms +step:1254/1680 train_time:110159ms step_avg:87.85ms +step:1255/1680 train_time:110248ms step_avg:87.85ms +step:1256/1680 train_time:110337ms step_avg:87.85ms +step:1257/1680 train_time:110425ms step_avg:87.85ms +step:1258/1680 train_time:110513ms step_avg:87.85ms +step:1259/1680 train_time:110602ms step_avg:87.85ms +step:1260/1680 train_time:110689ms step_avg:87.85ms +step:1261/1680 train_time:110777ms step_avg:87.85ms +step:1262/1680 train_time:110866ms step_avg:87.85ms +step:1263/1680 train_time:110957ms step_avg:87.85ms +step:1264/1680 train_time:111048ms step_avg:87.85ms +step:1265/1680 train_time:111138ms step_avg:87.86ms +step:1266/1680 train_time:111227ms step_avg:87.86ms +step:1267/1680 train_time:111316ms step_avg:87.86ms +step:1268/1680 train_time:111405ms step_avg:87.86ms +step:1269/1680 train_time:111493ms step_avg:87.86ms +step:1270/1680 train_time:111581ms step_avg:87.86ms +step:1271/1680 train_time:111669ms step_avg:87.86ms +step:1272/1680 train_time:111757ms step_avg:87.86ms +step:1273/1680 train_time:111847ms step_avg:87.86ms +step:1274/1680 train_time:111938ms step_avg:87.86ms +step:1275/1680 train_time:112029ms step_avg:87.87ms +step:1276/1680 train_time:112119ms step_avg:87.87ms +step:1277/1680 train_time:112208ms step_avg:87.87ms +step:1278/1680 train_time:112298ms step_avg:87.87ms +step:1279/1680 train_time:112386ms step_avg:87.87ms +step:1280/1680 train_time:112474ms step_avg:87.87ms +step:1281/1680 train_time:112562ms step_avg:87.87ms +step:1282/1680 train_time:112651ms step_avg:87.87ms +step:1283/1680 train_time:112740ms step_avg:87.87ms +step:1284/1680 train_time:112829ms step_avg:87.87ms +step:1285/1680 train_time:112918ms step_avg:87.87ms +step:1286/1680 train_time:113009ms step_avg:87.88ms +step:1287/1680 train_time:113098ms step_avg:87.88ms +step:1288/1680 train_time:113187ms step_avg:87.88ms +step:1289/1680 train_time:113276ms step_avg:87.88ms +step:1290/1680 train_time:113365ms step_avg:87.88ms +step:1291/1680 train_time:113454ms step_avg:87.88ms +step:1292/1680 train_time:113542ms step_avg:87.88ms +step:1293/1680 train_time:113630ms step_avg:87.88ms +step:1294/1680 train_time:113719ms step_avg:87.88ms +step:1295/1680 train_time:113809ms step_avg:87.88ms +step:1296/1680 train_time:113898ms step_avg:87.88ms +step:1297/1680 train_time:113987ms step_avg:87.88ms +step:1298/1680 train_time:114076ms step_avg:87.89ms +step:1299/1680 train_time:114165ms step_avg:87.89ms +step:1300/1680 train_time:114254ms step_avg:87.89ms +step:1301/1680 train_time:114343ms step_avg:87.89ms +step:1302/1680 train_time:114432ms step_avg:87.89ms +step:1303/1680 train_time:114522ms step_avg:87.89ms +step:1304/1680 train_time:114611ms step_avg:87.89ms +step:1305/1680 train_time:114700ms step_avg:87.89ms +step:1306/1680 train_time:114789ms step_avg:87.89ms +step:1307/1680 train_time:114877ms step_avg:87.89ms +step:1308/1680 train_time:114966ms step_avg:87.89ms +step:1309/1680 train_time:115056ms step_avg:87.90ms +step:1310/1680 train_time:115145ms step_avg:87.90ms +step:1311/1680 train_time:115234ms step_avg:87.90ms +step:1312/1680 train_time:115324ms step_avg:87.90ms +step:1313/1680 train_time:115414ms step_avg:87.90ms +step:1314/1680 train_time:115503ms step_avg:87.90ms +step:1315/1680 train_time:115591ms step_avg:87.90ms +step:1316/1680 train_time:115680ms step_avg:87.90ms +step:1317/1680 train_time:115769ms step_avg:87.90ms +step:1318/1680 train_time:115858ms step_avg:87.90ms +step:1319/1680 train_time:115947ms step_avg:87.90ms +step:1320/1680 train_time:116036ms step_avg:87.91ms +step:1321/1680 train_time:116125ms step_avg:87.91ms +step:1322/1680 train_time:116214ms step_avg:87.91ms +step:1323/1680 train_time:116304ms step_avg:87.91ms +step:1324/1680 train_time:116393ms step_avg:87.91ms +step:1325/1680 train_time:116482ms step_avg:87.91ms +step:1326/1680 train_time:116571ms step_avg:87.91ms +step:1327/1680 train_time:116659ms step_avg:87.91ms +step:1328/1680 train_time:116749ms step_avg:87.91ms +step:1329/1680 train_time:116839ms step_avg:87.91ms +step:1330/1680 train_time:116927ms step_avg:87.91ms +step:1331/1680 train_time:117017ms step_avg:87.92ms +step:1332/1680 train_time:117106ms step_avg:87.92ms +step:1333/1680 train_time:117195ms step_avg:87.92ms +step:1334/1680 train_time:117284ms step_avg:87.92ms +step:1335/1680 train_time:117372ms step_avg:87.92ms +step:1336/1680 train_time:117462ms step_avg:87.92ms +step:1337/1680 train_time:117551ms step_avg:87.92ms +step:1338/1680 train_time:117640ms step_avg:87.92ms +step:1339/1680 train_time:117729ms step_avg:87.92ms +step:1340/1680 train_time:117819ms step_avg:87.92ms +step:1341/1680 train_time:117908ms step_avg:87.93ms +step:1342/1680 train_time:117996ms step_avg:87.93ms +step:1343/1680 train_time:118085ms step_avg:87.93ms +step:1344/1680 train_time:118174ms step_avg:87.93ms +step:1345/1680 train_time:118262ms step_avg:87.93ms +step:1346/1680 train_time:118351ms step_avg:87.93ms +step:1347/1680 train_time:118440ms step_avg:87.93ms +step:1348/1680 train_time:118529ms step_avg:87.93ms +step:1349/1680 train_time:118619ms step_avg:87.93ms +step:1350/1680 train_time:118708ms step_avg:87.93ms +step:1351/1680 train_time:118798ms step_avg:87.93ms +step:1352/1680 train_time:118887ms step_avg:87.93ms +step:1353/1680 train_time:118977ms step_avg:87.94ms +step:1354/1680 train_time:119066ms step_avg:87.94ms +step:1355/1680 train_time:119154ms step_avg:87.94ms +step:1356/1680 train_time:119243ms step_avg:87.94ms +step:1357/1680 train_time:119331ms step_avg:87.94ms +step:1358/1680 train_time:119419ms step_avg:87.94ms +step:1359/1680 train_time:119509ms step_avg:87.94ms +step:1360/1680 train_time:119598ms step_avg:87.94ms +step:1361/1680 train_time:119686ms step_avg:87.94ms +step:1362/1680 train_time:119775ms step_avg:87.94ms +step:1363/1680 train_time:119864ms step_avg:87.94ms +step:1364/1680 train_time:119953ms step_avg:87.94ms +step:1365/1680 train_time:120042ms step_avg:87.94ms +step:1366/1680 train_time:120131ms step_avg:87.94ms +step:1367/1680 train_time:120220ms step_avg:87.94ms +step:1368/1680 train_time:120310ms step_avg:87.95ms +step:1369/1680 train_time:120399ms step_avg:87.95ms +step:1370/1680 train_time:120487ms step_avg:87.95ms +step:1371/1680 train_time:120576ms step_avg:87.95ms +step:1372/1680 train_time:120665ms step_avg:87.95ms +step:1373/1680 train_time:120754ms step_avg:87.95ms +step:1374/1680 train_time:120843ms step_avg:87.95ms +step:1375/1680 train_time:120932ms step_avg:87.95ms +step:1375/1680 val_loss:3.3423 train_time:121023ms step_avg:88.02ms +step:1376/1680 train_time:121041ms step_avg:87.97ms +step:1377/1680 train_time:121113ms step_avg:87.95ms +step:1378/1680 train_time:121206ms step_avg:87.96ms +step:1379/1680 train_time:121296ms step_avg:87.96ms +step:1380/1680 train_time:121384ms step_avg:87.96ms +step:1381/1680 train_time:121472ms step_avg:87.96ms +step:1382/1680 train_time:121560ms step_avg:87.96ms +step:1383/1680 train_time:121648ms step_avg:87.96ms +step:1384/1680 train_time:121736ms step_avg:87.96ms +step:1385/1680 train_time:121825ms step_avg:87.96ms +step:1386/1680 train_time:121913ms step_avg:87.96ms +step:1387/1680 train_time:122002ms step_avg:87.96ms +step:1388/1680 train_time:122093ms step_avg:87.96ms +step:1389/1680 train_time:122184ms step_avg:87.97ms +step:1390/1680 train_time:122273ms step_avg:87.97ms +step:1391/1680 train_time:122362ms step_avg:87.97ms +step:1392/1680 train_time:122451ms step_avg:87.97ms +step:1393/1680 train_time:122540ms step_avg:87.97ms +step:1394/1680 train_time:122628ms step_avg:87.97ms +step:1395/1680 train_time:122716ms step_avg:87.97ms +step:1396/1680 train_time:122804ms step_avg:87.97ms +step:1397/1680 train_time:122892ms step_avg:87.97ms +step:1398/1680 train_time:122981ms step_avg:87.97ms +step:1399/1680 train_time:123071ms step_avg:87.97ms +step:1400/1680 train_time:123161ms step_avg:87.97ms +step:1401/1680 train_time:123250ms step_avg:87.97ms +step:1402/1680 train_time:123339ms step_avg:87.97ms +step:1403/1680 train_time:123428ms step_avg:87.97ms +step:1404/1680 train_time:123517ms step_avg:87.97ms +step:1405/1680 train_time:123605ms step_avg:87.97ms +step:1406/1680 train_time:123693ms step_avg:87.98ms +step:1407/1680 train_time:123782ms step_avg:87.98ms +step:1408/1680 train_time:123870ms step_avg:87.98ms +step:1409/1680 train_time:123960ms step_avg:87.98ms +step:1410/1680 train_time:124050ms step_avg:87.98ms +step:1411/1680 train_time:124140ms step_avg:87.98ms +step:1412/1680 train_time:124229ms step_avg:87.98ms +step:1413/1680 train_time:124319ms step_avg:87.98ms +step:1414/1680 train_time:124408ms step_avg:87.98ms +step:1415/1680 train_time:124497ms step_avg:87.98ms +step:1416/1680 train_time:124586ms step_avg:87.98ms +step:1417/1680 train_time:124674ms step_avg:87.98ms +step:1418/1680 train_time:124763ms step_avg:87.98ms +step:1419/1680 train_time:124851ms step_avg:87.99ms +step:1420/1680 train_time:124940ms step_avg:87.99ms +step:1421/1680 train_time:125030ms step_avg:87.99ms +step:1422/1680 train_time:125120ms step_avg:87.99ms +step:1423/1680 train_time:125209ms step_avg:87.99ms +step:1424/1680 train_time:125298ms step_avg:87.99ms +step:1425/1680 train_time:125388ms step_avg:87.99ms +step:1426/1680 train_time:125477ms step_avg:87.99ms +step:1427/1680 train_time:125566ms step_avg:87.99ms +step:1428/1680 train_time:125655ms step_avg:87.99ms +step:1429/1680 train_time:125744ms step_avg:87.99ms +step:1430/1680 train_time:125832ms step_avg:87.99ms +step:1431/1680 train_time:125921ms step_avg:88.00ms +step:1432/1680 train_time:126010ms step_avg:88.00ms +step:1433/1680 train_time:126099ms step_avg:88.00ms +step:1434/1680 train_time:126188ms step_avg:88.00ms +step:1435/1680 train_time:126277ms step_avg:88.00ms +step:1436/1680 train_time:126367ms step_avg:88.00ms +step:1437/1680 train_time:126456ms step_avg:88.00ms +step:1438/1680 train_time:126545ms step_avg:88.00ms +step:1439/1680 train_time:126633ms step_avg:88.00ms +step:1440/1680 train_time:126722ms step_avg:88.00ms +step:1441/1680 train_time:126810ms step_avg:88.00ms +step:1442/1680 train_time:126899ms step_avg:88.00ms +step:1443/1680 train_time:126988ms step_avg:88.00ms +step:1444/1680 train_time:127078ms step_avg:88.00ms +step:1445/1680 train_time:127167ms step_avg:88.01ms +step:1446/1680 train_time:127256ms step_avg:88.01ms +step:1447/1680 train_time:127345ms step_avg:88.01ms +step:1448/1680 train_time:127433ms step_avg:88.01ms +step:1449/1680 train_time:127522ms step_avg:88.01ms +step:1450/1680 train_time:127612ms step_avg:88.01ms +step:1451/1680 train_time:127700ms step_avg:88.01ms +step:1452/1680 train_time:127788ms step_avg:88.01ms +step:1453/1680 train_time:127877ms step_avg:88.01ms +step:1454/1680 train_time:127967ms step_avg:88.01ms +step:1455/1680 train_time:128056ms step_avg:88.01ms +step:1456/1680 train_time:128146ms step_avg:88.01ms +step:1457/1680 train_time:128234ms step_avg:88.01ms +step:1458/1680 train_time:128323ms step_avg:88.01ms +step:1459/1680 train_time:128412ms step_avg:88.01ms +step:1460/1680 train_time:128501ms step_avg:88.01ms +step:1461/1680 train_time:128590ms step_avg:88.02ms +step:1462/1680 train_time:128679ms step_avg:88.02ms +step:1463/1680 train_time:128768ms step_avg:88.02ms +step:1464/1680 train_time:128857ms step_avg:88.02ms +step:1465/1680 train_time:128947ms step_avg:88.02ms +step:1466/1680 train_time:129036ms step_avg:88.02ms +step:1467/1680 train_time:129126ms step_avg:88.02ms +step:1468/1680 train_time:129215ms step_avg:88.02ms +step:1469/1680 train_time:129304ms step_avg:88.02ms +step:1470/1680 train_time:129392ms step_avg:88.02ms +step:1471/1680 train_time:129481ms step_avg:88.02ms +step:1472/1680 train_time:129571ms step_avg:88.02ms +step:1473/1680 train_time:129659ms step_avg:88.02ms +step:1474/1680 train_time:129748ms step_avg:88.02ms +step:1475/1680 train_time:129837ms step_avg:88.03ms +step:1476/1680 train_time:129927ms step_avg:88.03ms +step:1477/1680 train_time:130016ms step_avg:88.03ms +step:1478/1680 train_time:130105ms step_avg:88.03ms +step:1479/1680 train_time:130193ms step_avg:88.03ms +step:1480/1680 train_time:130283ms step_avg:88.03ms +step:1481/1680 train_time:130372ms step_avg:88.03ms +step:1482/1680 train_time:130461ms step_avg:88.03ms +step:1483/1680 train_time:130550ms step_avg:88.03ms +step:1484/1680 train_time:130639ms step_avg:88.03ms +step:1485/1680 train_time:130728ms step_avg:88.03ms +step:1486/1680 train_time:130816ms step_avg:88.03ms +step:1487/1680 train_time:130905ms step_avg:88.03ms +step:1488/1680 train_time:130994ms step_avg:88.03ms +step:1489/1680 train_time:131083ms step_avg:88.03ms +step:1490/1680 train_time:131173ms step_avg:88.04ms +step:1491/1680 train_time:131262ms step_avg:88.04ms +step:1492/1680 train_time:131351ms step_avg:88.04ms +step:1493/1680 train_time:131440ms step_avg:88.04ms +step:1494/1680 train_time:131529ms step_avg:88.04ms +step:1495/1680 train_time:131618ms step_avg:88.04ms +step:1496/1680 train_time:131706ms step_avg:88.04ms +step:1497/1680 train_time:131795ms step_avg:88.04ms +step:1498/1680 train_time:131883ms step_avg:88.04ms +step:1499/1680 train_time:131972ms step_avg:88.04ms +step:1500/1680 train_time:132061ms step_avg:88.04ms +step:1500/1680 val_loss:3.3127 train_time:132151ms step_avg:88.10ms +step:1501/1680 train_time:132169ms step_avg:88.05ms +step:1502/1680 train_time:132241ms step_avg:88.04ms +step:1503/1680 train_time:132338ms step_avg:88.05ms +step:1504/1680 train_time:132428ms step_avg:88.05ms +step:1505/1680 train_time:132516ms step_avg:88.05ms +step:1506/1680 train_time:132604ms step_avg:88.05ms +step:1507/1680 train_time:132693ms step_avg:88.05ms +step:1508/1680 train_time:132781ms step_avg:88.05ms +step:1509/1680 train_time:132869ms step_avg:88.05ms +step:1510/1680 train_time:132957ms step_avg:88.05ms +step:1511/1680 train_time:133044ms step_avg:88.05ms +step:1512/1680 train_time:133134ms step_avg:88.05ms +step:1513/1680 train_time:133224ms step_avg:88.05ms +step:1514/1680 train_time:133316ms step_avg:88.06ms +step:1515/1680 train_time:133407ms step_avg:88.06ms +step:1516/1680 train_time:133496ms step_avg:88.06ms +step:1517/1680 train_time:133585ms step_avg:88.06ms +step:1518/1680 train_time:133673ms step_avg:88.06ms +step:1519/1680 train_time:133762ms step_avg:88.06ms +step:1520/1680 train_time:133851ms step_avg:88.06ms +step:1521/1680 train_time:133939ms step_avg:88.06ms +step:1522/1680 train_time:134028ms step_avg:88.06ms +step:1523/1680 train_time:134117ms step_avg:88.06ms +step:1524/1680 train_time:134207ms step_avg:88.06ms +step:1525/1680 train_time:134297ms step_avg:88.06ms +step:1526/1680 train_time:134387ms step_avg:88.06ms +step:1527/1680 train_time:134476ms step_avg:88.07ms +step:1528/1680 train_time:134565ms step_avg:88.07ms +step:1529/1680 train_time:134653ms step_avg:88.07ms +step:1530/1680 train_time:134741ms step_avg:88.07ms +step:1531/1680 train_time:134830ms step_avg:88.07ms +step:1532/1680 train_time:134919ms step_avg:88.07ms +step:1533/1680 train_time:135007ms step_avg:88.07ms +step:1534/1680 train_time:135096ms step_avg:88.07ms +step:1535/1680 train_time:135186ms step_avg:88.07ms +step:1536/1680 train_time:135276ms step_avg:88.07ms +step:1537/1680 train_time:135366ms step_avg:88.07ms +step:1538/1680 train_time:135456ms step_avg:88.07ms +step:1539/1680 train_time:135545ms step_avg:88.07ms +step:1540/1680 train_time:135634ms step_avg:88.07ms +step:1541/1680 train_time:135723ms step_avg:88.07ms +step:1542/1680 train_time:135811ms step_avg:88.07ms +step:1543/1680 train_time:135900ms step_avg:88.08ms +step:1544/1680 train_time:135989ms step_avg:88.08ms +step:1545/1680 train_time:136077ms step_avg:88.08ms +step:1546/1680 train_time:136166ms step_avg:88.08ms +step:1547/1680 train_time:136255ms step_avg:88.08ms +step:1548/1680 train_time:136344ms step_avg:88.08ms +step:1549/1680 train_time:136434ms step_avg:88.08ms +step:1550/1680 train_time:136523ms step_avg:88.08ms +step:1551/1680 train_time:136612ms step_avg:88.08ms +step:1552/1680 train_time:136700ms step_avg:88.08ms +step:1553/1680 train_time:136789ms step_avg:88.08ms +step:1554/1680 train_time:136877ms step_avg:88.08ms +step:1555/1680 train_time:136966ms step_avg:88.08ms +step:1556/1680 train_time:137055ms step_avg:88.08ms +step:1557/1680 train_time:137144ms step_avg:88.08ms +step:1558/1680 train_time:137235ms step_avg:88.08ms +step:1559/1680 train_time:137325ms step_avg:88.09ms +step:1560/1680 train_time:137414ms step_avg:88.09ms +step:1561/1680 train_time:137504ms step_avg:88.09ms +step:1562/1680 train_time:137594ms step_avg:88.09ms +step:1563/1680 train_time:137683ms step_avg:88.09ms +step:1564/1680 train_time:137771ms step_avg:88.09ms +step:1565/1680 train_time:137860ms step_avg:88.09ms +step:1566/1680 train_time:137948ms step_avg:88.09ms +step:1567/1680 train_time:138038ms step_avg:88.09ms +step:1568/1680 train_time:138127ms step_avg:88.09ms +step:1569/1680 train_time:138216ms step_avg:88.09ms +step:1570/1680 train_time:138306ms step_avg:88.09ms +step:1571/1680 train_time:138395ms step_avg:88.09ms +step:1572/1680 train_time:138485ms step_avg:88.09ms +step:1573/1680 train_time:138574ms step_avg:88.10ms +step:1574/1680 train_time:138663ms step_avg:88.10ms +step:1575/1680 train_time:138751ms step_avg:88.10ms +step:1576/1680 train_time:138840ms step_avg:88.10ms +step:1577/1680 train_time:138929ms step_avg:88.10ms +step:1578/1680 train_time:139018ms step_avg:88.10ms +step:1579/1680 train_time:139106ms step_avg:88.10ms +step:1580/1680 train_time:139196ms step_avg:88.10ms +step:1581/1680 train_time:139285ms step_avg:88.10ms +step:1582/1680 train_time:139373ms step_avg:88.10ms +step:1583/1680 train_time:139463ms step_avg:88.10ms +step:1584/1680 train_time:139552ms step_avg:88.10ms +step:1585/1680 train_time:139641ms step_avg:88.10ms +step:1586/1680 train_time:139730ms step_avg:88.10ms +step:1587/1680 train_time:139818ms step_avg:88.10ms +step:1588/1680 train_time:139907ms step_avg:88.10ms +step:1589/1680 train_time:139997ms step_avg:88.10ms +step:1590/1680 train_time:140086ms step_avg:88.10ms +step:1591/1680 train_time:140175ms step_avg:88.11ms +step:1592/1680 train_time:140264ms step_avg:88.11ms +step:1593/1680 train_time:140353ms step_avg:88.11ms +step:1594/1680 train_time:140442ms step_avg:88.11ms +step:1595/1680 train_time:140531ms step_avg:88.11ms +step:1596/1680 train_time:140619ms step_avg:88.11ms +step:1597/1680 train_time:140708ms step_avg:88.11ms +step:1598/1680 train_time:140796ms step_avg:88.11ms +step:1599/1680 train_time:140885ms step_avg:88.11ms +step:1600/1680 train_time:140974ms step_avg:88.11ms +step:1601/1680 train_time:141063ms step_avg:88.11ms +step:1602/1680 train_time:141152ms step_avg:88.11ms +step:1603/1680 train_time:141242ms step_avg:88.11ms +step:1604/1680 train_time:141332ms step_avg:88.11ms +step:1605/1680 train_time:141421ms step_avg:88.11ms +step:1606/1680 train_time:141511ms step_avg:88.11ms +step:1607/1680 train_time:141599ms step_avg:88.11ms +step:1608/1680 train_time:141688ms step_avg:88.11ms +step:1609/1680 train_time:141778ms step_avg:88.12ms +step:1610/1680 train_time:141867ms step_avg:88.12ms +step:1611/1680 train_time:141956ms step_avg:88.12ms +step:1612/1680 train_time:142044ms step_avg:88.12ms +step:1613/1680 train_time:142134ms step_avg:88.12ms +step:1614/1680 train_time:142223ms step_avg:88.12ms +step:1615/1680 train_time:142313ms step_avg:88.12ms +step:1616/1680 train_time:142402ms step_avg:88.12ms +step:1617/1680 train_time:142492ms step_avg:88.12ms +step:1618/1680 train_time:142580ms step_avg:88.12ms +step:1619/1680 train_time:142669ms step_avg:88.12ms +step:1620/1680 train_time:142759ms step_avg:88.12ms +step:1621/1680 train_time:142848ms step_avg:88.12ms +step:1622/1680 train_time:142937ms step_avg:88.12ms +step:1623/1680 train_time:143025ms step_avg:88.12ms +step:1624/1680 train_time:143114ms step_avg:88.12ms +step:1625/1680 train_time:143203ms step_avg:88.12ms +step:1625/1680 val_loss:3.2887 train_time:143294ms step_avg:88.18ms +step:1626/1680 train_time:143312ms step_avg:88.14ms +step:1627/1680 train_time:143385ms step_avg:88.13ms +step:1628/1680 train_time:143478ms step_avg:88.13ms +step:1629/1680 train_time:143567ms step_avg:88.13ms +step:1630/1680 train_time:143655ms step_avg:88.13ms +step:1631/1680 train_time:143743ms step_avg:88.13ms +step:1632/1680 train_time:143831ms step_avg:88.13ms +step:1633/1680 train_time:143919ms step_avg:88.13ms +step:1634/1680 train_time:144007ms step_avg:88.13ms +step:1635/1680 train_time:144095ms step_avg:88.13ms +step:1636/1680 train_time:144184ms step_avg:88.13ms +step:1637/1680 train_time:144274ms step_avg:88.13ms +step:1638/1680 train_time:144365ms step_avg:88.13ms +step:1639/1680 train_time:144456ms step_avg:88.14ms +step:1640/1680 train_time:144546ms step_avg:88.14ms +step:1641/1680 train_time:144637ms step_avg:88.14ms +step:1642/1680 train_time:144725ms step_avg:88.14ms +step:1643/1680 train_time:144813ms step_avg:88.14ms +step:1644/1680 train_time:144901ms step_avg:88.14ms +step:1645/1680 train_time:144989ms step_avg:88.14ms +step:1646/1680 train_time:145078ms step_avg:88.14ms +step:1647/1680 train_time:145167ms step_avg:88.14ms +step:1648/1680 train_time:145256ms step_avg:88.14ms +step:1649/1680 train_time:145346ms step_avg:88.14ms +step:1650/1680 train_time:145436ms step_avg:88.14ms +step:1651/1680 train_time:145526ms step_avg:88.14ms +step:1652/1680 train_time:145615ms step_avg:88.14ms +step:1653/1680 train_time:145704ms step_avg:88.15ms +step:1654/1680 train_time:145793ms step_avg:88.15ms +step:1655/1680 train_time:145881ms step_avg:88.15ms +step:1656/1680 train_time:145969ms step_avg:88.15ms +step:1657/1680 train_time:146058ms step_avg:88.15ms +step:1658/1680 train_time:146148ms step_avg:88.15ms +step:1659/1680 train_time:146237ms step_avg:88.15ms +step:1660/1680 train_time:146327ms step_avg:88.15ms +step:1661/1680 train_time:146416ms step_avg:88.15ms +step:1662/1680 train_time:146506ms step_avg:88.15ms +step:1663/1680 train_time:146596ms step_avg:88.15ms +step:1664/1680 train_time:146685ms step_avg:88.15ms +step:1665/1680 train_time:146774ms step_avg:88.15ms +step:1666/1680 train_time:146863ms step_avg:88.15ms +step:1667/1680 train_time:146952ms step_avg:88.15ms +step:1668/1680 train_time:147040ms step_avg:88.15ms +step:1669/1680 train_time:147129ms step_avg:88.15ms +step:1670/1680 train_time:147217ms step_avg:88.15ms +step:1671/1680 train_time:147306ms step_avg:88.15ms +step:1672/1680 train_time:147397ms step_avg:88.16ms +step:1673/1680 train_time:147486ms step_avg:88.16ms +step:1674/1680 train_time:147576ms step_avg:88.16ms +step:1675/1680 train_time:147666ms step_avg:88.16ms +step:1676/1680 train_time:147756ms step_avg:88.16ms +step:1677/1680 train_time:147844ms step_avg:88.16ms +step:1678/1680 train_time:147933ms step_avg:88.16ms +step:1679/1680 train_time:148021ms step_avg:88.16ms +step:1680/1680 train_time:148110ms step_avg:88.16ms +step:1680/1680 val_loss:3.2787 train_time:148201ms step_avg:88.22ms +peak memory allocated: 30760 MiB reserved: 46014 MiB diff --git a/records/092725_BF16CE/351728bd-3438-40d2-a006-41ed492e139f.txt b/records/092725_BF16CE/351728bd-3438-40d2-a006-41ed492e139f.txt new file mode 100644 index 000000000..ac2e0d793 --- /dev/null +++ b/records/092725_BF16CE/351728bd-3438-40d2-a006-41ed492e139f.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:54:48 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 165264 C /usr/bin/python 0MiB | +| 0 N/A N/A 165265 C /usr/bin/python 0MiB | +| 0 N/A N/A 165266 C /usr/bin/python 0MiB | +| 0 N/A N/A 165267 C /usr/bin/python 0MiB | +| 0 N/A N/A 165268 C /usr/bin/python 0MiB | +| 0 N/A N/A 165269 C /usr/bin/python 0MiB | +| 0 N/A N/A 165270 C /usr/bin/python 0MiB | +| 0 N/A N/A 165271 C /usr/bin/python 0MiB | +| 1 N/A N/A 165265 C /usr/bin/python 0MiB | +| 2 N/A N/A 165266 C /usr/bin/python 0MiB | +| 3 N/A N/A 165267 C /usr/bin/python 0MiB | +| 4 N/A N/A 165268 C /usr/bin/python 0MiB | +| 5 N/A N/A 165269 C /usr/bin/python 0MiB | +| 6 N/A N/A 165270 C /usr/bin/python 0MiB | +| 7 N/A N/A 165271 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:145ms step_avg:144.94ms +step:2/1680 train_time:165ms step_avg:82.25ms +step:3/1680 train_time:229ms step_avg:76.32ms +step:4/1680 train_time:314ms step_avg:78.42ms +step:5/1680 train_time:399ms step_avg:79.87ms +step:6/1680 train_time:485ms step_avg:80.89ms +step:7/1680 train_time:571ms step_avg:81.64ms +step:8/1680 train_time:658ms step_avg:82.20ms +step:9/1680 train_time:744ms step_avg:82.64ms +step:10/1680 train_time:830ms step_avg:83.01ms +step:11/1680 train_time:916ms step_avg:83.30ms +step:12/1680 train_time:1004ms step_avg:83.68ms +step:13/1680 train_time:1094ms step_avg:84.19ms +step:14/1680 train_time:1185ms step_avg:84.62ms +step:15/1680 train_time:1273ms step_avg:84.87ms +step:16/1680 train_time:1361ms step_avg:85.03ms +step:17/1680 train_time:1448ms step_avg:85.15ms +step:18/1680 train_time:1534ms step_avg:85.22ms +step:19/1680 train_time:1621ms step_avg:85.30ms +step:20/1680 train_time:1707ms step_avg:85.36ms +step:21/1680 train_time:1794ms step_avg:85.41ms +step:22/1680 train_time:1880ms step_avg:85.46ms +step:23/1680 train_time:1967ms step_avg:85.53ms +step:24/1680 train_time:2056ms step_avg:85.66ms +step:25/1680 train_time:2145ms step_avg:85.80ms +step:26/1680 train_time:2233ms step_avg:85.90ms +step:27/1680 train_time:2321ms step_avg:85.97ms +step:28/1680 train_time:2409ms step_avg:86.02ms +step:29/1680 train_time:2495ms step_avg:86.05ms +step:30/1680 train_time:2582ms step_avg:86.07ms +step:31/1680 train_time:2668ms step_avg:86.08ms +step:32/1680 train_time:2755ms step_avg:86.09ms +step:33/1680 train_time:2842ms step_avg:86.12ms +step:34/1680 train_time:2929ms step_avg:86.15ms +step:35/1680 train_time:3017ms step_avg:86.21ms +step:36/1680 train_time:3106ms step_avg:86.28ms +step:37/1680 train_time:3194ms step_avg:86.32ms +step:38/1680 train_time:3281ms step_avg:86.35ms +step:39/1680 train_time:3369ms step_avg:86.38ms +step:40/1680 train_time:3456ms step_avg:86.40ms +step:41/1680 train_time:3544ms step_avg:86.44ms +step:42/1680 train_time:3631ms step_avg:86.46ms +step:43/1680 train_time:3718ms step_avg:86.46ms +step:44/1680 train_time:3805ms step_avg:86.47ms +step:45/1680 train_time:3892ms step_avg:86.48ms +step:46/1680 train_time:3979ms step_avg:86.50ms +step:47/1680 train_time:4066ms step_avg:86.52ms +step:48/1680 train_time:4154ms step_avg:86.55ms +step:49/1680 train_time:4242ms step_avg:86.58ms +step:50/1680 train_time:4330ms step_avg:86.59ms +step:51/1680 train_time:4417ms step_avg:86.61ms +step:52/1680 train_time:4504ms step_avg:86.62ms +step:53/1680 train_time:4591ms step_avg:86.62ms +step:54/1680 train_time:4678ms step_avg:86.63ms +step:55/1680 train_time:4765ms step_avg:86.63ms +step:56/1680 train_time:4851ms step_avg:86.63ms +step:57/1680 train_time:4939ms step_avg:86.65ms +step:58/1680 train_time:5026ms step_avg:86.66ms +step:59/1680 train_time:5114ms step_avg:86.68ms +step:60/1680 train_time:5202ms step_avg:86.69ms +step:61/1680 train_time:5289ms step_avg:86.70ms +step:62/1680 train_time:5376ms step_avg:86.71ms +step:63/1680 train_time:5463ms step_avg:86.72ms +step:64/1680 train_time:5550ms step_avg:86.72ms +step:65/1680 train_time:5638ms step_avg:86.74ms +step:66/1680 train_time:5725ms step_avg:86.74ms +step:67/1680 train_time:5812ms step_avg:86.75ms +step:68/1680 train_time:5899ms step_avg:86.75ms +step:69/1680 train_time:5986ms step_avg:86.75ms +step:70/1680 train_time:6074ms step_avg:86.77ms +step:71/1680 train_time:6161ms step_avg:86.77ms +step:72/1680 train_time:6248ms step_avg:86.78ms +step:73/1680 train_time:6335ms step_avg:86.79ms +step:74/1680 train_time:6423ms step_avg:86.80ms +step:75/1680 train_time:6510ms step_avg:86.80ms +step:76/1680 train_time:6597ms step_avg:86.80ms +step:77/1680 train_time:6684ms step_avg:86.81ms +step:78/1680 train_time:6771ms step_avg:86.81ms +step:79/1680 train_time:6858ms step_avg:86.81ms +step:80/1680 train_time:6945ms step_avg:86.81ms +step:81/1680 train_time:7031ms step_avg:86.80ms +step:82/1680 train_time:7118ms step_avg:86.80ms +step:83/1680 train_time:7205ms step_avg:86.81ms +step:84/1680 train_time:7292ms step_avg:86.81ms +step:85/1680 train_time:7380ms step_avg:86.83ms +step:86/1680 train_time:7467ms step_avg:86.83ms +step:87/1680 train_time:7554ms step_avg:86.83ms +step:88/1680 train_time:7642ms step_avg:86.84ms +step:89/1680 train_time:7728ms step_avg:86.83ms +step:90/1680 train_time:7816ms step_avg:86.84ms +step:91/1680 train_time:7904ms step_avg:86.85ms +step:92/1680 train_time:7990ms step_avg:86.85ms +step:93/1680 train_time:8077ms step_avg:86.85ms +step:94/1680 train_time:8165ms step_avg:86.86ms +step:95/1680 train_time:8252ms step_avg:86.86ms +step:96/1680 train_time:8339ms step_avg:86.86ms +step:97/1680 train_time:8426ms step_avg:86.86ms +step:98/1680 train_time:8514ms step_avg:86.88ms +step:99/1680 train_time:8601ms step_avg:86.88ms +step:100/1680 train_time:8688ms step_avg:86.88ms +step:101/1680 train_time:8775ms step_avg:86.89ms +step:102/1680 train_time:8862ms step_avg:86.89ms +step:103/1680 train_time:8950ms step_avg:86.90ms +step:104/1680 train_time:9037ms step_avg:86.90ms +step:105/1680 train_time:9124ms step_avg:86.90ms +step:106/1680 train_time:9211ms step_avg:86.90ms +step:107/1680 train_time:9298ms step_avg:86.90ms +step:108/1680 train_time:9385ms step_avg:86.90ms +step:109/1680 train_time:9473ms step_avg:86.90ms +step:110/1680 train_time:9559ms step_avg:86.90ms +step:111/1680 train_time:9647ms step_avg:86.91ms +step:112/1680 train_time:9734ms step_avg:86.91ms +step:113/1680 train_time:9821ms step_avg:86.91ms +step:114/1680 train_time:9908ms step_avg:86.91ms +step:115/1680 train_time:9995ms step_avg:86.91ms +step:116/1680 train_time:10082ms step_avg:86.91ms +step:117/1680 train_time:10169ms step_avg:86.91ms +step:118/1680 train_time:10256ms step_avg:86.91ms +step:119/1680 train_time:10343ms step_avg:86.92ms +step:120/1680 train_time:10430ms step_avg:86.92ms +step:121/1680 train_time:10517ms step_avg:86.92ms +step:122/1680 train_time:10604ms step_avg:86.92ms +step:123/1680 train_time:10691ms step_avg:86.92ms +step:124/1680 train_time:10778ms step_avg:86.92ms +step:125/1680 train_time:10865ms step_avg:86.92ms +step:125/1680 val_loss:4.3070 train_time:10953ms step_avg:87.62ms +step:126/1680 train_time:10971ms step_avg:87.07ms +step:127/1680 train_time:11043ms step_avg:86.95ms +step:128/1680 train_time:11140ms step_avg:87.03ms +step:129/1680 train_time:11230ms step_avg:87.05ms +step:130/1680 train_time:11317ms step_avg:87.05ms +step:131/1680 train_time:11405ms step_avg:87.06ms +step:132/1680 train_time:11491ms step_avg:87.05ms +step:133/1680 train_time:11577ms step_avg:87.04ms +step:134/1680 train_time:11662ms step_avg:87.03ms +step:135/1680 train_time:11749ms step_avg:87.03ms +step:136/1680 train_time:11835ms step_avg:87.02ms +step:137/1680 train_time:11920ms step_avg:87.01ms +step:138/1680 train_time:12008ms step_avg:87.01ms +step:139/1680 train_time:12097ms step_avg:87.03ms +step:140/1680 train_time:12186ms step_avg:87.04ms +step:141/1680 train_time:12274ms step_avg:87.05ms +step:142/1680 train_time:12361ms step_avg:87.05ms +step:143/1680 train_time:12448ms step_avg:87.05ms +step:144/1680 train_time:12535ms step_avg:87.05ms +step:145/1680 train_time:12622ms step_avg:87.05ms +step:146/1680 train_time:12708ms step_avg:87.04ms +step:147/1680 train_time:12794ms step_avg:87.03ms +step:148/1680 train_time:12880ms step_avg:87.03ms +step:149/1680 train_time:12967ms step_avg:87.03ms +step:150/1680 train_time:13055ms step_avg:87.03ms +step:151/1680 train_time:13144ms step_avg:87.04ms +step:152/1680 train_time:13232ms step_avg:87.05ms +step:153/1680 train_time:13320ms step_avg:87.06ms +step:154/1680 train_time:13408ms step_avg:87.06ms +step:155/1680 train_time:13495ms step_avg:87.06ms +step:156/1680 train_time:13582ms step_avg:87.06ms +step:157/1680 train_time:13668ms step_avg:87.06ms +step:158/1680 train_time:13754ms step_avg:87.05ms +step:159/1680 train_time:13840ms step_avg:87.05ms +step:160/1680 train_time:13926ms step_avg:87.04ms +step:161/1680 train_time:14013ms step_avg:87.04ms +step:162/1680 train_time:14101ms step_avg:87.04ms +step:163/1680 train_time:14189ms step_avg:87.05ms +step:164/1680 train_time:14277ms step_avg:87.05ms +step:165/1680 train_time:14364ms step_avg:87.05ms +step:166/1680 train_time:14452ms step_avg:87.06ms +step:167/1680 train_time:14538ms step_avg:87.05ms +step:168/1680 train_time:14625ms step_avg:87.05ms +step:169/1680 train_time:14712ms step_avg:87.05ms +step:170/1680 train_time:14799ms step_avg:87.05ms +step:171/1680 train_time:14885ms step_avg:87.05ms +step:172/1680 train_time:14972ms step_avg:87.05ms +step:173/1680 train_time:15059ms step_avg:87.05ms +step:174/1680 train_time:15146ms step_avg:87.05ms +step:175/1680 train_time:15234ms step_avg:87.05ms +step:176/1680 train_time:15321ms step_avg:87.05ms +step:177/1680 train_time:15409ms step_avg:87.06ms +step:178/1680 train_time:15496ms step_avg:87.06ms +step:179/1680 train_time:15583ms step_avg:87.06ms +step:180/1680 train_time:15670ms step_avg:87.06ms +step:181/1680 train_time:15756ms step_avg:87.05ms +step:182/1680 train_time:15843ms step_avg:87.05ms +step:183/1680 train_time:15929ms step_avg:87.04ms +step:184/1680 train_time:16016ms step_avg:87.04ms +step:185/1680 train_time:16103ms step_avg:87.05ms +step:186/1680 train_time:16191ms step_avg:87.05ms +step:187/1680 train_time:16278ms step_avg:87.05ms +step:188/1680 train_time:16365ms step_avg:87.05ms +step:189/1680 train_time:16452ms step_avg:87.05ms +step:190/1680 train_time:16539ms step_avg:87.05ms +step:191/1680 train_time:16626ms step_avg:87.05ms +step:192/1680 train_time:16713ms step_avg:87.05ms +step:193/1680 train_time:16801ms step_avg:87.05ms +step:194/1680 train_time:16887ms step_avg:87.05ms +step:195/1680 train_time:16974ms step_avg:87.05ms +step:196/1680 train_time:17061ms step_avg:87.04ms +step:197/1680 train_time:17148ms step_avg:87.04ms +step:198/1680 train_time:17234ms step_avg:87.04ms +step:199/1680 train_time:17322ms step_avg:87.04ms +step:200/1680 train_time:17409ms step_avg:87.04ms +step:201/1680 train_time:17496ms step_avg:87.05ms +step:202/1680 train_time:17583ms step_avg:87.05ms +step:203/1680 train_time:17670ms step_avg:87.05ms +step:204/1680 train_time:17757ms step_avg:87.04ms +step:205/1680 train_time:17844ms step_avg:87.04ms +step:206/1680 train_time:17930ms step_avg:87.04ms +step:207/1680 train_time:18018ms step_avg:87.04ms +step:208/1680 train_time:18105ms step_avg:87.04ms +step:209/1680 train_time:18192ms step_avg:87.04ms +step:210/1680 train_time:18279ms step_avg:87.04ms +step:211/1680 train_time:18366ms step_avg:87.04ms +step:212/1680 train_time:18452ms step_avg:87.04ms +step:213/1680 train_time:18540ms step_avg:87.04ms +step:214/1680 train_time:18627ms step_avg:87.04ms +step:215/1680 train_time:18713ms step_avg:87.04ms +step:216/1680 train_time:18800ms step_avg:87.04ms +step:217/1680 train_time:18887ms step_avg:87.04ms +step:218/1680 train_time:18974ms step_avg:87.04ms +step:219/1680 train_time:19061ms step_avg:87.04ms +step:220/1680 train_time:19148ms step_avg:87.04ms +step:221/1680 train_time:19235ms step_avg:87.03ms +step:222/1680 train_time:19322ms step_avg:87.04ms +step:223/1680 train_time:19409ms step_avg:87.04ms +step:224/1680 train_time:19496ms step_avg:87.03ms +step:225/1680 train_time:19583ms step_avg:87.03ms +step:226/1680 train_time:19669ms step_avg:87.03ms +step:227/1680 train_time:19756ms step_avg:87.03ms +step:228/1680 train_time:19842ms step_avg:87.03ms +step:229/1680 train_time:19929ms step_avg:87.03ms +step:230/1680 train_time:20017ms step_avg:87.03ms +step:231/1680 train_time:20104ms step_avg:87.03ms +step:232/1680 train_time:20191ms step_avg:87.03ms +step:233/1680 train_time:20277ms step_avg:87.03ms +step:234/1680 train_time:20365ms step_avg:87.03ms +step:235/1680 train_time:20451ms step_avg:87.03ms +step:236/1680 train_time:20538ms step_avg:87.02ms +step:237/1680 train_time:20625ms step_avg:87.02ms +step:238/1680 train_time:20712ms step_avg:87.03ms +step:239/1680 train_time:20799ms step_avg:87.02ms +step:240/1680 train_time:20885ms step_avg:87.02ms +step:241/1680 train_time:20972ms step_avg:87.02ms +step:242/1680 train_time:21059ms step_avg:87.02ms +step:243/1680 train_time:21146ms step_avg:87.02ms +step:244/1680 train_time:21233ms step_avg:87.02ms +step:245/1680 train_time:21320ms step_avg:87.02ms +step:246/1680 train_time:21407ms step_avg:87.02ms +step:247/1680 train_time:21494ms step_avg:87.02ms +step:248/1680 train_time:21582ms step_avg:87.02ms +step:249/1680 train_time:21668ms step_avg:87.02ms +step:250/1680 train_time:21755ms step_avg:87.02ms +step:250/1680 val_loss:3.9735 train_time:21843ms step_avg:87.37ms +step:251/1680 train_time:21861ms step_avg:87.09ms +step:252/1680 train_time:21933ms step_avg:87.04ms +step:253/1680 train_time:22021ms step_avg:87.04ms +step:254/1680 train_time:22109ms step_avg:87.04ms +step:255/1680 train_time:22196ms step_avg:87.04ms +step:256/1680 train_time:22283ms step_avg:87.04ms +step:257/1680 train_time:22369ms step_avg:87.04ms +step:258/1680 train_time:22455ms step_avg:87.04ms +step:259/1680 train_time:22541ms step_avg:87.03ms +step:260/1680 train_time:22627ms step_avg:87.03ms +step:261/1680 train_time:22714ms step_avg:87.03ms +step:262/1680 train_time:22801ms step_avg:87.03ms +step:263/1680 train_time:22890ms step_avg:87.03ms +step:264/1680 train_time:22978ms step_avg:87.04ms +step:265/1680 train_time:23066ms step_avg:87.04ms +step:266/1680 train_time:23153ms step_avg:87.04ms +step:267/1680 train_time:23240ms step_avg:87.04ms +step:268/1680 train_time:23326ms step_avg:87.04ms +step:269/1680 train_time:23413ms step_avg:87.04ms +step:270/1680 train_time:23500ms step_avg:87.04ms +step:271/1680 train_time:23586ms step_avg:87.03ms +step:272/1680 train_time:23673ms step_avg:87.03ms +step:273/1680 train_time:23759ms step_avg:87.03ms +step:274/1680 train_time:23847ms step_avg:87.03ms +step:275/1680 train_time:23935ms step_avg:87.04ms +step:276/1680 train_time:24022ms step_avg:87.04ms +step:277/1680 train_time:24110ms step_avg:87.04ms +step:278/1680 train_time:24197ms step_avg:87.04ms +step:279/1680 train_time:24284ms step_avg:87.04ms +step:280/1680 train_time:24371ms step_avg:87.04ms +step:281/1680 train_time:24457ms step_avg:87.04ms +step:282/1680 train_time:24544ms step_avg:87.04ms +step:283/1680 train_time:24630ms step_avg:87.03ms +step:284/1680 train_time:24717ms step_avg:87.03ms +step:285/1680 train_time:24804ms step_avg:87.03ms +step:286/1680 train_time:24891ms step_avg:87.03ms +step:287/1680 train_time:24978ms step_avg:87.03ms +step:288/1680 train_time:25066ms step_avg:87.03ms +step:289/1680 train_time:25152ms step_avg:87.03ms +step:290/1680 train_time:25240ms step_avg:87.03ms +step:291/1680 train_time:25327ms step_avg:87.03ms +step:292/1680 train_time:25413ms step_avg:87.03ms +step:293/1680 train_time:25500ms step_avg:87.03ms +step:294/1680 train_time:25586ms step_avg:87.03ms +step:295/1680 train_time:25673ms step_avg:87.03ms +step:296/1680 train_time:25761ms step_avg:87.03ms +step:297/1680 train_time:25848ms step_avg:87.03ms +step:298/1680 train_time:25935ms step_avg:87.03ms +step:299/1680 train_time:26024ms step_avg:87.04ms +step:300/1680 train_time:26110ms step_avg:87.03ms +step:301/1680 train_time:26197ms step_avg:87.03ms +step:302/1680 train_time:26284ms step_avg:87.03ms +step:303/1680 train_time:26371ms step_avg:87.03ms +step:304/1680 train_time:26458ms step_avg:87.03ms +step:305/1680 train_time:26545ms step_avg:87.03ms +step:306/1680 train_time:26632ms step_avg:87.03ms +step:307/1680 train_time:26718ms step_avg:87.03ms +step:308/1680 train_time:26805ms step_avg:87.03ms +step:309/1680 train_time:26892ms step_avg:87.03ms +step:310/1680 train_time:26980ms step_avg:87.03ms +step:311/1680 train_time:27068ms step_avg:87.03ms +step:312/1680 train_time:27154ms step_avg:87.03ms +step:313/1680 train_time:27241ms step_avg:87.03ms +step:314/1680 train_time:27328ms step_avg:87.03ms +step:315/1680 train_time:27415ms step_avg:87.03ms +step:316/1680 train_time:27502ms step_avg:87.03ms +step:317/1680 train_time:27588ms step_avg:87.03ms +step:318/1680 train_time:27675ms step_avg:87.03ms +step:319/1680 train_time:27762ms step_avg:87.03ms +step:320/1680 train_time:27849ms step_avg:87.03ms +step:321/1680 train_time:27936ms step_avg:87.03ms +step:322/1680 train_time:28024ms step_avg:87.03ms +step:323/1680 train_time:28111ms step_avg:87.03ms +step:324/1680 train_time:28199ms step_avg:87.03ms +step:325/1680 train_time:28286ms step_avg:87.03ms +step:326/1680 train_time:28372ms step_avg:87.03ms +step:327/1680 train_time:28459ms step_avg:87.03ms +step:328/1680 train_time:28546ms step_avg:87.03ms +step:329/1680 train_time:28632ms step_avg:87.03ms +step:330/1680 train_time:28718ms step_avg:87.03ms +step:331/1680 train_time:28805ms step_avg:87.03ms +step:332/1680 train_time:28892ms step_avg:87.02ms +step:333/1680 train_time:28979ms step_avg:87.02ms +step:334/1680 train_time:29066ms step_avg:87.03ms +step:335/1680 train_time:29153ms step_avg:87.02ms +step:336/1680 train_time:29240ms step_avg:87.02ms +step:337/1680 train_time:29327ms step_avg:87.02ms +step:338/1680 train_time:29414ms step_avg:87.02ms +step:339/1680 train_time:29502ms step_avg:87.03ms +step:340/1680 train_time:29588ms step_avg:87.02ms +step:341/1680 train_time:29675ms step_avg:87.02ms +step:342/1680 train_time:29762ms step_avg:87.02ms +step:343/1680 train_time:29849ms step_avg:87.02ms +step:344/1680 train_time:29937ms step_avg:87.02ms +step:345/1680 train_time:30024ms step_avg:87.03ms +step:346/1680 train_time:30111ms step_avg:87.03ms +step:347/1680 train_time:30198ms step_avg:87.03ms +step:348/1680 train_time:30285ms step_avg:87.03ms +step:349/1680 train_time:30372ms step_avg:87.03ms +step:350/1680 train_time:30459ms step_avg:87.03ms +step:351/1680 train_time:30546ms step_avg:87.03ms +step:352/1680 train_time:30633ms step_avg:87.03ms +step:353/1680 train_time:30720ms step_avg:87.03ms +step:354/1680 train_time:30807ms step_avg:87.02ms +step:355/1680 train_time:30894ms step_avg:87.02ms +step:356/1680 train_time:30981ms step_avg:87.03ms +step:357/1680 train_time:31068ms step_avg:87.03ms +step:358/1680 train_time:31155ms step_avg:87.02ms +step:359/1680 train_time:31242ms step_avg:87.02ms +step:360/1680 train_time:31329ms step_avg:87.02ms +step:361/1680 train_time:31416ms step_avg:87.02ms +step:362/1680 train_time:31503ms step_avg:87.02ms +step:363/1680 train_time:31589ms step_avg:87.02ms +step:364/1680 train_time:31677ms step_avg:87.02ms +step:365/1680 train_time:31764ms step_avg:87.02ms +step:366/1680 train_time:31851ms step_avg:87.03ms +step:367/1680 train_time:31938ms step_avg:87.02ms +step:368/1680 train_time:32025ms step_avg:87.02ms +step:369/1680 train_time:32112ms step_avg:87.02ms +step:370/1680 train_time:32199ms step_avg:87.03ms +step:371/1680 train_time:32286ms step_avg:87.02ms +step:372/1680 train_time:32373ms step_avg:87.02ms +step:373/1680 train_time:32460ms step_avg:87.02ms +step:374/1680 train_time:32547ms step_avg:87.02ms +step:375/1680 train_time:32635ms step_avg:87.03ms +step:375/1680 val_loss:3.8216 train_time:32723ms step_avg:87.26ms +step:376/1680 train_time:32743ms step_avg:87.08ms +step:377/1680 train_time:32814ms step_avg:87.04ms +step:378/1680 train_time:32904ms step_avg:87.05ms +step:379/1680 train_time:32992ms step_avg:87.05ms +step:380/1680 train_time:33079ms step_avg:87.05ms +step:381/1680 train_time:33165ms step_avg:87.05ms +step:382/1680 train_time:33251ms step_avg:87.04ms +step:383/1680 train_time:33337ms step_avg:87.04ms +step:384/1680 train_time:33423ms step_avg:87.04ms +step:385/1680 train_time:33509ms step_avg:87.04ms +step:386/1680 train_time:33595ms step_avg:87.03ms +step:387/1680 train_time:33684ms step_avg:87.04ms +step:388/1680 train_time:33772ms step_avg:87.04ms +step:389/1680 train_time:33861ms step_avg:87.05ms +step:390/1680 train_time:33950ms step_avg:87.05ms +step:391/1680 train_time:34037ms step_avg:87.05ms +step:392/1680 train_time:34124ms step_avg:87.05ms +step:393/1680 train_time:34210ms step_avg:87.05ms +step:394/1680 train_time:34296ms step_avg:87.05ms +step:395/1680 train_time:34383ms step_avg:87.04ms +step:396/1680 train_time:34469ms step_avg:87.04ms +step:397/1680 train_time:34556ms step_avg:87.04ms +step:398/1680 train_time:34642ms step_avg:87.04ms +step:399/1680 train_time:34729ms step_avg:87.04ms +step:400/1680 train_time:34817ms step_avg:87.04ms +step:401/1680 train_time:34906ms step_avg:87.05ms +step:402/1680 train_time:34993ms step_avg:87.05ms +step:403/1680 train_time:35080ms step_avg:87.05ms +step:404/1680 train_time:35167ms step_avg:87.05ms +step:405/1680 train_time:35254ms step_avg:87.05ms +step:406/1680 train_time:35340ms step_avg:87.04ms +step:407/1680 train_time:35426ms step_avg:87.04ms +step:408/1680 train_time:35513ms step_avg:87.04ms +step:409/1680 train_time:35599ms step_avg:87.04ms +step:410/1680 train_time:35687ms step_avg:87.04ms +step:411/1680 train_time:35773ms step_avg:87.04ms +step:412/1680 train_time:35861ms step_avg:87.04ms +step:413/1680 train_time:35948ms step_avg:87.04ms +step:414/1680 train_time:36035ms step_avg:87.04ms +step:415/1680 train_time:36122ms step_avg:87.04ms +step:416/1680 train_time:36209ms step_avg:87.04ms +step:417/1680 train_time:36296ms step_avg:87.04ms +step:418/1680 train_time:36383ms step_avg:87.04ms +step:419/1680 train_time:36470ms step_avg:87.04ms +step:420/1680 train_time:36556ms step_avg:87.04ms +step:421/1680 train_time:36643ms step_avg:87.04ms +step:422/1680 train_time:36730ms step_avg:87.04ms +step:423/1680 train_time:36817ms step_avg:87.04ms +step:424/1680 train_time:36904ms step_avg:87.04ms +step:425/1680 train_time:36991ms step_avg:87.04ms +step:426/1680 train_time:37078ms step_avg:87.04ms +step:427/1680 train_time:37165ms step_avg:87.04ms +step:428/1680 train_time:37252ms step_avg:87.04ms +step:429/1680 train_time:37339ms step_avg:87.04ms +step:430/1680 train_time:37425ms step_avg:87.04ms +step:431/1680 train_time:37512ms step_avg:87.04ms +step:432/1680 train_time:37600ms step_avg:87.04ms +step:433/1680 train_time:37687ms step_avg:87.04ms +step:434/1680 train_time:37773ms step_avg:87.03ms +step:435/1680 train_time:37860ms step_avg:87.04ms +step:436/1680 train_time:37948ms step_avg:87.04ms +step:437/1680 train_time:38035ms step_avg:87.04ms +step:438/1680 train_time:38122ms step_avg:87.04ms +step:439/1680 train_time:38209ms step_avg:87.04ms +step:440/1680 train_time:38296ms step_avg:87.04ms +step:441/1680 train_time:38383ms step_avg:87.04ms +step:442/1680 train_time:38470ms step_avg:87.04ms +step:443/1680 train_time:38557ms step_avg:87.04ms +step:444/1680 train_time:38644ms step_avg:87.04ms +step:445/1680 train_time:38731ms step_avg:87.04ms +step:446/1680 train_time:38818ms step_avg:87.04ms +step:447/1680 train_time:38905ms step_avg:87.04ms +step:448/1680 train_time:38992ms step_avg:87.04ms +step:449/1680 train_time:39080ms step_avg:87.04ms +step:450/1680 train_time:39167ms step_avg:87.04ms +step:451/1680 train_time:39253ms step_avg:87.04ms +step:452/1680 train_time:39340ms step_avg:87.04ms +step:453/1680 train_time:39427ms step_avg:87.04ms +step:454/1680 train_time:39514ms step_avg:87.04ms +step:455/1680 train_time:39601ms step_avg:87.03ms +step:456/1680 train_time:39688ms step_avg:87.04ms +step:457/1680 train_time:39776ms step_avg:87.04ms +step:458/1680 train_time:39863ms step_avg:87.04ms +step:459/1680 train_time:39950ms step_avg:87.04ms +step:460/1680 train_time:40037ms step_avg:87.04ms +step:461/1680 train_time:40125ms step_avg:87.04ms +step:462/1680 train_time:40212ms step_avg:87.04ms +step:463/1680 train_time:40299ms step_avg:87.04ms +step:464/1680 train_time:40386ms step_avg:87.04ms +step:465/1680 train_time:40473ms step_avg:87.04ms +step:466/1680 train_time:40559ms step_avg:87.04ms +step:467/1680 train_time:40647ms step_avg:87.04ms +step:468/1680 train_time:40734ms step_avg:87.04ms +step:469/1680 train_time:40821ms step_avg:87.04ms +step:470/1680 train_time:40908ms step_avg:87.04ms +step:471/1680 train_time:40995ms step_avg:87.04ms +step:472/1680 train_time:41082ms step_avg:87.04ms +step:473/1680 train_time:41169ms step_avg:87.04ms +step:474/1680 train_time:41256ms step_avg:87.04ms +step:475/1680 train_time:41343ms step_avg:87.04ms +step:476/1680 train_time:41429ms step_avg:87.04ms +step:477/1680 train_time:41516ms step_avg:87.04ms +step:478/1680 train_time:41603ms step_avg:87.03ms +step:479/1680 train_time:41689ms step_avg:87.03ms +step:480/1680 train_time:41776ms step_avg:87.03ms +step:481/1680 train_time:41863ms step_avg:87.03ms +step:482/1680 train_time:41950ms step_avg:87.03ms +step:483/1680 train_time:42037ms step_avg:87.03ms +step:484/1680 train_time:42124ms step_avg:87.03ms +step:485/1680 train_time:42211ms step_avg:87.03ms +step:486/1680 train_time:42298ms step_avg:87.03ms +step:487/1680 train_time:42385ms step_avg:87.03ms +step:488/1680 train_time:42473ms step_avg:87.03ms +step:489/1680 train_time:42559ms step_avg:87.03ms +step:490/1680 train_time:42646ms step_avg:87.03ms +step:491/1680 train_time:42733ms step_avg:87.03ms +step:492/1680 train_time:42820ms step_avg:87.03ms +step:493/1680 train_time:42907ms step_avg:87.03ms +step:494/1680 train_time:42993ms step_avg:87.03ms +step:495/1680 train_time:43081ms step_avg:87.03ms +step:496/1680 train_time:43168ms step_avg:87.03ms +step:497/1680 train_time:43255ms step_avg:87.03ms +step:498/1680 train_time:43342ms step_avg:87.03ms +step:499/1680 train_time:43430ms step_avg:87.03ms +step:500/1680 train_time:43517ms step_avg:87.03ms +step:500/1680 val_loss:3.7176 train_time:43605ms step_avg:87.21ms +step:501/1680 train_time:43623ms step_avg:87.07ms +step:502/1680 train_time:43696ms step_avg:87.04ms +step:503/1680 train_time:43788ms step_avg:87.05ms +step:504/1680 train_time:43879ms step_avg:87.06ms +step:505/1680 train_time:43967ms step_avg:87.06ms +step:506/1680 train_time:44053ms step_avg:87.06ms +step:507/1680 train_time:44139ms step_avg:87.06ms +step:508/1680 train_time:44225ms step_avg:87.06ms +step:509/1680 train_time:44311ms step_avg:87.05ms +step:510/1680 train_time:44397ms step_avg:87.05ms +step:511/1680 train_time:44483ms step_avg:87.05ms +step:512/1680 train_time:44569ms step_avg:87.05ms +step:513/1680 train_time:44657ms step_avg:87.05ms +step:514/1680 train_time:44746ms step_avg:87.05ms +step:515/1680 train_time:44834ms step_avg:87.06ms +step:516/1680 train_time:44922ms step_avg:87.06ms +step:517/1680 train_time:45009ms step_avg:87.06ms +step:518/1680 train_time:45095ms step_avg:87.06ms +step:519/1680 train_time:45181ms step_avg:87.05ms +step:520/1680 train_time:45267ms step_avg:87.05ms +step:521/1680 train_time:45354ms step_avg:87.05ms +step:522/1680 train_time:45440ms step_avg:87.05ms +step:523/1680 train_time:45526ms step_avg:87.05ms +step:524/1680 train_time:45612ms step_avg:87.05ms +step:525/1680 train_time:45700ms step_avg:87.05ms +step:526/1680 train_time:45788ms step_avg:87.05ms +step:527/1680 train_time:45875ms step_avg:87.05ms +step:528/1680 train_time:45963ms step_avg:87.05ms +step:529/1680 train_time:46050ms step_avg:87.05ms +step:530/1680 train_time:46137ms step_avg:87.05ms +step:531/1680 train_time:46224ms step_avg:87.05ms +step:532/1680 train_time:46310ms step_avg:87.05ms +step:533/1680 train_time:46397ms step_avg:87.05ms +step:534/1680 train_time:46484ms step_avg:87.05ms +step:535/1680 train_time:46571ms step_avg:87.05ms +step:536/1680 train_time:46658ms step_avg:87.05ms +step:537/1680 train_time:46745ms step_avg:87.05ms +step:538/1680 train_time:46833ms step_avg:87.05ms +step:539/1680 train_time:46920ms step_avg:87.05ms +step:540/1680 train_time:47007ms step_avg:87.05ms +step:541/1680 train_time:47095ms step_avg:87.05ms +step:542/1680 train_time:47181ms step_avg:87.05ms +step:543/1680 train_time:47268ms step_avg:87.05ms +step:544/1680 train_time:47355ms step_avg:87.05ms +step:545/1680 train_time:47442ms step_avg:87.05ms +step:546/1680 train_time:47529ms step_avg:87.05ms +step:547/1680 train_time:47616ms step_avg:87.05ms +step:548/1680 train_time:47703ms step_avg:87.05ms +step:549/1680 train_time:47793ms step_avg:87.05ms +step:550/1680 train_time:47882ms step_avg:87.06ms +step:551/1680 train_time:47970ms step_avg:87.06ms +step:552/1680 train_time:48058ms step_avg:87.06ms +step:553/1680 train_time:48146ms step_avg:87.06ms +step:554/1680 train_time:48234ms step_avg:87.06ms +step:555/1680 train_time:48322ms step_avg:87.07ms +step:556/1680 train_time:48410ms step_avg:87.07ms +step:557/1680 train_time:48498ms step_avg:87.07ms +step:558/1680 train_time:48586ms step_avg:87.07ms +step:559/1680 train_time:48674ms step_avg:87.07ms +step:560/1680 train_time:48763ms step_avg:87.08ms +step:561/1680 train_time:48852ms step_avg:87.08ms +step:562/1680 train_time:48940ms step_avg:87.08ms +step:563/1680 train_time:49028ms step_avg:87.08ms +step:564/1680 train_time:49116ms step_avg:87.09ms +step:565/1680 train_time:49204ms step_avg:87.09ms +step:566/1680 train_time:49292ms step_avg:87.09ms +step:567/1680 train_time:49379ms step_avg:87.09ms +step:568/1680 train_time:49467ms step_avg:87.09ms +step:569/1680 train_time:49555ms step_avg:87.09ms +step:570/1680 train_time:49643ms step_avg:87.09ms +step:571/1680 train_time:49731ms step_avg:87.10ms +step:572/1680 train_time:49819ms step_avg:87.10ms +step:573/1680 train_time:49908ms step_avg:87.10ms +step:574/1680 train_time:49996ms step_avg:87.10ms +step:575/1680 train_time:50084ms step_avg:87.10ms +step:576/1680 train_time:50172ms step_avg:87.10ms +step:577/1680 train_time:50261ms step_avg:87.11ms +step:578/1680 train_time:50348ms step_avg:87.11ms +step:579/1680 train_time:50436ms step_avg:87.11ms +step:580/1680 train_time:50524ms step_avg:87.11ms +step:581/1680 train_time:50611ms step_avg:87.11ms +step:582/1680 train_time:50699ms step_avg:87.11ms +step:583/1680 train_time:50789ms step_avg:87.12ms +step:584/1680 train_time:50877ms step_avg:87.12ms +step:585/1680 train_time:50965ms step_avg:87.12ms +step:586/1680 train_time:51054ms step_avg:87.12ms +step:587/1680 train_time:51143ms step_avg:87.13ms +step:588/1680 train_time:51230ms step_avg:87.13ms +step:589/1680 train_time:51318ms step_avg:87.13ms +step:590/1680 train_time:51406ms step_avg:87.13ms +step:591/1680 train_time:51494ms step_avg:87.13ms +step:592/1680 train_time:51582ms step_avg:87.13ms +step:593/1680 train_time:51670ms step_avg:87.13ms +step:594/1680 train_time:51758ms step_avg:87.14ms +step:595/1680 train_time:51846ms step_avg:87.14ms +step:596/1680 train_time:51935ms step_avg:87.14ms +step:597/1680 train_time:52024ms step_avg:87.14ms +step:598/1680 train_time:52112ms step_avg:87.14ms +step:599/1680 train_time:52200ms step_avg:87.15ms +step:600/1680 train_time:52289ms step_avg:87.15ms +step:601/1680 train_time:52377ms step_avg:87.15ms +step:602/1680 train_time:52465ms step_avg:87.15ms +step:603/1680 train_time:52553ms step_avg:87.15ms +step:604/1680 train_time:52640ms step_avg:87.15ms +step:605/1680 train_time:52728ms step_avg:87.15ms +step:606/1680 train_time:52816ms step_avg:87.15ms +step:607/1680 train_time:52904ms step_avg:87.16ms +step:608/1680 train_time:52992ms step_avg:87.16ms +step:609/1680 train_time:53080ms step_avg:87.16ms +step:610/1680 train_time:53169ms step_avg:87.16ms +step:611/1680 train_time:53257ms step_avg:87.16ms +step:612/1680 train_time:53345ms step_avg:87.17ms +step:613/1680 train_time:53433ms step_avg:87.17ms +step:614/1680 train_time:53522ms step_avg:87.17ms +step:615/1680 train_time:53610ms step_avg:87.17ms +step:616/1680 train_time:53697ms step_avg:87.17ms +step:617/1680 train_time:53785ms step_avg:87.17ms +step:618/1680 train_time:53874ms step_avg:87.17ms +step:619/1680 train_time:53962ms step_avg:87.18ms +step:620/1680 train_time:54050ms step_avg:87.18ms +step:621/1680 train_time:54138ms step_avg:87.18ms +step:622/1680 train_time:54226ms step_avg:87.18ms +step:623/1680 train_time:54314ms step_avg:87.18ms +step:624/1680 train_time:54402ms step_avg:87.18ms +step:625/1680 train_time:54490ms step_avg:87.18ms +step:625/1680 val_loss:3.6166 train_time:54580ms step_avg:87.33ms +step:626/1680 train_time:54599ms step_avg:87.22ms +step:627/1680 train_time:54669ms step_avg:87.19ms +step:628/1680 train_time:54757ms step_avg:87.19ms +step:629/1680 train_time:54848ms step_avg:87.20ms +step:630/1680 train_time:54935ms step_avg:87.20ms +step:631/1680 train_time:55022ms step_avg:87.20ms +step:632/1680 train_time:55109ms step_avg:87.20ms +step:633/1680 train_time:55197ms step_avg:87.20ms +step:634/1680 train_time:55284ms step_avg:87.20ms +step:635/1680 train_time:55371ms step_avg:87.20ms +step:636/1680 train_time:55459ms step_avg:87.20ms +step:637/1680 train_time:55550ms step_avg:87.21ms +step:638/1680 train_time:55640ms step_avg:87.21ms +step:639/1680 train_time:55728ms step_avg:87.21ms +step:640/1680 train_time:55817ms step_avg:87.21ms +step:641/1680 train_time:55905ms step_avg:87.21ms +step:642/1680 train_time:55992ms step_avg:87.22ms +step:643/1680 train_time:56080ms step_avg:87.22ms +step:644/1680 train_time:56167ms step_avg:87.22ms +step:645/1680 train_time:56255ms step_avg:87.22ms +step:646/1680 train_time:56342ms step_avg:87.22ms +step:647/1680 train_time:56431ms step_avg:87.22ms +step:648/1680 train_time:56520ms step_avg:87.22ms +step:649/1680 train_time:56609ms step_avg:87.23ms +step:650/1680 train_time:56698ms step_avg:87.23ms +step:651/1680 train_time:56786ms step_avg:87.23ms +step:652/1680 train_time:56875ms step_avg:87.23ms +step:653/1680 train_time:56963ms step_avg:87.23ms +step:654/1680 train_time:57050ms step_avg:87.23ms +step:655/1680 train_time:57138ms step_avg:87.23ms +step:656/1680 train_time:57226ms step_avg:87.23ms +step:657/1680 train_time:57313ms step_avg:87.23ms +step:658/1680 train_time:57401ms step_avg:87.24ms +step:659/1680 train_time:57489ms step_avg:87.24ms +step:660/1680 train_time:57577ms step_avg:87.24ms +step:661/1680 train_time:57666ms step_avg:87.24ms +step:662/1680 train_time:57755ms step_avg:87.24ms +step:663/1680 train_time:57843ms step_avg:87.24ms +step:664/1680 train_time:57931ms step_avg:87.25ms +step:665/1680 train_time:58019ms step_avg:87.25ms +step:666/1680 train_time:58107ms step_avg:87.25ms +step:667/1680 train_time:58194ms step_avg:87.25ms +step:668/1680 train_time:58283ms step_avg:87.25ms +step:669/1680 train_time:58370ms step_avg:87.25ms +step:670/1680 train_time:58459ms step_avg:87.25ms +step:671/1680 train_time:58547ms step_avg:87.25ms +step:672/1680 train_time:58635ms step_avg:87.25ms +step:673/1680 train_time:58723ms step_avg:87.26ms +step:674/1680 train_time:58811ms step_avg:87.26ms +step:675/1680 train_time:58899ms step_avg:87.26ms +step:676/1680 train_time:58989ms step_avg:87.26ms +step:677/1680 train_time:59077ms step_avg:87.26ms +step:678/1680 train_time:59165ms step_avg:87.26ms +step:679/1680 train_time:59254ms step_avg:87.27ms +step:680/1680 train_time:59341ms step_avg:87.27ms +step:681/1680 train_time:59429ms step_avg:87.27ms +step:682/1680 train_time:59517ms step_avg:87.27ms +step:683/1680 train_time:59605ms step_avg:87.27ms +step:684/1680 train_time:59693ms step_avg:87.27ms +step:685/1680 train_time:59781ms step_avg:87.27ms +step:686/1680 train_time:59869ms step_avg:87.27ms +step:687/1680 train_time:59957ms step_avg:87.27ms +step:688/1680 train_time:60045ms step_avg:87.27ms +step:689/1680 train_time:60133ms step_avg:87.28ms +step:690/1680 train_time:60221ms step_avg:87.28ms +step:691/1680 train_time:60310ms step_avg:87.28ms +step:692/1680 train_time:60398ms step_avg:87.28ms +step:693/1680 train_time:60486ms step_avg:87.28ms +step:694/1680 train_time:60574ms step_avg:87.28ms +step:695/1680 train_time:60663ms step_avg:87.28ms +step:696/1680 train_time:60751ms step_avg:87.29ms +step:697/1680 train_time:60839ms step_avg:87.29ms +step:698/1680 train_time:60928ms step_avg:87.29ms +step:699/1680 train_time:61015ms step_avg:87.29ms +step:700/1680 train_time:61104ms step_avg:87.29ms +step:701/1680 train_time:61192ms step_avg:87.29ms +step:702/1680 train_time:61280ms step_avg:87.29ms +step:703/1680 train_time:61368ms step_avg:87.29ms +step:704/1680 train_time:61456ms step_avg:87.30ms +step:705/1680 train_time:61544ms step_avg:87.30ms +step:706/1680 train_time:61632ms step_avg:87.30ms +step:707/1680 train_time:61720ms step_avg:87.30ms +step:708/1680 train_time:61808ms step_avg:87.30ms +step:709/1680 train_time:61896ms step_avg:87.30ms +step:710/1680 train_time:61983ms step_avg:87.30ms +step:711/1680 train_time:62072ms step_avg:87.30ms +step:712/1680 train_time:62159ms step_avg:87.30ms +step:713/1680 train_time:62248ms step_avg:87.30ms +step:714/1680 train_time:62336ms step_avg:87.31ms +step:715/1680 train_time:62424ms step_avg:87.31ms +step:716/1680 train_time:62512ms step_avg:87.31ms +step:717/1680 train_time:62601ms step_avg:87.31ms +step:718/1680 train_time:62689ms step_avg:87.31ms +step:719/1680 train_time:62777ms step_avg:87.31ms +step:720/1680 train_time:62865ms step_avg:87.31ms +step:721/1680 train_time:62953ms step_avg:87.31ms +step:722/1680 train_time:63042ms step_avg:87.32ms +step:723/1680 train_time:63129ms step_avg:87.32ms +step:724/1680 train_time:63217ms step_avg:87.32ms +step:725/1680 train_time:63305ms step_avg:87.32ms +step:726/1680 train_time:63393ms step_avg:87.32ms +step:727/1680 train_time:63482ms step_avg:87.32ms +step:728/1680 train_time:63571ms step_avg:87.32ms +step:729/1680 train_time:63659ms step_avg:87.32ms +step:730/1680 train_time:63747ms step_avg:87.32ms +step:731/1680 train_time:63835ms step_avg:87.32ms +step:732/1680 train_time:63922ms step_avg:87.33ms +step:733/1680 train_time:64011ms step_avg:87.33ms +step:734/1680 train_time:64098ms step_avg:87.33ms +step:735/1680 train_time:64186ms step_avg:87.33ms +step:736/1680 train_time:64274ms step_avg:87.33ms +step:737/1680 train_time:64362ms step_avg:87.33ms +step:738/1680 train_time:64450ms step_avg:87.33ms +step:739/1680 train_time:64539ms step_avg:87.33ms +step:740/1680 train_time:64626ms step_avg:87.33ms +step:741/1680 train_time:64714ms step_avg:87.33ms +step:742/1680 train_time:64802ms step_avg:87.33ms +step:743/1680 train_time:64891ms step_avg:87.34ms +step:744/1680 train_time:64979ms step_avg:87.34ms +step:745/1680 train_time:65067ms step_avg:87.34ms +step:746/1680 train_time:65155ms step_avg:87.34ms +step:747/1680 train_time:65243ms step_avg:87.34ms +step:748/1680 train_time:65331ms step_avg:87.34ms +step:749/1680 train_time:65419ms step_avg:87.34ms +step:750/1680 train_time:65508ms step_avg:87.34ms +step:750/1680 val_loss:3.5658 train_time:65597ms step_avg:87.46ms +step:751/1680 train_time:65615ms step_avg:87.37ms +step:752/1680 train_time:65688ms step_avg:87.35ms +step:753/1680 train_time:65780ms step_avg:87.36ms +step:754/1680 train_time:65870ms step_avg:87.36ms +step:755/1680 train_time:65958ms step_avg:87.36ms +step:756/1680 train_time:66045ms step_avg:87.36ms +step:757/1680 train_time:66133ms step_avg:87.36ms +step:758/1680 train_time:66221ms step_avg:87.36ms +step:759/1680 train_time:66308ms step_avg:87.36ms +step:760/1680 train_time:66395ms step_avg:87.36ms +step:761/1680 train_time:66482ms step_avg:87.36ms +step:762/1680 train_time:66570ms step_avg:87.36ms +step:763/1680 train_time:66659ms step_avg:87.36ms +step:764/1680 train_time:66748ms step_avg:87.37ms +step:765/1680 train_time:66838ms step_avg:87.37ms +step:766/1680 train_time:66928ms step_avg:87.37ms +step:767/1680 train_time:67016ms step_avg:87.37ms +step:768/1680 train_time:67103ms step_avg:87.37ms +step:769/1680 train_time:67191ms step_avg:87.37ms +step:770/1680 train_time:67278ms step_avg:87.37ms +step:771/1680 train_time:67365ms step_avg:87.37ms +step:772/1680 train_time:67453ms step_avg:87.37ms +step:773/1680 train_time:67541ms step_avg:87.37ms +step:774/1680 train_time:67629ms step_avg:87.38ms +step:775/1680 train_time:67718ms step_avg:87.38ms +step:776/1680 train_time:67807ms step_avg:87.38ms +step:777/1680 train_time:67896ms step_avg:87.38ms +step:778/1680 train_time:67984ms step_avg:87.38ms +step:779/1680 train_time:68073ms step_avg:87.38ms +step:780/1680 train_time:68161ms step_avg:87.39ms +step:781/1680 train_time:68248ms step_avg:87.39ms +step:782/1680 train_time:68335ms step_avg:87.39ms +step:783/1680 train_time:68423ms step_avg:87.39ms +step:784/1680 train_time:68510ms step_avg:87.39ms +step:785/1680 train_time:68599ms step_avg:87.39ms +step:786/1680 train_time:68687ms step_avg:87.39ms +step:787/1680 train_time:68775ms step_avg:87.39ms +step:788/1680 train_time:68864ms step_avg:87.39ms +step:789/1680 train_time:68954ms step_avg:87.39ms +step:790/1680 train_time:69042ms step_avg:87.40ms +step:791/1680 train_time:69130ms step_avg:87.40ms +step:792/1680 train_time:69218ms step_avg:87.40ms +step:793/1680 train_time:69306ms step_avg:87.40ms +step:794/1680 train_time:69395ms step_avg:87.40ms +step:795/1680 train_time:69483ms step_avg:87.40ms +step:796/1680 train_time:69571ms step_avg:87.40ms +step:797/1680 train_time:69659ms step_avg:87.40ms +step:798/1680 train_time:69748ms step_avg:87.40ms +step:799/1680 train_time:69837ms step_avg:87.41ms +step:800/1680 train_time:69925ms step_avg:87.41ms +step:801/1680 train_time:70014ms step_avg:87.41ms +step:802/1680 train_time:70101ms step_avg:87.41ms +step:803/1680 train_time:70189ms step_avg:87.41ms +step:804/1680 train_time:70276ms step_avg:87.41ms +step:805/1680 train_time:70364ms step_avg:87.41ms +step:806/1680 train_time:70453ms step_avg:87.41ms +step:807/1680 train_time:70541ms step_avg:87.41ms +step:808/1680 train_time:70628ms step_avg:87.41ms +step:809/1680 train_time:70716ms step_avg:87.41ms +step:810/1680 train_time:70804ms step_avg:87.41ms +step:811/1680 train_time:70892ms step_avg:87.41ms +step:812/1680 train_time:70981ms step_avg:87.42ms +step:813/1680 train_time:71069ms step_avg:87.42ms +step:814/1680 train_time:71157ms step_avg:87.42ms +step:815/1680 train_time:71245ms step_avg:87.42ms +step:816/1680 train_time:71333ms step_avg:87.42ms +step:817/1680 train_time:71422ms step_avg:87.42ms +step:818/1680 train_time:71511ms step_avg:87.42ms +step:819/1680 train_time:71599ms step_avg:87.42ms +step:820/1680 train_time:71687ms step_avg:87.42ms +step:821/1680 train_time:71775ms step_avg:87.42ms +step:822/1680 train_time:71863ms step_avg:87.42ms +step:823/1680 train_time:71952ms step_avg:87.43ms +step:824/1680 train_time:72041ms step_avg:87.43ms +step:825/1680 train_time:72130ms step_avg:87.43ms +step:826/1680 train_time:72217ms step_avg:87.43ms +step:827/1680 train_time:72305ms step_avg:87.43ms +step:828/1680 train_time:72393ms step_avg:87.43ms +step:829/1680 train_time:72481ms step_avg:87.43ms +step:830/1680 train_time:72569ms step_avg:87.43ms +step:831/1680 train_time:72657ms step_avg:87.43ms +step:832/1680 train_time:72745ms step_avg:87.43ms +step:833/1680 train_time:72833ms step_avg:87.43ms +step:834/1680 train_time:72922ms step_avg:87.44ms +step:835/1680 train_time:73009ms step_avg:87.44ms +step:836/1680 train_time:73097ms step_avg:87.44ms +step:837/1680 train_time:73186ms step_avg:87.44ms +step:838/1680 train_time:73274ms step_avg:87.44ms +step:839/1680 train_time:73362ms step_avg:87.44ms +step:840/1680 train_time:73450ms step_avg:87.44ms +step:841/1680 train_time:73538ms step_avg:87.44ms +step:842/1680 train_time:73626ms step_avg:87.44ms +step:843/1680 train_time:73715ms step_avg:87.44ms +step:844/1680 train_time:73803ms step_avg:87.44ms +step:845/1680 train_time:73891ms step_avg:87.44ms +step:846/1680 train_time:73979ms step_avg:87.45ms +step:847/1680 train_time:74068ms step_avg:87.45ms +step:848/1680 train_time:74156ms step_avg:87.45ms +step:849/1680 train_time:74244ms step_avg:87.45ms +step:850/1680 train_time:74332ms step_avg:87.45ms +step:851/1680 train_time:74420ms step_avg:87.45ms +step:852/1680 train_time:74508ms step_avg:87.45ms +step:853/1680 train_time:74596ms step_avg:87.45ms +step:854/1680 train_time:74685ms step_avg:87.45ms +step:855/1680 train_time:74773ms step_avg:87.45ms +step:856/1680 train_time:74861ms step_avg:87.45ms +step:857/1680 train_time:74949ms step_avg:87.46ms +step:858/1680 train_time:75037ms step_avg:87.46ms +step:859/1680 train_time:75126ms step_avg:87.46ms +step:860/1680 train_time:75215ms step_avg:87.46ms +step:861/1680 train_time:75303ms step_avg:87.46ms +step:862/1680 train_time:75391ms step_avg:87.46ms +step:863/1680 train_time:75479ms step_avg:87.46ms +step:864/1680 train_time:75567ms step_avg:87.46ms +step:865/1680 train_time:75655ms step_avg:87.46ms +step:866/1680 train_time:75743ms step_avg:87.46ms +step:867/1680 train_time:75832ms step_avg:87.46ms +step:868/1680 train_time:75919ms step_avg:87.46ms +step:869/1680 train_time:76007ms step_avg:87.47ms +step:870/1680 train_time:76095ms step_avg:87.47ms +step:871/1680 train_time:76184ms step_avg:87.47ms +step:872/1680 train_time:76272ms step_avg:87.47ms +step:873/1680 train_time:76360ms step_avg:87.47ms +step:874/1680 train_time:76448ms step_avg:87.47ms +step:875/1680 train_time:76536ms step_avg:87.47ms +step:875/1680 val_loss:3.5187 train_time:76626ms step_avg:87.57ms +step:876/1680 train_time:76646ms step_avg:87.50ms +step:877/1680 train_time:76719ms step_avg:87.48ms +step:878/1680 train_time:76813ms step_avg:87.49ms +step:879/1680 train_time:76903ms step_avg:87.49ms +step:880/1680 train_time:76991ms step_avg:87.49ms +step:881/1680 train_time:77078ms step_avg:87.49ms +step:882/1680 train_time:77165ms step_avg:87.49ms +step:883/1680 train_time:77252ms step_avg:87.49ms +step:884/1680 train_time:77339ms step_avg:87.49ms +step:885/1680 train_time:77426ms step_avg:87.49ms +step:886/1680 train_time:77514ms step_avg:87.49ms +step:887/1680 train_time:77603ms step_avg:87.49ms +step:888/1680 train_time:77692ms step_avg:87.49ms +step:889/1680 train_time:77782ms step_avg:87.49ms +step:890/1680 train_time:77872ms step_avg:87.50ms +step:891/1680 train_time:77961ms step_avg:87.50ms +step:892/1680 train_time:78050ms step_avg:87.50ms +step:893/1680 train_time:78137ms step_avg:87.50ms +step:894/1680 train_time:78224ms step_avg:87.50ms +step:895/1680 train_time:78312ms step_avg:87.50ms +step:896/1680 train_time:78399ms step_avg:87.50ms +step:897/1680 train_time:78487ms step_avg:87.50ms +step:898/1680 train_time:78575ms step_avg:87.50ms +step:899/1680 train_time:78664ms step_avg:87.50ms +step:900/1680 train_time:78754ms step_avg:87.50ms +step:901/1680 train_time:78842ms step_avg:87.51ms +step:902/1680 train_time:78931ms step_avg:87.51ms +step:903/1680 train_time:79019ms step_avg:87.51ms +step:904/1680 train_time:79108ms step_avg:87.51ms +step:905/1680 train_time:79196ms step_avg:87.51ms +step:906/1680 train_time:79283ms step_avg:87.51ms +step:907/1680 train_time:79371ms step_avg:87.51ms +step:908/1680 train_time:79459ms step_avg:87.51ms +step:909/1680 train_time:79546ms step_avg:87.51ms +step:910/1680 train_time:79634ms step_avg:87.51ms +step:911/1680 train_time:79724ms step_avg:87.51ms +step:912/1680 train_time:79812ms step_avg:87.51ms +step:913/1680 train_time:79902ms step_avg:87.52ms +step:914/1680 train_time:79990ms step_avg:87.52ms +step:915/1680 train_time:80078ms step_avg:87.52ms +step:916/1680 train_time:80166ms step_avg:87.52ms +step:917/1680 train_time:80255ms step_avg:87.52ms +step:918/1680 train_time:80342ms step_avg:87.52ms +step:919/1680 train_time:80430ms step_avg:87.52ms +step:920/1680 train_time:80518ms step_avg:87.52ms +step:921/1680 train_time:80605ms step_avg:87.52ms +step:922/1680 train_time:80694ms step_avg:87.52ms +step:923/1680 train_time:80783ms step_avg:87.52ms +step:924/1680 train_time:80872ms step_avg:87.52ms +step:925/1680 train_time:80961ms step_avg:87.53ms +step:926/1680 train_time:81049ms step_avg:87.53ms +step:927/1680 train_time:81136ms step_avg:87.53ms +step:928/1680 train_time:81224ms step_avg:87.53ms +step:929/1680 train_time:81312ms step_avg:87.53ms +step:930/1680 train_time:81401ms step_avg:87.53ms +step:931/1680 train_time:81489ms step_avg:87.53ms +step:932/1680 train_time:81576ms step_avg:87.53ms +step:933/1680 train_time:81664ms step_avg:87.53ms +step:934/1680 train_time:81752ms step_avg:87.53ms +step:935/1680 train_time:81840ms step_avg:87.53ms +step:936/1680 train_time:81928ms step_avg:87.53ms +step:937/1680 train_time:82017ms step_avg:87.53ms +step:938/1680 train_time:82105ms step_avg:87.53ms +step:939/1680 train_time:82193ms step_avg:87.53ms +step:940/1680 train_time:82282ms step_avg:87.53ms +step:941/1680 train_time:82369ms step_avg:87.53ms +step:942/1680 train_time:82457ms step_avg:87.53ms +step:943/1680 train_time:82545ms step_avg:87.53ms +step:944/1680 train_time:82633ms step_avg:87.53ms +step:945/1680 train_time:82721ms step_avg:87.54ms +step:946/1680 train_time:82809ms step_avg:87.54ms +step:947/1680 train_time:82898ms step_avg:87.54ms +step:948/1680 train_time:82986ms step_avg:87.54ms +step:949/1680 train_time:83075ms step_avg:87.54ms +step:950/1680 train_time:83163ms step_avg:87.54ms +step:951/1680 train_time:83251ms step_avg:87.54ms +step:952/1680 train_time:83339ms step_avg:87.54ms +step:953/1680 train_time:83426ms step_avg:87.54ms +step:954/1680 train_time:83514ms step_avg:87.54ms +step:955/1680 train_time:83601ms step_avg:87.54ms +step:956/1680 train_time:83690ms step_avg:87.54ms +step:957/1680 train_time:83778ms step_avg:87.54ms +step:958/1680 train_time:83866ms step_avg:87.54ms +step:959/1680 train_time:83954ms step_avg:87.54ms +step:960/1680 train_time:84043ms step_avg:87.54ms +step:961/1680 train_time:84131ms step_avg:87.55ms +step:962/1680 train_time:84220ms step_avg:87.55ms +step:963/1680 train_time:84307ms step_avg:87.55ms +step:964/1680 train_time:84396ms step_avg:87.55ms +step:965/1680 train_time:84484ms step_avg:87.55ms +step:966/1680 train_time:84573ms step_avg:87.55ms +step:967/1680 train_time:84662ms step_avg:87.55ms +step:968/1680 train_time:84748ms step_avg:87.55ms +step:969/1680 train_time:84837ms step_avg:87.55ms +step:970/1680 train_time:84924ms step_avg:87.55ms +step:971/1680 train_time:85013ms step_avg:87.55ms +step:972/1680 train_time:85101ms step_avg:87.55ms +step:973/1680 train_time:85189ms step_avg:87.55ms +step:974/1680 train_time:85277ms step_avg:87.55ms +step:975/1680 train_time:85366ms step_avg:87.55ms +step:976/1680 train_time:85454ms step_avg:87.55ms +step:977/1680 train_time:85542ms step_avg:87.56ms +step:978/1680 train_time:85630ms step_avg:87.56ms +step:979/1680 train_time:85718ms step_avg:87.56ms +step:980/1680 train_time:85807ms step_avg:87.56ms +step:981/1680 train_time:85895ms step_avg:87.56ms +step:982/1680 train_time:85983ms step_avg:87.56ms +step:983/1680 train_time:86071ms step_avg:87.56ms +step:984/1680 train_time:86160ms step_avg:87.56ms +step:985/1680 train_time:86247ms step_avg:87.56ms +step:986/1680 train_time:86335ms step_avg:87.56ms +step:987/1680 train_time:86423ms step_avg:87.56ms +step:988/1680 train_time:86512ms step_avg:87.56ms +step:989/1680 train_time:86600ms step_avg:87.56ms +step:990/1680 train_time:86688ms step_avg:87.56ms +step:991/1680 train_time:86776ms step_avg:87.56ms +step:992/1680 train_time:86865ms step_avg:87.57ms +step:993/1680 train_time:86953ms step_avg:87.57ms +step:994/1680 train_time:87041ms step_avg:87.57ms +step:995/1680 train_time:87129ms step_avg:87.57ms +step:996/1680 train_time:87217ms step_avg:87.57ms +step:997/1680 train_time:87305ms step_avg:87.57ms +step:998/1680 train_time:87393ms step_avg:87.57ms +step:999/1680 train_time:87482ms step_avg:87.57ms +step:1000/1680 train_time:87570ms step_avg:87.57ms +step:1000/1680 val_loss:3.4694 train_time:87660ms step_avg:87.66ms +step:1001/1680 train_time:87678ms step_avg:87.59ms +step:1002/1680 train_time:87753ms step_avg:87.58ms +step:1003/1680 train_time:87846ms step_avg:87.58ms +step:1004/1680 train_time:87935ms step_avg:87.59ms +step:1005/1680 train_time:88023ms step_avg:87.58ms +step:1006/1680 train_time:88110ms step_avg:87.58ms +step:1007/1680 train_time:88197ms step_avg:87.58ms +step:1008/1680 train_time:88285ms step_avg:87.58ms +step:1009/1680 train_time:88372ms step_avg:87.58ms +step:1010/1680 train_time:88459ms step_avg:87.58ms +step:1011/1680 train_time:88546ms step_avg:87.58ms +step:1012/1680 train_time:88635ms step_avg:87.58ms +step:1013/1680 train_time:88725ms step_avg:87.59ms +step:1014/1680 train_time:88815ms step_avg:87.59ms +step:1015/1680 train_time:88905ms step_avg:87.59ms +step:1016/1680 train_time:88993ms step_avg:87.59ms +step:1017/1680 train_time:89081ms step_avg:87.59ms +step:1018/1680 train_time:89168ms step_avg:87.59ms +step:1019/1680 train_time:89256ms step_avg:87.59ms +step:1020/1680 train_time:89343ms step_avg:87.59ms +step:1021/1680 train_time:89430ms step_avg:87.59ms +step:1022/1680 train_time:89518ms step_avg:87.59ms +step:1023/1680 train_time:89606ms step_avg:87.59ms +step:1024/1680 train_time:89695ms step_avg:87.59ms +step:1025/1680 train_time:89785ms step_avg:87.60ms +step:1026/1680 train_time:89874ms step_avg:87.60ms +step:1027/1680 train_time:89963ms step_avg:87.60ms +step:1028/1680 train_time:90051ms step_avg:87.60ms +step:1029/1680 train_time:90139ms step_avg:87.60ms +step:1030/1680 train_time:90226ms step_avg:87.60ms +step:1031/1680 train_time:90314ms step_avg:87.60ms +step:1032/1680 train_time:90402ms step_avg:87.60ms +step:1033/1680 train_time:90489ms step_avg:87.60ms +step:1034/1680 train_time:90577ms step_avg:87.60ms +step:1035/1680 train_time:90666ms step_avg:87.60ms +step:1036/1680 train_time:90754ms step_avg:87.60ms +step:1037/1680 train_time:90843ms step_avg:87.60ms +step:1038/1680 train_time:90931ms step_avg:87.60ms +step:1039/1680 train_time:91019ms step_avg:87.60ms +step:1040/1680 train_time:91108ms step_avg:87.60ms +step:1041/1680 train_time:91197ms step_avg:87.61ms +step:1042/1680 train_time:91284ms step_avg:87.60ms +step:1043/1680 train_time:91372ms step_avg:87.61ms +step:1044/1680 train_time:91460ms step_avg:87.61ms +step:1045/1680 train_time:91548ms step_avg:87.61ms +step:1046/1680 train_time:91636ms step_avg:87.61ms +step:1047/1680 train_time:91725ms step_avg:87.61ms +step:1048/1680 train_time:91814ms step_avg:87.61ms +step:1049/1680 train_time:91902ms step_avg:87.61ms +step:1050/1680 train_time:91990ms step_avg:87.61ms +step:1051/1680 train_time:92078ms step_avg:87.61ms +step:1052/1680 train_time:92166ms step_avg:87.61ms +step:1053/1680 train_time:92254ms step_avg:87.61ms +step:1054/1680 train_time:92342ms step_avg:87.61ms +step:1055/1680 train_time:92429ms step_avg:87.61ms +step:1056/1680 train_time:92518ms step_avg:87.61ms +step:1057/1680 train_time:92606ms step_avg:87.61ms +step:1058/1680 train_time:92694ms step_avg:87.61ms +step:1059/1680 train_time:92783ms step_avg:87.61ms +step:1060/1680 train_time:92871ms step_avg:87.61ms +step:1061/1680 train_time:92959ms step_avg:87.61ms +step:1062/1680 train_time:93047ms step_avg:87.61ms +step:1063/1680 train_time:93135ms step_avg:87.62ms +step:1064/1680 train_time:93223ms step_avg:87.62ms +step:1065/1680 train_time:93311ms step_avg:87.62ms +step:1066/1680 train_time:93398ms step_avg:87.62ms +step:1067/1680 train_time:93486ms step_avg:87.62ms +step:1068/1680 train_time:93574ms step_avg:87.62ms +step:1069/1680 train_time:93661ms step_avg:87.62ms +step:1070/1680 train_time:93750ms step_avg:87.62ms +step:1071/1680 train_time:93838ms step_avg:87.62ms +step:1072/1680 train_time:93927ms step_avg:87.62ms +step:1073/1680 train_time:94015ms step_avg:87.62ms +step:1074/1680 train_time:94103ms step_avg:87.62ms +step:1075/1680 train_time:94191ms step_avg:87.62ms +step:1076/1680 train_time:94279ms step_avg:87.62ms +step:1077/1680 train_time:94367ms step_avg:87.62ms +step:1078/1680 train_time:94455ms step_avg:87.62ms +step:1079/1680 train_time:94543ms step_avg:87.62ms +step:1080/1680 train_time:94631ms step_avg:87.62ms +step:1081/1680 train_time:94719ms step_avg:87.62ms +step:1082/1680 train_time:94808ms step_avg:87.62ms +step:1083/1680 train_time:94896ms step_avg:87.62ms +step:1084/1680 train_time:94984ms step_avg:87.62ms +step:1085/1680 train_time:95072ms step_avg:87.62ms +step:1086/1680 train_time:95160ms step_avg:87.62ms +step:1087/1680 train_time:95249ms step_avg:87.63ms +step:1088/1680 train_time:95338ms step_avg:87.63ms +step:1089/1680 train_time:95425ms step_avg:87.63ms +step:1090/1680 train_time:95513ms step_avg:87.63ms +step:1091/1680 train_time:95601ms step_avg:87.63ms +step:1092/1680 train_time:95689ms step_avg:87.63ms +step:1093/1680 train_time:95777ms step_avg:87.63ms +step:1094/1680 train_time:95865ms step_avg:87.63ms +step:1095/1680 train_time:95954ms step_avg:87.63ms +step:1096/1680 train_time:96043ms step_avg:87.63ms +step:1097/1680 train_time:96131ms step_avg:87.63ms +step:1098/1680 train_time:96220ms step_avg:87.63ms +step:1099/1680 train_time:96310ms step_avg:87.63ms +step:1100/1680 train_time:96399ms step_avg:87.64ms +step:1101/1680 train_time:96488ms step_avg:87.64ms +step:1102/1680 train_time:96577ms step_avg:87.64ms +step:1103/1680 train_time:96665ms step_avg:87.64ms +step:1104/1680 train_time:96755ms step_avg:87.64ms +step:1105/1680 train_time:96843ms step_avg:87.64ms +step:1106/1680 train_time:96933ms step_avg:87.64ms +step:1107/1680 train_time:97022ms step_avg:87.64ms +step:1108/1680 train_time:97111ms step_avg:87.65ms +step:1109/1680 train_time:97200ms step_avg:87.65ms +step:1110/1680 train_time:97288ms step_avg:87.65ms +step:1111/1680 train_time:97377ms step_avg:87.65ms +step:1112/1680 train_time:97466ms step_avg:87.65ms +step:1113/1680 train_time:97555ms step_avg:87.65ms +step:1114/1680 train_time:97643ms step_avg:87.65ms +step:1115/1680 train_time:97734ms step_avg:87.65ms +step:1116/1680 train_time:97823ms step_avg:87.66ms +step:1117/1680 train_time:97912ms step_avg:87.66ms +step:1118/1680 train_time:98002ms step_avg:87.66ms +step:1119/1680 train_time:98091ms step_avg:87.66ms +step:1120/1680 train_time:98179ms step_avg:87.66ms +step:1121/1680 train_time:98267ms step_avg:87.66ms +step:1122/1680 train_time:98356ms step_avg:87.66ms +step:1123/1680 train_time:98445ms step_avg:87.66ms +step:1124/1680 train_time:98534ms step_avg:87.66ms +step:1125/1680 train_time:98623ms step_avg:87.66ms +step:1125/1680 val_loss:3.4160 train_time:98713ms step_avg:87.75ms +step:1126/1680 train_time:98731ms step_avg:87.68ms +step:1127/1680 train_time:98803ms step_avg:87.67ms +step:1128/1680 train_time:98893ms step_avg:87.67ms +step:1129/1680 train_time:98984ms step_avg:87.67ms +step:1130/1680 train_time:99072ms step_avg:87.67ms +step:1131/1680 train_time:99160ms step_avg:87.67ms +step:1132/1680 train_time:99248ms step_avg:87.68ms +step:1133/1680 train_time:99337ms step_avg:87.68ms +step:1134/1680 train_time:99424ms step_avg:87.68ms +step:1135/1680 train_time:99512ms step_avg:87.68ms +step:1136/1680 train_time:99602ms step_avg:87.68ms +step:1137/1680 train_time:99693ms step_avg:87.68ms +step:1138/1680 train_time:99783ms step_avg:87.68ms +step:1139/1680 train_time:99874ms step_avg:87.69ms +step:1140/1680 train_time:99963ms step_avg:87.69ms +step:1141/1680 train_time:100052ms step_avg:87.69ms +step:1142/1680 train_time:100141ms step_avg:87.69ms +step:1143/1680 train_time:100230ms step_avg:87.69ms +step:1144/1680 train_time:100318ms step_avg:87.69ms +step:1145/1680 train_time:100406ms step_avg:87.69ms +step:1146/1680 train_time:100494ms step_avg:87.69ms +step:1147/1680 train_time:100584ms step_avg:87.69ms +step:1148/1680 train_time:100673ms step_avg:87.69ms +step:1149/1680 train_time:100763ms step_avg:87.70ms +step:1150/1680 train_time:100853ms step_avg:87.70ms +step:1151/1680 train_time:100942ms step_avg:87.70ms +step:1152/1680 train_time:101031ms step_avg:87.70ms +step:1153/1680 train_time:101120ms step_avg:87.70ms +step:1154/1680 train_time:101208ms step_avg:87.70ms +step:1155/1680 train_time:101297ms step_avg:87.70ms +step:1156/1680 train_time:101385ms step_avg:87.70ms +step:1157/1680 train_time:101473ms step_avg:87.70ms +step:1158/1680 train_time:101562ms step_avg:87.70ms +step:1159/1680 train_time:101651ms step_avg:87.71ms +step:1160/1680 train_time:101741ms step_avg:87.71ms +step:1161/1680 train_time:101831ms step_avg:87.71ms +step:1162/1680 train_time:101921ms step_avg:87.71ms +step:1163/1680 train_time:102010ms step_avg:87.71ms +step:1164/1680 train_time:102099ms step_avg:87.71ms +step:1165/1680 train_time:102188ms step_avg:87.71ms +step:1166/1680 train_time:102276ms step_avg:87.72ms +step:1167/1680 train_time:102365ms step_avg:87.72ms +step:1168/1680 train_time:102453ms step_avg:87.72ms +step:1169/1680 train_time:102543ms step_avg:87.72ms +step:1170/1680 train_time:102632ms step_avg:87.72ms +step:1171/1680 train_time:102721ms step_avg:87.72ms +step:1172/1680 train_time:102810ms step_avg:87.72ms +step:1173/1680 train_time:102900ms step_avg:87.72ms +step:1174/1680 train_time:102989ms step_avg:87.72ms +step:1175/1680 train_time:103078ms step_avg:87.73ms +step:1176/1680 train_time:103166ms step_avg:87.73ms +step:1177/1680 train_time:103255ms step_avg:87.73ms +step:1178/1680 train_time:103343ms step_avg:87.73ms +step:1179/1680 train_time:103431ms step_avg:87.73ms +step:1180/1680 train_time:103520ms step_avg:87.73ms +step:1181/1680 train_time:103609ms step_avg:87.73ms +step:1182/1680 train_time:103698ms step_avg:87.73ms +step:1183/1680 train_time:103787ms step_avg:87.73ms +step:1184/1680 train_time:103876ms step_avg:87.73ms +step:1185/1680 train_time:103965ms step_avg:87.73ms +step:1186/1680 train_time:104054ms step_avg:87.74ms +step:1187/1680 train_time:104143ms step_avg:87.74ms +step:1188/1680 train_time:104232ms step_avg:87.74ms +step:1189/1680 train_time:104321ms step_avg:87.74ms +step:1190/1680 train_time:104410ms step_avg:87.74ms +step:1191/1680 train_time:104499ms step_avg:87.74ms +step:1192/1680 train_time:104587ms step_avg:87.74ms +step:1193/1680 train_time:104676ms step_avg:87.74ms +step:1194/1680 train_time:104765ms step_avg:87.74ms +step:1195/1680 train_time:104854ms step_avg:87.74ms +step:1196/1680 train_time:104943ms step_avg:87.75ms +step:1197/1680 train_time:105032ms step_avg:87.75ms +step:1198/1680 train_time:105121ms step_avg:87.75ms +step:1199/1680 train_time:105209ms step_avg:87.75ms +step:1200/1680 train_time:105298ms step_avg:87.75ms +step:1201/1680 train_time:105387ms step_avg:87.75ms +step:1202/1680 train_time:105476ms step_avg:87.75ms +step:1203/1680 train_time:105565ms step_avg:87.75ms +step:1204/1680 train_time:105654ms step_avg:87.75ms +step:1205/1680 train_time:105743ms step_avg:87.75ms +step:1206/1680 train_time:105831ms step_avg:87.75ms +step:1207/1680 train_time:105921ms step_avg:87.76ms +step:1208/1680 train_time:106010ms step_avg:87.76ms +step:1209/1680 train_time:106099ms step_avg:87.76ms +step:1210/1680 train_time:106188ms step_avg:87.76ms +step:1211/1680 train_time:106277ms step_avg:87.76ms +step:1212/1680 train_time:106366ms step_avg:87.76ms +step:1213/1680 train_time:106455ms step_avg:87.76ms +step:1214/1680 train_time:106544ms step_avg:87.76ms +step:1215/1680 train_time:106633ms step_avg:87.76ms +step:1216/1680 train_time:106722ms step_avg:87.77ms +step:1217/1680 train_time:106811ms step_avg:87.77ms +step:1218/1680 train_time:106901ms step_avg:87.77ms +step:1219/1680 train_time:106990ms step_avg:87.77ms +step:1220/1680 train_time:107080ms step_avg:87.77ms +step:1221/1680 train_time:107169ms step_avg:87.77ms +step:1222/1680 train_time:107257ms step_avg:87.77ms +step:1223/1680 train_time:107347ms step_avg:87.77ms +step:1224/1680 train_time:107436ms step_avg:87.77ms +step:1225/1680 train_time:107525ms step_avg:87.78ms +step:1226/1680 train_time:107614ms step_avg:87.78ms +step:1227/1680 train_time:107703ms step_avg:87.78ms +step:1228/1680 train_time:107792ms step_avg:87.78ms +step:1229/1680 train_time:107882ms step_avg:87.78ms +step:1230/1680 train_time:107971ms step_avg:87.78ms +step:1231/1680 train_time:108061ms step_avg:87.78ms +step:1232/1680 train_time:108151ms step_avg:87.78ms +step:1233/1680 train_time:108241ms step_avg:87.79ms +step:1234/1680 train_time:108330ms step_avg:87.79ms +step:1235/1680 train_time:108419ms step_avg:87.79ms +step:1236/1680 train_time:108508ms step_avg:87.79ms +step:1237/1680 train_time:108596ms step_avg:87.79ms +step:1238/1680 train_time:108685ms step_avg:87.79ms +step:1239/1680 train_time:108773ms step_avg:87.79ms +step:1240/1680 train_time:108863ms step_avg:87.79ms +step:1241/1680 train_time:108952ms step_avg:87.79ms +step:1242/1680 train_time:109041ms step_avg:87.79ms +step:1243/1680 train_time:109130ms step_avg:87.80ms +step:1244/1680 train_time:109219ms step_avg:87.80ms +step:1245/1680 train_time:109309ms step_avg:87.80ms +step:1246/1680 train_time:109397ms step_avg:87.80ms +step:1247/1680 train_time:109486ms step_avg:87.80ms +step:1248/1680 train_time:109575ms step_avg:87.80ms +step:1249/1680 train_time:109663ms step_avg:87.80ms +step:1250/1680 train_time:109752ms step_avg:87.80ms +step:1250/1680 val_loss:3.3777 train_time:109842ms step_avg:87.87ms +step:1251/1680 train_time:109859ms step_avg:87.82ms +step:1252/1680 train_time:109936ms step_avg:87.81ms +step:1253/1680 train_time:110031ms step_avg:87.81ms +step:1254/1680 train_time:110121ms step_avg:87.82ms +step:1255/1680 train_time:110210ms step_avg:87.82ms +step:1256/1680 train_time:110297ms step_avg:87.82ms +step:1257/1680 train_time:110385ms step_avg:87.82ms +step:1258/1680 train_time:110472ms step_avg:87.82ms +step:1259/1680 train_time:110560ms step_avg:87.82ms +step:1260/1680 train_time:110648ms step_avg:87.82ms +step:1261/1680 train_time:110735ms step_avg:87.82ms +step:1262/1680 train_time:110825ms step_avg:87.82ms +step:1263/1680 train_time:110916ms step_avg:87.82ms +step:1264/1680 train_time:111007ms step_avg:87.82ms +step:1265/1680 train_time:111097ms step_avg:87.82ms +step:1266/1680 train_time:111187ms step_avg:87.83ms +step:1267/1680 train_time:111275ms step_avg:87.83ms +step:1268/1680 train_time:111363ms step_avg:87.83ms +step:1269/1680 train_time:111452ms step_avg:87.83ms +step:1270/1680 train_time:111540ms step_avg:87.83ms +step:1271/1680 train_time:111628ms step_avg:87.83ms +step:1272/1680 train_time:111716ms step_avg:87.83ms +step:1273/1680 train_time:111805ms step_avg:87.83ms +step:1274/1680 train_time:111896ms step_avg:87.83ms +step:1275/1680 train_time:111987ms step_avg:87.83ms +step:1276/1680 train_time:112078ms step_avg:87.84ms +step:1277/1680 train_time:112168ms step_avg:87.84ms +step:1278/1680 train_time:112256ms step_avg:87.84ms +step:1279/1680 train_time:112344ms step_avg:87.84ms +step:1280/1680 train_time:112433ms step_avg:87.84ms +step:1281/1680 train_time:112521ms step_avg:87.84ms +step:1282/1680 train_time:112610ms step_avg:87.84ms +step:1283/1680 train_time:112698ms step_avg:87.84ms +step:1284/1680 train_time:112787ms step_avg:87.84ms +step:1285/1680 train_time:112876ms step_avg:87.84ms +step:1286/1680 train_time:112965ms step_avg:87.84ms +step:1287/1680 train_time:113055ms step_avg:87.84ms +step:1288/1680 train_time:113145ms step_avg:87.85ms +step:1289/1680 train_time:113234ms step_avg:87.85ms +step:1290/1680 train_time:113323ms step_avg:87.85ms +step:1291/1680 train_time:113412ms step_avg:87.85ms +step:1292/1680 train_time:113500ms step_avg:87.85ms +step:1293/1680 train_time:113589ms step_avg:87.85ms +step:1294/1680 train_time:113678ms step_avg:87.85ms +step:1295/1680 train_time:113766ms step_avg:87.85ms +step:1296/1680 train_time:113854ms step_avg:87.85ms +step:1297/1680 train_time:113943ms step_avg:87.85ms +step:1298/1680 train_time:114032ms step_avg:87.85ms +step:1299/1680 train_time:114122ms step_avg:87.85ms +step:1300/1680 train_time:114211ms step_avg:87.85ms +step:1301/1680 train_time:114300ms step_avg:87.86ms +step:1302/1680 train_time:114389ms step_avg:87.86ms +step:1303/1680 train_time:114478ms step_avg:87.86ms +step:1304/1680 train_time:114567ms step_avg:87.86ms +step:1305/1680 train_time:114655ms step_avg:87.86ms +step:1306/1680 train_time:114744ms step_avg:87.86ms +step:1307/1680 train_time:114832ms step_avg:87.86ms +step:1308/1680 train_time:114922ms step_avg:87.86ms +step:1309/1680 train_time:115011ms step_avg:87.86ms +step:1310/1680 train_time:115101ms step_avg:87.86ms +step:1311/1680 train_time:115190ms step_avg:87.86ms +step:1312/1680 train_time:115279ms step_avg:87.87ms +step:1313/1680 train_time:115369ms step_avg:87.87ms +step:1314/1680 train_time:115457ms step_avg:87.87ms +step:1315/1680 train_time:115546ms step_avg:87.87ms +step:1316/1680 train_time:115635ms step_avg:87.87ms +step:1317/1680 train_time:115724ms step_avg:87.87ms +step:1318/1680 train_time:115813ms step_avg:87.87ms +step:1319/1680 train_time:115902ms step_avg:87.87ms +step:1320/1680 train_time:115992ms step_avg:87.87ms +step:1321/1680 train_time:116082ms step_avg:87.87ms +step:1322/1680 train_time:116172ms step_avg:87.88ms +step:1323/1680 train_time:116262ms step_avg:87.88ms +step:1324/1680 train_time:116351ms step_avg:87.88ms +step:1325/1680 train_time:116440ms step_avg:87.88ms +step:1326/1680 train_time:116528ms step_avg:87.88ms +step:1327/1680 train_time:116616ms step_avg:87.88ms +step:1328/1680 train_time:116705ms step_avg:87.88ms +step:1329/1680 train_time:116794ms step_avg:87.88ms +step:1330/1680 train_time:116885ms step_avg:87.88ms +step:1331/1680 train_time:116974ms step_avg:87.88ms +step:1332/1680 train_time:117063ms step_avg:87.89ms +step:1333/1680 train_time:117152ms step_avg:87.89ms +step:1334/1680 train_time:117241ms step_avg:87.89ms +step:1335/1680 train_time:117330ms step_avg:87.89ms +step:1336/1680 train_time:117419ms step_avg:87.89ms +step:1337/1680 train_time:117508ms step_avg:87.89ms +step:1338/1680 train_time:117597ms step_avg:87.89ms +step:1339/1680 train_time:117685ms step_avg:87.89ms +step:1340/1680 train_time:117774ms step_avg:87.89ms +step:1341/1680 train_time:117863ms step_avg:87.89ms +step:1342/1680 train_time:117951ms step_avg:87.89ms +step:1343/1680 train_time:118040ms step_avg:87.89ms +step:1344/1680 train_time:118129ms step_avg:87.89ms +step:1345/1680 train_time:118218ms step_avg:87.89ms +step:1346/1680 train_time:118307ms step_avg:87.90ms +step:1347/1680 train_time:118397ms step_avg:87.90ms +step:1348/1680 train_time:118485ms step_avg:87.90ms +step:1349/1680 train_time:118574ms step_avg:87.90ms +step:1350/1680 train_time:118663ms step_avg:87.90ms +step:1351/1680 train_time:118752ms step_avg:87.90ms +step:1352/1680 train_time:118842ms step_avg:87.90ms +step:1353/1680 train_time:118933ms step_avg:87.90ms +step:1354/1680 train_time:119022ms step_avg:87.90ms +step:1355/1680 train_time:119111ms step_avg:87.90ms +step:1356/1680 train_time:119199ms step_avg:87.90ms +step:1357/1680 train_time:119288ms step_avg:87.91ms +step:1358/1680 train_time:119378ms step_avg:87.91ms +step:1359/1680 train_time:119467ms step_avg:87.91ms +step:1360/1680 train_time:119556ms step_avg:87.91ms +step:1361/1680 train_time:119644ms step_avg:87.91ms +step:1362/1680 train_time:119733ms step_avg:87.91ms +step:1363/1680 train_time:119823ms step_avg:87.91ms +step:1364/1680 train_time:119911ms step_avg:87.91ms +step:1365/1680 train_time:120000ms step_avg:87.91ms +step:1366/1680 train_time:120089ms step_avg:87.91ms +step:1367/1680 train_time:120178ms step_avg:87.91ms +step:1368/1680 train_time:120267ms step_avg:87.91ms +step:1369/1680 train_time:120356ms step_avg:87.92ms +step:1370/1680 train_time:120445ms step_avg:87.92ms +step:1371/1680 train_time:120534ms step_avg:87.92ms +step:1372/1680 train_time:120623ms step_avg:87.92ms +step:1373/1680 train_time:120711ms step_avg:87.92ms +step:1374/1680 train_time:120800ms step_avg:87.92ms +step:1375/1680 train_time:120889ms step_avg:87.92ms +step:1375/1680 val_loss:3.3433 train_time:120980ms step_avg:87.99ms +step:1376/1680 train_time:120998ms step_avg:87.93ms +step:1377/1680 train_time:121071ms step_avg:87.92ms +step:1378/1680 train_time:121163ms step_avg:87.93ms +step:1379/1680 train_time:121252ms step_avg:87.93ms +step:1380/1680 train_time:121340ms step_avg:87.93ms +step:1381/1680 train_time:121428ms step_avg:87.93ms +step:1382/1680 train_time:121516ms step_avg:87.93ms +step:1383/1680 train_time:121604ms step_avg:87.93ms +step:1384/1680 train_time:121692ms step_avg:87.93ms +step:1385/1680 train_time:121781ms step_avg:87.93ms +step:1386/1680 train_time:121870ms step_avg:87.93ms +step:1387/1680 train_time:121961ms step_avg:87.93ms +step:1388/1680 train_time:122052ms step_avg:87.93ms +step:1389/1680 train_time:122142ms step_avg:87.93ms +step:1390/1680 train_time:122231ms step_avg:87.94ms +step:1391/1680 train_time:122320ms step_avg:87.94ms +step:1392/1680 train_time:122409ms step_avg:87.94ms +step:1393/1680 train_time:122497ms step_avg:87.94ms +step:1394/1680 train_time:122585ms step_avg:87.94ms +step:1395/1680 train_time:122673ms step_avg:87.94ms +step:1396/1680 train_time:122762ms step_avg:87.94ms +step:1397/1680 train_time:122852ms step_avg:87.94ms +step:1398/1680 train_time:122941ms step_avg:87.94ms +step:1399/1680 train_time:123030ms step_avg:87.94ms +step:1400/1680 train_time:123120ms step_avg:87.94ms +step:1401/1680 train_time:123209ms step_avg:87.94ms +step:1402/1680 train_time:123297ms step_avg:87.94ms +step:1403/1680 train_time:123387ms step_avg:87.94ms +step:1404/1680 train_time:123475ms step_avg:87.95ms +step:1405/1680 train_time:123564ms step_avg:87.95ms +step:1406/1680 train_time:123653ms step_avg:87.95ms +step:1407/1680 train_time:123741ms step_avg:87.95ms +step:1408/1680 train_time:123829ms step_avg:87.95ms +step:1409/1680 train_time:123918ms step_avg:87.95ms +step:1410/1680 train_time:124008ms step_avg:87.95ms +step:1411/1680 train_time:124097ms step_avg:87.95ms +step:1412/1680 train_time:124187ms step_avg:87.95ms +step:1413/1680 train_time:124277ms step_avg:87.95ms +step:1414/1680 train_time:124366ms step_avg:87.95ms +step:1415/1680 train_time:124456ms step_avg:87.95ms +step:1416/1680 train_time:124544ms step_avg:87.95ms +step:1417/1680 train_time:124634ms step_avg:87.96ms +step:1418/1680 train_time:124722ms step_avg:87.96ms +step:1419/1680 train_time:124811ms step_avg:87.96ms +step:1420/1680 train_time:124900ms step_avg:87.96ms +step:1421/1680 train_time:124989ms step_avg:87.96ms +step:1422/1680 train_time:125079ms step_avg:87.96ms +step:1423/1680 train_time:125168ms step_avg:87.96ms +step:1424/1680 train_time:125257ms step_avg:87.96ms +step:1425/1680 train_time:125346ms step_avg:87.96ms +step:1426/1680 train_time:125435ms step_avg:87.96ms +step:1427/1680 train_time:125524ms step_avg:87.96ms +step:1428/1680 train_time:125613ms step_avg:87.96ms +step:1429/1680 train_time:125702ms step_avg:87.97ms +step:1430/1680 train_time:125791ms step_avg:87.97ms +step:1431/1680 train_time:125880ms step_avg:87.97ms +step:1432/1680 train_time:125969ms step_avg:87.97ms +step:1433/1680 train_time:126059ms step_avg:87.97ms +step:1434/1680 train_time:126148ms step_avg:87.97ms +step:1435/1680 train_time:126238ms step_avg:87.97ms +step:1436/1680 train_time:126328ms step_avg:87.97ms +step:1437/1680 train_time:126417ms step_avg:87.97ms +step:1438/1680 train_time:126506ms step_avg:87.97ms +step:1439/1680 train_time:126594ms step_avg:87.97ms +step:1440/1680 train_time:126683ms step_avg:87.97ms +step:1441/1680 train_time:126772ms step_avg:87.97ms +step:1442/1680 train_time:126861ms step_avg:87.98ms +step:1443/1680 train_time:126950ms step_avg:87.98ms +step:1444/1680 train_time:127039ms step_avg:87.98ms +step:1445/1680 train_time:127128ms step_avg:87.98ms +step:1446/1680 train_time:127217ms step_avg:87.98ms +step:1447/1680 train_time:127306ms step_avg:87.98ms +step:1448/1680 train_time:127395ms step_avg:87.98ms +step:1449/1680 train_time:127485ms step_avg:87.98ms +step:1450/1680 train_time:127574ms step_avg:87.98ms +step:1451/1680 train_time:127664ms step_avg:87.98ms +step:1452/1680 train_time:127753ms step_avg:87.98ms +step:1453/1680 train_time:127842ms step_avg:87.98ms +step:1454/1680 train_time:127931ms step_avg:87.99ms +step:1455/1680 train_time:128020ms step_avg:87.99ms +step:1456/1680 train_time:128109ms step_avg:87.99ms +step:1457/1680 train_time:128198ms step_avg:87.99ms +step:1458/1680 train_time:128287ms step_avg:87.99ms +step:1459/1680 train_time:128376ms step_avg:87.99ms +step:1460/1680 train_time:128466ms step_avg:87.99ms +step:1461/1680 train_time:128555ms step_avg:87.99ms +step:1462/1680 train_time:128644ms step_avg:87.99ms +step:1463/1680 train_time:128734ms step_avg:87.99ms +step:1464/1680 train_time:128823ms step_avg:87.99ms +step:1465/1680 train_time:128911ms step_avg:87.99ms +step:1466/1680 train_time:129001ms step_avg:87.99ms +step:1467/1680 train_time:129090ms step_avg:88.00ms +step:1468/1680 train_time:129178ms step_avg:88.00ms +step:1469/1680 train_time:129268ms step_avg:88.00ms +step:1470/1680 train_time:129356ms step_avg:88.00ms +step:1471/1680 train_time:129446ms step_avg:88.00ms +step:1472/1680 train_time:129535ms step_avg:88.00ms +step:1473/1680 train_time:129626ms step_avg:88.00ms +step:1474/1680 train_time:129715ms step_avg:88.00ms +step:1475/1680 train_time:129804ms step_avg:88.00ms +step:1476/1680 train_time:129892ms step_avg:88.00ms +step:1477/1680 train_time:129981ms step_avg:88.00ms +step:1478/1680 train_time:130070ms step_avg:88.00ms +step:1479/1680 train_time:130160ms step_avg:88.01ms +step:1480/1680 train_time:130249ms step_avg:88.01ms +step:1481/1680 train_time:130338ms step_avg:88.01ms +step:1482/1680 train_time:130427ms step_avg:88.01ms +step:1483/1680 train_time:130516ms step_avg:88.01ms +step:1484/1680 train_time:130606ms step_avg:88.01ms +step:1485/1680 train_time:130695ms step_avg:88.01ms +step:1486/1680 train_time:130783ms step_avg:88.01ms +step:1487/1680 train_time:130873ms step_avg:88.01ms +step:1488/1680 train_time:130961ms step_avg:88.01ms +step:1489/1680 train_time:131050ms step_avg:88.01ms +step:1490/1680 train_time:131139ms step_avg:88.01ms +step:1491/1680 train_time:131229ms step_avg:88.01ms +step:1492/1680 train_time:131318ms step_avg:88.01ms +step:1493/1680 train_time:131407ms step_avg:88.02ms +step:1494/1680 train_time:131496ms step_avg:88.02ms +step:1495/1680 train_time:131585ms step_avg:88.02ms +step:1496/1680 train_time:131673ms step_avg:88.02ms +step:1497/1680 train_time:131762ms step_avg:88.02ms +step:1498/1680 train_time:131851ms step_avg:88.02ms +step:1499/1680 train_time:131941ms step_avg:88.02ms +step:1500/1680 train_time:132030ms step_avg:88.02ms +step:1500/1680 val_loss:3.3135 train_time:132121ms step_avg:88.08ms +step:1501/1680 train_time:132139ms step_avg:88.03ms +step:1502/1680 train_time:132212ms step_avg:88.02ms +step:1503/1680 train_time:132304ms step_avg:88.03ms +step:1504/1680 train_time:132395ms step_avg:88.03ms +step:1505/1680 train_time:132483ms step_avg:88.03ms +step:1506/1680 train_time:132572ms step_avg:88.03ms +step:1507/1680 train_time:132660ms step_avg:88.03ms +step:1508/1680 train_time:132748ms step_avg:88.03ms +step:1509/1680 train_time:132836ms step_avg:88.03ms +step:1510/1680 train_time:132924ms step_avg:88.03ms +step:1511/1680 train_time:133012ms step_avg:88.03ms +step:1512/1680 train_time:133103ms step_avg:88.03ms +step:1513/1680 train_time:133193ms step_avg:88.03ms +step:1514/1680 train_time:133284ms step_avg:88.03ms +step:1515/1680 train_time:133374ms step_avg:88.04ms +step:1516/1680 train_time:133463ms step_avg:88.04ms +step:1517/1680 train_time:133551ms step_avg:88.04ms +step:1518/1680 train_time:133640ms step_avg:88.04ms +step:1519/1680 train_time:133728ms step_avg:88.04ms +step:1520/1680 train_time:133816ms step_avg:88.04ms +step:1521/1680 train_time:133904ms step_avg:88.04ms +step:1522/1680 train_time:133993ms step_avg:88.04ms +step:1523/1680 train_time:134082ms step_avg:88.04ms +step:1524/1680 train_time:134171ms step_avg:88.04ms +step:1525/1680 train_time:134262ms step_avg:88.04ms +step:1526/1680 train_time:134351ms step_avg:88.04ms +step:1527/1680 train_time:134441ms step_avg:88.04ms +step:1528/1680 train_time:134529ms step_avg:88.04ms +step:1529/1680 train_time:134617ms step_avg:88.04ms +step:1530/1680 train_time:134706ms step_avg:88.04ms +step:1531/1680 train_time:134794ms step_avg:88.04ms +step:1532/1680 train_time:134882ms step_avg:88.04ms +step:1533/1680 train_time:134971ms step_avg:88.04ms +step:1534/1680 train_time:135061ms step_avg:88.04ms +step:1535/1680 train_time:135149ms step_avg:88.05ms +step:1536/1680 train_time:135239ms step_avg:88.05ms +step:1537/1680 train_time:135329ms step_avg:88.05ms +step:1538/1680 train_time:135418ms step_avg:88.05ms +step:1539/1680 train_time:135507ms step_avg:88.05ms +step:1540/1680 train_time:135596ms step_avg:88.05ms +step:1541/1680 train_time:135686ms step_avg:88.05ms +step:1542/1680 train_time:135774ms step_avg:88.05ms +step:1543/1680 train_time:135863ms step_avg:88.05ms +step:1544/1680 train_time:135951ms step_avg:88.05ms +step:1545/1680 train_time:136040ms step_avg:88.05ms +step:1546/1680 train_time:136129ms step_avg:88.05ms +step:1547/1680 train_time:136218ms step_avg:88.05ms +step:1548/1680 train_time:136308ms step_avg:88.05ms +step:1549/1680 train_time:136398ms step_avg:88.06ms +step:1550/1680 train_time:136487ms step_avg:88.06ms +step:1551/1680 train_time:136575ms step_avg:88.06ms +step:1552/1680 train_time:136664ms step_avg:88.06ms +step:1553/1680 train_time:136753ms step_avg:88.06ms +step:1554/1680 train_time:136842ms step_avg:88.06ms +step:1555/1680 train_time:136930ms step_avg:88.06ms +step:1556/1680 train_time:137019ms step_avg:88.06ms +step:1557/1680 train_time:137108ms step_avg:88.06ms +step:1558/1680 train_time:137197ms step_avg:88.06ms +step:1559/1680 train_time:137286ms step_avg:88.06ms +step:1560/1680 train_time:137375ms step_avg:88.06ms +step:1561/1680 train_time:137464ms step_avg:88.06ms +step:1562/1680 train_time:137554ms step_avg:88.06ms +step:1563/1680 train_time:137642ms step_avg:88.06ms +step:1564/1680 train_time:137731ms step_avg:88.06ms +step:1565/1680 train_time:137820ms step_avg:88.06ms +step:1566/1680 train_time:137908ms step_avg:88.06ms +step:1567/1680 train_time:137998ms step_avg:88.06ms +step:1568/1680 train_time:138087ms step_avg:88.07ms +step:1569/1680 train_time:138176ms step_avg:88.07ms +step:1570/1680 train_time:138265ms step_avg:88.07ms +step:1571/1680 train_time:138354ms step_avg:88.07ms +step:1572/1680 train_time:138442ms step_avg:88.07ms +step:1573/1680 train_time:138531ms step_avg:88.07ms +step:1574/1680 train_time:138621ms step_avg:88.07ms +step:1575/1680 train_time:138709ms step_avg:88.07ms +step:1576/1680 train_time:138798ms step_avg:88.07ms +step:1577/1680 train_time:138887ms step_avg:88.07ms +step:1578/1680 train_time:138977ms step_avg:88.07ms +step:1579/1680 train_time:139066ms step_avg:88.07ms +step:1580/1680 train_time:139155ms step_avg:88.07ms +step:1581/1680 train_time:139244ms step_avg:88.07ms +step:1582/1680 train_time:139333ms step_avg:88.07ms +step:1583/1680 train_time:139422ms step_avg:88.07ms +step:1584/1680 train_time:139511ms step_avg:88.07ms +step:1585/1680 train_time:139600ms step_avg:88.08ms +step:1586/1680 train_time:139688ms step_avg:88.08ms +step:1587/1680 train_time:139777ms step_avg:88.08ms +step:1588/1680 train_time:139866ms step_avg:88.08ms +step:1589/1680 train_time:139956ms step_avg:88.08ms +step:1590/1680 train_time:140045ms step_avg:88.08ms +step:1591/1680 train_time:140134ms step_avg:88.08ms +step:1592/1680 train_time:140223ms step_avg:88.08ms +step:1593/1680 train_time:140312ms step_avg:88.08ms +step:1594/1680 train_time:140400ms step_avg:88.08ms +step:1595/1680 train_time:140490ms step_avg:88.08ms +step:1596/1680 train_time:140579ms step_avg:88.08ms +step:1597/1680 train_time:140667ms step_avg:88.08ms +step:1598/1680 train_time:140757ms step_avg:88.08ms +step:1599/1680 train_time:140846ms step_avg:88.08ms +step:1600/1680 train_time:140935ms step_avg:88.08ms +step:1601/1680 train_time:141024ms step_avg:88.08ms +step:1602/1680 train_time:141112ms step_avg:88.09ms +step:1603/1680 train_time:141202ms step_avg:88.09ms +step:1604/1680 train_time:141291ms step_avg:88.09ms +step:1605/1680 train_time:141380ms step_avg:88.09ms +step:1606/1680 train_time:141469ms step_avg:88.09ms +step:1607/1680 train_time:141558ms step_avg:88.09ms +step:1608/1680 train_time:141646ms step_avg:88.09ms +step:1609/1680 train_time:141736ms step_avg:88.09ms +step:1610/1680 train_time:141825ms step_avg:88.09ms +step:1611/1680 train_time:141914ms step_avg:88.09ms +step:1612/1680 train_time:142003ms step_avg:88.09ms +step:1613/1680 train_time:142091ms step_avg:88.09ms +step:1614/1680 train_time:142181ms step_avg:88.09ms +step:1615/1680 train_time:142270ms step_avg:88.09ms +step:1616/1680 train_time:142360ms step_avg:88.09ms +step:1617/1680 train_time:142449ms step_avg:88.09ms +step:1618/1680 train_time:142538ms step_avg:88.10ms +step:1619/1680 train_time:142627ms step_avg:88.10ms +step:1620/1680 train_time:142716ms step_avg:88.10ms +step:1621/1680 train_time:142806ms step_avg:88.10ms +step:1622/1680 train_time:142895ms step_avg:88.10ms +step:1623/1680 train_time:142984ms step_avg:88.10ms +step:1624/1680 train_time:143073ms step_avg:88.10ms +step:1625/1680 train_time:143161ms step_avg:88.10ms +step:1625/1680 val_loss:3.2898 train_time:143251ms step_avg:88.15ms +step:1626/1680 train_time:143269ms step_avg:88.11ms +step:1627/1680 train_time:143343ms step_avg:88.10ms +step:1628/1680 train_time:143438ms step_avg:88.11ms +step:1629/1680 train_time:143526ms step_avg:88.11ms +step:1630/1680 train_time:143615ms step_avg:88.11ms +step:1631/1680 train_time:143704ms step_avg:88.11ms +step:1632/1680 train_time:143792ms step_avg:88.11ms +step:1633/1680 train_time:143880ms step_avg:88.11ms +step:1634/1680 train_time:143968ms step_avg:88.11ms +step:1635/1680 train_time:144056ms step_avg:88.11ms +step:1636/1680 train_time:144144ms step_avg:88.11ms +step:1637/1680 train_time:144234ms step_avg:88.11ms +step:1638/1680 train_time:144326ms step_avg:88.11ms +step:1639/1680 train_time:144417ms step_avg:88.11ms +step:1640/1680 train_time:144507ms step_avg:88.11ms +step:1641/1680 train_time:144597ms step_avg:88.11ms +step:1642/1680 train_time:144685ms step_avg:88.12ms +step:1643/1680 train_time:144773ms step_avg:88.12ms +step:1644/1680 train_time:144861ms step_avg:88.12ms +step:1645/1680 train_time:144950ms step_avg:88.12ms +step:1646/1680 train_time:145038ms step_avg:88.12ms +step:1647/1680 train_time:145126ms step_avg:88.12ms +step:1648/1680 train_time:145215ms step_avg:88.12ms +step:1649/1680 train_time:145304ms step_avg:88.12ms +step:1650/1680 train_time:145394ms step_avg:88.12ms +step:1651/1680 train_time:145484ms step_avg:88.12ms +step:1652/1680 train_time:145574ms step_avg:88.12ms +step:1653/1680 train_time:145662ms step_avg:88.12ms +step:1654/1680 train_time:145750ms step_avg:88.12ms +step:1655/1680 train_time:145840ms step_avg:88.12ms +step:1656/1680 train_time:145928ms step_avg:88.12ms +step:1657/1680 train_time:146017ms step_avg:88.12ms +step:1658/1680 train_time:146106ms step_avg:88.12ms +step:1659/1680 train_time:146195ms step_avg:88.12ms +step:1660/1680 train_time:146283ms step_avg:88.12ms +step:1661/1680 train_time:146372ms step_avg:88.12ms +step:1662/1680 train_time:146461ms step_avg:88.12ms +step:1663/1680 train_time:146552ms step_avg:88.12ms +step:1664/1680 train_time:146641ms step_avg:88.13ms +step:1665/1680 train_time:146730ms step_avg:88.13ms +step:1666/1680 train_time:146820ms step_avg:88.13ms +step:1667/1680 train_time:146908ms step_avg:88.13ms +step:1668/1680 train_time:146997ms step_avg:88.13ms +step:1669/1680 train_time:147085ms step_avg:88.13ms +step:1670/1680 train_time:147174ms step_avg:88.13ms +step:1671/1680 train_time:147262ms step_avg:88.13ms +step:1672/1680 train_time:147351ms step_avg:88.13ms +step:1673/1680 train_time:147440ms step_avg:88.13ms +step:1674/1680 train_time:147530ms step_avg:88.13ms +step:1675/1680 train_time:147621ms step_avg:88.13ms +step:1676/1680 train_time:147711ms step_avg:88.13ms +step:1677/1680 train_time:147800ms step_avg:88.13ms +step:1678/1680 train_time:147889ms step_avg:88.13ms +step:1679/1680 train_time:147977ms step_avg:88.13ms +step:1680/1680 train_time:148066ms step_avg:88.13ms +step:1680/1680 val_loss:3.2791 train_time:148156ms step_avg:88.19ms +peak memory allocated: 30760 MiB reserved: 46114 MiB diff --git a/records/092725_BF16CE/550ba6aa-d6a7-4a20-8303-f2b8d93c5f52.txt b/records/092725_BF16CE/550ba6aa-d6a7-4a20-8303-f2b8d93c5f52.txt new file mode 100644 index 000000000..7ea1f1d19 --- /dev/null +++ b/records/092725_BF16CE/550ba6aa-d6a7-4a20-8303-f2b8d93c5f52.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:58:57 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 123W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 166464 C /usr/bin/python 0MiB | +| 0 N/A N/A 166465 C /usr/bin/python 0MiB | +| 0 N/A N/A 166466 C /usr/bin/python 0MiB | +| 0 N/A N/A 166467 C /usr/bin/python 0MiB | +| 0 N/A N/A 166468 C /usr/bin/python 0MiB | +| 0 N/A N/A 166469 C /usr/bin/python 0MiB | +| 0 N/A N/A 166470 C /usr/bin/python 0MiB | +| 0 N/A N/A 166471 C /usr/bin/python 0MiB | +| 1 N/A N/A 166465 C /usr/bin/python 0MiB | +| 2 N/A N/A 166466 C /usr/bin/python 0MiB | +| 3 N/A N/A 166467 C /usr/bin/python 0MiB | +| 4 N/A N/A 166468 C /usr/bin/python 0MiB | +| 5 N/A N/A 166469 C /usr/bin/python 0MiB | +| 6 N/A N/A 166470 C /usr/bin/python 0MiB | +| 7 N/A N/A 166471 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:136ms step_avg:135.60ms +step:2/1680 train_time:157ms step_avg:78.30ms +step:3/1680 train_time:219ms step_avg:73.08ms +step:4/1680 train_time:304ms step_avg:76.00ms +step:5/1680 train_time:390ms step_avg:78.02ms +step:6/1680 train_time:476ms step_avg:79.35ms +step:7/1680 train_time:563ms step_avg:80.38ms +step:8/1680 train_time:649ms step_avg:81.09ms +step:9/1680 train_time:735ms step_avg:81.63ms +step:10/1680 train_time:821ms step_avg:82.10ms +step:11/1680 train_time:907ms step_avg:82.50ms +step:12/1680 train_time:996ms step_avg:83.01ms +step:13/1680 train_time:1089ms step_avg:83.79ms +step:14/1680 train_time:1179ms step_avg:84.18ms +step:15/1680 train_time:1266ms step_avg:84.39ms +step:16/1680 train_time:1353ms step_avg:84.54ms +step:17/1680 train_time:1440ms step_avg:84.69ms +step:18/1680 train_time:1527ms step_avg:84.85ms +step:19/1680 train_time:1614ms step_avg:84.92ms +step:20/1680 train_time:1701ms step_avg:85.03ms +step:21/1680 train_time:1787ms step_avg:85.10ms +step:22/1680 train_time:1874ms step_avg:85.17ms +step:23/1680 train_time:1961ms step_avg:85.28ms +step:24/1680 train_time:2050ms step_avg:85.42ms +step:25/1680 train_time:2139ms step_avg:85.55ms +step:26/1680 train_time:2227ms step_avg:85.65ms +step:27/1680 train_time:2315ms step_avg:85.73ms +step:28/1680 train_time:2402ms step_avg:85.80ms +step:29/1680 train_time:2490ms step_avg:85.86ms +step:30/1680 train_time:2577ms step_avg:85.89ms +step:31/1680 train_time:2664ms step_avg:85.92ms +step:32/1680 train_time:2750ms step_avg:85.94ms +step:33/1680 train_time:2837ms step_avg:85.98ms +step:34/1680 train_time:2924ms step_avg:86.00ms +step:35/1680 train_time:3013ms step_avg:86.07ms +step:36/1680 train_time:3101ms step_avg:86.13ms +step:37/1680 train_time:3189ms step_avg:86.19ms +step:38/1680 train_time:3276ms step_avg:86.22ms +step:39/1680 train_time:3365ms step_avg:86.27ms +step:40/1680 train_time:3452ms step_avg:86.29ms +step:41/1680 train_time:3539ms step_avg:86.31ms +step:42/1680 train_time:3626ms step_avg:86.33ms +step:43/1680 train_time:3712ms step_avg:86.34ms +step:44/1680 train_time:3799ms step_avg:86.35ms +step:45/1680 train_time:3886ms step_avg:86.35ms +step:46/1680 train_time:3973ms step_avg:86.36ms +step:47/1680 train_time:4061ms step_avg:86.40ms +step:48/1680 train_time:4149ms step_avg:86.43ms +step:49/1680 train_time:4236ms step_avg:86.45ms +step:50/1680 train_time:4324ms step_avg:86.48ms +step:51/1680 train_time:4411ms step_avg:86.49ms +step:52/1680 train_time:4498ms step_avg:86.50ms +step:53/1680 train_time:4585ms step_avg:86.51ms +step:54/1680 train_time:4672ms step_avg:86.52ms +step:55/1680 train_time:4759ms step_avg:86.53ms +step:56/1680 train_time:4846ms step_avg:86.54ms +step:57/1680 train_time:4933ms step_avg:86.54ms +step:58/1680 train_time:5020ms step_avg:86.55ms +step:59/1680 train_time:5107ms step_avg:86.57ms +step:60/1680 train_time:5195ms step_avg:86.58ms +step:61/1680 train_time:5282ms step_avg:86.60ms +step:62/1680 train_time:5370ms step_avg:86.61ms +step:63/1680 train_time:5458ms step_avg:86.63ms +step:64/1680 train_time:5544ms step_avg:86.63ms +step:65/1680 train_time:5632ms step_avg:86.65ms +step:66/1680 train_time:5719ms step_avg:86.65ms +step:67/1680 train_time:5806ms step_avg:86.66ms +step:68/1680 train_time:5894ms step_avg:86.68ms +step:69/1680 train_time:5981ms step_avg:86.69ms +step:70/1680 train_time:6068ms step_avg:86.69ms +step:71/1680 train_time:6155ms step_avg:86.69ms +step:72/1680 train_time:6243ms step_avg:86.70ms +step:73/1680 train_time:6330ms step_avg:86.72ms +step:74/1680 train_time:6418ms step_avg:86.73ms +step:75/1680 train_time:6505ms step_avg:86.74ms +step:76/1680 train_time:6592ms step_avg:86.74ms +step:77/1680 train_time:6680ms step_avg:86.75ms +step:78/1680 train_time:6767ms step_avg:86.76ms +step:79/1680 train_time:6854ms step_avg:86.76ms +step:80/1680 train_time:6941ms step_avg:86.77ms +step:81/1680 train_time:7029ms step_avg:86.77ms +step:82/1680 train_time:7115ms step_avg:86.77ms +step:83/1680 train_time:7203ms step_avg:86.78ms +step:84/1680 train_time:7290ms step_avg:86.78ms +step:85/1680 train_time:7378ms step_avg:86.79ms +step:86/1680 train_time:7465ms step_avg:86.80ms +step:87/1680 train_time:7551ms step_avg:86.80ms +step:88/1680 train_time:7639ms step_avg:86.81ms +step:89/1680 train_time:7727ms step_avg:86.82ms +step:90/1680 train_time:7814ms step_avg:86.82ms +step:91/1680 train_time:7901ms step_avg:86.82ms +step:92/1680 train_time:7989ms step_avg:86.83ms +step:93/1680 train_time:8076ms step_avg:86.84ms +step:94/1680 train_time:8163ms step_avg:86.84ms +step:95/1680 train_time:8250ms step_avg:86.84ms +step:96/1680 train_time:8337ms step_avg:86.84ms +step:97/1680 train_time:8424ms step_avg:86.85ms +step:98/1680 train_time:8512ms step_avg:86.85ms +step:99/1680 train_time:8599ms step_avg:86.86ms +step:100/1680 train_time:8686ms step_avg:86.86ms +step:101/1680 train_time:8774ms step_avg:86.87ms +step:102/1680 train_time:8862ms step_avg:86.88ms +step:103/1680 train_time:8949ms step_avg:86.88ms +step:104/1680 train_time:9035ms step_avg:86.88ms +step:105/1680 train_time:9123ms step_avg:86.88ms +step:106/1680 train_time:9210ms step_avg:86.89ms +step:107/1680 train_time:9297ms step_avg:86.89ms +step:108/1680 train_time:9384ms step_avg:86.89ms +step:109/1680 train_time:9471ms step_avg:86.89ms +step:110/1680 train_time:9558ms step_avg:86.89ms +step:111/1680 train_time:9645ms step_avg:86.89ms +step:112/1680 train_time:9732ms step_avg:86.89ms +step:113/1680 train_time:9819ms step_avg:86.90ms +step:114/1680 train_time:9907ms step_avg:86.90ms +step:115/1680 train_time:9994ms step_avg:86.91ms +step:116/1680 train_time:10081ms step_avg:86.91ms +step:117/1680 train_time:10168ms step_avg:86.91ms +step:118/1680 train_time:10255ms step_avg:86.91ms +step:119/1680 train_time:10343ms step_avg:86.92ms +step:120/1680 train_time:10430ms step_avg:86.92ms +step:121/1680 train_time:10517ms step_avg:86.92ms +step:122/1680 train_time:10605ms step_avg:86.93ms +step:123/1680 train_time:10691ms step_avg:86.92ms +step:124/1680 train_time:10779ms step_avg:86.93ms +step:125/1680 train_time:10866ms step_avg:86.93ms +step:125/1680 val_loss:4.3233 train_time:10954ms step_avg:87.63ms +step:126/1680 train_time:10978ms step_avg:87.12ms +step:127/1680 train_time:11046ms step_avg:86.98ms +step:128/1680 train_time:11143ms step_avg:87.05ms +step:129/1680 train_time:11232ms step_avg:87.07ms +step:130/1680 train_time:11318ms step_avg:87.06ms +step:131/1680 train_time:11404ms step_avg:87.06ms +step:132/1680 train_time:11491ms step_avg:87.05ms +step:133/1680 train_time:11576ms step_avg:87.04ms +step:134/1680 train_time:11662ms step_avg:87.03ms +step:135/1680 train_time:11748ms step_avg:87.02ms +step:136/1680 train_time:11834ms step_avg:87.02ms +step:137/1680 train_time:11920ms step_avg:87.01ms +step:138/1680 train_time:12008ms step_avg:87.01ms +step:139/1680 train_time:12098ms step_avg:87.03ms +step:140/1680 train_time:12187ms step_avg:87.05ms +step:141/1680 train_time:12275ms step_avg:87.06ms +step:142/1680 train_time:12362ms step_avg:87.06ms +step:143/1680 train_time:12449ms step_avg:87.05ms +step:144/1680 train_time:12535ms step_avg:87.05ms +step:145/1680 train_time:12621ms step_avg:87.04ms +step:146/1680 train_time:12707ms step_avg:87.04ms +step:147/1680 train_time:12794ms step_avg:87.03ms +step:148/1680 train_time:12880ms step_avg:87.03ms +step:149/1680 train_time:12968ms step_avg:87.03ms +step:150/1680 train_time:13057ms step_avg:87.04ms +step:151/1680 train_time:13145ms step_avg:87.05ms +step:152/1680 train_time:13232ms step_avg:87.06ms +step:153/1680 train_time:13320ms step_avg:87.06ms +step:154/1680 train_time:13406ms step_avg:87.05ms +step:155/1680 train_time:13493ms step_avg:87.05ms +step:156/1680 train_time:13580ms step_avg:87.05ms +step:157/1680 train_time:13667ms step_avg:87.05ms +step:158/1680 train_time:13753ms step_avg:87.04ms +step:159/1680 train_time:13839ms step_avg:87.04ms +step:160/1680 train_time:13926ms step_avg:87.04ms +step:161/1680 train_time:14013ms step_avg:87.04ms +step:162/1680 train_time:14101ms step_avg:87.04ms +step:163/1680 train_time:14188ms step_avg:87.04ms +step:164/1680 train_time:14276ms step_avg:87.05ms +step:165/1680 train_time:14364ms step_avg:87.05ms +step:166/1680 train_time:14450ms step_avg:87.05ms +step:167/1680 train_time:14537ms step_avg:87.05ms +step:168/1680 train_time:14624ms step_avg:87.05ms +step:169/1680 train_time:14711ms step_avg:87.05ms +step:170/1680 train_time:14797ms step_avg:87.04ms +step:171/1680 train_time:14885ms step_avg:87.04ms +step:172/1680 train_time:14972ms step_avg:87.05ms +step:173/1680 train_time:15059ms step_avg:87.05ms +step:174/1680 train_time:15147ms step_avg:87.05ms +step:175/1680 train_time:15235ms step_avg:87.05ms +step:176/1680 train_time:15322ms step_avg:87.06ms +step:177/1680 train_time:15409ms step_avg:87.06ms +step:178/1680 train_time:15497ms step_avg:87.06ms +step:179/1680 train_time:15584ms step_avg:87.06ms +step:180/1680 train_time:15671ms step_avg:87.06ms +step:181/1680 train_time:15757ms step_avg:87.06ms +step:182/1680 train_time:15844ms step_avg:87.05ms +step:183/1680 train_time:15931ms step_avg:87.05ms +step:184/1680 train_time:16017ms step_avg:87.05ms +step:185/1680 train_time:16105ms step_avg:87.05ms +step:186/1680 train_time:16192ms step_avg:87.05ms +step:187/1680 train_time:16279ms step_avg:87.05ms +step:188/1680 train_time:16367ms step_avg:87.06ms +step:189/1680 train_time:16454ms step_avg:87.06ms +step:190/1680 train_time:16541ms step_avg:87.06ms +step:191/1680 train_time:16627ms step_avg:87.05ms +step:192/1680 train_time:16714ms step_avg:87.05ms +step:193/1680 train_time:16801ms step_avg:87.05ms +step:194/1680 train_time:16888ms step_avg:87.05ms +step:195/1680 train_time:16974ms step_avg:87.05ms +step:196/1680 train_time:17061ms step_avg:87.05ms +step:197/1680 train_time:17148ms step_avg:87.05ms +step:198/1680 train_time:17236ms step_avg:87.05ms +step:199/1680 train_time:17324ms step_avg:87.05ms +step:200/1680 train_time:17411ms step_avg:87.05ms +step:201/1680 train_time:17498ms step_avg:87.05ms +step:202/1680 train_time:17585ms step_avg:87.05ms +step:203/1680 train_time:17672ms step_avg:87.05ms +step:204/1680 train_time:17759ms step_avg:87.05ms +step:205/1680 train_time:17846ms step_avg:87.05ms +step:206/1680 train_time:17933ms step_avg:87.05ms +step:207/1680 train_time:18020ms step_avg:87.05ms +step:208/1680 train_time:18107ms step_avg:87.05ms +step:209/1680 train_time:18194ms step_avg:87.05ms +step:210/1680 train_time:18281ms step_avg:87.05ms +step:211/1680 train_time:18368ms step_avg:87.05ms +step:212/1680 train_time:18455ms step_avg:87.05ms +step:213/1680 train_time:18542ms step_avg:87.05ms +step:214/1680 train_time:18630ms step_avg:87.06ms +step:215/1680 train_time:18717ms step_avg:87.06ms +step:216/1680 train_time:18804ms step_avg:87.06ms +step:217/1680 train_time:18892ms step_avg:87.06ms +step:218/1680 train_time:18978ms step_avg:87.06ms +step:219/1680 train_time:19067ms step_avg:87.06ms +step:220/1680 train_time:19153ms step_avg:87.06ms +step:221/1680 train_time:19240ms step_avg:87.06ms +step:222/1680 train_time:19327ms step_avg:87.06ms +step:223/1680 train_time:19415ms step_avg:87.06ms +step:224/1680 train_time:19502ms step_avg:87.06ms +step:225/1680 train_time:19589ms step_avg:87.06ms +step:226/1680 train_time:19675ms step_avg:87.06ms +step:227/1680 train_time:19763ms step_avg:87.06ms +step:228/1680 train_time:19850ms step_avg:87.06ms +step:229/1680 train_time:19937ms step_avg:87.06ms +step:230/1680 train_time:20024ms step_avg:87.06ms +step:231/1680 train_time:20111ms step_avg:87.06ms +step:232/1680 train_time:20199ms step_avg:87.06ms +step:233/1680 train_time:20286ms step_avg:87.06ms +step:234/1680 train_time:20373ms step_avg:87.06ms +step:235/1680 train_time:20461ms step_avg:87.07ms +step:236/1680 train_time:20547ms step_avg:87.06ms +step:237/1680 train_time:20634ms step_avg:87.06ms +step:238/1680 train_time:20721ms step_avg:87.06ms +step:239/1680 train_time:20808ms step_avg:87.06ms +step:240/1680 train_time:20896ms step_avg:87.07ms +step:241/1680 train_time:20983ms step_avg:87.06ms +step:242/1680 train_time:21070ms step_avg:87.06ms +step:243/1680 train_time:21156ms step_avg:87.06ms +step:244/1680 train_time:21243ms step_avg:87.06ms +step:245/1680 train_time:21331ms step_avg:87.06ms +step:246/1680 train_time:21417ms step_avg:87.06ms +step:247/1680 train_time:21504ms step_avg:87.06ms +step:248/1680 train_time:21591ms step_avg:87.06ms +step:249/1680 train_time:21678ms step_avg:87.06ms +step:250/1680 train_time:21765ms step_avg:87.06ms +step:250/1680 val_loss:3.9789 train_time:21855ms step_avg:87.42ms +step:251/1680 train_time:21875ms step_avg:87.15ms +step:252/1680 train_time:21944ms step_avg:87.08ms +step:253/1680 train_time:22035ms step_avg:87.09ms +step:254/1680 train_time:22123ms step_avg:87.10ms +step:255/1680 train_time:22209ms step_avg:87.10ms +step:256/1680 train_time:22296ms step_avg:87.09ms +step:257/1680 train_time:22382ms step_avg:87.09ms +step:258/1680 train_time:22468ms step_avg:87.08ms +step:259/1680 train_time:22554ms step_avg:87.08ms +step:260/1680 train_time:22641ms step_avg:87.08ms +step:261/1680 train_time:22727ms step_avg:87.08ms +step:262/1680 train_time:22815ms step_avg:87.08ms +step:263/1680 train_time:22902ms step_avg:87.08ms +step:264/1680 train_time:22992ms step_avg:87.09ms +step:265/1680 train_time:23080ms step_avg:87.09ms +step:266/1680 train_time:23167ms step_avg:87.10ms +step:267/1680 train_time:23255ms step_avg:87.10ms +step:268/1680 train_time:23341ms step_avg:87.09ms +step:269/1680 train_time:23428ms step_avg:87.09ms +step:270/1680 train_time:23515ms step_avg:87.09ms +step:271/1680 train_time:23602ms step_avg:87.09ms +step:272/1680 train_time:23687ms step_avg:87.09ms +step:273/1680 train_time:23774ms step_avg:87.08ms +step:274/1680 train_time:23862ms step_avg:87.09ms +step:275/1680 train_time:23949ms step_avg:87.09ms +step:276/1680 train_time:24038ms step_avg:87.09ms +step:277/1680 train_time:24125ms step_avg:87.09ms +step:278/1680 train_time:24213ms step_avg:87.10ms +step:279/1680 train_time:24300ms step_avg:87.10ms +step:280/1680 train_time:24386ms step_avg:87.09ms +step:281/1680 train_time:24473ms step_avg:87.09ms +step:282/1680 train_time:24559ms step_avg:87.09ms +step:283/1680 train_time:24645ms step_avg:87.09ms +step:284/1680 train_time:24732ms step_avg:87.09ms +step:285/1680 train_time:24820ms step_avg:87.09ms +step:286/1680 train_time:24907ms step_avg:87.09ms +step:287/1680 train_time:24995ms step_avg:87.09ms +step:288/1680 train_time:25082ms step_avg:87.09ms +step:289/1680 train_time:25170ms step_avg:87.09ms +step:290/1680 train_time:25257ms step_avg:87.09ms +step:291/1680 train_time:25343ms step_avg:87.09ms +step:292/1680 train_time:25430ms step_avg:87.09ms +step:293/1680 train_time:25517ms step_avg:87.09ms +step:294/1680 train_time:25604ms step_avg:87.09ms +step:295/1680 train_time:25691ms step_avg:87.09ms +step:296/1680 train_time:25777ms step_avg:87.09ms +step:297/1680 train_time:25864ms step_avg:87.08ms +step:298/1680 train_time:25951ms step_avg:87.08ms +step:299/1680 train_time:26038ms step_avg:87.08ms +step:300/1680 train_time:26125ms step_avg:87.08ms +step:301/1680 train_time:26213ms step_avg:87.08ms +step:302/1680 train_time:26299ms step_avg:87.08ms +step:303/1680 train_time:26386ms step_avg:87.08ms +step:304/1680 train_time:26474ms step_avg:87.08ms +step:305/1680 train_time:26560ms step_avg:87.08ms +step:306/1680 train_time:26647ms step_avg:87.08ms +step:307/1680 train_time:26734ms step_avg:87.08ms +step:308/1680 train_time:26822ms step_avg:87.08ms +step:309/1680 train_time:26908ms step_avg:87.08ms +step:310/1680 train_time:26996ms step_avg:87.09ms +step:311/1680 train_time:27083ms step_avg:87.08ms +step:312/1680 train_time:27171ms step_avg:87.09ms +step:313/1680 train_time:27258ms step_avg:87.09ms +step:314/1680 train_time:27345ms step_avg:87.09ms +step:315/1680 train_time:27433ms step_avg:87.09ms +step:316/1680 train_time:27520ms step_avg:87.09ms +step:317/1680 train_time:27607ms step_avg:87.09ms +step:318/1680 train_time:27695ms step_avg:87.09ms +step:319/1680 train_time:27781ms step_avg:87.09ms +step:320/1680 train_time:27868ms step_avg:87.09ms +step:321/1680 train_time:27955ms step_avg:87.09ms +step:322/1680 train_time:28042ms step_avg:87.09ms +step:323/1680 train_time:28129ms step_avg:87.09ms +step:324/1680 train_time:28217ms step_avg:87.09ms +step:325/1680 train_time:28304ms step_avg:87.09ms +step:326/1680 train_time:28392ms step_avg:87.09ms +step:327/1680 train_time:28478ms step_avg:87.09ms +step:328/1680 train_time:28566ms step_avg:87.09ms +step:329/1680 train_time:28653ms step_avg:87.09ms +step:330/1680 train_time:28739ms step_avg:87.09ms +step:331/1680 train_time:28826ms step_avg:87.09ms +step:332/1680 train_time:28915ms step_avg:87.09ms +step:333/1680 train_time:29000ms step_avg:87.09ms +step:334/1680 train_time:29088ms step_avg:87.09ms +step:335/1680 train_time:29175ms step_avg:87.09ms +step:336/1680 train_time:29262ms step_avg:87.09ms +step:337/1680 train_time:29350ms step_avg:87.09ms +step:338/1680 train_time:29437ms step_avg:87.09ms +step:339/1680 train_time:29524ms step_avg:87.09ms +step:340/1680 train_time:29611ms step_avg:87.09ms +step:341/1680 train_time:29697ms step_avg:87.09ms +step:342/1680 train_time:29784ms step_avg:87.09ms +step:343/1680 train_time:29872ms step_avg:87.09ms +step:344/1680 train_time:29960ms step_avg:87.09ms +step:345/1680 train_time:30047ms step_avg:87.09ms +step:346/1680 train_time:30134ms step_avg:87.09ms +step:347/1680 train_time:30221ms step_avg:87.09ms +step:348/1680 train_time:30308ms step_avg:87.09ms +step:349/1680 train_time:30396ms step_avg:87.09ms +step:350/1680 train_time:30483ms step_avg:87.10ms +step:351/1680 train_time:30571ms step_avg:87.10ms +step:352/1680 train_time:30657ms step_avg:87.09ms +step:353/1680 train_time:30744ms step_avg:87.09ms +step:354/1680 train_time:30832ms step_avg:87.10ms +step:355/1680 train_time:30919ms step_avg:87.10ms +step:356/1680 train_time:31007ms step_avg:87.10ms +step:357/1680 train_time:31094ms step_avg:87.10ms +step:358/1680 train_time:31182ms step_avg:87.10ms +step:359/1680 train_time:31268ms step_avg:87.10ms +step:360/1680 train_time:31355ms step_avg:87.10ms +step:361/1680 train_time:31442ms step_avg:87.10ms +step:362/1680 train_time:31529ms step_avg:87.10ms +step:363/1680 train_time:31617ms step_avg:87.10ms +step:364/1680 train_time:31704ms step_avg:87.10ms +step:365/1680 train_time:31792ms step_avg:87.10ms +step:366/1680 train_time:31879ms step_avg:87.10ms +step:367/1680 train_time:31966ms step_avg:87.10ms +step:368/1680 train_time:32053ms step_avg:87.10ms +step:369/1680 train_time:32140ms step_avg:87.10ms +step:370/1680 train_time:32227ms step_avg:87.10ms +step:371/1680 train_time:32315ms step_avg:87.10ms +step:372/1680 train_time:32401ms step_avg:87.10ms +step:373/1680 train_time:32489ms step_avg:87.10ms +step:374/1680 train_time:32576ms step_avg:87.10ms +step:375/1680 train_time:32663ms step_avg:87.10ms +step:375/1680 val_loss:3.8240 train_time:32751ms step_avg:87.34ms +step:376/1680 train_time:32773ms step_avg:87.16ms +step:377/1680 train_time:32841ms step_avg:87.11ms +step:378/1680 train_time:32931ms step_avg:87.12ms +step:379/1680 train_time:33018ms step_avg:87.12ms +step:380/1680 train_time:33105ms step_avg:87.12ms +step:381/1680 train_time:33191ms step_avg:87.11ms +step:382/1680 train_time:33277ms step_avg:87.11ms +step:383/1680 train_time:33363ms step_avg:87.11ms +step:384/1680 train_time:33449ms step_avg:87.11ms +step:385/1680 train_time:33535ms step_avg:87.10ms +step:386/1680 train_time:33621ms step_avg:87.10ms +step:387/1680 train_time:33709ms step_avg:87.10ms +step:388/1680 train_time:33797ms step_avg:87.11ms +step:389/1680 train_time:33886ms step_avg:87.11ms +step:390/1680 train_time:33975ms step_avg:87.12ms +step:391/1680 train_time:34062ms step_avg:87.12ms +step:392/1680 train_time:34149ms step_avg:87.12ms +step:393/1680 train_time:34236ms step_avg:87.11ms +step:394/1680 train_time:34322ms step_avg:87.11ms +step:395/1680 train_time:34408ms step_avg:87.11ms +step:396/1680 train_time:34495ms step_avg:87.11ms +step:397/1680 train_time:34582ms step_avg:87.11ms +step:398/1680 train_time:34669ms step_avg:87.11ms +step:399/1680 train_time:34756ms step_avg:87.11ms +step:400/1680 train_time:34844ms step_avg:87.11ms +step:401/1680 train_time:34932ms step_avg:87.11ms +step:402/1680 train_time:35020ms step_avg:87.11ms +step:403/1680 train_time:35107ms step_avg:87.11ms +step:404/1680 train_time:35194ms step_avg:87.11ms +step:405/1680 train_time:35280ms step_avg:87.11ms +step:406/1680 train_time:35367ms step_avg:87.11ms +step:407/1680 train_time:35454ms step_avg:87.11ms +step:408/1680 train_time:35541ms step_avg:87.11ms +step:409/1680 train_time:35628ms step_avg:87.11ms +step:410/1680 train_time:35715ms step_avg:87.11ms +step:411/1680 train_time:35802ms step_avg:87.11ms +step:412/1680 train_time:35890ms step_avg:87.11ms +step:413/1680 train_time:35977ms step_avg:87.11ms +step:414/1680 train_time:36065ms step_avg:87.11ms +step:415/1680 train_time:36152ms step_avg:87.11ms +step:416/1680 train_time:36238ms step_avg:87.11ms +step:417/1680 train_time:36325ms step_avg:87.11ms +step:418/1680 train_time:36412ms step_avg:87.11ms +step:419/1680 train_time:36499ms step_avg:87.11ms +step:420/1680 train_time:36586ms step_avg:87.11ms +step:421/1680 train_time:36674ms step_avg:87.11ms +step:422/1680 train_time:36760ms step_avg:87.11ms +step:423/1680 train_time:36847ms step_avg:87.11ms +step:424/1680 train_time:36935ms step_avg:87.11ms +step:425/1680 train_time:37022ms step_avg:87.11ms +step:426/1680 train_time:37109ms step_avg:87.11ms +step:427/1680 train_time:37196ms step_avg:87.11ms +step:428/1680 train_time:37284ms step_avg:87.11ms +step:429/1680 train_time:37370ms step_avg:87.11ms +step:430/1680 train_time:37457ms step_avg:87.11ms +step:431/1680 train_time:37544ms step_avg:87.11ms +step:432/1680 train_time:37631ms step_avg:87.11ms +step:433/1680 train_time:37717ms step_avg:87.11ms +step:434/1680 train_time:37804ms step_avg:87.10ms +step:435/1680 train_time:37890ms step_avg:87.10ms +step:436/1680 train_time:37978ms step_avg:87.11ms +step:437/1680 train_time:38065ms step_avg:87.11ms +step:438/1680 train_time:38152ms step_avg:87.11ms +step:439/1680 train_time:38239ms step_avg:87.11ms +step:440/1680 train_time:38326ms step_avg:87.10ms +step:441/1680 train_time:38414ms step_avg:87.11ms +step:442/1680 train_time:38501ms step_avg:87.11ms +step:443/1680 train_time:38588ms step_avg:87.11ms +step:444/1680 train_time:38674ms step_avg:87.10ms +step:445/1680 train_time:38761ms step_avg:87.10ms +step:446/1680 train_time:38848ms step_avg:87.10ms +step:447/1680 train_time:38935ms step_avg:87.10ms +step:448/1680 train_time:39022ms step_avg:87.10ms +step:449/1680 train_time:39109ms step_avg:87.10ms +step:450/1680 train_time:39195ms step_avg:87.10ms +step:451/1680 train_time:39283ms step_avg:87.10ms +step:452/1680 train_time:39370ms step_avg:87.10ms +step:453/1680 train_time:39458ms step_avg:87.10ms +step:454/1680 train_time:39544ms step_avg:87.10ms +step:455/1680 train_time:39631ms step_avg:87.10ms +step:456/1680 train_time:39718ms step_avg:87.10ms +step:457/1680 train_time:39805ms step_avg:87.10ms +step:458/1680 train_time:39891ms step_avg:87.10ms +step:459/1680 train_time:39979ms step_avg:87.10ms +step:460/1680 train_time:40065ms step_avg:87.10ms +step:461/1680 train_time:40153ms step_avg:87.10ms +step:462/1680 train_time:40239ms step_avg:87.10ms +step:463/1680 train_time:40327ms step_avg:87.10ms +step:464/1680 train_time:40414ms step_avg:87.10ms +step:465/1680 train_time:40501ms step_avg:87.10ms +step:466/1680 train_time:40588ms step_avg:87.10ms +step:467/1680 train_time:40675ms step_avg:87.10ms +step:468/1680 train_time:40762ms step_avg:87.10ms +step:469/1680 train_time:40849ms step_avg:87.10ms +step:470/1680 train_time:40935ms step_avg:87.10ms +step:471/1680 train_time:41022ms step_avg:87.10ms +step:472/1680 train_time:41109ms step_avg:87.10ms +step:473/1680 train_time:41196ms step_avg:87.10ms +step:474/1680 train_time:41283ms step_avg:87.09ms +step:475/1680 train_time:41371ms step_avg:87.10ms +step:476/1680 train_time:41458ms step_avg:87.10ms +step:477/1680 train_time:41545ms step_avg:87.10ms +step:478/1680 train_time:41633ms step_avg:87.10ms +step:479/1680 train_time:41720ms step_avg:87.10ms +step:480/1680 train_time:41807ms step_avg:87.10ms +step:481/1680 train_time:41894ms step_avg:87.10ms +step:482/1680 train_time:41981ms step_avg:87.10ms +step:483/1680 train_time:42069ms step_avg:87.10ms +step:484/1680 train_time:42155ms step_avg:87.10ms +step:485/1680 train_time:42242ms step_avg:87.10ms +step:486/1680 train_time:42329ms step_avg:87.10ms +step:487/1680 train_time:42416ms step_avg:87.10ms +step:488/1680 train_time:42503ms step_avg:87.10ms +step:489/1680 train_time:42590ms step_avg:87.10ms +step:490/1680 train_time:42678ms step_avg:87.10ms +step:491/1680 train_time:42765ms step_avg:87.10ms +step:492/1680 train_time:42852ms step_avg:87.10ms +step:493/1680 train_time:42939ms step_avg:87.10ms +step:494/1680 train_time:43026ms step_avg:87.10ms +step:495/1680 train_time:43113ms step_avg:87.10ms +step:496/1680 train_time:43201ms step_avg:87.10ms +step:497/1680 train_time:43288ms step_avg:87.10ms +step:498/1680 train_time:43375ms step_avg:87.10ms +step:499/1680 train_time:43462ms step_avg:87.10ms +step:500/1680 train_time:43548ms step_avg:87.10ms +step:500/1680 val_loss:3.7209 train_time:43638ms step_avg:87.28ms +step:501/1680 train_time:43656ms step_avg:87.14ms +step:502/1680 train_time:43727ms step_avg:87.11ms +step:503/1680 train_time:43817ms step_avg:87.11ms +step:504/1680 train_time:43905ms step_avg:87.11ms +step:505/1680 train_time:43992ms step_avg:87.11ms +step:506/1680 train_time:44078ms step_avg:87.11ms +step:507/1680 train_time:44165ms step_avg:87.11ms +step:508/1680 train_time:44251ms step_avg:87.11ms +step:509/1680 train_time:44337ms step_avg:87.11ms +step:510/1680 train_time:44424ms step_avg:87.11ms +step:511/1680 train_time:44510ms step_avg:87.10ms +step:512/1680 train_time:44597ms step_avg:87.10ms +step:513/1680 train_time:44685ms step_avg:87.11ms +step:514/1680 train_time:44773ms step_avg:87.11ms +step:515/1680 train_time:44861ms step_avg:87.11ms +step:516/1680 train_time:44949ms step_avg:87.11ms +step:517/1680 train_time:45037ms step_avg:87.11ms +step:518/1680 train_time:45123ms step_avg:87.11ms +step:519/1680 train_time:45210ms step_avg:87.11ms +step:520/1680 train_time:45297ms step_avg:87.11ms +step:521/1680 train_time:45385ms step_avg:87.11ms +step:522/1680 train_time:45471ms step_avg:87.11ms +step:523/1680 train_time:45558ms step_avg:87.11ms +step:524/1680 train_time:45645ms step_avg:87.11ms +step:525/1680 train_time:45732ms step_avg:87.11ms +step:526/1680 train_time:45821ms step_avg:87.11ms +step:527/1680 train_time:45908ms step_avg:87.11ms +step:528/1680 train_time:45996ms step_avg:87.11ms +step:529/1680 train_time:46083ms step_avg:87.11ms +step:530/1680 train_time:46169ms step_avg:87.11ms +step:531/1680 train_time:46257ms step_avg:87.11ms +step:532/1680 train_time:46345ms step_avg:87.11ms +step:533/1680 train_time:46431ms step_avg:87.11ms +step:534/1680 train_time:46518ms step_avg:87.11ms +step:535/1680 train_time:46604ms step_avg:87.11ms +step:536/1680 train_time:46692ms step_avg:87.11ms +step:537/1680 train_time:46779ms step_avg:87.11ms +step:538/1680 train_time:46868ms step_avg:87.11ms +step:539/1680 train_time:46955ms step_avg:87.11ms +step:540/1680 train_time:47042ms step_avg:87.11ms +step:541/1680 train_time:47129ms step_avg:87.11ms +step:542/1680 train_time:47216ms step_avg:87.11ms +step:543/1680 train_time:47303ms step_avg:87.11ms +step:544/1680 train_time:47390ms step_avg:87.11ms +step:545/1680 train_time:47477ms step_avg:87.11ms +step:546/1680 train_time:47565ms step_avg:87.11ms +step:547/1680 train_time:47651ms step_avg:87.11ms +step:548/1680 train_time:47739ms step_avg:87.12ms +step:549/1680 train_time:47828ms step_avg:87.12ms +step:550/1680 train_time:47917ms step_avg:87.12ms +step:551/1680 train_time:48005ms step_avg:87.12ms +step:552/1680 train_time:48094ms step_avg:87.13ms +step:553/1680 train_time:48181ms step_avg:87.13ms +step:554/1680 train_time:48269ms step_avg:87.13ms +step:555/1680 train_time:48357ms step_avg:87.13ms +step:556/1680 train_time:48445ms step_avg:87.13ms +step:557/1680 train_time:48534ms step_avg:87.13ms +step:558/1680 train_time:48621ms step_avg:87.14ms +step:559/1680 train_time:48710ms step_avg:87.14ms +step:560/1680 train_time:48798ms step_avg:87.14ms +step:561/1680 train_time:48886ms step_avg:87.14ms +step:562/1680 train_time:48974ms step_avg:87.14ms +step:563/1680 train_time:49063ms step_avg:87.15ms +step:564/1680 train_time:49151ms step_avg:87.15ms +step:565/1680 train_time:49239ms step_avg:87.15ms +step:566/1680 train_time:49327ms step_avg:87.15ms +step:567/1680 train_time:49415ms step_avg:87.15ms +step:568/1680 train_time:49503ms step_avg:87.15ms +step:569/1680 train_time:49591ms step_avg:87.15ms +step:570/1680 train_time:49679ms step_avg:87.16ms +step:571/1680 train_time:49769ms step_avg:87.16ms +step:572/1680 train_time:49857ms step_avg:87.16ms +step:573/1680 train_time:49945ms step_avg:87.16ms +step:574/1680 train_time:50033ms step_avg:87.16ms +step:575/1680 train_time:50121ms step_avg:87.17ms +step:576/1680 train_time:50209ms step_avg:87.17ms +step:577/1680 train_time:50297ms step_avg:87.17ms +step:578/1680 train_time:50385ms step_avg:87.17ms +step:579/1680 train_time:50473ms step_avg:87.17ms +step:580/1680 train_time:50561ms step_avg:87.17ms +step:581/1680 train_time:50650ms step_avg:87.18ms +step:582/1680 train_time:50738ms step_avg:87.18ms +step:583/1680 train_time:50827ms step_avg:87.18ms +step:584/1680 train_time:50915ms step_avg:87.18ms +step:585/1680 train_time:51003ms step_avg:87.18ms +step:586/1680 train_time:51091ms step_avg:87.19ms +step:587/1680 train_time:51179ms step_avg:87.19ms +step:588/1680 train_time:51268ms step_avg:87.19ms +step:589/1680 train_time:51356ms step_avg:87.19ms +step:590/1680 train_time:51443ms step_avg:87.19ms +step:591/1680 train_time:51531ms step_avg:87.19ms +step:592/1680 train_time:51619ms step_avg:87.19ms +step:593/1680 train_time:51707ms step_avg:87.20ms +step:594/1680 train_time:51795ms step_avg:87.20ms +step:595/1680 train_time:51883ms step_avg:87.20ms +step:596/1680 train_time:51971ms step_avg:87.20ms +step:597/1680 train_time:52059ms step_avg:87.20ms +step:598/1680 train_time:52148ms step_avg:87.20ms +step:599/1680 train_time:52236ms step_avg:87.21ms +step:600/1680 train_time:52325ms step_avg:87.21ms +step:601/1680 train_time:52413ms step_avg:87.21ms +step:602/1680 train_time:52501ms step_avg:87.21ms +step:603/1680 train_time:52589ms step_avg:87.21ms +step:604/1680 train_time:52677ms step_avg:87.21ms +step:605/1680 train_time:52766ms step_avg:87.22ms +step:606/1680 train_time:52854ms step_avg:87.22ms +step:607/1680 train_time:52942ms step_avg:87.22ms +step:608/1680 train_time:53030ms step_avg:87.22ms +step:609/1680 train_time:53118ms step_avg:87.22ms +step:610/1680 train_time:53206ms step_avg:87.22ms +step:611/1680 train_time:53294ms step_avg:87.22ms +step:612/1680 train_time:53382ms step_avg:87.23ms +step:613/1680 train_time:53470ms step_avg:87.23ms +step:614/1680 train_time:53559ms step_avg:87.23ms +step:615/1680 train_time:53648ms step_avg:87.23ms +step:616/1680 train_time:53736ms step_avg:87.23ms +step:617/1680 train_time:53825ms step_avg:87.24ms +step:618/1680 train_time:53912ms step_avg:87.24ms +step:619/1680 train_time:54001ms step_avg:87.24ms +step:620/1680 train_time:54089ms step_avg:87.24ms +step:621/1680 train_time:54176ms step_avg:87.24ms +step:622/1680 train_time:54265ms step_avg:87.24ms +step:623/1680 train_time:54352ms step_avg:87.24ms +step:624/1680 train_time:54441ms step_avg:87.25ms +step:625/1680 train_time:54529ms step_avg:87.25ms +step:625/1680 val_loss:3.6205 train_time:54619ms step_avg:87.39ms +step:626/1680 train_time:54640ms step_avg:87.28ms +step:627/1680 train_time:54712ms step_avg:87.26ms +step:628/1680 train_time:54801ms step_avg:87.26ms +step:629/1680 train_time:54890ms step_avg:87.27ms +step:630/1680 train_time:54979ms step_avg:87.27ms +step:631/1680 train_time:55066ms step_avg:87.27ms +step:632/1680 train_time:55152ms step_avg:87.27ms +step:633/1680 train_time:55240ms step_avg:87.27ms +step:634/1680 train_time:55326ms step_avg:87.27ms +step:635/1680 train_time:55413ms step_avg:87.27ms +step:636/1680 train_time:55501ms step_avg:87.27ms +step:637/1680 train_time:55592ms step_avg:87.27ms +step:638/1680 train_time:55682ms step_avg:87.28ms +step:639/1680 train_time:55772ms step_avg:87.28ms +step:640/1680 train_time:55861ms step_avg:87.28ms +step:641/1680 train_time:55949ms step_avg:87.28ms +step:642/1680 train_time:56037ms step_avg:87.28ms +step:643/1680 train_time:56125ms step_avg:87.29ms +step:644/1680 train_time:56212ms step_avg:87.29ms +step:645/1680 train_time:56299ms step_avg:87.29ms +step:646/1680 train_time:56386ms step_avg:87.29ms +step:647/1680 train_time:56475ms step_avg:87.29ms +step:648/1680 train_time:56564ms step_avg:87.29ms +step:649/1680 train_time:56653ms step_avg:87.29ms +step:650/1680 train_time:56742ms step_avg:87.29ms +step:651/1680 train_time:56831ms step_avg:87.30ms +step:652/1680 train_time:56919ms step_avg:87.30ms +step:653/1680 train_time:57008ms step_avg:87.30ms +step:654/1680 train_time:57096ms step_avg:87.30ms +step:655/1680 train_time:57184ms step_avg:87.30ms +step:656/1680 train_time:57271ms step_avg:87.30ms +step:657/1680 train_time:57359ms step_avg:87.30ms +step:658/1680 train_time:57447ms step_avg:87.31ms +step:659/1680 train_time:57535ms step_avg:87.31ms +step:660/1680 train_time:57623ms step_avg:87.31ms +step:661/1680 train_time:57712ms step_avg:87.31ms +step:662/1680 train_time:57801ms step_avg:87.31ms +step:663/1680 train_time:57889ms step_avg:87.31ms +step:664/1680 train_time:57977ms step_avg:87.32ms +step:665/1680 train_time:58066ms step_avg:87.32ms +step:666/1680 train_time:58153ms step_avg:87.32ms +step:667/1680 train_time:58241ms step_avg:87.32ms +step:668/1680 train_time:58329ms step_avg:87.32ms +step:669/1680 train_time:58416ms step_avg:87.32ms +step:670/1680 train_time:58504ms step_avg:87.32ms +step:671/1680 train_time:58592ms step_avg:87.32ms +step:672/1680 train_time:58681ms step_avg:87.32ms +step:673/1680 train_time:58769ms step_avg:87.32ms +step:674/1680 train_time:58858ms step_avg:87.33ms +step:675/1680 train_time:58946ms step_avg:87.33ms +step:676/1680 train_time:59034ms step_avg:87.33ms +step:677/1680 train_time:59122ms step_avg:87.33ms +step:678/1680 train_time:59209ms step_avg:87.33ms +step:679/1680 train_time:59297ms step_avg:87.33ms +step:680/1680 train_time:59384ms step_avg:87.33ms +step:681/1680 train_time:59472ms step_avg:87.33ms +step:682/1680 train_time:59560ms step_avg:87.33ms +step:683/1680 train_time:59648ms step_avg:87.33ms +step:684/1680 train_time:59737ms step_avg:87.33ms +step:685/1680 train_time:59825ms step_avg:87.34ms +step:686/1680 train_time:59914ms step_avg:87.34ms +step:687/1680 train_time:60002ms step_avg:87.34ms +step:688/1680 train_time:60091ms step_avg:87.34ms +step:689/1680 train_time:60179ms step_avg:87.34ms +step:690/1680 train_time:60267ms step_avg:87.34ms +step:691/1680 train_time:60355ms step_avg:87.34ms +step:692/1680 train_time:60444ms step_avg:87.35ms +step:693/1680 train_time:60532ms step_avg:87.35ms +step:694/1680 train_time:60619ms step_avg:87.35ms +step:695/1680 train_time:60707ms step_avg:87.35ms +step:696/1680 train_time:60795ms step_avg:87.35ms +step:697/1680 train_time:60883ms step_avg:87.35ms +step:698/1680 train_time:60972ms step_avg:87.35ms +step:699/1680 train_time:61060ms step_avg:87.35ms +step:700/1680 train_time:61148ms step_avg:87.35ms +step:701/1680 train_time:61236ms step_avg:87.36ms +step:702/1680 train_time:61324ms step_avg:87.36ms +step:703/1680 train_time:61412ms step_avg:87.36ms +step:704/1680 train_time:61499ms step_avg:87.36ms +step:705/1680 train_time:61587ms step_avg:87.36ms +step:706/1680 train_time:61676ms step_avg:87.36ms +step:707/1680 train_time:61765ms step_avg:87.36ms +step:708/1680 train_time:61853ms step_avg:87.36ms +step:709/1680 train_time:61942ms step_avg:87.36ms +step:710/1680 train_time:62029ms step_avg:87.37ms +step:711/1680 train_time:62117ms step_avg:87.37ms +step:712/1680 train_time:62205ms step_avg:87.37ms +step:713/1680 train_time:62292ms step_avg:87.37ms +step:714/1680 train_time:62380ms step_avg:87.37ms +step:715/1680 train_time:62469ms step_avg:87.37ms +step:716/1680 train_time:62556ms step_avg:87.37ms +step:717/1680 train_time:62644ms step_avg:87.37ms +step:718/1680 train_time:62732ms step_avg:87.37ms +step:719/1680 train_time:62821ms step_avg:87.37ms +step:720/1680 train_time:62909ms step_avg:87.37ms +step:721/1680 train_time:62997ms step_avg:87.37ms +step:722/1680 train_time:63085ms step_avg:87.38ms +step:723/1680 train_time:63173ms step_avg:87.38ms +step:724/1680 train_time:63261ms step_avg:87.38ms +step:725/1680 train_time:63349ms step_avg:87.38ms +step:726/1680 train_time:63437ms step_avg:87.38ms +step:727/1680 train_time:63525ms step_avg:87.38ms +step:728/1680 train_time:63613ms step_avg:87.38ms +step:729/1680 train_time:63702ms step_avg:87.38ms +step:730/1680 train_time:63790ms step_avg:87.38ms +step:731/1680 train_time:63877ms step_avg:87.38ms +step:732/1680 train_time:63966ms step_avg:87.38ms +step:733/1680 train_time:64054ms step_avg:87.39ms +step:734/1680 train_time:64142ms step_avg:87.39ms +step:735/1680 train_time:64230ms step_avg:87.39ms +step:736/1680 train_time:64318ms step_avg:87.39ms +step:737/1680 train_time:64407ms step_avg:87.39ms +step:738/1680 train_time:64494ms step_avg:87.39ms +step:739/1680 train_time:64582ms step_avg:87.39ms +step:740/1680 train_time:64669ms step_avg:87.39ms +step:741/1680 train_time:64757ms step_avg:87.39ms +step:742/1680 train_time:64845ms step_avg:87.39ms +step:743/1680 train_time:64933ms step_avg:87.39ms +step:744/1680 train_time:65022ms step_avg:87.40ms +step:745/1680 train_time:65111ms step_avg:87.40ms +step:746/1680 train_time:65199ms step_avg:87.40ms +step:747/1680 train_time:65287ms step_avg:87.40ms +step:748/1680 train_time:65376ms step_avg:87.40ms +step:749/1680 train_time:65464ms step_avg:87.40ms +step:750/1680 train_time:65552ms step_avg:87.40ms +step:750/1680 val_loss:3.5684 train_time:65642ms step_avg:87.52ms +step:751/1680 train_time:65662ms step_avg:87.43ms +step:752/1680 train_time:65733ms step_avg:87.41ms +step:753/1680 train_time:65824ms step_avg:87.42ms +step:754/1680 train_time:65913ms step_avg:87.42ms +step:755/1680 train_time:66000ms step_avg:87.42ms +step:756/1680 train_time:66087ms step_avg:87.42ms +step:757/1680 train_time:66174ms step_avg:87.42ms +step:758/1680 train_time:66261ms step_avg:87.42ms +step:759/1680 train_time:66348ms step_avg:87.42ms +step:760/1680 train_time:66435ms step_avg:87.41ms +step:761/1680 train_time:66523ms step_avg:87.41ms +step:762/1680 train_time:66611ms step_avg:87.42ms +step:763/1680 train_time:66700ms step_avg:87.42ms +step:764/1680 train_time:66791ms step_avg:87.42ms +step:765/1680 train_time:66880ms step_avg:87.43ms +step:766/1680 train_time:66969ms step_avg:87.43ms +step:767/1680 train_time:67057ms step_avg:87.43ms +step:768/1680 train_time:67145ms step_avg:87.43ms +step:769/1680 train_time:67232ms step_avg:87.43ms +step:770/1680 train_time:67320ms step_avg:87.43ms +step:771/1680 train_time:67408ms step_avg:87.43ms +step:772/1680 train_time:67495ms step_avg:87.43ms +step:773/1680 train_time:67583ms step_avg:87.43ms +step:774/1680 train_time:67672ms step_avg:87.43ms +step:775/1680 train_time:67761ms step_avg:87.43ms +step:776/1680 train_time:67850ms step_avg:87.44ms +step:777/1680 train_time:67939ms step_avg:87.44ms +step:778/1680 train_time:68027ms step_avg:87.44ms +step:779/1680 train_time:68115ms step_avg:87.44ms +step:780/1680 train_time:68203ms step_avg:87.44ms +step:781/1680 train_time:68290ms step_avg:87.44ms +step:782/1680 train_time:68378ms step_avg:87.44ms +step:783/1680 train_time:68465ms step_avg:87.44ms +step:784/1680 train_time:68554ms step_avg:87.44ms +step:785/1680 train_time:68642ms step_avg:87.44ms +step:786/1680 train_time:68731ms step_avg:87.44ms +step:787/1680 train_time:68820ms step_avg:87.45ms +step:788/1680 train_time:68909ms step_avg:87.45ms +step:789/1680 train_time:68998ms step_avg:87.45ms +step:790/1680 train_time:69086ms step_avg:87.45ms +step:791/1680 train_time:69174ms step_avg:87.45ms +step:792/1680 train_time:69261ms step_avg:87.45ms +step:793/1680 train_time:69349ms step_avg:87.45ms +step:794/1680 train_time:69437ms step_avg:87.45ms +step:795/1680 train_time:69525ms step_avg:87.45ms +step:796/1680 train_time:69612ms step_avg:87.45ms +step:797/1680 train_time:69700ms step_avg:87.45ms +step:798/1680 train_time:69789ms step_avg:87.46ms +step:799/1680 train_time:69879ms step_avg:87.46ms +step:800/1680 train_time:69967ms step_avg:87.46ms +step:801/1680 train_time:70055ms step_avg:87.46ms +step:802/1680 train_time:70143ms step_avg:87.46ms +step:803/1680 train_time:70231ms step_avg:87.46ms +step:804/1680 train_time:70319ms step_avg:87.46ms +step:805/1680 train_time:70407ms step_avg:87.46ms +step:806/1680 train_time:70495ms step_avg:87.46ms +step:807/1680 train_time:70583ms step_avg:87.46ms +step:808/1680 train_time:70672ms step_avg:87.46ms +step:809/1680 train_time:70761ms step_avg:87.47ms +step:810/1680 train_time:70849ms step_avg:87.47ms +step:811/1680 train_time:70938ms step_avg:87.47ms +step:812/1680 train_time:71026ms step_avg:87.47ms +step:813/1680 train_time:71115ms step_avg:87.47ms +step:814/1680 train_time:71203ms step_avg:87.47ms +step:815/1680 train_time:71291ms step_avg:87.47ms +step:816/1680 train_time:71379ms step_avg:87.47ms +step:817/1680 train_time:71466ms step_avg:87.47ms +step:818/1680 train_time:71555ms step_avg:87.47ms +step:819/1680 train_time:71642ms step_avg:87.48ms +step:820/1680 train_time:71730ms step_avg:87.48ms +step:821/1680 train_time:71818ms step_avg:87.48ms +step:822/1680 train_time:71906ms step_avg:87.48ms +step:823/1680 train_time:71994ms step_avg:87.48ms +step:824/1680 train_time:72082ms step_avg:87.48ms +step:825/1680 train_time:72171ms step_avg:87.48ms +step:826/1680 train_time:72260ms step_avg:87.48ms +step:827/1680 train_time:72348ms step_avg:87.48ms +step:828/1680 train_time:72435ms step_avg:87.48ms +step:829/1680 train_time:72524ms step_avg:87.48ms +step:830/1680 train_time:72612ms step_avg:87.48ms +step:831/1680 train_time:72700ms step_avg:87.48ms +step:832/1680 train_time:72788ms step_avg:87.49ms +step:833/1680 train_time:72876ms step_avg:87.49ms +step:834/1680 train_time:72964ms step_avg:87.49ms +step:835/1680 train_time:73052ms step_avg:87.49ms +step:836/1680 train_time:73141ms step_avg:87.49ms +step:837/1680 train_time:73229ms step_avg:87.49ms +step:838/1680 train_time:73317ms step_avg:87.49ms +step:839/1680 train_time:73405ms step_avg:87.49ms +step:840/1680 train_time:73493ms step_avg:87.49ms +step:841/1680 train_time:73581ms step_avg:87.49ms +step:842/1680 train_time:73668ms step_avg:87.49ms +step:843/1680 train_time:73756ms step_avg:87.49ms +step:844/1680 train_time:73844ms step_avg:87.49ms +step:845/1680 train_time:73933ms step_avg:87.49ms +step:846/1680 train_time:74021ms step_avg:87.50ms +step:847/1680 train_time:74109ms step_avg:87.50ms +step:848/1680 train_time:74197ms step_avg:87.50ms +step:849/1680 train_time:74284ms step_avg:87.50ms +step:850/1680 train_time:74372ms step_avg:87.50ms +step:851/1680 train_time:74461ms step_avg:87.50ms +step:852/1680 train_time:74548ms step_avg:87.50ms +step:853/1680 train_time:74637ms step_avg:87.50ms +step:854/1680 train_time:74725ms step_avg:87.50ms +step:855/1680 train_time:74812ms step_avg:87.50ms +step:856/1680 train_time:74900ms step_avg:87.50ms +step:857/1680 train_time:74989ms step_avg:87.50ms +step:858/1680 train_time:75077ms step_avg:87.50ms +step:859/1680 train_time:75164ms step_avg:87.50ms +step:860/1680 train_time:75253ms step_avg:87.50ms +step:861/1680 train_time:75341ms step_avg:87.50ms +step:862/1680 train_time:75429ms step_avg:87.51ms +step:863/1680 train_time:75518ms step_avg:87.51ms +step:864/1680 train_time:75606ms step_avg:87.51ms +step:865/1680 train_time:75694ms step_avg:87.51ms +step:866/1680 train_time:75783ms step_avg:87.51ms +step:867/1680 train_time:75870ms step_avg:87.51ms +step:868/1680 train_time:75959ms step_avg:87.51ms +step:869/1680 train_time:76047ms step_avg:87.51ms +step:870/1680 train_time:76135ms step_avg:87.51ms +step:871/1680 train_time:76223ms step_avg:87.51ms +step:872/1680 train_time:76311ms step_avg:87.51ms +step:873/1680 train_time:76399ms step_avg:87.51ms +step:874/1680 train_time:76487ms step_avg:87.51ms +step:875/1680 train_time:76575ms step_avg:87.51ms +step:875/1680 val_loss:3.5211 train_time:76665ms step_avg:87.62ms +step:876/1680 train_time:76684ms step_avg:87.54ms +step:877/1680 train_time:76756ms step_avg:87.52ms +step:878/1680 train_time:76850ms step_avg:87.53ms +step:879/1680 train_time:76939ms step_avg:87.53ms +step:880/1680 train_time:77026ms step_avg:87.53ms +step:881/1680 train_time:77113ms step_avg:87.53ms +step:882/1680 train_time:77200ms step_avg:87.53ms +step:883/1680 train_time:77287ms step_avg:87.53ms +step:884/1680 train_time:77374ms step_avg:87.53ms +step:885/1680 train_time:77461ms step_avg:87.53ms +step:886/1680 train_time:77549ms step_avg:87.53ms +step:887/1680 train_time:77637ms step_avg:87.53ms +step:888/1680 train_time:77726ms step_avg:87.53ms +step:889/1680 train_time:77817ms step_avg:87.53ms +step:890/1680 train_time:77906ms step_avg:87.53ms +step:891/1680 train_time:77994ms step_avg:87.54ms +step:892/1680 train_time:78082ms step_avg:87.54ms +step:893/1680 train_time:78170ms step_avg:87.54ms +step:894/1680 train_time:78258ms step_avg:87.54ms +step:895/1680 train_time:78345ms step_avg:87.54ms +step:896/1680 train_time:78433ms step_avg:87.54ms +step:897/1680 train_time:78520ms step_avg:87.54ms +step:898/1680 train_time:78608ms step_avg:87.54ms +step:899/1680 train_time:78697ms step_avg:87.54ms +step:900/1680 train_time:78786ms step_avg:87.54ms +step:901/1680 train_time:78875ms step_avg:87.54ms +step:902/1680 train_time:78964ms step_avg:87.54ms +step:903/1680 train_time:79052ms step_avg:87.54ms +step:904/1680 train_time:79140ms step_avg:87.54ms +step:905/1680 train_time:79228ms step_avg:87.54ms +step:906/1680 train_time:79315ms step_avg:87.54ms +step:907/1680 train_time:79403ms step_avg:87.54ms +step:908/1680 train_time:79490ms step_avg:87.54ms +step:909/1680 train_time:79579ms step_avg:87.55ms +step:910/1680 train_time:79667ms step_avg:87.55ms +step:911/1680 train_time:79755ms step_avg:87.55ms +step:912/1680 train_time:79845ms step_avg:87.55ms +step:913/1680 train_time:79933ms step_avg:87.55ms +step:914/1680 train_time:80022ms step_avg:87.55ms +step:915/1680 train_time:80110ms step_avg:87.55ms +step:916/1680 train_time:80197ms step_avg:87.55ms +step:917/1680 train_time:80285ms step_avg:87.55ms +step:918/1680 train_time:80373ms step_avg:87.55ms +step:919/1680 train_time:80460ms step_avg:87.55ms +step:920/1680 train_time:80548ms step_avg:87.55ms +step:921/1680 train_time:80636ms step_avg:87.55ms +step:922/1680 train_time:80724ms step_avg:87.55ms +step:923/1680 train_time:80814ms step_avg:87.56ms +step:924/1680 train_time:80903ms step_avg:87.56ms +step:925/1680 train_time:80991ms step_avg:87.56ms +step:926/1680 train_time:81080ms step_avg:87.56ms +step:927/1680 train_time:81168ms step_avg:87.56ms +step:928/1680 train_time:81256ms step_avg:87.56ms +step:929/1680 train_time:81344ms step_avg:87.56ms +step:930/1680 train_time:81432ms step_avg:87.56ms +step:931/1680 train_time:81519ms step_avg:87.56ms +step:932/1680 train_time:81607ms step_avg:87.56ms +step:933/1680 train_time:81696ms step_avg:87.56ms +step:934/1680 train_time:81784ms step_avg:87.56ms +step:935/1680 train_time:81873ms step_avg:87.56ms +step:936/1680 train_time:81961ms step_avg:87.57ms +step:937/1680 train_time:82049ms step_avg:87.57ms +step:938/1680 train_time:82138ms step_avg:87.57ms +step:939/1680 train_time:82225ms step_avg:87.57ms +step:940/1680 train_time:82313ms step_avg:87.57ms +step:941/1680 train_time:82402ms step_avg:87.57ms +step:942/1680 train_time:82489ms step_avg:87.57ms +step:943/1680 train_time:82577ms step_avg:87.57ms +step:944/1680 train_time:82665ms step_avg:87.57ms +step:945/1680 train_time:82754ms step_avg:87.57ms +step:946/1680 train_time:82842ms step_avg:87.57ms +step:947/1680 train_time:82931ms step_avg:87.57ms +step:948/1680 train_time:83019ms step_avg:87.57ms +step:949/1680 train_time:83107ms step_avg:87.57ms +step:950/1680 train_time:83195ms step_avg:87.57ms +step:951/1680 train_time:83284ms step_avg:87.57ms +step:952/1680 train_time:83372ms step_avg:87.58ms +step:953/1680 train_time:83460ms step_avg:87.58ms +step:954/1680 train_time:83548ms step_avg:87.58ms +step:955/1680 train_time:83636ms step_avg:87.58ms +step:956/1680 train_time:83724ms step_avg:87.58ms +step:957/1680 train_time:83813ms step_avg:87.58ms +step:958/1680 train_time:83901ms step_avg:87.58ms +step:959/1680 train_time:83989ms step_avg:87.58ms +step:960/1680 train_time:84077ms step_avg:87.58ms +step:961/1680 train_time:84164ms step_avg:87.58ms +step:962/1680 train_time:84253ms step_avg:87.58ms +step:963/1680 train_time:84341ms step_avg:87.58ms +step:964/1680 train_time:84429ms step_avg:87.58ms +step:965/1680 train_time:84516ms step_avg:87.58ms +step:966/1680 train_time:84604ms step_avg:87.58ms +step:967/1680 train_time:84692ms step_avg:87.58ms +step:968/1680 train_time:84780ms step_avg:87.58ms +step:969/1680 train_time:84869ms step_avg:87.58ms +step:970/1680 train_time:84957ms step_avg:87.58ms +step:971/1680 train_time:85045ms step_avg:87.59ms +step:972/1680 train_time:85133ms step_avg:87.59ms +step:973/1680 train_time:85220ms step_avg:87.59ms +step:974/1680 train_time:85308ms step_avg:87.59ms +step:975/1680 train_time:85397ms step_avg:87.59ms +step:976/1680 train_time:85484ms step_avg:87.59ms +step:977/1680 train_time:85573ms step_avg:87.59ms +step:978/1680 train_time:85661ms step_avg:87.59ms +step:979/1680 train_time:85749ms step_avg:87.59ms +step:980/1680 train_time:85837ms step_avg:87.59ms +step:981/1680 train_time:85925ms step_avg:87.59ms +step:982/1680 train_time:86013ms step_avg:87.59ms +step:983/1680 train_time:86102ms step_avg:87.59ms +step:984/1680 train_time:86189ms step_avg:87.59ms +step:985/1680 train_time:86277ms step_avg:87.59ms +step:986/1680 train_time:86365ms step_avg:87.59ms +step:987/1680 train_time:86453ms step_avg:87.59ms +step:988/1680 train_time:86541ms step_avg:87.59ms +step:989/1680 train_time:86629ms step_avg:87.59ms +step:990/1680 train_time:86718ms step_avg:87.59ms +step:991/1680 train_time:86805ms step_avg:87.59ms +step:992/1680 train_time:86893ms step_avg:87.59ms +step:993/1680 train_time:86981ms step_avg:87.59ms +step:994/1680 train_time:87069ms step_avg:87.59ms +step:995/1680 train_time:87157ms step_avg:87.59ms +step:996/1680 train_time:87245ms step_avg:87.60ms +step:997/1680 train_time:87333ms step_avg:87.60ms +step:998/1680 train_time:87422ms step_avg:87.60ms +step:999/1680 train_time:87511ms step_avg:87.60ms +step:1000/1680 train_time:87599ms step_avg:87.60ms +step:1000/1680 val_loss:3.4705 train_time:87689ms step_avg:87.69ms +step:1001/1680 train_time:87710ms step_avg:87.62ms +step:1002/1680 train_time:87779ms step_avg:87.60ms +step:1003/1680 train_time:87870ms step_avg:87.61ms +step:1004/1680 train_time:87959ms step_avg:87.61ms +step:1005/1680 train_time:88046ms step_avg:87.61ms +step:1006/1680 train_time:88133ms step_avg:87.61ms +step:1007/1680 train_time:88220ms step_avg:87.61ms +step:1008/1680 train_time:88307ms step_avg:87.61ms +step:1009/1680 train_time:88395ms step_avg:87.61ms +step:1010/1680 train_time:88484ms step_avg:87.61ms +step:1011/1680 train_time:88571ms step_avg:87.61ms +step:1012/1680 train_time:88659ms step_avg:87.61ms +step:1013/1680 train_time:88749ms step_avg:87.61ms +step:1014/1680 train_time:88838ms step_avg:87.61ms +step:1015/1680 train_time:88927ms step_avg:87.61ms +step:1016/1680 train_time:89014ms step_avg:87.61ms +step:1017/1680 train_time:89102ms step_avg:87.61ms +step:1018/1680 train_time:89190ms step_avg:87.61ms +step:1019/1680 train_time:89277ms step_avg:87.61ms +step:1020/1680 train_time:89365ms step_avg:87.61ms +step:1021/1680 train_time:89452ms step_avg:87.61ms +step:1022/1680 train_time:89540ms step_avg:87.61ms +step:1023/1680 train_time:89628ms step_avg:87.61ms +step:1024/1680 train_time:89717ms step_avg:87.61ms +step:1025/1680 train_time:89806ms step_avg:87.62ms +step:1026/1680 train_time:89895ms step_avg:87.62ms +step:1027/1680 train_time:89983ms step_avg:87.62ms +step:1028/1680 train_time:90071ms step_avg:87.62ms +step:1029/1680 train_time:90159ms step_avg:87.62ms +step:1030/1680 train_time:90247ms step_avg:87.62ms +step:1031/1680 train_time:90335ms step_avg:87.62ms +step:1032/1680 train_time:90424ms step_avg:87.62ms +step:1033/1680 train_time:90512ms step_avg:87.62ms +step:1034/1680 train_time:90599ms step_avg:87.62ms +step:1035/1680 train_time:90687ms step_avg:87.62ms +step:1036/1680 train_time:90776ms step_avg:87.62ms +step:1037/1680 train_time:90866ms step_avg:87.62ms +step:1038/1680 train_time:90954ms step_avg:87.62ms +step:1039/1680 train_time:91043ms step_avg:87.63ms +step:1040/1680 train_time:91130ms step_avg:87.63ms +step:1041/1680 train_time:91219ms step_avg:87.63ms +step:1042/1680 train_time:91306ms step_avg:87.63ms +step:1043/1680 train_time:91394ms step_avg:87.63ms +step:1044/1680 train_time:91482ms step_avg:87.63ms +step:1045/1680 train_time:91570ms step_avg:87.63ms +step:1046/1680 train_time:91658ms step_avg:87.63ms +step:1047/1680 train_time:91746ms step_avg:87.63ms +step:1048/1680 train_time:91835ms step_avg:87.63ms +step:1049/1680 train_time:91925ms step_avg:87.63ms +step:1050/1680 train_time:92013ms step_avg:87.63ms +step:1051/1680 train_time:92101ms step_avg:87.63ms +step:1052/1680 train_time:92189ms step_avg:87.63ms +step:1053/1680 train_time:92276ms step_avg:87.63ms +step:1054/1680 train_time:92364ms step_avg:87.63ms +step:1055/1680 train_time:92452ms step_avg:87.63ms +step:1056/1680 train_time:92540ms step_avg:87.63ms +step:1057/1680 train_time:92628ms step_avg:87.63ms +step:1058/1680 train_time:92716ms step_avg:87.63ms +step:1059/1680 train_time:92805ms step_avg:87.63ms +step:1060/1680 train_time:92893ms step_avg:87.63ms +step:1061/1680 train_time:92981ms step_avg:87.64ms +step:1062/1680 train_time:93069ms step_avg:87.64ms +step:1063/1680 train_time:93158ms step_avg:87.64ms +step:1064/1680 train_time:93246ms step_avg:87.64ms +step:1065/1680 train_time:93334ms step_avg:87.64ms +step:1066/1680 train_time:93422ms step_avg:87.64ms +step:1067/1680 train_time:93509ms step_avg:87.64ms +step:1068/1680 train_time:93597ms step_avg:87.64ms +step:1069/1680 train_time:93685ms step_avg:87.64ms +step:1070/1680 train_time:93774ms step_avg:87.64ms +step:1071/1680 train_time:93863ms step_avg:87.64ms +step:1072/1680 train_time:93952ms step_avg:87.64ms +step:1073/1680 train_time:94040ms step_avg:87.64ms +step:1074/1680 train_time:94128ms step_avg:87.64ms +step:1075/1680 train_time:94216ms step_avg:87.64ms +step:1076/1680 train_time:94304ms step_avg:87.64ms +step:1077/1680 train_time:94391ms step_avg:87.64ms +step:1078/1680 train_time:94479ms step_avg:87.64ms +step:1079/1680 train_time:94567ms step_avg:87.64ms +step:1080/1680 train_time:94656ms step_avg:87.64ms +step:1081/1680 train_time:94743ms step_avg:87.64ms +step:1082/1680 train_time:94832ms step_avg:87.65ms +step:1083/1680 train_time:94920ms step_avg:87.65ms +step:1084/1680 train_time:95008ms step_avg:87.65ms +step:1085/1680 train_time:95097ms step_avg:87.65ms +step:1086/1680 train_time:95185ms step_avg:87.65ms +step:1087/1680 train_time:95273ms step_avg:87.65ms +step:1088/1680 train_time:95361ms step_avg:87.65ms +step:1089/1680 train_time:95449ms step_avg:87.65ms +step:1090/1680 train_time:95537ms step_avg:87.65ms +step:1091/1680 train_time:95625ms step_avg:87.65ms +step:1092/1680 train_time:95713ms step_avg:87.65ms +step:1093/1680 train_time:95800ms step_avg:87.65ms +step:1094/1680 train_time:95889ms step_avg:87.65ms +step:1095/1680 train_time:95978ms step_avg:87.65ms +step:1096/1680 train_time:96066ms step_avg:87.65ms +step:1097/1680 train_time:96155ms step_avg:87.65ms +step:1098/1680 train_time:96244ms step_avg:87.65ms +step:1099/1680 train_time:96333ms step_avg:87.66ms +step:1100/1680 train_time:96423ms step_avg:87.66ms +step:1101/1680 train_time:96511ms step_avg:87.66ms +step:1102/1680 train_time:96599ms step_avg:87.66ms +step:1103/1680 train_time:96688ms step_avg:87.66ms +step:1104/1680 train_time:96776ms step_avg:87.66ms +step:1105/1680 train_time:96865ms step_avg:87.66ms +step:1106/1680 train_time:96954ms step_avg:87.66ms +step:1107/1680 train_time:97044ms step_avg:87.66ms +step:1108/1680 train_time:97133ms step_avg:87.67ms +step:1109/1680 train_time:97222ms step_avg:87.67ms +step:1110/1680 train_time:97311ms step_avg:87.67ms +step:1111/1680 train_time:97400ms step_avg:87.67ms +step:1112/1680 train_time:97488ms step_avg:87.67ms +step:1113/1680 train_time:97577ms step_avg:87.67ms +step:1114/1680 train_time:97665ms step_avg:87.67ms +step:1115/1680 train_time:97754ms step_avg:87.67ms +step:1116/1680 train_time:97843ms step_avg:87.67ms +step:1117/1680 train_time:97932ms step_avg:87.67ms +step:1118/1680 train_time:98021ms step_avg:87.68ms +step:1119/1680 train_time:98109ms step_avg:87.68ms +step:1120/1680 train_time:98198ms step_avg:87.68ms +step:1121/1680 train_time:98287ms step_avg:87.68ms +step:1122/1680 train_time:98376ms step_avg:87.68ms +step:1123/1680 train_time:98465ms step_avg:87.68ms +step:1124/1680 train_time:98553ms step_avg:87.68ms +step:1125/1680 train_time:98642ms step_avg:87.68ms +step:1125/1680 val_loss:3.4174 train_time:98732ms step_avg:87.76ms +step:1126/1680 train_time:98752ms step_avg:87.70ms +step:1127/1680 train_time:98825ms step_avg:87.69ms +step:1128/1680 train_time:98915ms step_avg:87.69ms +step:1129/1680 train_time:99007ms step_avg:87.69ms +step:1130/1680 train_time:99096ms step_avg:87.70ms +step:1131/1680 train_time:99186ms step_avg:87.70ms +step:1132/1680 train_time:99273ms step_avg:87.70ms +step:1133/1680 train_time:99361ms step_avg:87.70ms +step:1134/1680 train_time:99448ms step_avg:87.70ms +step:1135/1680 train_time:99536ms step_avg:87.70ms +step:1136/1680 train_time:99624ms step_avg:87.70ms +step:1137/1680 train_time:99714ms step_avg:87.70ms +step:1138/1680 train_time:99805ms step_avg:87.70ms +step:1139/1680 train_time:99895ms step_avg:87.70ms +step:1140/1680 train_time:99985ms step_avg:87.71ms +step:1141/1680 train_time:100073ms step_avg:87.71ms +step:1142/1680 train_time:100162ms step_avg:87.71ms +step:1143/1680 train_time:100251ms step_avg:87.71ms +step:1144/1680 train_time:100339ms step_avg:87.71ms +step:1145/1680 train_time:100427ms step_avg:87.71ms +step:1146/1680 train_time:100515ms step_avg:87.71ms +step:1147/1680 train_time:100603ms step_avg:87.71ms +step:1148/1680 train_time:100692ms step_avg:87.71ms +step:1149/1680 train_time:100783ms step_avg:87.71ms +step:1150/1680 train_time:100873ms step_avg:87.72ms +step:1151/1680 train_time:100963ms step_avg:87.72ms +step:1152/1680 train_time:101051ms step_avg:87.72ms +step:1153/1680 train_time:101140ms step_avg:87.72ms +step:1154/1680 train_time:101228ms step_avg:87.72ms +step:1155/1680 train_time:101318ms step_avg:87.72ms +step:1156/1680 train_time:101406ms step_avg:87.72ms +step:1157/1680 train_time:101494ms step_avg:87.72ms +step:1158/1680 train_time:101583ms step_avg:87.72ms +step:1159/1680 train_time:101671ms step_avg:87.72ms +step:1160/1680 train_time:101760ms step_avg:87.72ms +step:1161/1680 train_time:101852ms step_avg:87.73ms +step:1162/1680 train_time:101941ms step_avg:87.73ms +step:1163/1680 train_time:102030ms step_avg:87.73ms +step:1164/1680 train_time:102120ms step_avg:87.73ms +step:1165/1680 train_time:102209ms step_avg:87.73ms +step:1166/1680 train_time:102298ms step_avg:87.73ms +step:1167/1680 train_time:102386ms step_avg:87.73ms +step:1168/1680 train_time:102476ms step_avg:87.74ms +step:1169/1680 train_time:102563ms step_avg:87.74ms +step:1170/1680 train_time:102651ms step_avg:87.74ms +step:1171/1680 train_time:102740ms step_avg:87.74ms +step:1172/1680 train_time:102829ms step_avg:87.74ms +step:1173/1680 train_time:102919ms step_avg:87.74ms +step:1174/1680 train_time:103008ms step_avg:87.74ms +step:1175/1680 train_time:103097ms step_avg:87.74ms +step:1176/1680 train_time:103186ms step_avg:87.74ms +step:1177/1680 train_time:103275ms step_avg:87.74ms +step:1178/1680 train_time:103363ms step_avg:87.74ms +step:1179/1680 train_time:103451ms step_avg:87.74ms +step:1180/1680 train_time:103540ms step_avg:87.75ms +step:1181/1680 train_time:103628ms step_avg:87.75ms +step:1182/1680 train_time:103717ms step_avg:87.75ms +step:1183/1680 train_time:103806ms step_avg:87.75ms +step:1184/1680 train_time:103895ms step_avg:87.75ms +step:1185/1680 train_time:103984ms step_avg:87.75ms +step:1186/1680 train_time:104073ms step_avg:87.75ms +step:1187/1680 train_time:104163ms step_avg:87.75ms +step:1188/1680 train_time:104252ms step_avg:87.75ms +step:1189/1680 train_time:104341ms step_avg:87.76ms +step:1190/1680 train_time:104430ms step_avg:87.76ms +step:1191/1680 train_time:104518ms step_avg:87.76ms +step:1192/1680 train_time:104607ms step_avg:87.76ms +step:1193/1680 train_time:104695ms step_avg:87.76ms +step:1194/1680 train_time:104784ms step_avg:87.76ms +step:1195/1680 train_time:104873ms step_avg:87.76ms +step:1196/1680 train_time:104962ms step_avg:87.76ms +step:1197/1680 train_time:105051ms step_avg:87.76ms +step:1198/1680 train_time:105141ms step_avg:87.76ms +step:1199/1680 train_time:105230ms step_avg:87.76ms +step:1200/1680 train_time:105319ms step_avg:87.77ms +step:1201/1680 train_time:105407ms step_avg:87.77ms +step:1202/1680 train_time:105496ms step_avg:87.77ms +step:1203/1680 train_time:105584ms step_avg:87.77ms +step:1204/1680 train_time:105674ms step_avg:87.77ms +step:1205/1680 train_time:105763ms step_avg:87.77ms +step:1206/1680 train_time:105852ms step_avg:87.77ms +step:1207/1680 train_time:105942ms step_avg:87.77ms +step:1208/1680 train_time:106030ms step_avg:87.77ms +step:1209/1680 train_time:106120ms step_avg:87.78ms +step:1210/1680 train_time:106209ms step_avg:87.78ms +step:1211/1680 train_time:106298ms step_avg:87.78ms +step:1212/1680 train_time:106388ms step_avg:87.78ms +step:1213/1680 train_time:106477ms step_avg:87.78ms +step:1214/1680 train_time:106565ms step_avg:87.78ms +step:1215/1680 train_time:106654ms step_avg:87.78ms +step:1216/1680 train_time:106742ms step_avg:87.78ms +step:1217/1680 train_time:106830ms step_avg:87.78ms +step:1218/1680 train_time:106920ms step_avg:87.78ms +step:1219/1680 train_time:107010ms step_avg:87.78ms +step:1220/1680 train_time:107100ms step_avg:87.79ms +step:1221/1680 train_time:107189ms step_avg:87.79ms +step:1222/1680 train_time:107279ms step_avg:87.79ms +step:1223/1680 train_time:107368ms step_avg:87.79ms +step:1224/1680 train_time:107457ms step_avg:87.79ms +step:1225/1680 train_time:107546ms step_avg:87.79ms +step:1226/1680 train_time:107635ms step_avg:87.79ms +step:1227/1680 train_time:107723ms step_avg:87.79ms +step:1228/1680 train_time:107812ms step_avg:87.79ms +step:1229/1680 train_time:107900ms step_avg:87.80ms +step:1230/1680 train_time:107989ms step_avg:87.80ms +step:1231/1680 train_time:108078ms step_avg:87.80ms +step:1232/1680 train_time:108167ms step_avg:87.80ms +step:1233/1680 train_time:108256ms step_avg:87.80ms +step:1234/1680 train_time:108345ms step_avg:87.80ms +step:1235/1680 train_time:108434ms step_avg:87.80ms +step:1236/1680 train_time:108523ms step_avg:87.80ms +step:1237/1680 train_time:108612ms step_avg:87.80ms +step:1238/1680 train_time:108701ms step_avg:87.80ms +step:1239/1680 train_time:108790ms step_avg:87.81ms +step:1240/1680 train_time:108879ms step_avg:87.81ms +step:1241/1680 train_time:108968ms step_avg:87.81ms +step:1242/1680 train_time:109057ms step_avg:87.81ms +step:1243/1680 train_time:109145ms step_avg:87.81ms +step:1244/1680 train_time:109235ms step_avg:87.81ms +step:1245/1680 train_time:109324ms step_avg:87.81ms +step:1246/1680 train_time:109413ms step_avg:87.81ms +step:1247/1680 train_time:109501ms step_avg:87.81ms +step:1248/1680 train_time:109590ms step_avg:87.81ms +step:1249/1680 train_time:109680ms step_avg:87.81ms +step:1250/1680 train_time:109768ms step_avg:87.81ms +step:1250/1680 val_loss:3.3791 train_time:109859ms step_avg:87.89ms +step:1251/1680 train_time:109879ms step_avg:87.83ms +step:1252/1680 train_time:109950ms step_avg:87.82ms +step:1253/1680 train_time:110041ms step_avg:87.82ms +step:1254/1680 train_time:110131ms step_avg:87.82ms +step:1255/1680 train_time:110219ms step_avg:87.82ms +step:1256/1680 train_time:110306ms step_avg:87.82ms +step:1257/1680 train_time:110394ms step_avg:87.82ms +step:1258/1680 train_time:110482ms step_avg:87.82ms +step:1259/1680 train_time:110570ms step_avg:87.82ms +step:1260/1680 train_time:110658ms step_avg:87.82ms +step:1261/1680 train_time:110746ms step_avg:87.82ms +step:1262/1680 train_time:110838ms step_avg:87.83ms +step:1263/1680 train_time:110929ms step_avg:87.83ms +step:1264/1680 train_time:111020ms step_avg:87.83ms +step:1265/1680 train_time:111110ms step_avg:87.83ms +step:1266/1680 train_time:111200ms step_avg:87.84ms +step:1267/1680 train_time:111288ms step_avg:87.84ms +step:1268/1680 train_time:111376ms step_avg:87.84ms +step:1269/1680 train_time:111465ms step_avg:87.84ms +step:1270/1680 train_time:111553ms step_avg:87.84ms +step:1271/1680 train_time:111641ms step_avg:87.84ms +step:1272/1680 train_time:111730ms step_avg:87.84ms +step:1273/1680 train_time:111820ms step_avg:87.84ms +step:1274/1680 train_time:111909ms step_avg:87.84ms +step:1275/1680 train_time:112000ms step_avg:87.84ms +step:1276/1680 train_time:112090ms step_avg:87.84ms +step:1277/1680 train_time:112180ms step_avg:87.85ms +step:1278/1680 train_time:112269ms step_avg:87.85ms +step:1279/1680 train_time:112357ms step_avg:87.85ms +step:1280/1680 train_time:112445ms step_avg:87.85ms +step:1281/1680 train_time:112533ms step_avg:87.85ms +step:1282/1680 train_time:112623ms step_avg:87.85ms +step:1283/1680 train_time:112711ms step_avg:87.85ms +step:1284/1680 train_time:112800ms step_avg:87.85ms +step:1285/1680 train_time:112889ms step_avg:87.85ms +step:1286/1680 train_time:112979ms step_avg:87.85ms +step:1287/1680 train_time:113069ms step_avg:87.85ms +step:1288/1680 train_time:113158ms step_avg:87.86ms +step:1289/1680 train_time:113248ms step_avg:87.86ms +step:1290/1680 train_time:113338ms step_avg:87.86ms +step:1291/1680 train_time:113427ms step_avg:87.86ms +step:1292/1680 train_time:113515ms step_avg:87.86ms +step:1293/1680 train_time:113603ms step_avg:87.86ms +step:1294/1680 train_time:113691ms step_avg:87.86ms +step:1295/1680 train_time:113780ms step_avg:87.86ms +step:1296/1680 train_time:113869ms step_avg:87.86ms +step:1297/1680 train_time:113959ms step_avg:87.86ms +step:1298/1680 train_time:114049ms step_avg:87.87ms +step:1299/1680 train_time:114138ms step_avg:87.87ms +step:1300/1680 train_time:114228ms step_avg:87.87ms +step:1301/1680 train_time:114317ms step_avg:87.87ms +step:1302/1680 train_time:114406ms step_avg:87.87ms +step:1303/1680 train_time:114494ms step_avg:87.87ms +step:1304/1680 train_time:114582ms step_avg:87.87ms +step:1305/1680 train_time:114671ms step_avg:87.87ms +step:1306/1680 train_time:114760ms step_avg:87.87ms +step:1307/1680 train_time:114848ms step_avg:87.87ms +step:1308/1680 train_time:114937ms step_avg:87.87ms +step:1309/1680 train_time:115026ms step_avg:87.87ms +step:1310/1680 train_time:115116ms step_avg:87.87ms +step:1311/1680 train_time:115204ms step_avg:87.87ms +step:1312/1680 train_time:115293ms step_avg:87.88ms +step:1313/1680 train_time:115381ms step_avg:87.88ms +step:1314/1680 train_time:115470ms step_avg:87.88ms +step:1315/1680 train_time:115558ms step_avg:87.88ms +step:1316/1680 train_time:115647ms step_avg:87.88ms +step:1317/1680 train_time:115736ms step_avg:87.88ms +step:1318/1680 train_time:115826ms step_avg:87.88ms +step:1319/1680 train_time:115915ms step_avg:87.88ms +step:1320/1680 train_time:116004ms step_avg:87.88ms +step:1321/1680 train_time:116093ms step_avg:87.88ms +step:1322/1680 train_time:116182ms step_avg:87.88ms +step:1323/1680 train_time:116272ms step_avg:87.88ms +step:1324/1680 train_time:116360ms step_avg:87.89ms +step:1325/1680 train_time:116448ms step_avg:87.89ms +step:1326/1680 train_time:116536ms step_avg:87.89ms +step:1327/1680 train_time:116625ms step_avg:87.89ms +step:1328/1680 train_time:116714ms step_avg:87.89ms +step:1329/1680 train_time:116802ms step_avg:87.89ms +step:1330/1680 train_time:116891ms step_avg:87.89ms +step:1331/1680 train_time:116980ms step_avg:87.89ms +step:1332/1680 train_time:117070ms step_avg:87.89ms +step:1333/1680 train_time:117158ms step_avg:87.89ms +step:1334/1680 train_time:117248ms step_avg:87.89ms +step:1335/1680 train_time:117337ms step_avg:87.89ms +step:1336/1680 train_time:117426ms step_avg:87.89ms +step:1337/1680 train_time:117515ms step_avg:87.89ms +step:1338/1680 train_time:117603ms step_avg:87.89ms +step:1339/1680 train_time:117693ms step_avg:87.90ms +step:1340/1680 train_time:117782ms step_avg:87.90ms +step:1341/1680 train_time:117870ms step_avg:87.90ms +step:1342/1680 train_time:117959ms step_avg:87.90ms +step:1343/1680 train_time:118048ms step_avg:87.90ms +step:1344/1680 train_time:118137ms step_avg:87.90ms +step:1345/1680 train_time:118227ms step_avg:87.90ms +step:1346/1680 train_time:118315ms step_avg:87.90ms +step:1347/1680 train_time:118404ms step_avg:87.90ms +step:1348/1680 train_time:118493ms step_avg:87.90ms +step:1349/1680 train_time:118582ms step_avg:87.90ms +step:1350/1680 train_time:118671ms step_avg:87.90ms +step:1351/1680 train_time:118761ms step_avg:87.91ms +step:1352/1680 train_time:118850ms step_avg:87.91ms +step:1353/1680 train_time:118940ms step_avg:87.91ms +step:1354/1680 train_time:119028ms step_avg:87.91ms +step:1355/1680 train_time:119118ms step_avg:87.91ms +step:1356/1680 train_time:119206ms step_avg:87.91ms +step:1357/1680 train_time:119295ms step_avg:87.91ms +step:1358/1680 train_time:119384ms step_avg:87.91ms +step:1359/1680 train_time:119472ms step_avg:87.91ms +step:1360/1680 train_time:119561ms step_avg:87.91ms +step:1361/1680 train_time:119649ms step_avg:87.91ms +step:1362/1680 train_time:119738ms step_avg:87.91ms +step:1363/1680 train_time:119827ms step_avg:87.91ms +step:1364/1680 train_time:119916ms step_avg:87.91ms +step:1365/1680 train_time:120005ms step_avg:87.92ms +step:1366/1680 train_time:120093ms step_avg:87.92ms +step:1367/1680 train_time:120182ms step_avg:87.92ms +step:1368/1680 train_time:120271ms step_avg:87.92ms +step:1369/1680 train_time:120360ms step_avg:87.92ms +step:1370/1680 train_time:120449ms step_avg:87.92ms +step:1371/1680 train_time:120538ms step_avg:87.92ms +step:1372/1680 train_time:120627ms step_avg:87.92ms +step:1373/1680 train_time:120716ms step_avg:87.92ms +step:1374/1680 train_time:120805ms step_avg:87.92ms +step:1375/1680 train_time:120894ms step_avg:87.92ms +step:1375/1680 val_loss:3.3440 train_time:120985ms step_avg:87.99ms +step:1376/1680 train_time:121006ms step_avg:87.94ms +step:1377/1680 train_time:121077ms step_avg:87.93ms +step:1378/1680 train_time:121167ms step_avg:87.93ms +step:1379/1680 train_time:121256ms step_avg:87.93ms +step:1380/1680 train_time:121344ms step_avg:87.93ms +step:1381/1680 train_time:121432ms step_avg:87.93ms +step:1382/1680 train_time:121519ms step_avg:87.93ms +step:1383/1680 train_time:121608ms step_avg:87.93ms +step:1384/1680 train_time:121696ms step_avg:87.93ms +step:1385/1680 train_time:121785ms step_avg:87.93ms +step:1386/1680 train_time:121874ms step_avg:87.93ms +step:1387/1680 train_time:121965ms step_avg:87.93ms +step:1388/1680 train_time:122056ms step_avg:87.94ms +step:1389/1680 train_time:122147ms step_avg:87.94ms +step:1390/1680 train_time:122238ms step_avg:87.94ms +step:1391/1680 train_time:122326ms step_avg:87.94ms +step:1392/1680 train_time:122414ms step_avg:87.94ms +step:1393/1680 train_time:122503ms step_avg:87.94ms +step:1394/1680 train_time:122591ms step_avg:87.94ms +step:1395/1680 train_time:122679ms step_avg:87.94ms +step:1396/1680 train_time:122767ms step_avg:87.94ms +step:1397/1680 train_time:122855ms step_avg:87.94ms +step:1398/1680 train_time:122944ms step_avg:87.94ms +step:1399/1680 train_time:123034ms step_avg:87.94ms +step:1400/1680 train_time:123125ms step_avg:87.95ms +step:1401/1680 train_time:123216ms step_avg:87.95ms +step:1402/1680 train_time:123305ms step_avg:87.95ms +step:1403/1680 train_time:123394ms step_avg:87.95ms +step:1404/1680 train_time:123483ms step_avg:87.95ms +step:1405/1680 train_time:123571ms step_avg:87.95ms +step:1406/1680 train_time:123659ms step_avg:87.95ms +step:1407/1680 train_time:123747ms step_avg:87.95ms +step:1408/1680 train_time:123835ms step_avg:87.95ms +step:1409/1680 train_time:123925ms step_avg:87.95ms +step:1410/1680 train_time:124015ms step_avg:87.95ms +step:1411/1680 train_time:124105ms step_avg:87.96ms +step:1412/1680 train_time:124195ms step_avg:87.96ms +step:1413/1680 train_time:124284ms step_avg:87.96ms +step:1414/1680 train_time:124374ms step_avg:87.96ms +step:1415/1680 train_time:124463ms step_avg:87.96ms +step:1416/1680 train_time:124551ms step_avg:87.96ms +step:1417/1680 train_time:124639ms step_avg:87.96ms +step:1418/1680 train_time:124728ms step_avg:87.96ms +step:1419/1680 train_time:124816ms step_avg:87.96ms +step:1420/1680 train_time:124905ms step_avg:87.96ms +step:1421/1680 train_time:124995ms step_avg:87.96ms +step:1422/1680 train_time:125084ms step_avg:87.96ms +step:1423/1680 train_time:125174ms step_avg:87.96ms +step:1424/1680 train_time:125263ms step_avg:87.97ms +step:1425/1680 train_time:125353ms step_avg:87.97ms +step:1426/1680 train_time:125442ms step_avg:87.97ms +step:1427/1680 train_time:125530ms step_avg:87.97ms +step:1428/1680 train_time:125620ms step_avg:87.97ms +step:1429/1680 train_time:125708ms step_avg:87.97ms +step:1430/1680 train_time:125797ms step_avg:87.97ms +step:1431/1680 train_time:125885ms step_avg:87.97ms +step:1432/1680 train_time:125975ms step_avg:87.97ms +step:1433/1680 train_time:126064ms step_avg:87.97ms +step:1434/1680 train_time:126153ms step_avg:87.97ms +step:1435/1680 train_time:126242ms step_avg:87.97ms +step:1436/1680 train_time:126332ms step_avg:87.97ms +step:1437/1680 train_time:126420ms step_avg:87.98ms +step:1438/1680 train_time:126509ms step_avg:87.98ms +step:1439/1680 train_time:126598ms step_avg:87.98ms +step:1440/1680 train_time:126687ms step_avg:87.98ms +step:1441/1680 train_time:126776ms step_avg:87.98ms +step:1442/1680 train_time:126865ms step_avg:87.98ms +step:1443/1680 train_time:126954ms step_avg:87.98ms +step:1444/1680 train_time:127043ms step_avg:87.98ms +step:1445/1680 train_time:127133ms step_avg:87.98ms +step:1446/1680 train_time:127222ms step_avg:87.98ms +step:1447/1680 train_time:127311ms step_avg:87.98ms +step:1448/1680 train_time:127400ms step_avg:87.98ms +step:1449/1680 train_time:127489ms step_avg:87.98ms +step:1450/1680 train_time:127578ms step_avg:87.99ms +step:1451/1680 train_time:127667ms step_avg:87.99ms +step:1452/1680 train_time:127756ms step_avg:87.99ms +step:1453/1680 train_time:127845ms step_avg:87.99ms +step:1454/1680 train_time:127933ms step_avg:87.99ms +step:1455/1680 train_time:128022ms step_avg:87.99ms +step:1456/1680 train_time:128112ms step_avg:87.99ms +step:1457/1680 train_time:128201ms step_avg:87.99ms +step:1458/1680 train_time:128291ms step_avg:87.99ms +step:1459/1680 train_time:128380ms step_avg:87.99ms +step:1460/1680 train_time:128468ms step_avg:87.99ms +step:1461/1680 train_time:128557ms step_avg:87.99ms +step:1462/1680 train_time:128646ms step_avg:87.99ms +step:1463/1680 train_time:128734ms step_avg:87.99ms +step:1464/1680 train_time:128823ms step_avg:87.99ms +step:1465/1680 train_time:128912ms step_avg:87.99ms +step:1466/1680 train_time:129002ms step_avg:88.00ms +step:1467/1680 train_time:129091ms step_avg:88.00ms +step:1468/1680 train_time:129180ms step_avg:88.00ms +step:1469/1680 train_time:129270ms step_avg:88.00ms +step:1470/1680 train_time:129359ms step_avg:88.00ms +step:1471/1680 train_time:129448ms step_avg:88.00ms +step:1472/1680 train_time:129536ms step_avg:88.00ms +step:1473/1680 train_time:129625ms step_avg:88.00ms +step:1474/1680 train_time:129714ms step_avg:88.00ms +step:1475/1680 train_time:129803ms step_avg:88.00ms +step:1476/1680 train_time:129892ms step_avg:88.00ms +step:1477/1680 train_time:129982ms step_avg:88.00ms +step:1478/1680 train_time:130070ms step_avg:88.00ms +step:1479/1680 train_time:130159ms step_avg:88.00ms +step:1480/1680 train_time:130249ms step_avg:88.01ms +step:1481/1680 train_time:130338ms step_avg:88.01ms +step:1482/1680 train_time:130428ms step_avg:88.01ms +step:1483/1680 train_time:130517ms step_avg:88.01ms +step:1484/1680 train_time:130606ms step_avg:88.01ms +step:1485/1680 train_time:130695ms step_avg:88.01ms +step:1486/1680 train_time:130783ms step_avg:88.01ms +step:1487/1680 train_time:130872ms step_avg:88.01ms +step:1488/1680 train_time:130961ms step_avg:88.01ms +step:1489/1680 train_time:131049ms step_avg:88.01ms +step:1490/1680 train_time:131138ms step_avg:88.01ms +step:1491/1680 train_time:131227ms step_avg:88.01ms +step:1492/1680 train_time:131316ms step_avg:88.01ms +step:1493/1680 train_time:131406ms step_avg:88.01ms +step:1494/1680 train_time:131495ms step_avg:88.02ms +step:1495/1680 train_time:131584ms step_avg:88.02ms +step:1496/1680 train_time:131673ms step_avg:88.02ms +step:1497/1680 train_time:131761ms step_avg:88.02ms +step:1498/1680 train_time:131849ms step_avg:88.02ms +step:1499/1680 train_time:131939ms step_avg:88.02ms +step:1500/1680 train_time:132028ms step_avg:88.02ms +step:1500/1680 val_loss:3.3142 train_time:132118ms step_avg:88.08ms +step:1501/1680 train_time:132138ms step_avg:88.03ms +step:1502/1680 train_time:132211ms step_avg:88.02ms +step:1503/1680 train_time:132301ms step_avg:88.02ms +step:1504/1680 train_time:132390ms step_avg:88.03ms +step:1505/1680 train_time:132478ms step_avg:88.03ms +step:1506/1680 train_time:132567ms step_avg:88.03ms +step:1507/1680 train_time:132655ms step_avg:88.03ms +step:1508/1680 train_time:132744ms step_avg:88.03ms +step:1509/1680 train_time:132832ms step_avg:88.03ms +step:1510/1680 train_time:132921ms step_avg:88.03ms +step:1511/1680 train_time:133009ms step_avg:88.03ms +step:1512/1680 train_time:133099ms step_avg:88.03ms +step:1513/1680 train_time:133189ms step_avg:88.03ms +step:1514/1680 train_time:133279ms step_avg:88.03ms +step:1515/1680 train_time:133369ms step_avg:88.03ms +step:1516/1680 train_time:133458ms step_avg:88.03ms +step:1517/1680 train_time:133547ms step_avg:88.03ms +step:1518/1680 train_time:133635ms step_avg:88.03ms +step:1519/1680 train_time:133723ms step_avg:88.03ms +step:1520/1680 train_time:133812ms step_avg:88.03ms +step:1521/1680 train_time:133900ms step_avg:88.03ms +step:1522/1680 train_time:133989ms step_avg:88.03ms +step:1523/1680 train_time:134078ms step_avg:88.04ms +step:1524/1680 train_time:134167ms step_avg:88.04ms +step:1525/1680 train_time:134257ms step_avg:88.04ms +step:1526/1680 train_time:134346ms step_avg:88.04ms +step:1527/1680 train_time:134435ms step_avg:88.04ms +step:1528/1680 train_time:134524ms step_avg:88.04ms +step:1529/1680 train_time:134613ms step_avg:88.04ms +step:1530/1680 train_time:134701ms step_avg:88.04ms +step:1531/1680 train_time:134790ms step_avg:88.04ms +step:1532/1680 train_time:134878ms step_avg:88.04ms +step:1533/1680 train_time:134966ms step_avg:88.04ms +step:1534/1680 train_time:135056ms step_avg:88.04ms +step:1535/1680 train_time:135145ms step_avg:88.04ms +step:1536/1680 train_time:135235ms step_avg:88.04ms +step:1537/1680 train_time:135324ms step_avg:88.04ms +step:1538/1680 train_time:135414ms step_avg:88.05ms +step:1539/1680 train_time:135502ms step_avg:88.05ms +step:1540/1680 train_time:135591ms step_avg:88.05ms +step:1541/1680 train_time:135680ms step_avg:88.05ms +step:1542/1680 train_time:135769ms step_avg:88.05ms +step:1543/1680 train_time:135857ms step_avg:88.05ms +step:1544/1680 train_time:135946ms step_avg:88.05ms +step:1545/1680 train_time:136035ms step_avg:88.05ms +step:1546/1680 train_time:136124ms step_avg:88.05ms +step:1547/1680 train_time:136214ms step_avg:88.05ms +step:1548/1680 train_time:136303ms step_avg:88.05ms +step:1549/1680 train_time:136391ms step_avg:88.05ms +step:1550/1680 train_time:136480ms step_avg:88.05ms +step:1551/1680 train_time:136569ms step_avg:88.05ms +step:1552/1680 train_time:136658ms step_avg:88.05ms +step:1553/1680 train_time:136746ms step_avg:88.05ms +step:1554/1680 train_time:136835ms step_avg:88.05ms +step:1555/1680 train_time:136924ms step_avg:88.05ms +step:1556/1680 train_time:137013ms step_avg:88.05ms +step:1557/1680 train_time:137102ms step_avg:88.06ms +step:1558/1680 train_time:137190ms step_avg:88.06ms +step:1559/1680 train_time:137279ms step_avg:88.06ms +step:1560/1680 train_time:137367ms step_avg:88.06ms +step:1561/1680 train_time:137457ms step_avg:88.06ms +step:1562/1680 train_time:137546ms step_avg:88.06ms +step:1563/1680 train_time:137635ms step_avg:88.06ms +step:1564/1680 train_time:137723ms step_avg:88.06ms +step:1565/1680 train_time:137811ms step_avg:88.06ms +step:1566/1680 train_time:137900ms step_avg:88.06ms +step:1567/1680 train_time:137990ms step_avg:88.06ms +step:1568/1680 train_time:138078ms step_avg:88.06ms +step:1569/1680 train_time:138167ms step_avg:88.06ms +step:1570/1680 train_time:138257ms step_avg:88.06ms +step:1571/1680 train_time:138345ms step_avg:88.06ms +step:1572/1680 train_time:138435ms step_avg:88.06ms +step:1573/1680 train_time:138524ms step_avg:88.06ms +step:1574/1680 train_time:138614ms step_avg:88.06ms +step:1575/1680 train_time:138703ms step_avg:88.07ms +step:1576/1680 train_time:138791ms step_avg:88.07ms +step:1577/1680 train_time:138881ms step_avg:88.07ms +step:1578/1680 train_time:138970ms step_avg:88.07ms +step:1579/1680 train_time:139058ms step_avg:88.07ms +step:1580/1680 train_time:139147ms step_avg:88.07ms +step:1581/1680 train_time:139236ms step_avg:88.07ms +step:1582/1680 train_time:139326ms step_avg:88.07ms +step:1583/1680 train_time:139415ms step_avg:88.07ms +step:1584/1680 train_time:139504ms step_avg:88.07ms +step:1585/1680 train_time:139594ms step_avg:88.07ms +step:1586/1680 train_time:139682ms step_avg:88.07ms +step:1587/1680 train_time:139771ms step_avg:88.07ms +step:1588/1680 train_time:139861ms step_avg:88.07ms +step:1589/1680 train_time:139950ms step_avg:88.07ms +step:1590/1680 train_time:140038ms step_avg:88.07ms +step:1591/1680 train_time:140127ms step_avg:88.07ms +step:1592/1680 train_time:140216ms step_avg:88.08ms +step:1593/1680 train_time:140305ms step_avg:88.08ms +step:1594/1680 train_time:140393ms step_avg:88.08ms +step:1595/1680 train_time:140482ms step_avg:88.08ms +step:1596/1680 train_time:140571ms step_avg:88.08ms +step:1597/1680 train_time:140660ms step_avg:88.08ms +step:1598/1680 train_time:140749ms step_avg:88.08ms +step:1599/1680 train_time:140838ms step_avg:88.08ms +step:1600/1680 train_time:140927ms step_avg:88.08ms +step:1601/1680 train_time:141017ms step_avg:88.08ms +step:1602/1680 train_time:141105ms step_avg:88.08ms +step:1603/1680 train_time:141194ms step_avg:88.08ms +step:1604/1680 train_time:141283ms step_avg:88.08ms +step:1605/1680 train_time:141371ms step_avg:88.08ms +step:1606/1680 train_time:141461ms step_avg:88.08ms +step:1607/1680 train_time:141550ms step_avg:88.08ms +step:1608/1680 train_time:141639ms step_avg:88.08ms +step:1609/1680 train_time:141727ms step_avg:88.08ms +step:1610/1680 train_time:141816ms step_avg:88.08ms +step:1611/1680 train_time:141906ms step_avg:88.09ms +step:1612/1680 train_time:141996ms step_avg:88.09ms +step:1613/1680 train_time:142086ms step_avg:88.09ms +step:1614/1680 train_time:142175ms step_avg:88.09ms +step:1615/1680 train_time:142263ms step_avg:88.09ms +step:1616/1680 train_time:142353ms step_avg:88.09ms +step:1617/1680 train_time:142443ms step_avg:88.09ms +step:1618/1680 train_time:142532ms step_avg:88.09ms +step:1619/1680 train_time:142621ms step_avg:88.09ms +step:1620/1680 train_time:142709ms step_avg:88.09ms +step:1621/1680 train_time:142799ms step_avg:88.09ms +step:1622/1680 train_time:142887ms step_avg:88.09ms +step:1623/1680 train_time:142977ms step_avg:88.09ms +step:1624/1680 train_time:143066ms step_avg:88.09ms +step:1625/1680 train_time:143154ms step_avg:88.10ms +step:1625/1680 val_loss:3.2907 train_time:143244ms step_avg:88.15ms +step:1626/1680 train_time:143264ms step_avg:88.11ms +step:1627/1680 train_time:143335ms step_avg:88.10ms +step:1628/1680 train_time:143426ms step_avg:88.10ms +step:1629/1680 train_time:143516ms step_avg:88.10ms +step:1630/1680 train_time:143604ms step_avg:88.10ms +step:1631/1680 train_time:143692ms step_avg:88.10ms +step:1632/1680 train_time:143779ms step_avg:88.10ms +step:1633/1680 train_time:143867ms step_avg:88.10ms +step:1634/1680 train_time:143955ms step_avg:88.10ms +step:1635/1680 train_time:144044ms step_avg:88.10ms +step:1636/1680 train_time:144134ms step_avg:88.10ms +step:1637/1680 train_time:144224ms step_avg:88.10ms +step:1638/1680 train_time:144314ms step_avg:88.10ms +step:1639/1680 train_time:144405ms step_avg:88.11ms +step:1640/1680 train_time:144495ms step_avg:88.11ms +step:1641/1680 train_time:144584ms step_avg:88.11ms +step:1642/1680 train_time:144672ms step_avg:88.11ms +step:1643/1680 train_time:144761ms step_avg:88.11ms +step:1644/1680 train_time:144849ms step_avg:88.11ms +step:1645/1680 train_time:144937ms step_avg:88.11ms +step:1646/1680 train_time:145025ms step_avg:88.11ms +step:1647/1680 train_time:145113ms step_avg:88.11ms +step:1648/1680 train_time:145203ms step_avg:88.11ms +step:1649/1680 train_time:145292ms step_avg:88.11ms +step:1650/1680 train_time:145383ms step_avg:88.11ms +step:1651/1680 train_time:145473ms step_avg:88.11ms +step:1652/1680 train_time:145562ms step_avg:88.11ms +step:1653/1680 train_time:145651ms step_avg:88.11ms +step:1654/1680 train_time:145739ms step_avg:88.11ms +step:1655/1680 train_time:145827ms step_avg:88.11ms +step:1656/1680 train_time:145916ms step_avg:88.11ms +step:1657/1680 train_time:146004ms step_avg:88.11ms +step:1658/1680 train_time:146093ms step_avg:88.11ms +step:1659/1680 train_time:146183ms step_avg:88.12ms +step:1660/1680 train_time:146273ms step_avg:88.12ms +step:1661/1680 train_time:146365ms step_avg:88.12ms +step:1662/1680 train_time:146455ms step_avg:88.12ms +step:1663/1680 train_time:146545ms step_avg:88.12ms +step:1664/1680 train_time:146634ms step_avg:88.12ms +step:1665/1680 train_time:146722ms step_avg:88.12ms +step:1666/1680 train_time:146811ms step_avg:88.12ms +step:1667/1680 train_time:146899ms step_avg:88.12ms +step:1668/1680 train_time:146988ms step_avg:88.12ms +step:1669/1680 train_time:147076ms step_avg:88.12ms +step:1670/1680 train_time:147165ms step_avg:88.12ms +step:1671/1680 train_time:147254ms step_avg:88.12ms +step:1672/1680 train_time:147343ms step_avg:88.12ms +step:1673/1680 train_time:147433ms step_avg:88.12ms +step:1674/1680 train_time:147522ms step_avg:88.13ms +step:1675/1680 train_time:147611ms step_avg:88.13ms +step:1676/1680 train_time:147700ms step_avg:88.13ms +step:1677/1680 train_time:147788ms step_avg:88.13ms +step:1678/1680 train_time:147877ms step_avg:88.13ms +step:1679/1680 train_time:147966ms step_avg:88.13ms +step:1680/1680 train_time:148054ms step_avg:88.13ms +step:1680/1680 val_loss:3.2800 train_time:148144ms step_avg:88.18ms +peak memory allocated: 30760 MiB reserved: 46054 MiB diff --git a/records/092725_BF16CE/559a562e-aaa1-46ef-aa6a-06f46a3b019d.txt b/records/092725_BF16CE/559a562e-aaa1-46ef-aa6a-06f46a3b019d.txt new file mode 100644 index 000000000..0cb6f6514 --- /dev/null +++ b/records/092725_BF16CE/559a562e-aaa1-46ef-aa6a-06f46a3b019d.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:50:35 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 28C P0 123W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 23C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 25C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 164050 C /usr/bin/python 0MiB | +| 0 N/A N/A 164051 C /usr/bin/python 0MiB | +| 0 N/A N/A 164052 C /usr/bin/python 0MiB | +| 0 N/A N/A 164053 C /usr/bin/python 0MiB | +| 0 N/A N/A 164054 C /usr/bin/python 0MiB | +| 0 N/A N/A 164055 C /usr/bin/python 0MiB | +| 0 N/A N/A 164056 C /usr/bin/python 0MiB | +| 0 N/A N/A 164057 C /usr/bin/python 0MiB | +| 1 N/A N/A 164051 C /usr/bin/python 0MiB | +| 2 N/A N/A 164052 C /usr/bin/python 0MiB | +| 3 N/A N/A 164053 C /usr/bin/python 0MiB | +| 4 N/A N/A 164054 C /usr/bin/python 0MiB | +| 5 N/A N/A 164055 C /usr/bin/python 0MiB | +| 6 N/A N/A 164056 C /usr/bin/python 0MiB | +| 7 N/A N/A 164057 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:140ms step_avg:139.82ms +step:2/1680 train_time:161ms step_avg:80.40ms +step:3/1680 train_time:225ms step_avg:74.87ms +step:4/1680 train_time:309ms step_avg:77.37ms +step:5/1680 train_time:395ms step_avg:79.08ms +step:6/1680 train_time:481ms step_avg:80.22ms +step:7/1680 train_time:567ms step_avg:81.06ms +step:8/1680 train_time:654ms step_avg:81.70ms +step:9/1680 train_time:740ms step_avg:82.19ms +step:10/1680 train_time:826ms step_avg:82.65ms +step:11/1680 train_time:913ms step_avg:82.96ms +step:12/1680 train_time:1000ms step_avg:83.30ms +step:13/1680 train_time:1090ms step_avg:83.86ms +step:14/1680 train_time:1180ms step_avg:84.30ms +step:15/1680 train_time:1268ms step_avg:84.55ms +step:16/1680 train_time:1355ms step_avg:84.72ms +step:17/1680 train_time:1443ms step_avg:84.89ms +step:18/1680 train_time:1529ms step_avg:84.97ms +step:19/1680 train_time:1616ms step_avg:85.05ms +step:20/1680 train_time:1703ms step_avg:85.14ms +step:21/1680 train_time:1789ms step_avg:85.19ms +step:22/1680 train_time:1876ms step_avg:85.26ms +step:23/1680 train_time:1963ms step_avg:85.36ms +step:24/1680 train_time:2052ms step_avg:85.50ms +step:25/1680 train_time:2141ms step_avg:85.63ms +step:26/1680 train_time:2230ms step_avg:85.76ms +step:27/1680 train_time:2318ms step_avg:85.84ms +step:28/1680 train_time:2405ms step_avg:85.91ms +step:29/1680 train_time:2493ms step_avg:85.95ms +step:30/1680 train_time:2579ms step_avg:85.97ms +step:31/1680 train_time:2667ms step_avg:86.02ms +step:32/1680 train_time:2753ms step_avg:86.04ms +step:33/1680 train_time:2840ms step_avg:86.06ms +step:34/1680 train_time:2926ms step_avg:86.07ms +step:35/1680 train_time:3014ms step_avg:86.11ms +step:36/1680 train_time:3101ms step_avg:86.14ms +step:37/1680 train_time:3189ms step_avg:86.19ms +step:38/1680 train_time:3277ms step_avg:86.23ms +step:39/1680 train_time:3365ms step_avg:86.28ms +step:40/1680 train_time:3452ms step_avg:86.30ms +step:41/1680 train_time:3539ms step_avg:86.31ms +step:42/1680 train_time:3626ms step_avg:86.32ms +step:43/1680 train_time:3712ms step_avg:86.32ms +step:44/1680 train_time:3800ms step_avg:86.36ms +step:45/1680 train_time:3887ms step_avg:86.38ms +step:46/1680 train_time:3974ms step_avg:86.40ms +step:47/1680 train_time:4062ms step_avg:86.42ms +step:48/1680 train_time:4150ms step_avg:86.45ms +step:49/1680 train_time:4238ms step_avg:86.49ms +step:50/1680 train_time:4326ms step_avg:86.51ms +step:51/1680 train_time:4414ms step_avg:86.54ms +step:52/1680 train_time:4501ms step_avg:86.55ms +step:53/1680 train_time:4588ms step_avg:86.56ms +step:54/1680 train_time:4675ms step_avg:86.58ms +step:55/1680 train_time:4762ms step_avg:86.58ms +step:56/1680 train_time:4849ms step_avg:86.59ms +step:57/1680 train_time:4937ms step_avg:86.61ms +step:58/1680 train_time:5023ms step_avg:86.61ms +step:59/1680 train_time:5111ms step_avg:86.62ms +step:60/1680 train_time:5198ms step_avg:86.64ms +step:61/1680 train_time:5287ms step_avg:86.66ms +step:62/1680 train_time:5375ms step_avg:86.69ms +step:63/1680 train_time:5462ms step_avg:86.70ms +step:64/1680 train_time:5549ms step_avg:86.71ms +step:65/1680 train_time:5637ms step_avg:86.72ms +step:66/1680 train_time:5724ms step_avg:86.72ms +step:67/1680 train_time:5810ms step_avg:86.72ms +step:68/1680 train_time:5897ms step_avg:86.72ms +step:69/1680 train_time:5985ms step_avg:86.73ms +step:70/1680 train_time:6071ms step_avg:86.73ms +step:71/1680 train_time:6158ms step_avg:86.74ms +step:72/1680 train_time:6246ms step_avg:86.75ms +step:73/1680 train_time:6335ms step_avg:86.78ms +step:74/1680 train_time:6421ms step_avg:86.78ms +step:75/1680 train_time:6508ms step_avg:86.78ms +step:76/1680 train_time:6595ms step_avg:86.78ms +step:77/1680 train_time:6682ms step_avg:86.79ms +step:78/1680 train_time:6769ms step_avg:86.78ms +step:79/1680 train_time:6857ms step_avg:86.80ms +step:80/1680 train_time:6945ms step_avg:86.81ms +step:81/1680 train_time:7033ms step_avg:86.82ms +step:82/1680 train_time:7120ms step_avg:86.83ms +step:83/1680 train_time:7207ms step_avg:86.84ms +step:84/1680 train_time:7295ms step_avg:86.84ms +step:85/1680 train_time:7381ms step_avg:86.84ms +step:86/1680 train_time:7469ms step_avg:86.84ms +step:87/1680 train_time:7556ms step_avg:86.85ms +step:88/1680 train_time:7643ms step_avg:86.85ms +step:89/1680 train_time:7730ms step_avg:86.85ms +step:90/1680 train_time:7817ms step_avg:86.86ms +step:91/1680 train_time:7904ms step_avg:86.86ms +step:92/1680 train_time:7991ms step_avg:86.86ms +step:93/1680 train_time:8078ms step_avg:86.86ms +step:94/1680 train_time:8165ms step_avg:86.86ms +step:95/1680 train_time:8252ms step_avg:86.86ms +step:96/1680 train_time:8339ms step_avg:86.86ms +step:97/1680 train_time:8426ms step_avg:86.87ms +step:98/1680 train_time:8514ms step_avg:86.87ms +step:99/1680 train_time:8601ms step_avg:86.87ms +step:100/1680 train_time:8687ms step_avg:86.87ms +step:101/1680 train_time:8775ms step_avg:86.89ms +step:102/1680 train_time:8862ms step_avg:86.88ms +step:103/1680 train_time:8949ms step_avg:86.88ms +step:104/1680 train_time:9037ms step_avg:86.89ms +step:105/1680 train_time:9124ms step_avg:86.89ms +step:106/1680 train_time:9210ms step_avg:86.89ms +step:107/1680 train_time:9297ms step_avg:86.89ms +step:108/1680 train_time:9385ms step_avg:86.90ms +step:109/1680 train_time:9472ms step_avg:86.90ms +step:110/1680 train_time:9559ms step_avg:86.90ms +step:111/1680 train_time:9648ms step_avg:86.92ms +step:112/1680 train_time:9735ms step_avg:86.92ms +step:113/1680 train_time:9822ms step_avg:86.92ms +step:114/1680 train_time:9909ms step_avg:86.92ms +step:115/1680 train_time:9996ms step_avg:86.92ms +step:116/1680 train_time:10083ms step_avg:86.92ms +step:117/1680 train_time:10170ms step_avg:86.92ms +step:118/1680 train_time:10257ms step_avg:86.92ms +step:119/1680 train_time:10344ms step_avg:86.93ms +step:120/1680 train_time:10431ms step_avg:86.93ms +step:121/1680 train_time:10518ms step_avg:86.92ms +step:122/1680 train_time:10606ms step_avg:86.93ms +step:123/1680 train_time:10692ms step_avg:86.93ms +step:124/1680 train_time:10779ms step_avg:86.93ms +step:125/1680 train_time:10866ms step_avg:86.93ms +step:125/1680 val_loss:4.3066 train_time:10955ms step_avg:87.64ms +step:126/1680 train_time:10975ms step_avg:87.10ms +step:127/1680 train_time:11044ms step_avg:86.96ms +step:128/1680 train_time:11139ms step_avg:87.02ms +step:129/1680 train_time:11229ms step_avg:87.04ms +step:130/1680 train_time:11316ms step_avg:87.05ms +step:131/1680 train_time:11403ms step_avg:87.04ms +step:132/1680 train_time:11489ms step_avg:87.04ms +step:133/1680 train_time:11575ms step_avg:87.03ms +step:134/1680 train_time:11661ms step_avg:87.02ms +step:135/1680 train_time:11747ms step_avg:87.01ms +step:136/1680 train_time:11833ms step_avg:87.01ms +step:137/1680 train_time:11919ms step_avg:87.00ms +step:138/1680 train_time:12006ms step_avg:87.00ms +step:139/1680 train_time:12097ms step_avg:87.03ms +step:140/1680 train_time:12187ms step_avg:87.05ms +step:141/1680 train_time:12275ms step_avg:87.06ms +step:142/1680 train_time:12363ms step_avg:87.06ms +step:143/1680 train_time:12449ms step_avg:87.05ms +step:144/1680 train_time:12535ms step_avg:87.05ms +step:145/1680 train_time:12622ms step_avg:87.05ms +step:146/1680 train_time:12708ms step_avg:87.04ms +step:147/1680 train_time:12794ms step_avg:87.03ms +step:148/1680 train_time:12881ms step_avg:87.03ms +step:149/1680 train_time:12968ms step_avg:87.03ms +step:150/1680 train_time:13056ms step_avg:87.04ms +step:151/1680 train_time:13144ms step_avg:87.05ms +step:152/1680 train_time:13234ms step_avg:87.07ms +step:153/1680 train_time:13321ms step_avg:87.07ms +step:154/1680 train_time:13408ms step_avg:87.07ms +step:155/1680 train_time:13496ms step_avg:87.07ms +step:156/1680 train_time:13583ms step_avg:87.07ms +step:157/1680 train_time:13669ms step_avg:87.07ms +step:158/1680 train_time:13756ms step_avg:87.06ms +step:159/1680 train_time:13843ms step_avg:87.06ms +step:160/1680 train_time:13929ms step_avg:87.05ms +step:161/1680 train_time:14016ms step_avg:87.06ms +step:162/1680 train_time:14104ms step_avg:87.06ms +step:163/1680 train_time:14193ms step_avg:87.07ms +step:164/1680 train_time:14280ms step_avg:87.07ms +step:165/1680 train_time:14367ms step_avg:87.07ms +step:166/1680 train_time:14454ms step_avg:87.07ms +step:167/1680 train_time:14541ms step_avg:87.07ms +step:168/1680 train_time:14628ms step_avg:87.07ms +step:169/1680 train_time:14715ms step_avg:87.07ms +step:170/1680 train_time:14802ms step_avg:87.07ms +step:171/1680 train_time:14888ms step_avg:87.06ms +step:172/1680 train_time:14975ms step_avg:87.06ms +step:173/1680 train_time:15062ms step_avg:87.06ms +step:174/1680 train_time:15150ms step_avg:87.07ms +step:175/1680 train_time:15238ms step_avg:87.07ms +step:176/1680 train_time:15325ms step_avg:87.07ms +step:177/1680 train_time:15412ms step_avg:87.08ms +step:178/1680 train_time:15499ms step_avg:87.07ms +step:179/1680 train_time:15586ms step_avg:87.07ms +step:180/1680 train_time:15673ms step_avg:87.07ms +step:181/1680 train_time:15760ms step_avg:87.07ms +step:182/1680 train_time:15847ms step_avg:87.07ms +step:183/1680 train_time:15934ms step_avg:87.07ms +step:184/1680 train_time:16022ms step_avg:87.08ms +step:185/1680 train_time:16109ms step_avg:87.08ms +step:186/1680 train_time:16197ms step_avg:87.08ms +step:187/1680 train_time:16284ms step_avg:87.08ms +step:188/1680 train_time:16372ms step_avg:87.08ms +step:189/1680 train_time:16459ms step_avg:87.09ms +step:190/1680 train_time:16546ms step_avg:87.09ms +step:191/1680 train_time:16634ms step_avg:87.09ms +step:192/1680 train_time:16720ms step_avg:87.08ms +step:193/1680 train_time:16807ms step_avg:87.08ms +step:194/1680 train_time:16894ms step_avg:87.08ms +step:195/1680 train_time:16981ms step_avg:87.08ms +step:196/1680 train_time:17068ms step_avg:87.08ms +step:197/1680 train_time:17156ms step_avg:87.09ms +step:198/1680 train_time:17243ms step_avg:87.09ms +step:199/1680 train_time:17330ms step_avg:87.09ms +step:200/1680 train_time:17417ms step_avg:87.09ms +step:201/1680 train_time:17504ms step_avg:87.09ms +step:202/1680 train_time:17592ms step_avg:87.09ms +step:203/1680 train_time:17679ms step_avg:87.09ms +step:204/1680 train_time:17766ms step_avg:87.09ms +step:205/1680 train_time:17853ms step_avg:87.09ms +step:206/1680 train_time:17941ms step_avg:87.09ms +step:207/1680 train_time:18028ms step_avg:87.09ms +step:208/1680 train_time:18115ms step_avg:87.09ms +step:209/1680 train_time:18202ms step_avg:87.09ms +step:210/1680 train_time:18289ms step_avg:87.09ms +step:211/1680 train_time:18376ms step_avg:87.09ms +step:212/1680 train_time:18463ms step_avg:87.09ms +step:213/1680 train_time:18551ms step_avg:87.09ms +step:214/1680 train_time:18638ms step_avg:87.09ms +step:215/1680 train_time:18724ms step_avg:87.09ms +step:216/1680 train_time:18811ms step_avg:87.09ms +step:217/1680 train_time:18898ms step_avg:87.09ms +step:218/1680 train_time:18985ms step_avg:87.09ms +step:219/1680 train_time:19073ms step_avg:87.09ms +step:220/1680 train_time:19160ms step_avg:87.09ms +step:221/1680 train_time:19247ms step_avg:87.09ms +step:222/1680 train_time:19335ms step_avg:87.09ms +step:223/1680 train_time:19421ms step_avg:87.09ms +step:224/1680 train_time:19509ms step_avg:87.09ms +step:225/1680 train_time:19597ms step_avg:87.10ms +step:226/1680 train_time:19684ms step_avg:87.10ms +step:227/1680 train_time:19771ms step_avg:87.10ms +step:228/1680 train_time:19858ms step_avg:87.10ms +step:229/1680 train_time:19944ms step_avg:87.09ms +step:230/1680 train_time:20032ms step_avg:87.10ms +step:231/1680 train_time:20119ms step_avg:87.10ms +step:232/1680 train_time:20206ms step_avg:87.10ms +step:233/1680 train_time:20294ms step_avg:87.10ms +step:234/1680 train_time:20380ms step_avg:87.09ms +step:235/1680 train_time:20467ms step_avg:87.09ms +step:236/1680 train_time:20555ms step_avg:87.10ms +step:237/1680 train_time:20643ms step_avg:87.10ms +step:238/1680 train_time:20730ms step_avg:87.10ms +step:239/1680 train_time:20817ms step_avg:87.10ms +step:240/1680 train_time:20904ms step_avg:87.10ms +step:241/1680 train_time:20991ms step_avg:87.10ms +step:242/1680 train_time:21078ms step_avg:87.10ms +step:243/1680 train_time:21166ms step_avg:87.10ms +step:244/1680 train_time:21253ms step_avg:87.10ms +step:245/1680 train_time:21340ms step_avg:87.10ms +step:246/1680 train_time:21427ms step_avg:87.10ms +step:247/1680 train_time:21514ms step_avg:87.10ms +step:248/1680 train_time:21601ms step_avg:87.10ms +step:249/1680 train_time:21688ms step_avg:87.10ms +step:250/1680 train_time:21775ms step_avg:87.10ms +step:250/1680 val_loss:3.9735 train_time:21864ms step_avg:87.46ms +step:251/1680 train_time:21883ms step_avg:87.18ms +step:252/1680 train_time:21951ms step_avg:87.11ms +step:253/1680 train_time:22041ms step_avg:87.12ms +step:254/1680 train_time:22128ms step_avg:87.12ms +step:255/1680 train_time:22215ms step_avg:87.12ms +step:256/1680 train_time:22301ms step_avg:87.12ms +step:257/1680 train_time:22387ms step_avg:87.11ms +step:258/1680 train_time:22473ms step_avg:87.11ms +step:259/1680 train_time:22559ms step_avg:87.10ms +step:260/1680 train_time:22646ms step_avg:87.10ms +step:261/1680 train_time:22732ms step_avg:87.10ms +step:262/1680 train_time:22819ms step_avg:87.10ms +step:263/1680 train_time:22908ms step_avg:87.10ms +step:264/1680 train_time:22998ms step_avg:87.11ms +step:265/1680 train_time:23086ms step_avg:87.12ms +step:266/1680 train_time:23174ms step_avg:87.12ms +step:267/1680 train_time:23260ms step_avg:87.12ms +step:268/1680 train_time:23347ms step_avg:87.11ms +step:269/1680 train_time:23433ms step_avg:87.11ms +step:270/1680 train_time:23519ms step_avg:87.11ms +step:271/1680 train_time:23605ms step_avg:87.10ms +step:272/1680 train_time:23692ms step_avg:87.10ms +step:273/1680 train_time:23779ms step_avg:87.10ms +step:274/1680 train_time:23866ms step_avg:87.10ms +step:275/1680 train_time:23955ms step_avg:87.11ms +step:276/1680 train_time:24043ms step_avg:87.11ms +step:277/1680 train_time:24131ms step_avg:87.11ms +step:278/1680 train_time:24218ms step_avg:87.11ms +step:279/1680 train_time:24305ms step_avg:87.11ms +step:280/1680 train_time:24392ms step_avg:87.11ms +step:281/1680 train_time:24478ms step_avg:87.11ms +step:282/1680 train_time:24564ms step_avg:87.11ms +step:283/1680 train_time:24651ms step_avg:87.11ms +step:284/1680 train_time:24738ms step_avg:87.10ms +step:285/1680 train_time:24825ms step_avg:87.11ms +step:286/1680 train_time:24913ms step_avg:87.11ms +step:287/1680 train_time:25000ms step_avg:87.11ms +step:288/1680 train_time:25089ms step_avg:87.11ms +step:289/1680 train_time:25177ms step_avg:87.12ms +step:290/1680 train_time:25264ms step_avg:87.12ms +step:291/1680 train_time:25351ms step_avg:87.12ms +step:292/1680 train_time:25437ms step_avg:87.11ms +step:293/1680 train_time:25524ms step_avg:87.11ms +step:294/1680 train_time:25610ms step_avg:87.11ms +step:295/1680 train_time:25697ms step_avg:87.11ms +step:296/1680 train_time:25784ms step_avg:87.11ms +step:297/1680 train_time:25872ms step_avg:87.11ms +step:298/1680 train_time:25959ms step_avg:87.11ms +step:299/1680 train_time:26046ms step_avg:87.11ms +step:300/1680 train_time:26134ms step_avg:87.11ms +step:301/1680 train_time:26221ms step_avg:87.11ms +step:302/1680 train_time:26308ms step_avg:87.11ms +step:303/1680 train_time:26396ms step_avg:87.11ms +step:304/1680 train_time:26482ms step_avg:87.11ms +step:305/1680 train_time:26569ms step_avg:87.11ms +step:306/1680 train_time:26656ms step_avg:87.11ms +step:307/1680 train_time:26743ms step_avg:87.11ms +step:308/1680 train_time:26830ms step_avg:87.11ms +step:309/1680 train_time:26917ms step_avg:87.11ms +step:310/1680 train_time:27005ms step_avg:87.11ms +step:311/1680 train_time:27092ms step_avg:87.11ms +step:312/1680 train_time:27179ms step_avg:87.11ms +step:313/1680 train_time:27266ms step_avg:87.11ms +step:314/1680 train_time:27355ms step_avg:87.12ms +step:315/1680 train_time:27442ms step_avg:87.12ms +step:316/1680 train_time:27529ms step_avg:87.12ms +step:317/1680 train_time:27617ms step_avg:87.12ms +step:318/1680 train_time:27704ms step_avg:87.12ms +step:319/1680 train_time:27791ms step_avg:87.12ms +step:320/1680 train_time:27878ms step_avg:87.12ms +step:321/1680 train_time:27965ms step_avg:87.12ms +step:322/1680 train_time:28052ms step_avg:87.12ms +step:323/1680 train_time:28139ms step_avg:87.12ms +step:324/1680 train_time:28226ms step_avg:87.12ms +step:325/1680 train_time:28313ms step_avg:87.12ms +step:326/1680 train_time:28399ms step_avg:87.11ms +step:327/1680 train_time:28486ms step_avg:87.11ms +step:328/1680 train_time:28574ms step_avg:87.11ms +step:329/1680 train_time:28660ms step_avg:87.11ms +step:330/1680 train_time:28748ms step_avg:87.11ms +step:331/1680 train_time:28835ms step_avg:87.11ms +step:332/1680 train_time:28922ms step_avg:87.11ms +step:333/1680 train_time:29009ms step_avg:87.11ms +step:334/1680 train_time:29097ms step_avg:87.12ms +step:335/1680 train_time:29183ms step_avg:87.11ms +step:336/1680 train_time:29270ms step_avg:87.11ms +step:337/1680 train_time:29357ms step_avg:87.11ms +step:338/1680 train_time:29445ms step_avg:87.11ms +step:339/1680 train_time:29531ms step_avg:87.11ms +step:340/1680 train_time:29618ms step_avg:87.11ms +step:341/1680 train_time:29706ms step_avg:87.11ms +step:342/1680 train_time:29793ms step_avg:87.11ms +step:343/1680 train_time:29880ms step_avg:87.11ms +step:344/1680 train_time:29968ms step_avg:87.12ms +step:345/1680 train_time:30055ms step_avg:87.12ms +step:346/1680 train_time:30142ms step_avg:87.12ms +step:347/1680 train_time:30230ms step_avg:87.12ms +step:348/1680 train_time:30317ms step_avg:87.12ms +step:349/1680 train_time:30403ms step_avg:87.12ms +step:350/1680 train_time:30491ms step_avg:87.12ms +step:351/1680 train_time:30577ms step_avg:87.12ms +step:352/1680 train_time:30665ms step_avg:87.12ms +step:353/1680 train_time:30752ms step_avg:87.12ms +step:354/1680 train_time:30839ms step_avg:87.11ms +step:355/1680 train_time:30926ms step_avg:87.11ms +step:356/1680 train_time:31012ms step_avg:87.11ms +step:357/1680 train_time:31099ms step_avg:87.11ms +step:358/1680 train_time:31186ms step_avg:87.11ms +step:359/1680 train_time:31274ms step_avg:87.11ms +step:360/1680 train_time:31361ms step_avg:87.11ms +step:361/1680 train_time:31448ms step_avg:87.11ms +step:362/1680 train_time:31535ms step_avg:87.11ms +step:363/1680 train_time:31621ms step_avg:87.11ms +step:364/1680 train_time:31708ms step_avg:87.11ms +step:365/1680 train_time:31795ms step_avg:87.11ms +step:366/1680 train_time:31882ms step_avg:87.11ms +step:367/1680 train_time:31970ms step_avg:87.11ms +step:368/1680 train_time:32057ms step_avg:87.11ms +step:369/1680 train_time:32144ms step_avg:87.11ms +step:370/1680 train_time:32231ms step_avg:87.11ms +step:371/1680 train_time:32318ms step_avg:87.11ms +step:372/1680 train_time:32405ms step_avg:87.11ms +step:373/1680 train_time:32493ms step_avg:87.11ms +step:374/1680 train_time:32580ms step_avg:87.11ms +step:375/1680 train_time:32667ms step_avg:87.11ms +step:375/1680 val_loss:3.8183 train_time:32756ms step_avg:87.35ms +step:376/1680 train_time:32776ms step_avg:87.17ms +step:377/1680 train_time:32846ms step_avg:87.12ms +step:378/1680 train_time:32938ms step_avg:87.14ms +step:379/1680 train_time:33026ms step_avg:87.14ms +step:380/1680 train_time:33114ms step_avg:87.14ms +step:381/1680 train_time:33200ms step_avg:87.14ms +step:382/1680 train_time:33286ms step_avg:87.14ms +step:383/1680 train_time:33372ms step_avg:87.13ms +step:384/1680 train_time:33458ms step_avg:87.13ms +step:385/1680 train_time:33544ms step_avg:87.13ms +step:386/1680 train_time:33629ms step_avg:87.12ms +step:387/1680 train_time:33716ms step_avg:87.12ms +step:388/1680 train_time:33804ms step_avg:87.12ms +step:389/1680 train_time:33892ms step_avg:87.13ms +step:390/1680 train_time:33981ms step_avg:87.13ms +step:391/1680 train_time:34069ms step_avg:87.13ms +step:392/1680 train_time:34156ms step_avg:87.13ms +step:393/1680 train_time:34243ms step_avg:87.13ms +step:394/1680 train_time:34330ms step_avg:87.13ms +step:395/1680 train_time:34416ms step_avg:87.13ms +step:396/1680 train_time:34502ms step_avg:87.13ms +step:397/1680 train_time:34588ms step_avg:87.12ms +step:398/1680 train_time:34675ms step_avg:87.12ms +step:399/1680 train_time:34762ms step_avg:87.12ms +step:400/1680 train_time:34850ms step_avg:87.12ms +step:401/1680 train_time:34938ms step_avg:87.13ms +step:402/1680 train_time:35026ms step_avg:87.13ms +step:403/1680 train_time:35114ms step_avg:87.13ms +step:404/1680 train_time:35201ms step_avg:87.13ms +step:405/1680 train_time:35287ms step_avg:87.13ms +step:406/1680 train_time:35374ms step_avg:87.13ms +step:407/1680 train_time:35461ms step_avg:87.13ms +step:408/1680 train_time:35547ms step_avg:87.13ms +step:409/1680 train_time:35634ms step_avg:87.13ms +step:410/1680 train_time:35721ms step_avg:87.12ms +step:411/1680 train_time:35808ms step_avg:87.12ms +step:412/1680 train_time:35896ms step_avg:87.13ms +step:413/1680 train_time:35984ms step_avg:87.13ms +step:414/1680 train_time:36071ms step_avg:87.13ms +step:415/1680 train_time:36159ms step_avg:87.13ms +step:416/1680 train_time:36246ms step_avg:87.13ms +step:417/1680 train_time:36333ms step_avg:87.13ms +step:418/1680 train_time:36419ms step_avg:87.13ms +step:419/1680 train_time:36506ms step_avg:87.13ms +step:420/1680 train_time:36593ms step_avg:87.13ms +step:421/1680 train_time:36681ms step_avg:87.13ms +step:422/1680 train_time:36768ms step_avg:87.13ms +step:423/1680 train_time:36855ms step_avg:87.13ms +step:424/1680 train_time:36943ms step_avg:87.13ms +step:425/1680 train_time:37030ms step_avg:87.13ms +step:426/1680 train_time:37118ms step_avg:87.13ms +step:427/1680 train_time:37205ms step_avg:87.13ms +step:428/1680 train_time:37292ms step_avg:87.13ms +step:429/1680 train_time:37378ms step_avg:87.13ms +step:430/1680 train_time:37465ms step_avg:87.13ms +step:431/1680 train_time:37551ms step_avg:87.13ms +step:432/1680 train_time:37638ms step_avg:87.12ms +step:433/1680 train_time:37725ms step_avg:87.12ms +step:434/1680 train_time:37812ms step_avg:87.12ms +step:435/1680 train_time:37899ms step_avg:87.12ms +step:436/1680 train_time:37986ms step_avg:87.12ms +step:437/1680 train_time:38073ms step_avg:87.12ms +step:438/1680 train_time:38160ms step_avg:87.12ms +step:439/1680 train_time:38248ms step_avg:87.12ms +step:440/1680 train_time:38335ms step_avg:87.13ms +step:441/1680 train_time:38423ms step_avg:87.13ms +step:442/1680 train_time:38510ms step_avg:87.13ms +step:443/1680 train_time:38596ms step_avg:87.12ms +step:444/1680 train_time:38684ms step_avg:87.13ms +step:445/1680 train_time:38771ms step_avg:87.13ms +step:446/1680 train_time:38858ms step_avg:87.13ms +step:447/1680 train_time:38945ms step_avg:87.13ms +step:448/1680 train_time:39032ms step_avg:87.13ms +step:449/1680 train_time:39119ms step_avg:87.13ms +step:450/1680 train_time:39206ms step_avg:87.13ms +step:451/1680 train_time:39293ms step_avg:87.12ms +step:452/1680 train_time:39380ms step_avg:87.12ms +step:453/1680 train_time:39467ms step_avg:87.12ms +step:454/1680 train_time:39554ms step_avg:87.12ms +step:455/1680 train_time:39641ms step_avg:87.12ms +step:456/1680 train_time:39728ms step_avg:87.12ms +step:457/1680 train_time:39815ms step_avg:87.12ms +step:458/1680 train_time:39902ms step_avg:87.12ms +step:459/1680 train_time:39989ms step_avg:87.12ms +step:460/1680 train_time:40077ms step_avg:87.12ms +step:461/1680 train_time:40164ms step_avg:87.12ms +step:462/1680 train_time:40251ms step_avg:87.12ms +step:463/1680 train_time:40338ms step_avg:87.12ms +step:464/1680 train_time:40425ms step_avg:87.12ms +step:465/1680 train_time:40511ms step_avg:87.12ms +step:466/1680 train_time:40598ms step_avg:87.12ms +step:467/1680 train_time:40684ms step_avg:87.12ms +step:468/1680 train_time:40772ms step_avg:87.12ms +step:469/1680 train_time:40859ms step_avg:87.12ms +step:470/1680 train_time:40946ms step_avg:87.12ms +step:471/1680 train_time:41033ms step_avg:87.12ms +step:472/1680 train_time:41121ms step_avg:87.12ms +step:473/1680 train_time:41208ms step_avg:87.12ms +step:474/1680 train_time:41295ms step_avg:87.12ms +step:475/1680 train_time:41383ms step_avg:87.12ms +step:476/1680 train_time:41469ms step_avg:87.12ms +step:477/1680 train_time:41556ms step_avg:87.12ms +step:478/1680 train_time:41643ms step_avg:87.12ms +step:479/1680 train_time:41729ms step_avg:87.12ms +step:480/1680 train_time:41817ms step_avg:87.12ms +step:481/1680 train_time:41904ms step_avg:87.12ms +step:482/1680 train_time:41990ms step_avg:87.12ms +step:483/1680 train_time:42078ms step_avg:87.12ms +step:484/1680 train_time:42165ms step_avg:87.12ms +step:485/1680 train_time:42252ms step_avg:87.12ms +step:486/1680 train_time:42339ms step_avg:87.12ms +step:487/1680 train_time:42426ms step_avg:87.12ms +step:488/1680 train_time:42512ms step_avg:87.11ms +step:489/1680 train_time:42599ms step_avg:87.11ms +step:490/1680 train_time:42685ms step_avg:87.11ms +step:491/1680 train_time:42773ms step_avg:87.11ms +step:492/1680 train_time:42860ms step_avg:87.11ms +step:493/1680 train_time:42948ms step_avg:87.11ms +step:494/1680 train_time:43035ms step_avg:87.11ms +step:495/1680 train_time:43122ms step_avg:87.12ms +step:496/1680 train_time:43210ms step_avg:87.12ms +step:497/1680 train_time:43297ms step_avg:87.12ms +step:498/1680 train_time:43384ms step_avg:87.12ms +step:499/1680 train_time:43471ms step_avg:87.12ms +step:500/1680 train_time:43558ms step_avg:87.12ms +step:500/1680 val_loss:3.7202 train_time:43646ms step_avg:87.29ms +step:501/1680 train_time:43665ms step_avg:87.16ms +step:502/1680 train_time:43734ms step_avg:87.12ms +step:503/1680 train_time:43825ms step_avg:87.13ms +step:504/1680 train_time:43913ms step_avg:87.13ms +step:505/1680 train_time:44000ms step_avg:87.13ms +step:506/1680 train_time:44088ms step_avg:87.13ms +step:507/1680 train_time:44174ms step_avg:87.13ms +step:508/1680 train_time:44261ms step_avg:87.13ms +step:509/1680 train_time:44347ms step_avg:87.13ms +step:510/1680 train_time:44433ms step_avg:87.12ms +step:511/1680 train_time:44519ms step_avg:87.12ms +step:512/1680 train_time:44606ms step_avg:87.12ms +step:513/1680 train_time:44694ms step_avg:87.12ms +step:514/1680 train_time:44782ms step_avg:87.13ms +step:515/1680 train_time:44871ms step_avg:87.13ms +step:516/1680 train_time:44959ms step_avg:87.13ms +step:517/1680 train_time:45046ms step_avg:87.13ms +step:518/1680 train_time:45132ms step_avg:87.13ms +step:519/1680 train_time:45220ms step_avg:87.13ms +step:520/1680 train_time:45307ms step_avg:87.13ms +step:521/1680 train_time:45393ms step_avg:87.13ms +step:522/1680 train_time:45479ms step_avg:87.12ms +step:523/1680 train_time:45566ms step_avg:87.12ms +step:524/1680 train_time:45652ms step_avg:87.12ms +step:525/1680 train_time:45740ms step_avg:87.12ms +step:526/1680 train_time:45828ms step_avg:87.13ms +step:527/1680 train_time:45916ms step_avg:87.13ms +step:528/1680 train_time:46004ms step_avg:87.13ms +step:529/1680 train_time:46091ms step_avg:87.13ms +step:530/1680 train_time:46179ms step_avg:87.13ms +step:531/1680 train_time:46265ms step_avg:87.13ms +step:532/1680 train_time:46352ms step_avg:87.13ms +step:533/1680 train_time:46438ms step_avg:87.13ms +step:534/1680 train_time:46525ms step_avg:87.13ms +step:535/1680 train_time:46612ms step_avg:87.13ms +step:536/1680 train_time:46699ms step_avg:87.13ms +step:537/1680 train_time:46788ms step_avg:87.13ms +step:538/1680 train_time:46875ms step_avg:87.13ms +step:539/1680 train_time:46962ms step_avg:87.13ms +step:540/1680 train_time:47049ms step_avg:87.13ms +step:541/1680 train_time:47136ms step_avg:87.13ms +step:542/1680 train_time:47223ms step_avg:87.13ms +step:543/1680 train_time:47310ms step_avg:87.13ms +step:544/1680 train_time:47398ms step_avg:87.13ms +step:545/1680 train_time:47485ms step_avg:87.13ms +step:546/1680 train_time:47571ms step_avg:87.13ms +step:547/1680 train_time:47659ms step_avg:87.13ms +step:548/1680 train_time:47746ms step_avg:87.13ms +step:549/1680 train_time:47834ms step_avg:87.13ms +step:550/1680 train_time:47923ms step_avg:87.13ms +step:551/1680 train_time:48011ms step_avg:87.13ms +step:552/1680 train_time:48100ms step_avg:87.14ms +step:553/1680 train_time:48188ms step_avg:87.14ms +step:554/1680 train_time:48276ms step_avg:87.14ms +step:555/1680 train_time:48364ms step_avg:87.14ms +step:556/1680 train_time:48452ms step_avg:87.14ms +step:557/1680 train_time:48539ms step_avg:87.14ms +step:558/1680 train_time:48628ms step_avg:87.15ms +step:559/1680 train_time:48716ms step_avg:87.15ms +step:560/1680 train_time:48804ms step_avg:87.15ms +step:561/1680 train_time:48892ms step_avg:87.15ms +step:562/1680 train_time:48980ms step_avg:87.15ms +step:563/1680 train_time:49069ms step_avg:87.16ms +step:564/1680 train_time:49157ms step_avg:87.16ms +step:565/1680 train_time:49245ms step_avg:87.16ms +step:566/1680 train_time:49333ms step_avg:87.16ms +step:567/1680 train_time:49422ms step_avg:87.16ms +step:568/1680 train_time:49510ms step_avg:87.17ms +step:569/1680 train_time:49598ms step_avg:87.17ms +step:570/1680 train_time:49687ms step_avg:87.17ms +step:571/1680 train_time:49775ms step_avg:87.17ms +step:572/1680 train_time:49863ms step_avg:87.17ms +step:573/1680 train_time:49951ms step_avg:87.17ms +step:574/1680 train_time:50039ms step_avg:87.18ms +step:575/1680 train_time:50127ms step_avg:87.18ms +step:576/1680 train_time:50215ms step_avg:87.18ms +step:577/1680 train_time:50304ms step_avg:87.18ms +step:578/1680 train_time:50392ms step_avg:87.18ms +step:579/1680 train_time:50480ms step_avg:87.18ms +step:580/1680 train_time:50568ms step_avg:87.19ms +step:581/1680 train_time:50656ms step_avg:87.19ms +step:582/1680 train_time:50744ms step_avg:87.19ms +step:583/1680 train_time:50832ms step_avg:87.19ms +step:584/1680 train_time:50921ms step_avg:87.19ms +step:585/1680 train_time:51009ms step_avg:87.19ms +step:586/1680 train_time:51097ms step_avg:87.20ms +step:587/1680 train_time:51185ms step_avg:87.20ms +step:588/1680 train_time:51273ms step_avg:87.20ms +step:589/1680 train_time:51361ms step_avg:87.20ms +step:590/1680 train_time:51450ms step_avg:87.20ms +step:591/1680 train_time:51537ms step_avg:87.20ms +step:592/1680 train_time:51626ms step_avg:87.21ms +step:593/1680 train_time:51713ms step_avg:87.21ms +step:594/1680 train_time:51802ms step_avg:87.21ms +step:595/1680 train_time:51892ms step_avg:87.21ms +step:596/1680 train_time:51980ms step_avg:87.21ms +step:597/1680 train_time:52068ms step_avg:87.22ms +step:598/1680 train_time:52156ms step_avg:87.22ms +step:599/1680 train_time:52243ms step_avg:87.22ms +step:600/1680 train_time:52332ms step_avg:87.22ms +step:601/1680 train_time:52420ms step_avg:87.22ms +step:602/1680 train_time:52508ms step_avg:87.22ms +step:603/1680 train_time:52596ms step_avg:87.22ms +step:604/1680 train_time:52685ms step_avg:87.23ms +step:605/1680 train_time:52773ms step_avg:87.23ms +step:606/1680 train_time:52861ms step_avg:87.23ms +step:607/1680 train_time:52949ms step_avg:87.23ms +step:608/1680 train_time:53037ms step_avg:87.23ms +step:609/1680 train_time:53125ms step_avg:87.23ms +step:610/1680 train_time:53213ms step_avg:87.24ms +step:611/1680 train_time:53302ms step_avg:87.24ms +step:612/1680 train_time:53391ms step_avg:87.24ms +step:613/1680 train_time:53480ms step_avg:87.24ms +step:614/1680 train_time:53567ms step_avg:87.24ms +step:615/1680 train_time:53655ms step_avg:87.24ms +step:616/1680 train_time:53744ms step_avg:87.25ms +step:617/1680 train_time:53832ms step_avg:87.25ms +step:618/1680 train_time:53920ms step_avg:87.25ms +step:619/1680 train_time:54009ms step_avg:87.25ms +step:620/1680 train_time:54097ms step_avg:87.25ms +step:621/1680 train_time:54185ms step_avg:87.25ms +step:622/1680 train_time:54273ms step_avg:87.26ms +step:623/1680 train_time:54362ms step_avg:87.26ms +step:624/1680 train_time:54451ms step_avg:87.26ms +step:625/1680 train_time:54539ms step_avg:87.26ms +step:625/1680 val_loss:3.6182 train_time:54630ms step_avg:87.41ms +step:626/1680 train_time:54650ms step_avg:87.30ms +step:627/1680 train_time:54719ms step_avg:87.27ms +step:628/1680 train_time:54809ms step_avg:87.28ms +step:629/1680 train_time:54901ms step_avg:87.28ms +step:630/1680 train_time:54990ms step_avg:87.29ms +step:631/1680 train_time:55077ms step_avg:87.29ms +step:632/1680 train_time:55165ms step_avg:87.29ms +step:633/1680 train_time:55251ms step_avg:87.29ms +step:634/1680 train_time:55338ms step_avg:87.28ms +step:635/1680 train_time:55425ms step_avg:87.28ms +step:636/1680 train_time:55513ms step_avg:87.29ms +step:637/1680 train_time:55607ms step_avg:87.30ms +step:638/1680 train_time:55697ms step_avg:87.30ms +step:639/1680 train_time:55785ms step_avg:87.30ms +step:640/1680 train_time:55875ms step_avg:87.30ms +step:641/1680 train_time:55963ms step_avg:87.31ms +step:642/1680 train_time:56050ms step_avg:87.31ms +step:643/1680 train_time:56138ms step_avg:87.31ms +step:644/1680 train_time:56225ms step_avg:87.31ms +step:645/1680 train_time:56312ms step_avg:87.31ms +step:646/1680 train_time:56400ms step_avg:87.31ms +step:647/1680 train_time:56488ms step_avg:87.31ms +step:648/1680 train_time:56577ms step_avg:87.31ms +step:649/1680 train_time:56667ms step_avg:87.31ms +step:650/1680 train_time:56755ms step_avg:87.32ms +step:651/1680 train_time:56844ms step_avg:87.32ms +step:652/1680 train_time:56933ms step_avg:87.32ms +step:653/1680 train_time:57021ms step_avg:87.32ms +step:654/1680 train_time:57109ms step_avg:87.32ms +step:655/1680 train_time:57198ms step_avg:87.32ms +step:656/1680 train_time:57285ms step_avg:87.32ms +step:657/1680 train_time:57373ms step_avg:87.33ms +step:658/1680 train_time:57461ms step_avg:87.33ms +step:659/1680 train_time:57549ms step_avg:87.33ms +step:660/1680 train_time:57637ms step_avg:87.33ms +step:661/1680 train_time:57725ms step_avg:87.33ms +step:662/1680 train_time:57814ms step_avg:87.33ms +step:663/1680 train_time:57903ms step_avg:87.34ms +step:664/1680 train_time:57992ms step_avg:87.34ms +step:665/1680 train_time:58080ms step_avg:87.34ms +step:666/1680 train_time:58168ms step_avg:87.34ms +step:667/1680 train_time:58257ms step_avg:87.34ms +step:668/1680 train_time:58344ms step_avg:87.34ms +step:669/1680 train_time:58432ms step_avg:87.34ms +step:670/1680 train_time:58520ms step_avg:87.34ms +step:671/1680 train_time:58608ms step_avg:87.34ms +step:672/1680 train_time:58696ms step_avg:87.34ms +step:673/1680 train_time:58784ms step_avg:87.35ms +step:674/1680 train_time:58873ms step_avg:87.35ms +step:675/1680 train_time:58962ms step_avg:87.35ms +step:676/1680 train_time:59050ms step_avg:87.35ms +step:677/1680 train_time:59138ms step_avg:87.35ms +step:678/1680 train_time:59226ms step_avg:87.35ms +step:679/1680 train_time:59314ms step_avg:87.35ms +step:680/1680 train_time:59402ms step_avg:87.36ms +step:681/1680 train_time:59490ms step_avg:87.36ms +step:682/1680 train_time:59578ms step_avg:87.36ms +step:683/1680 train_time:59667ms step_avg:87.36ms +step:684/1680 train_time:59755ms step_avg:87.36ms +step:685/1680 train_time:59844ms step_avg:87.36ms +step:686/1680 train_time:59933ms step_avg:87.37ms +step:687/1680 train_time:60021ms step_avg:87.37ms +step:688/1680 train_time:60109ms step_avg:87.37ms +step:689/1680 train_time:60197ms step_avg:87.37ms +step:690/1680 train_time:60285ms step_avg:87.37ms +step:691/1680 train_time:60373ms step_avg:87.37ms +step:692/1680 train_time:60462ms step_avg:87.37ms +step:693/1680 train_time:60551ms step_avg:87.37ms +step:694/1680 train_time:60638ms step_avg:87.37ms +step:695/1680 train_time:60726ms step_avg:87.38ms +step:696/1680 train_time:60815ms step_avg:87.38ms +step:697/1680 train_time:60903ms step_avg:87.38ms +step:698/1680 train_time:60991ms step_avg:87.38ms +step:699/1680 train_time:61080ms step_avg:87.38ms +step:700/1680 train_time:61168ms step_avg:87.38ms +step:701/1680 train_time:61255ms step_avg:87.38ms +step:702/1680 train_time:61343ms step_avg:87.38ms +step:703/1680 train_time:61431ms step_avg:87.38ms +step:704/1680 train_time:61519ms step_avg:87.38ms +step:705/1680 train_time:61607ms step_avg:87.39ms +step:706/1680 train_time:61695ms step_avg:87.39ms +step:707/1680 train_time:61784ms step_avg:87.39ms +step:708/1680 train_time:61873ms step_avg:87.39ms +step:709/1680 train_time:61962ms step_avg:87.39ms +step:710/1680 train_time:62051ms step_avg:87.40ms +step:711/1680 train_time:62139ms step_avg:87.40ms +step:712/1680 train_time:62227ms step_avg:87.40ms +step:713/1680 train_time:62315ms step_avg:87.40ms +step:714/1680 train_time:62403ms step_avg:87.40ms +step:715/1680 train_time:62491ms step_avg:87.40ms +step:716/1680 train_time:62580ms step_avg:87.40ms +step:717/1680 train_time:62668ms step_avg:87.40ms +step:718/1680 train_time:62756ms step_avg:87.40ms +step:719/1680 train_time:62843ms step_avg:87.40ms +step:720/1680 train_time:62931ms step_avg:87.40ms +step:721/1680 train_time:63020ms step_avg:87.41ms +step:722/1680 train_time:63107ms step_avg:87.41ms +step:723/1680 train_time:63195ms step_avg:87.41ms +step:724/1680 train_time:63283ms step_avg:87.41ms +step:725/1680 train_time:63371ms step_avg:87.41ms +step:726/1680 train_time:63459ms step_avg:87.41ms +step:727/1680 train_time:63548ms step_avg:87.41ms +step:728/1680 train_time:63635ms step_avg:87.41ms +step:729/1680 train_time:63723ms step_avg:87.41ms +step:730/1680 train_time:63811ms step_avg:87.41ms +step:731/1680 train_time:63900ms step_avg:87.41ms +step:732/1680 train_time:63989ms step_avg:87.42ms +step:733/1680 train_time:64077ms step_avg:87.42ms +step:734/1680 train_time:64165ms step_avg:87.42ms +step:735/1680 train_time:64253ms step_avg:87.42ms +step:736/1680 train_time:64340ms step_avg:87.42ms +step:737/1680 train_time:64428ms step_avg:87.42ms +step:738/1680 train_time:64517ms step_avg:87.42ms +step:739/1680 train_time:64604ms step_avg:87.42ms +step:740/1680 train_time:64692ms step_avg:87.42ms +step:741/1680 train_time:64781ms step_avg:87.42ms +step:742/1680 train_time:64870ms step_avg:87.43ms +step:743/1680 train_time:64959ms step_avg:87.43ms +step:744/1680 train_time:65047ms step_avg:87.43ms +step:745/1680 train_time:65134ms step_avg:87.43ms +step:746/1680 train_time:65223ms step_avg:87.43ms +step:747/1680 train_time:65311ms step_avg:87.43ms +step:748/1680 train_time:65399ms step_avg:87.43ms +step:749/1680 train_time:65487ms step_avg:87.43ms +step:750/1680 train_time:65575ms step_avg:87.43ms +step:750/1680 val_loss:3.5689 train_time:65664ms step_avg:87.55ms +step:751/1680 train_time:65683ms step_avg:87.46ms +step:752/1680 train_time:65757ms step_avg:87.44ms +step:753/1680 train_time:65850ms step_avg:87.45ms +step:754/1680 train_time:65941ms step_avg:87.46ms +step:755/1680 train_time:66029ms step_avg:87.46ms +step:756/1680 train_time:66116ms step_avg:87.46ms +step:757/1680 train_time:66204ms step_avg:87.46ms +step:758/1680 train_time:66290ms step_avg:87.45ms +step:759/1680 train_time:66377ms step_avg:87.45ms +step:760/1680 train_time:66464ms step_avg:87.45ms +step:761/1680 train_time:66551ms step_avg:87.45ms +step:762/1680 train_time:66640ms step_avg:87.45ms +step:763/1680 train_time:66730ms step_avg:87.46ms +step:764/1680 train_time:66821ms step_avg:87.46ms +step:765/1680 train_time:66911ms step_avg:87.47ms +step:766/1680 train_time:67000ms step_avg:87.47ms +step:767/1680 train_time:67087ms step_avg:87.47ms +step:768/1680 train_time:67175ms step_avg:87.47ms +step:769/1680 train_time:67262ms step_avg:87.47ms +step:770/1680 train_time:67350ms step_avg:87.47ms +step:771/1680 train_time:67437ms step_avg:87.47ms +step:772/1680 train_time:67524ms step_avg:87.47ms +step:773/1680 train_time:67612ms step_avg:87.47ms +step:774/1680 train_time:67700ms step_avg:87.47ms +step:775/1680 train_time:67791ms step_avg:87.47ms +step:776/1680 train_time:67880ms step_avg:87.47ms +step:777/1680 train_time:67969ms step_avg:87.48ms +step:778/1680 train_time:68057ms step_avg:87.48ms +step:779/1680 train_time:68146ms step_avg:87.48ms +step:780/1680 train_time:68234ms step_avg:87.48ms +step:781/1680 train_time:68321ms step_avg:87.48ms +step:782/1680 train_time:68410ms step_avg:87.48ms +step:783/1680 train_time:68497ms step_avg:87.48ms +step:784/1680 train_time:68586ms step_avg:87.48ms +step:785/1680 train_time:68674ms step_avg:87.48ms +step:786/1680 train_time:68763ms step_avg:87.48ms +step:787/1680 train_time:68853ms step_avg:87.49ms +step:788/1680 train_time:68942ms step_avg:87.49ms +step:789/1680 train_time:69030ms step_avg:87.49ms +step:790/1680 train_time:69119ms step_avg:87.49ms +step:791/1680 train_time:69206ms step_avg:87.49ms +step:792/1680 train_time:69294ms step_avg:87.49ms +step:793/1680 train_time:69382ms step_avg:87.49ms +step:794/1680 train_time:69470ms step_avg:87.49ms +step:795/1680 train_time:69557ms step_avg:87.49ms +step:796/1680 train_time:69646ms step_avg:87.50ms +step:797/1680 train_time:69735ms step_avg:87.50ms +step:798/1680 train_time:69823ms step_avg:87.50ms +step:799/1680 train_time:69913ms step_avg:87.50ms +step:800/1680 train_time:70001ms step_avg:87.50ms +step:801/1680 train_time:70089ms step_avg:87.50ms +step:802/1680 train_time:70176ms step_avg:87.50ms +step:803/1680 train_time:70265ms step_avg:87.50ms +step:804/1680 train_time:70353ms step_avg:87.50ms +step:805/1680 train_time:70441ms step_avg:87.50ms +step:806/1680 train_time:70528ms step_avg:87.50ms +step:807/1680 train_time:70616ms step_avg:87.50ms +step:808/1680 train_time:70704ms step_avg:87.51ms +step:809/1680 train_time:70793ms step_avg:87.51ms +step:810/1680 train_time:70882ms step_avg:87.51ms +step:811/1680 train_time:70971ms step_avg:87.51ms +step:812/1680 train_time:71059ms step_avg:87.51ms +step:813/1680 train_time:71147ms step_avg:87.51ms +step:814/1680 train_time:71235ms step_avg:87.51ms +step:815/1680 train_time:71323ms step_avg:87.51ms +step:816/1680 train_time:71411ms step_avg:87.51ms +step:817/1680 train_time:71499ms step_avg:87.51ms +step:818/1680 train_time:71587ms step_avg:87.51ms +step:819/1680 train_time:71675ms step_avg:87.51ms +step:820/1680 train_time:71763ms step_avg:87.52ms +step:821/1680 train_time:71851ms step_avg:87.52ms +step:822/1680 train_time:71939ms step_avg:87.52ms +step:823/1680 train_time:72028ms step_avg:87.52ms +step:824/1680 train_time:72116ms step_avg:87.52ms +step:825/1680 train_time:72204ms step_avg:87.52ms +step:826/1680 train_time:72293ms step_avg:87.52ms +step:827/1680 train_time:72381ms step_avg:87.52ms +step:828/1680 train_time:72469ms step_avg:87.52ms +step:829/1680 train_time:72557ms step_avg:87.52ms +step:830/1680 train_time:72645ms step_avg:87.52ms +step:831/1680 train_time:72733ms step_avg:87.53ms +step:832/1680 train_time:72822ms step_avg:87.53ms +step:833/1680 train_time:72911ms step_avg:87.53ms +step:834/1680 train_time:72999ms step_avg:87.53ms +step:835/1680 train_time:73088ms step_avg:87.53ms +step:836/1680 train_time:73176ms step_avg:87.53ms +step:837/1680 train_time:73264ms step_avg:87.53ms +step:838/1680 train_time:73352ms step_avg:87.53ms +step:839/1680 train_time:73441ms step_avg:87.53ms +step:840/1680 train_time:73528ms step_avg:87.53ms +step:841/1680 train_time:73617ms step_avg:87.53ms +step:842/1680 train_time:73705ms step_avg:87.54ms +step:843/1680 train_time:73793ms step_avg:87.54ms +step:844/1680 train_time:73881ms step_avg:87.54ms +step:845/1680 train_time:73969ms step_avg:87.54ms +step:846/1680 train_time:74057ms step_avg:87.54ms +step:847/1680 train_time:74145ms step_avg:87.54ms +step:848/1680 train_time:74234ms step_avg:87.54ms +step:849/1680 train_time:74322ms step_avg:87.54ms +step:850/1680 train_time:74410ms step_avg:87.54ms +step:851/1680 train_time:74498ms step_avg:87.54ms +step:852/1680 train_time:74586ms step_avg:87.54ms +step:853/1680 train_time:74673ms step_avg:87.54ms +step:854/1680 train_time:74762ms step_avg:87.54ms +step:855/1680 train_time:74850ms step_avg:87.54ms +step:856/1680 train_time:74938ms step_avg:87.54ms +step:857/1680 train_time:75027ms step_avg:87.55ms +step:858/1680 train_time:75116ms step_avg:87.55ms +step:859/1680 train_time:75205ms step_avg:87.55ms +step:860/1680 train_time:75293ms step_avg:87.55ms +step:861/1680 train_time:75381ms step_avg:87.55ms +step:862/1680 train_time:75468ms step_avg:87.55ms +step:863/1680 train_time:75556ms step_avg:87.55ms +step:864/1680 train_time:75644ms step_avg:87.55ms +step:865/1680 train_time:75732ms step_avg:87.55ms +step:866/1680 train_time:75820ms step_avg:87.55ms +step:867/1680 train_time:75908ms step_avg:87.55ms +step:868/1680 train_time:75996ms step_avg:87.55ms +step:869/1680 train_time:76085ms step_avg:87.55ms +step:870/1680 train_time:76173ms step_avg:87.55ms +step:871/1680 train_time:76261ms step_avg:87.56ms +step:872/1680 train_time:76350ms step_avg:87.56ms +step:873/1680 train_time:76437ms step_avg:87.56ms +step:874/1680 train_time:76525ms step_avg:87.56ms +step:875/1680 train_time:76614ms step_avg:87.56ms +step:875/1680 val_loss:3.5219 train_time:76704ms step_avg:87.66ms +step:876/1680 train_time:76722ms step_avg:87.58ms +step:877/1680 train_time:76795ms step_avg:87.57ms +step:878/1680 train_time:76889ms step_avg:87.57ms +step:879/1680 train_time:76978ms step_avg:87.57ms +step:880/1680 train_time:77066ms step_avg:87.57ms +step:881/1680 train_time:77153ms step_avg:87.57ms +step:882/1680 train_time:77239ms step_avg:87.57ms +step:883/1680 train_time:77326ms step_avg:87.57ms +step:884/1680 train_time:77413ms step_avg:87.57ms +step:885/1680 train_time:77500ms step_avg:87.57ms +step:886/1680 train_time:77588ms step_avg:87.57ms +step:887/1680 train_time:77677ms step_avg:87.57ms +step:888/1680 train_time:77768ms step_avg:87.58ms +step:889/1680 train_time:77859ms step_avg:87.58ms +step:890/1680 train_time:77950ms step_avg:87.58ms +step:891/1680 train_time:78039ms step_avg:87.59ms +step:892/1680 train_time:78126ms step_avg:87.59ms +step:893/1680 train_time:78214ms step_avg:87.59ms +step:894/1680 train_time:78301ms step_avg:87.59ms +step:895/1680 train_time:78389ms step_avg:87.58ms +step:896/1680 train_time:78476ms step_avg:87.58ms +step:897/1680 train_time:78564ms step_avg:87.59ms +step:898/1680 train_time:78652ms step_avg:87.59ms +step:899/1680 train_time:78741ms step_avg:87.59ms +step:900/1680 train_time:78831ms step_avg:87.59ms +step:901/1680 train_time:78920ms step_avg:87.59ms +step:902/1680 train_time:79009ms step_avg:87.59ms +step:903/1680 train_time:79097ms step_avg:87.59ms +step:904/1680 train_time:79185ms step_avg:87.59ms +step:905/1680 train_time:79272ms step_avg:87.59ms +step:906/1680 train_time:79360ms step_avg:87.59ms +step:907/1680 train_time:79448ms step_avg:87.59ms +step:908/1680 train_time:79535ms step_avg:87.59ms +step:909/1680 train_time:79623ms step_avg:87.59ms +step:910/1680 train_time:79711ms step_avg:87.59ms +step:911/1680 train_time:79800ms step_avg:87.60ms +step:912/1680 train_time:79890ms step_avg:87.60ms +step:913/1680 train_time:79979ms step_avg:87.60ms +step:914/1680 train_time:80068ms step_avg:87.60ms +step:915/1680 train_time:80156ms step_avg:87.60ms +step:916/1680 train_time:80243ms step_avg:87.60ms +step:917/1680 train_time:80331ms step_avg:87.60ms +step:918/1680 train_time:80418ms step_avg:87.60ms +step:919/1680 train_time:80506ms step_avg:87.60ms +step:920/1680 train_time:80594ms step_avg:87.60ms +step:921/1680 train_time:80683ms step_avg:87.60ms +step:922/1680 train_time:80771ms step_avg:87.60ms +step:923/1680 train_time:80860ms step_avg:87.61ms +step:924/1680 train_time:80949ms step_avg:87.61ms +step:925/1680 train_time:81037ms step_avg:87.61ms +step:926/1680 train_time:81126ms step_avg:87.61ms +step:927/1680 train_time:81213ms step_avg:87.61ms +step:928/1680 train_time:81301ms step_avg:87.61ms +step:929/1680 train_time:81389ms step_avg:87.61ms +step:930/1680 train_time:81477ms step_avg:87.61ms +step:931/1680 train_time:81565ms step_avg:87.61ms +step:932/1680 train_time:81654ms step_avg:87.61ms +step:933/1680 train_time:81742ms step_avg:87.61ms +step:934/1680 train_time:81830ms step_avg:87.61ms +step:935/1680 train_time:81919ms step_avg:87.61ms +step:936/1680 train_time:82008ms step_avg:87.62ms +step:937/1680 train_time:82096ms step_avg:87.62ms +step:938/1680 train_time:82185ms step_avg:87.62ms +step:939/1680 train_time:82273ms step_avg:87.62ms +step:940/1680 train_time:82361ms step_avg:87.62ms +step:941/1680 train_time:82448ms step_avg:87.62ms +step:942/1680 train_time:82536ms step_avg:87.62ms +step:943/1680 train_time:82624ms step_avg:87.62ms +step:944/1680 train_time:82712ms step_avg:87.62ms +step:945/1680 train_time:82800ms step_avg:87.62ms +step:946/1680 train_time:82889ms step_avg:87.62ms +step:947/1680 train_time:82978ms step_avg:87.62ms +step:948/1680 train_time:83066ms step_avg:87.62ms +step:949/1680 train_time:83155ms step_avg:87.62ms +step:950/1680 train_time:83242ms step_avg:87.62ms +step:951/1680 train_time:83330ms step_avg:87.62ms +step:952/1680 train_time:83419ms step_avg:87.62ms +step:953/1680 train_time:83508ms step_avg:87.63ms +step:954/1680 train_time:83596ms step_avg:87.63ms +step:955/1680 train_time:83684ms step_avg:87.63ms +step:956/1680 train_time:83771ms step_avg:87.63ms +step:957/1680 train_time:83860ms step_avg:87.63ms +step:958/1680 train_time:83948ms step_avg:87.63ms +step:959/1680 train_time:84037ms step_avg:87.63ms +step:960/1680 train_time:84125ms step_avg:87.63ms +step:961/1680 train_time:84214ms step_avg:87.63ms +step:962/1680 train_time:84303ms step_avg:87.63ms +step:963/1680 train_time:84390ms step_avg:87.63ms +step:964/1680 train_time:84478ms step_avg:87.63ms +step:965/1680 train_time:84567ms step_avg:87.63ms +step:966/1680 train_time:84656ms step_avg:87.64ms +step:967/1680 train_time:84744ms step_avg:87.64ms +step:968/1680 train_time:84832ms step_avg:87.64ms +step:969/1680 train_time:84919ms step_avg:87.64ms +step:970/1680 train_time:85008ms step_avg:87.64ms +step:971/1680 train_time:85096ms step_avg:87.64ms +step:972/1680 train_time:85185ms step_avg:87.64ms +step:973/1680 train_time:85273ms step_avg:87.64ms +step:974/1680 train_time:85361ms step_avg:87.64ms +step:975/1680 train_time:85448ms step_avg:87.64ms +step:976/1680 train_time:85536ms step_avg:87.64ms +step:977/1680 train_time:85624ms step_avg:87.64ms +step:978/1680 train_time:85712ms step_avg:87.64ms +step:979/1680 train_time:85801ms step_avg:87.64ms +step:980/1680 train_time:85889ms step_avg:87.64ms +step:981/1680 train_time:85977ms step_avg:87.64ms +step:982/1680 train_time:86066ms step_avg:87.64ms +step:983/1680 train_time:86155ms step_avg:87.65ms +step:984/1680 train_time:86244ms step_avg:87.65ms +step:985/1680 train_time:86332ms step_avg:87.65ms +step:986/1680 train_time:86420ms step_avg:87.65ms +step:987/1680 train_time:86508ms step_avg:87.65ms +step:988/1680 train_time:86596ms step_avg:87.65ms +step:989/1680 train_time:86685ms step_avg:87.65ms +step:990/1680 train_time:86773ms step_avg:87.65ms +step:991/1680 train_time:86861ms step_avg:87.65ms +step:992/1680 train_time:86949ms step_avg:87.65ms +step:993/1680 train_time:87037ms step_avg:87.65ms +step:994/1680 train_time:87126ms step_avg:87.65ms +step:995/1680 train_time:87214ms step_avg:87.65ms +step:996/1680 train_time:87302ms step_avg:87.65ms +step:997/1680 train_time:87390ms step_avg:87.65ms +step:998/1680 train_time:87479ms step_avg:87.65ms +step:999/1680 train_time:87567ms step_avg:87.65ms +step:1000/1680 train_time:87656ms step_avg:87.66ms +step:1000/1680 val_loss:3.4720 train_time:87745ms step_avg:87.75ms +step:1001/1680 train_time:87764ms step_avg:87.68ms +step:1002/1680 train_time:87840ms step_avg:87.66ms +step:1003/1680 train_time:87931ms step_avg:87.67ms +step:1004/1680 train_time:88021ms step_avg:87.67ms +step:1005/1680 train_time:88109ms step_avg:87.67ms +step:1006/1680 train_time:88196ms step_avg:87.67ms +step:1007/1680 train_time:88282ms step_avg:87.67ms +step:1008/1680 train_time:88369ms step_avg:87.67ms +step:1009/1680 train_time:88457ms step_avg:87.67ms +step:1010/1680 train_time:88544ms step_avg:87.67ms +step:1011/1680 train_time:88631ms step_avg:87.67ms +step:1012/1680 train_time:88721ms step_avg:87.67ms +step:1013/1680 train_time:88811ms step_avg:87.67ms +step:1014/1680 train_time:88901ms step_avg:87.67ms +step:1015/1680 train_time:88990ms step_avg:87.68ms +step:1016/1680 train_time:89079ms step_avg:87.68ms +step:1017/1680 train_time:89167ms step_avg:87.68ms +step:1018/1680 train_time:89254ms step_avg:87.68ms +step:1019/1680 train_time:89341ms step_avg:87.68ms +step:1020/1680 train_time:89429ms step_avg:87.68ms +step:1021/1680 train_time:89517ms step_avg:87.68ms +step:1022/1680 train_time:89604ms step_avg:87.68ms +step:1023/1680 train_time:89692ms step_avg:87.68ms +step:1024/1680 train_time:89781ms step_avg:87.68ms +step:1025/1680 train_time:89871ms step_avg:87.68ms +step:1026/1680 train_time:89961ms step_avg:87.68ms +step:1027/1680 train_time:90051ms step_avg:87.68ms +step:1028/1680 train_time:90140ms step_avg:87.68ms +step:1029/1680 train_time:90227ms step_avg:87.68ms +step:1030/1680 train_time:90315ms step_avg:87.68ms +step:1031/1680 train_time:90402ms step_avg:87.68ms +step:1032/1680 train_time:90490ms step_avg:87.68ms +step:1033/1680 train_time:90577ms step_avg:87.68ms +step:1034/1680 train_time:90665ms step_avg:87.68ms +step:1035/1680 train_time:90753ms step_avg:87.68ms +step:1036/1680 train_time:90842ms step_avg:87.69ms +step:1037/1680 train_time:90932ms step_avg:87.69ms +step:1038/1680 train_time:91023ms step_avg:87.69ms +step:1039/1680 train_time:91112ms step_avg:87.69ms +step:1040/1680 train_time:91201ms step_avg:87.69ms +step:1041/1680 train_time:91289ms step_avg:87.69ms +step:1042/1680 train_time:91377ms step_avg:87.69ms +step:1043/1680 train_time:91464ms step_avg:87.69ms +step:1044/1680 train_time:91552ms step_avg:87.69ms +step:1045/1680 train_time:91639ms step_avg:87.69ms +step:1046/1680 train_time:91727ms step_avg:87.69ms +step:1047/1680 train_time:91816ms step_avg:87.69ms +step:1048/1680 train_time:91905ms step_avg:87.70ms +step:1049/1680 train_time:91994ms step_avg:87.70ms +step:1050/1680 train_time:92082ms step_avg:87.70ms +step:1051/1680 train_time:92171ms step_avg:87.70ms +step:1052/1680 train_time:92260ms step_avg:87.70ms +step:1053/1680 train_time:92347ms step_avg:87.70ms +step:1054/1680 train_time:92435ms step_avg:87.70ms +step:1055/1680 train_time:92523ms step_avg:87.70ms +step:1056/1680 train_time:92610ms step_avg:87.70ms +step:1057/1680 train_time:92698ms step_avg:87.70ms +step:1058/1680 train_time:92786ms step_avg:87.70ms +step:1059/1680 train_time:92875ms step_avg:87.70ms +step:1060/1680 train_time:92963ms step_avg:87.70ms +step:1061/1680 train_time:93052ms step_avg:87.70ms +step:1062/1680 train_time:93141ms step_avg:87.70ms +step:1063/1680 train_time:93229ms step_avg:87.70ms +step:1064/1680 train_time:93318ms step_avg:87.70ms +step:1065/1680 train_time:93405ms step_avg:87.70ms +step:1066/1680 train_time:93493ms step_avg:87.70ms +step:1067/1680 train_time:93581ms step_avg:87.70ms +step:1068/1680 train_time:93669ms step_avg:87.70ms +step:1069/1680 train_time:93758ms step_avg:87.71ms +step:1070/1680 train_time:93846ms step_avg:87.71ms +step:1071/1680 train_time:93934ms step_avg:87.71ms +step:1072/1680 train_time:94023ms step_avg:87.71ms +step:1073/1680 train_time:94111ms step_avg:87.71ms +step:1074/1680 train_time:94200ms step_avg:87.71ms +step:1075/1680 train_time:94289ms step_avg:87.71ms +step:1076/1680 train_time:94377ms step_avg:87.71ms +step:1077/1680 train_time:94465ms step_avg:87.71ms +step:1078/1680 train_time:94553ms step_avg:87.71ms +step:1079/1680 train_time:94641ms step_avg:87.71ms +step:1080/1680 train_time:94730ms step_avg:87.71ms +step:1081/1680 train_time:94818ms step_avg:87.71ms +step:1082/1680 train_time:94906ms step_avg:87.71ms +step:1083/1680 train_time:94995ms step_avg:87.71ms +step:1084/1680 train_time:95083ms step_avg:87.72ms +step:1085/1680 train_time:95172ms step_avg:87.72ms +step:1086/1680 train_time:95261ms step_avg:87.72ms +step:1087/1680 train_time:95350ms step_avg:87.72ms +step:1088/1680 train_time:95438ms step_avg:87.72ms +step:1089/1680 train_time:95526ms step_avg:87.72ms +step:1090/1680 train_time:95614ms step_avg:87.72ms +step:1091/1680 train_time:95702ms step_avg:87.72ms +step:1092/1680 train_time:95790ms step_avg:87.72ms +step:1093/1680 train_time:95878ms step_avg:87.72ms +step:1094/1680 train_time:95966ms step_avg:87.72ms +step:1095/1680 train_time:96055ms step_avg:87.72ms +step:1096/1680 train_time:96144ms step_avg:87.72ms +step:1097/1680 train_time:96233ms step_avg:87.72ms +step:1098/1680 train_time:96322ms step_avg:87.73ms +step:1099/1680 train_time:96411ms step_avg:87.73ms +step:1100/1680 train_time:96500ms step_avg:87.73ms +step:1101/1680 train_time:96588ms step_avg:87.73ms +step:1102/1680 train_time:96677ms step_avg:87.73ms +step:1103/1680 train_time:96765ms step_avg:87.73ms +step:1104/1680 train_time:96854ms step_avg:87.73ms +step:1105/1680 train_time:96943ms step_avg:87.73ms +step:1106/1680 train_time:97031ms step_avg:87.73ms +step:1107/1680 train_time:97120ms step_avg:87.73ms +step:1108/1680 train_time:97209ms step_avg:87.73ms +step:1109/1680 train_time:97298ms step_avg:87.74ms +step:1110/1680 train_time:97387ms step_avg:87.74ms +step:1111/1680 train_time:97476ms step_avg:87.74ms +step:1112/1680 train_time:97564ms step_avg:87.74ms +step:1113/1680 train_time:97653ms step_avg:87.74ms +step:1114/1680 train_time:97741ms step_avg:87.74ms +step:1115/1680 train_time:97830ms step_avg:87.74ms +step:1116/1680 train_time:97920ms step_avg:87.74ms +step:1117/1680 train_time:98008ms step_avg:87.74ms +step:1118/1680 train_time:98096ms step_avg:87.74ms +step:1119/1680 train_time:98185ms step_avg:87.74ms +step:1120/1680 train_time:98274ms step_avg:87.74ms +step:1121/1680 train_time:98363ms step_avg:87.75ms +step:1122/1680 train_time:98452ms step_avg:87.75ms +step:1123/1680 train_time:98541ms step_avg:87.75ms +step:1124/1680 train_time:98630ms step_avg:87.75ms +step:1125/1680 train_time:98720ms step_avg:87.75ms +step:1125/1680 val_loss:3.4189 train_time:98809ms step_avg:87.83ms +step:1126/1680 train_time:98829ms step_avg:87.77ms +step:1127/1680 train_time:98900ms step_avg:87.75ms +step:1128/1680 train_time:98990ms step_avg:87.76ms +step:1129/1680 train_time:99082ms step_avg:87.76ms +step:1130/1680 train_time:99170ms step_avg:87.76ms +step:1131/1680 train_time:99259ms step_avg:87.76ms +step:1132/1680 train_time:99347ms step_avg:87.76ms +step:1133/1680 train_time:99434ms step_avg:87.76ms +step:1134/1680 train_time:99522ms step_avg:87.76ms +step:1135/1680 train_time:99610ms step_avg:87.76ms +step:1136/1680 train_time:99699ms step_avg:87.76ms +step:1137/1680 train_time:99789ms step_avg:87.77ms +step:1138/1680 train_time:99878ms step_avg:87.77ms +step:1139/1680 train_time:99968ms step_avg:87.77ms +step:1140/1680 train_time:100059ms step_avg:87.77ms +step:1141/1680 train_time:100148ms step_avg:87.77ms +step:1142/1680 train_time:100238ms step_avg:87.77ms +step:1143/1680 train_time:100326ms step_avg:87.77ms +step:1144/1680 train_time:100414ms step_avg:87.77ms +step:1145/1680 train_time:100502ms step_avg:87.77ms +step:1146/1680 train_time:100590ms step_avg:87.77ms +step:1147/1680 train_time:100679ms step_avg:87.78ms +step:1148/1680 train_time:100768ms step_avg:87.78ms +step:1149/1680 train_time:100858ms step_avg:87.78ms +step:1150/1680 train_time:100947ms step_avg:87.78ms +step:1151/1680 train_time:101037ms step_avg:87.78ms +step:1152/1680 train_time:101126ms step_avg:87.78ms +step:1153/1680 train_time:101216ms step_avg:87.78ms +step:1154/1680 train_time:101305ms step_avg:87.79ms +step:1155/1680 train_time:101394ms step_avg:87.79ms +step:1156/1680 train_time:101481ms step_avg:87.79ms +step:1157/1680 train_time:101570ms step_avg:87.79ms +step:1158/1680 train_time:101658ms step_avg:87.79ms +step:1159/1680 train_time:101747ms step_avg:87.79ms +step:1160/1680 train_time:101836ms step_avg:87.79ms +step:1161/1680 train_time:101926ms step_avg:87.79ms +step:1162/1680 train_time:102015ms step_avg:87.79ms +step:1163/1680 train_time:102105ms step_avg:87.79ms +step:1164/1680 train_time:102194ms step_avg:87.80ms +step:1165/1680 train_time:102283ms step_avg:87.80ms +step:1166/1680 train_time:102372ms step_avg:87.80ms +step:1167/1680 train_time:102460ms step_avg:87.80ms +step:1168/1680 train_time:102548ms step_avg:87.80ms +step:1169/1680 train_time:102637ms step_avg:87.80ms +step:1170/1680 train_time:102726ms step_avg:87.80ms +step:1171/1680 train_time:102814ms step_avg:87.80ms +step:1172/1680 train_time:102904ms step_avg:87.80ms +step:1173/1680 train_time:102994ms step_avg:87.80ms +step:1174/1680 train_time:103082ms step_avg:87.80ms +step:1175/1680 train_time:103172ms step_avg:87.81ms +step:1176/1680 train_time:103261ms step_avg:87.81ms +step:1177/1680 train_time:103350ms step_avg:87.81ms +step:1178/1680 train_time:103438ms step_avg:87.81ms +step:1179/1680 train_time:103527ms step_avg:87.81ms +step:1180/1680 train_time:103616ms step_avg:87.81ms +step:1181/1680 train_time:103704ms step_avg:87.81ms +step:1182/1680 train_time:103793ms step_avg:87.81ms +step:1183/1680 train_time:103883ms step_avg:87.81ms +step:1184/1680 train_time:103972ms step_avg:87.81ms +step:1185/1680 train_time:104061ms step_avg:87.82ms +step:1186/1680 train_time:104151ms step_avg:87.82ms +step:1187/1680 train_time:104239ms step_avg:87.82ms +step:1188/1680 train_time:104329ms step_avg:87.82ms +step:1189/1680 train_time:104419ms step_avg:87.82ms +step:1190/1680 train_time:104509ms step_avg:87.82ms +step:1191/1680 train_time:104598ms step_avg:87.82ms +step:1192/1680 train_time:104688ms step_avg:87.83ms +step:1193/1680 train_time:104777ms step_avg:87.83ms +step:1194/1680 train_time:104866ms step_avg:87.83ms +step:1195/1680 train_time:104955ms step_avg:87.83ms +step:1196/1680 train_time:105044ms step_avg:87.83ms +step:1197/1680 train_time:105133ms step_avg:87.83ms +step:1198/1680 train_time:105222ms step_avg:87.83ms +step:1199/1680 train_time:105311ms step_avg:87.83ms +step:1200/1680 train_time:105399ms step_avg:87.83ms +step:1201/1680 train_time:105488ms step_avg:87.83ms +step:1202/1680 train_time:105576ms step_avg:87.83ms +step:1203/1680 train_time:105664ms step_avg:87.83ms +step:1204/1680 train_time:105753ms step_avg:87.83ms +step:1205/1680 train_time:105842ms step_avg:87.84ms +step:1206/1680 train_time:105932ms step_avg:87.84ms +step:1207/1680 train_time:106021ms step_avg:87.84ms +step:1208/1680 train_time:106110ms step_avg:87.84ms +step:1209/1680 train_time:106199ms step_avg:87.84ms +step:1210/1680 train_time:106288ms step_avg:87.84ms +step:1211/1680 train_time:106377ms step_avg:87.84ms +step:1212/1680 train_time:106465ms step_avg:87.84ms +step:1213/1680 train_time:106554ms step_avg:87.84ms +step:1214/1680 train_time:106642ms step_avg:87.84ms +step:1215/1680 train_time:106731ms step_avg:87.84ms +step:1216/1680 train_time:106820ms step_avg:87.85ms +step:1217/1680 train_time:106910ms step_avg:87.85ms +step:1218/1680 train_time:107000ms step_avg:87.85ms +step:1219/1680 train_time:107091ms step_avg:87.85ms +step:1220/1680 train_time:107180ms step_avg:87.85ms +step:1221/1680 train_time:107268ms step_avg:87.85ms +step:1222/1680 train_time:107357ms step_avg:87.85ms +step:1223/1680 train_time:107445ms step_avg:87.85ms +step:1224/1680 train_time:107534ms step_avg:87.85ms +step:1225/1680 train_time:107623ms step_avg:87.86ms +step:1226/1680 train_time:107712ms step_avg:87.86ms +step:1227/1680 train_time:107801ms step_avg:87.86ms +step:1228/1680 train_time:107890ms step_avg:87.86ms +step:1229/1680 train_time:107979ms step_avg:87.86ms +step:1230/1680 train_time:108067ms step_avg:87.86ms +step:1231/1680 train_time:108156ms step_avg:87.86ms +step:1232/1680 train_time:108246ms step_avg:87.86ms +step:1233/1680 train_time:108334ms step_avg:87.86ms +step:1234/1680 train_time:108423ms step_avg:87.86ms +step:1235/1680 train_time:108512ms step_avg:87.86ms +step:1236/1680 train_time:108602ms step_avg:87.87ms +step:1237/1680 train_time:108691ms step_avg:87.87ms +step:1238/1680 train_time:108780ms step_avg:87.87ms +step:1239/1680 train_time:108869ms step_avg:87.87ms +step:1240/1680 train_time:108958ms step_avg:87.87ms +step:1241/1680 train_time:109047ms step_avg:87.87ms +step:1242/1680 train_time:109136ms step_avg:87.87ms +step:1243/1680 train_time:109225ms step_avg:87.87ms +step:1244/1680 train_time:109313ms step_avg:87.87ms +step:1245/1680 train_time:109402ms step_avg:87.87ms +step:1246/1680 train_time:109491ms step_avg:87.87ms +step:1247/1680 train_time:109581ms step_avg:87.88ms +step:1248/1680 train_time:109670ms step_avg:87.88ms +step:1249/1680 train_time:109758ms step_avg:87.88ms +step:1250/1680 train_time:109847ms step_avg:87.88ms +step:1250/1680 val_loss:3.3813 train_time:109938ms step_avg:87.95ms +step:1251/1680 train_time:109956ms step_avg:87.89ms +step:1252/1680 train_time:110028ms step_avg:87.88ms +step:1253/1680 train_time:110120ms step_avg:87.89ms +step:1254/1680 train_time:110210ms step_avg:87.89ms +step:1255/1680 train_time:110298ms step_avg:87.89ms +step:1256/1680 train_time:110387ms step_avg:87.89ms +step:1257/1680 train_time:110475ms step_avg:87.89ms +step:1258/1680 train_time:110564ms step_avg:87.89ms +step:1259/1680 train_time:110652ms step_avg:87.89ms +step:1260/1680 train_time:110740ms step_avg:87.89ms +step:1261/1680 train_time:110829ms step_avg:87.89ms +step:1262/1680 train_time:110920ms step_avg:87.89ms +step:1263/1680 train_time:111010ms step_avg:87.89ms +step:1264/1680 train_time:111100ms step_avg:87.90ms +step:1265/1680 train_time:111189ms step_avg:87.90ms +step:1266/1680 train_time:111278ms step_avg:87.90ms +step:1267/1680 train_time:111367ms step_avg:87.90ms +step:1268/1680 train_time:111456ms step_avg:87.90ms +step:1269/1680 train_time:111544ms step_avg:87.90ms +step:1270/1680 train_time:111633ms step_avg:87.90ms +step:1271/1680 train_time:111721ms step_avg:87.90ms +step:1272/1680 train_time:111809ms step_avg:87.90ms +step:1273/1680 train_time:111899ms step_avg:87.90ms +step:1274/1680 train_time:111989ms step_avg:87.90ms +step:1275/1680 train_time:112079ms step_avg:87.91ms +step:1276/1680 train_time:112169ms step_avg:87.91ms +step:1277/1680 train_time:112258ms step_avg:87.91ms +step:1278/1680 train_time:112348ms step_avg:87.91ms +step:1279/1680 train_time:112437ms step_avg:87.91ms +step:1280/1680 train_time:112525ms step_avg:87.91ms +step:1281/1680 train_time:112613ms step_avg:87.91ms +step:1282/1680 train_time:112702ms step_avg:87.91ms +step:1283/1680 train_time:112790ms step_avg:87.91ms +step:1284/1680 train_time:112879ms step_avg:87.91ms +step:1285/1680 train_time:112969ms step_avg:87.91ms +step:1286/1680 train_time:113059ms step_avg:87.92ms +step:1287/1680 train_time:113149ms step_avg:87.92ms +step:1288/1680 train_time:113238ms step_avg:87.92ms +step:1289/1680 train_time:113329ms step_avg:87.92ms +step:1290/1680 train_time:113418ms step_avg:87.92ms +step:1291/1680 train_time:113506ms step_avg:87.92ms +step:1292/1680 train_time:113595ms step_avg:87.92ms +step:1293/1680 train_time:113683ms step_avg:87.92ms +step:1294/1680 train_time:113771ms step_avg:87.92ms +step:1295/1680 train_time:113860ms step_avg:87.92ms +step:1296/1680 train_time:113949ms step_avg:87.92ms +step:1297/1680 train_time:114039ms step_avg:87.93ms +step:1298/1680 train_time:114128ms step_avg:87.93ms +step:1299/1680 train_time:114217ms step_avg:87.93ms +step:1300/1680 train_time:114307ms step_avg:87.93ms +step:1301/1680 train_time:114397ms step_avg:87.93ms +step:1302/1680 train_time:114485ms step_avg:87.93ms +step:1303/1680 train_time:114574ms step_avg:87.93ms +step:1304/1680 train_time:114663ms step_avg:87.93ms +step:1305/1680 train_time:114752ms step_avg:87.93ms +step:1306/1680 train_time:114840ms step_avg:87.93ms +step:1307/1680 train_time:114930ms step_avg:87.93ms +step:1308/1680 train_time:115019ms step_avg:87.93ms +step:1309/1680 train_time:115108ms step_avg:87.94ms +step:1310/1680 train_time:115197ms step_avg:87.94ms +step:1311/1680 train_time:115286ms step_avg:87.94ms +step:1312/1680 train_time:115376ms step_avg:87.94ms +step:1313/1680 train_time:115465ms step_avg:87.94ms +step:1314/1680 train_time:115554ms step_avg:87.94ms +step:1315/1680 train_time:115642ms step_avg:87.94ms +step:1316/1680 train_time:115731ms step_avg:87.94ms +step:1317/1680 train_time:115820ms step_avg:87.94ms +step:1318/1680 train_time:115910ms step_avg:87.94ms +step:1319/1680 train_time:115999ms step_avg:87.94ms +step:1320/1680 train_time:116088ms step_avg:87.95ms +step:1321/1680 train_time:116177ms step_avg:87.95ms +step:1322/1680 train_time:116267ms step_avg:87.95ms +step:1323/1680 train_time:116357ms step_avg:87.95ms +step:1324/1680 train_time:116445ms step_avg:87.95ms +step:1325/1680 train_time:116534ms step_avg:87.95ms +step:1326/1680 train_time:116623ms step_avg:87.95ms +step:1327/1680 train_time:116712ms step_avg:87.95ms +step:1328/1680 train_time:116800ms step_avg:87.95ms +step:1329/1680 train_time:116889ms step_avg:87.95ms +step:1330/1680 train_time:116978ms step_avg:87.95ms +step:1331/1680 train_time:117068ms step_avg:87.96ms +step:1332/1680 train_time:117158ms step_avg:87.96ms +step:1333/1680 train_time:117248ms step_avg:87.96ms +step:1334/1680 train_time:117337ms step_avg:87.96ms +step:1335/1680 train_time:117426ms step_avg:87.96ms +step:1336/1680 train_time:117515ms step_avg:87.96ms +step:1337/1680 train_time:117604ms step_avg:87.96ms +step:1338/1680 train_time:117693ms step_avg:87.96ms +step:1339/1680 train_time:117782ms step_avg:87.96ms +step:1340/1680 train_time:117870ms step_avg:87.96ms +step:1341/1680 train_time:117959ms step_avg:87.96ms +step:1342/1680 train_time:118048ms step_avg:87.96ms +step:1343/1680 train_time:118137ms step_avg:87.96ms +step:1344/1680 train_time:118226ms step_avg:87.97ms +step:1345/1680 train_time:118315ms step_avg:87.97ms +step:1346/1680 train_time:118404ms step_avg:87.97ms +step:1347/1680 train_time:118493ms step_avg:87.97ms +step:1348/1680 train_time:118582ms step_avg:87.97ms +step:1349/1680 train_time:118672ms step_avg:87.97ms +step:1350/1680 train_time:118761ms step_avg:87.97ms +step:1351/1680 train_time:118849ms step_avg:87.97ms +step:1352/1680 train_time:118938ms step_avg:87.97ms +step:1353/1680 train_time:119028ms step_avg:87.97ms +step:1354/1680 train_time:119118ms step_avg:87.97ms +step:1355/1680 train_time:119206ms step_avg:87.98ms +step:1356/1680 train_time:119296ms step_avg:87.98ms +step:1357/1680 train_time:119385ms step_avg:87.98ms +step:1358/1680 train_time:119474ms step_avg:87.98ms +step:1359/1680 train_time:119563ms step_avg:87.98ms +step:1360/1680 train_time:119653ms step_avg:87.98ms +step:1361/1680 train_time:119741ms step_avg:87.98ms +step:1362/1680 train_time:119830ms step_avg:87.98ms +step:1363/1680 train_time:119919ms step_avg:87.98ms +step:1364/1680 train_time:120008ms step_avg:87.98ms +step:1365/1680 train_time:120097ms step_avg:87.98ms +step:1366/1680 train_time:120186ms step_avg:87.98ms +step:1367/1680 train_time:120275ms step_avg:87.98ms +step:1368/1680 train_time:120364ms step_avg:87.99ms +step:1369/1680 train_time:120454ms step_avg:87.99ms +step:1370/1680 train_time:120543ms step_avg:87.99ms +step:1371/1680 train_time:120631ms step_avg:87.99ms +step:1372/1680 train_time:120720ms step_avg:87.99ms +step:1373/1680 train_time:120809ms step_avg:87.99ms +step:1374/1680 train_time:120898ms step_avg:87.99ms +step:1375/1680 train_time:120986ms step_avg:87.99ms +step:1375/1680 val_loss:3.3462 train_time:121077ms step_avg:88.06ms +step:1376/1680 train_time:121096ms step_avg:88.01ms +step:1377/1680 train_time:121169ms step_avg:87.99ms +step:1378/1680 train_time:121261ms step_avg:88.00ms +step:1379/1680 train_time:121349ms step_avg:88.00ms +step:1380/1680 train_time:121437ms step_avg:88.00ms +step:1381/1680 train_time:121525ms step_avg:88.00ms +step:1382/1680 train_time:121613ms step_avg:88.00ms +step:1383/1680 train_time:121701ms step_avg:88.00ms +step:1384/1680 train_time:121789ms step_avg:88.00ms +step:1385/1680 train_time:121877ms step_avg:88.00ms +step:1386/1680 train_time:121965ms step_avg:88.00ms +step:1387/1680 train_time:122056ms step_avg:88.00ms +step:1388/1680 train_time:122147ms step_avg:88.00ms +step:1389/1680 train_time:122238ms step_avg:88.00ms +step:1390/1680 train_time:122326ms step_avg:88.00ms +step:1391/1680 train_time:122415ms step_avg:88.00ms +step:1392/1680 train_time:122505ms step_avg:88.01ms +step:1393/1680 train_time:122593ms step_avg:88.01ms +step:1394/1680 train_time:122682ms step_avg:88.01ms +step:1395/1680 train_time:122771ms step_avg:88.01ms +step:1396/1680 train_time:122859ms step_avg:88.01ms +step:1397/1680 train_time:122947ms step_avg:88.01ms +step:1398/1680 train_time:123037ms step_avg:88.01ms +step:1399/1680 train_time:123127ms step_avg:88.01ms +step:1400/1680 train_time:123217ms step_avg:88.01ms +step:1401/1680 train_time:123305ms step_avg:88.01ms +step:1402/1680 train_time:123396ms step_avg:88.01ms +step:1403/1680 train_time:123485ms step_avg:88.01ms +step:1404/1680 train_time:123573ms step_avg:88.02ms +step:1405/1680 train_time:123661ms step_avg:88.02ms +step:1406/1680 train_time:123750ms step_avg:88.02ms +step:1407/1680 train_time:123838ms step_avg:88.02ms +step:1408/1680 train_time:123926ms step_avg:88.02ms +step:1409/1680 train_time:124015ms step_avg:88.02ms +step:1410/1680 train_time:124104ms step_avg:88.02ms +step:1411/1680 train_time:124194ms step_avg:88.02ms +step:1412/1680 train_time:124283ms step_avg:88.02ms +step:1413/1680 train_time:124373ms step_avg:88.02ms +step:1414/1680 train_time:124462ms step_avg:88.02ms +step:1415/1680 train_time:124551ms step_avg:88.02ms +step:1416/1680 train_time:124640ms step_avg:88.02ms +step:1417/1680 train_time:124729ms step_avg:88.02ms +step:1418/1680 train_time:124817ms step_avg:88.02ms +step:1419/1680 train_time:124905ms step_avg:88.02ms +step:1420/1680 train_time:124995ms step_avg:88.02ms +step:1421/1680 train_time:125085ms step_avg:88.03ms +step:1422/1680 train_time:125174ms step_avg:88.03ms +step:1423/1680 train_time:125263ms step_avg:88.03ms +step:1424/1680 train_time:125352ms step_avg:88.03ms +step:1425/1680 train_time:125441ms step_avg:88.03ms +step:1426/1680 train_time:125530ms step_avg:88.03ms +step:1427/1680 train_time:125619ms step_avg:88.03ms +step:1428/1680 train_time:125709ms step_avg:88.03ms +step:1429/1680 train_time:125797ms step_avg:88.03ms +step:1430/1680 train_time:125886ms step_avg:88.03ms +step:1431/1680 train_time:125974ms step_avg:88.03ms +step:1432/1680 train_time:126064ms step_avg:88.03ms +step:1433/1680 train_time:126154ms step_avg:88.04ms +step:1434/1680 train_time:126243ms step_avg:88.04ms +step:1435/1680 train_time:126332ms step_avg:88.04ms +step:1436/1680 train_time:126421ms step_avg:88.04ms +step:1437/1680 train_time:126510ms step_avg:88.04ms +step:1438/1680 train_time:126600ms step_avg:88.04ms +step:1439/1680 train_time:126688ms step_avg:88.04ms +step:1440/1680 train_time:126777ms step_avg:88.04ms +step:1441/1680 train_time:126866ms step_avg:88.04ms +step:1442/1680 train_time:126954ms step_avg:88.04ms +step:1443/1680 train_time:127043ms step_avg:88.04ms +step:1444/1680 train_time:127133ms step_avg:88.04ms +step:1445/1680 train_time:127223ms step_avg:88.04ms +step:1446/1680 train_time:127312ms step_avg:88.04ms +step:1447/1680 train_time:127403ms step_avg:88.05ms +step:1448/1680 train_time:127491ms step_avg:88.05ms +step:1449/1680 train_time:127580ms step_avg:88.05ms +step:1450/1680 train_time:127669ms step_avg:88.05ms +step:1451/1680 train_time:127758ms step_avg:88.05ms +step:1452/1680 train_time:127847ms step_avg:88.05ms +step:1453/1680 train_time:127936ms step_avg:88.05ms +step:1454/1680 train_time:128025ms step_avg:88.05ms +step:1455/1680 train_time:128114ms step_avg:88.05ms +step:1456/1680 train_time:128203ms step_avg:88.05ms +step:1457/1680 train_time:128293ms step_avg:88.05ms +step:1458/1680 train_time:128382ms step_avg:88.05ms +step:1459/1680 train_time:128472ms step_avg:88.05ms +step:1460/1680 train_time:128561ms step_avg:88.06ms +step:1461/1680 train_time:128650ms step_avg:88.06ms +step:1462/1680 train_time:128739ms step_avg:88.06ms +step:1463/1680 train_time:128827ms step_avg:88.06ms +step:1464/1680 train_time:128917ms step_avg:88.06ms +step:1465/1680 train_time:129005ms step_avg:88.06ms +step:1466/1680 train_time:129094ms step_avg:88.06ms +step:1467/1680 train_time:129184ms step_avg:88.06ms +step:1468/1680 train_time:129274ms step_avg:88.06ms +step:1469/1680 train_time:129364ms step_avg:88.06ms +step:1470/1680 train_time:129453ms step_avg:88.06ms +step:1471/1680 train_time:129541ms step_avg:88.06ms +step:1472/1680 train_time:129630ms step_avg:88.06ms +step:1473/1680 train_time:129719ms step_avg:88.06ms +step:1474/1680 train_time:129808ms step_avg:88.07ms +step:1475/1680 train_time:129897ms step_avg:88.07ms +step:1476/1680 train_time:129986ms step_avg:88.07ms +step:1477/1680 train_time:130074ms step_avg:88.07ms +step:1478/1680 train_time:130163ms step_avg:88.07ms +step:1479/1680 train_time:130252ms step_avg:88.07ms +step:1480/1680 train_time:130342ms step_avg:88.07ms +step:1481/1680 train_time:130431ms step_avg:88.07ms +step:1482/1680 train_time:130521ms step_avg:88.07ms +step:1483/1680 train_time:130610ms step_avg:88.07ms +step:1484/1680 train_time:130699ms step_avg:88.07ms +step:1485/1680 train_time:130788ms step_avg:88.07ms +step:1486/1680 train_time:130877ms step_avg:88.07ms +step:1487/1680 train_time:130966ms step_avg:88.07ms +step:1488/1680 train_time:131055ms step_avg:88.07ms +step:1489/1680 train_time:131143ms step_avg:88.07ms +step:1490/1680 train_time:131233ms step_avg:88.08ms +step:1491/1680 train_time:131322ms step_avg:88.08ms +step:1492/1680 train_time:131411ms step_avg:88.08ms +step:1493/1680 train_time:131500ms step_avg:88.08ms +step:1494/1680 train_time:131589ms step_avg:88.08ms +step:1495/1680 train_time:131678ms step_avg:88.08ms +step:1496/1680 train_time:131767ms step_avg:88.08ms +step:1497/1680 train_time:131856ms step_avg:88.08ms +step:1498/1680 train_time:131945ms step_avg:88.08ms +step:1499/1680 train_time:132034ms step_avg:88.08ms +step:1500/1680 train_time:132123ms step_avg:88.08ms +step:1500/1680 val_loss:3.3168 train_time:132214ms step_avg:88.14ms +step:1501/1680 train_time:132232ms step_avg:88.10ms +step:1502/1680 train_time:132306ms step_avg:88.09ms +step:1503/1680 train_time:132399ms step_avg:88.09ms +step:1504/1680 train_time:132488ms step_avg:88.09ms +step:1505/1680 train_time:132576ms step_avg:88.09ms +step:1506/1680 train_time:132664ms step_avg:88.09ms +step:1507/1680 train_time:132752ms step_avg:88.09ms +step:1508/1680 train_time:132839ms step_avg:88.09ms +step:1509/1680 train_time:132927ms step_avg:88.09ms +step:1510/1680 train_time:133016ms step_avg:88.09ms +step:1511/1680 train_time:133104ms step_avg:88.09ms +step:1512/1680 train_time:133194ms step_avg:88.09ms +step:1513/1680 train_time:133285ms step_avg:88.09ms +step:1514/1680 train_time:133376ms step_avg:88.10ms +step:1515/1680 train_time:133468ms step_avg:88.10ms +step:1516/1680 train_time:133557ms step_avg:88.10ms +step:1517/1680 train_time:133645ms step_avg:88.10ms +step:1518/1680 train_time:133734ms step_avg:88.10ms +step:1519/1680 train_time:133823ms step_avg:88.10ms +step:1520/1680 train_time:133911ms step_avg:88.10ms +step:1521/1680 train_time:133999ms step_avg:88.10ms +step:1522/1680 train_time:134087ms step_avg:88.10ms +step:1523/1680 train_time:134176ms step_avg:88.10ms +step:1524/1680 train_time:134267ms step_avg:88.10ms +step:1525/1680 train_time:134357ms step_avg:88.10ms +step:1526/1680 train_time:134448ms step_avg:88.11ms +step:1527/1680 train_time:134537ms step_avg:88.11ms +step:1528/1680 train_time:134626ms step_avg:88.11ms +step:1529/1680 train_time:134715ms step_avg:88.11ms +step:1530/1680 train_time:134804ms step_avg:88.11ms +step:1531/1680 train_time:134893ms step_avg:88.11ms +step:1532/1680 train_time:134981ms step_avg:88.11ms +step:1533/1680 train_time:135069ms step_avg:88.11ms +step:1534/1680 train_time:135157ms step_avg:88.11ms +step:1535/1680 train_time:135247ms step_avg:88.11ms +step:1536/1680 train_time:135338ms step_avg:88.11ms +step:1537/1680 train_time:135428ms step_avg:88.11ms +step:1538/1680 train_time:135518ms step_avg:88.11ms +step:1539/1680 train_time:135607ms step_avg:88.11ms +step:1540/1680 train_time:135696ms step_avg:88.11ms +step:1541/1680 train_time:135784ms step_avg:88.11ms +step:1542/1680 train_time:135873ms step_avg:88.11ms +step:1543/1680 train_time:135962ms step_avg:88.12ms +step:1544/1680 train_time:136051ms step_avg:88.12ms +step:1545/1680 train_time:136139ms step_avg:88.12ms +step:1546/1680 train_time:136228ms step_avg:88.12ms +step:1547/1680 train_time:136317ms step_avg:88.12ms +step:1548/1680 train_time:136406ms step_avg:88.12ms +step:1549/1680 train_time:136496ms step_avg:88.12ms +step:1550/1680 train_time:136585ms step_avg:88.12ms +step:1551/1680 train_time:136674ms step_avg:88.12ms +step:1552/1680 train_time:136762ms step_avg:88.12ms +step:1553/1680 train_time:136851ms step_avg:88.12ms +step:1554/1680 train_time:136940ms step_avg:88.12ms +step:1555/1680 train_time:137029ms step_avg:88.12ms +step:1556/1680 train_time:137118ms step_avg:88.12ms +step:1557/1680 train_time:137207ms step_avg:88.12ms +step:1558/1680 train_time:137296ms step_avg:88.12ms +step:1559/1680 train_time:137386ms step_avg:88.12ms +step:1560/1680 train_time:137475ms step_avg:88.13ms +step:1561/1680 train_time:137565ms step_avg:88.13ms +step:1562/1680 train_time:137654ms step_avg:88.13ms +step:1563/1680 train_time:137743ms step_avg:88.13ms +step:1564/1680 train_time:137831ms step_avg:88.13ms +step:1565/1680 train_time:137921ms step_avg:88.13ms +step:1566/1680 train_time:138009ms step_avg:88.13ms +step:1567/1680 train_time:138098ms step_avg:88.13ms +step:1568/1680 train_time:138186ms step_avg:88.13ms +step:1569/1680 train_time:138275ms step_avg:88.13ms +step:1570/1680 train_time:138365ms step_avg:88.13ms +step:1571/1680 train_time:138454ms step_avg:88.13ms +step:1572/1680 train_time:138543ms step_avg:88.13ms +step:1573/1680 train_time:138632ms step_avg:88.13ms +step:1574/1680 train_time:138722ms step_avg:88.13ms +step:1575/1680 train_time:138811ms step_avg:88.13ms +step:1576/1680 train_time:138900ms step_avg:88.13ms +step:1577/1680 train_time:138988ms step_avg:88.13ms +step:1578/1680 train_time:139078ms step_avg:88.14ms +step:1579/1680 train_time:139166ms step_avg:88.14ms +step:1580/1680 train_time:139256ms step_avg:88.14ms +step:1581/1680 train_time:139345ms step_avg:88.14ms +step:1582/1680 train_time:139435ms step_avg:88.14ms +step:1583/1680 train_time:139524ms step_avg:88.14ms +step:1584/1680 train_time:139613ms step_avg:88.14ms +step:1585/1680 train_time:139704ms step_avg:88.14ms +step:1586/1680 train_time:139792ms step_avg:88.14ms +step:1587/1680 train_time:139881ms step_avg:88.14ms +step:1588/1680 train_time:139970ms step_avg:88.14ms +step:1589/1680 train_time:140058ms step_avg:88.14ms +step:1590/1680 train_time:140146ms step_avg:88.14ms +step:1591/1680 train_time:140236ms step_avg:88.14ms +step:1592/1680 train_time:140324ms step_avg:88.14ms +step:1593/1680 train_time:140414ms step_avg:88.14ms +step:1594/1680 train_time:140503ms step_avg:88.14ms +step:1595/1680 train_time:140593ms step_avg:88.15ms +step:1596/1680 train_time:140682ms step_avg:88.15ms +step:1597/1680 train_time:140771ms step_avg:88.15ms +step:1598/1680 train_time:140860ms step_avg:88.15ms +step:1599/1680 train_time:140949ms step_avg:88.15ms +step:1600/1680 train_time:141038ms step_avg:88.15ms +step:1601/1680 train_time:141127ms step_avg:88.15ms +step:1602/1680 train_time:141216ms step_avg:88.15ms +step:1603/1680 train_time:141305ms step_avg:88.15ms +step:1604/1680 train_time:141394ms step_avg:88.15ms +step:1605/1680 train_time:141483ms step_avg:88.15ms +step:1606/1680 train_time:141572ms step_avg:88.15ms +step:1607/1680 train_time:141661ms step_avg:88.15ms +step:1608/1680 train_time:141751ms step_avg:88.15ms +step:1609/1680 train_time:141840ms step_avg:88.15ms +step:1610/1680 train_time:141929ms step_avg:88.15ms +step:1611/1680 train_time:142018ms step_avg:88.16ms +step:1612/1680 train_time:142107ms step_avg:88.16ms +step:1613/1680 train_time:142196ms step_avg:88.16ms +step:1614/1680 train_time:142285ms step_avg:88.16ms +step:1615/1680 train_time:142374ms step_avg:88.16ms +step:1616/1680 train_time:142463ms step_avg:88.16ms +step:1617/1680 train_time:142553ms step_avg:88.16ms +step:1618/1680 train_time:142643ms step_avg:88.16ms +step:1619/1680 train_time:142733ms step_avg:88.16ms +step:1620/1680 train_time:142822ms step_avg:88.16ms +step:1621/1680 train_time:142912ms step_avg:88.16ms +step:1622/1680 train_time:143001ms step_avg:88.16ms +step:1623/1680 train_time:143089ms step_avg:88.16ms +step:1624/1680 train_time:143178ms step_avg:88.16ms +step:1625/1680 train_time:143267ms step_avg:88.16ms +step:1625/1680 val_loss:3.2933 train_time:143357ms step_avg:88.22ms +step:1626/1680 train_time:143377ms step_avg:88.18ms +step:1627/1680 train_time:143450ms step_avg:88.17ms +step:1628/1680 train_time:143541ms step_avg:88.17ms +step:1629/1680 train_time:143631ms step_avg:88.17ms +step:1630/1680 train_time:143719ms step_avg:88.17ms +step:1631/1680 train_time:143807ms step_avg:88.17ms +step:1632/1680 train_time:143896ms step_avg:88.17ms +step:1633/1680 train_time:143984ms step_avg:88.17ms +step:1634/1680 train_time:144072ms step_avg:88.17ms +step:1635/1680 train_time:144160ms step_avg:88.17ms +step:1636/1680 train_time:144249ms step_avg:88.17ms +step:1637/1680 train_time:144338ms step_avg:88.17ms +step:1638/1680 train_time:144429ms step_avg:88.17ms +step:1639/1680 train_time:144519ms step_avg:88.18ms +step:1640/1680 train_time:144609ms step_avg:88.18ms +step:1641/1680 train_time:144700ms step_avg:88.18ms +step:1642/1680 train_time:144789ms step_avg:88.18ms +step:1643/1680 train_time:144878ms step_avg:88.18ms +step:1644/1680 train_time:144966ms step_avg:88.18ms +step:1645/1680 train_time:145056ms step_avg:88.18ms +step:1646/1680 train_time:145143ms step_avg:88.18ms +step:1647/1680 train_time:145232ms step_avg:88.18ms +step:1648/1680 train_time:145321ms step_avg:88.18ms +step:1649/1680 train_time:145410ms step_avg:88.18ms +step:1650/1680 train_time:145500ms step_avg:88.18ms +step:1651/1680 train_time:145590ms step_avg:88.18ms +step:1652/1680 train_time:145680ms step_avg:88.18ms +step:1653/1680 train_time:145769ms step_avg:88.18ms +step:1654/1680 train_time:145857ms step_avg:88.18ms +step:1655/1680 train_time:145946ms step_avg:88.19ms +step:1656/1680 train_time:146035ms step_avg:88.19ms +step:1657/1680 train_time:146124ms step_avg:88.19ms +step:1658/1680 train_time:146211ms step_avg:88.19ms +step:1659/1680 train_time:146301ms step_avg:88.19ms +step:1660/1680 train_time:146390ms step_avg:88.19ms +step:1661/1680 train_time:146480ms step_avg:88.19ms +step:1662/1680 train_time:146570ms step_avg:88.19ms +step:1663/1680 train_time:146659ms step_avg:88.19ms +step:1664/1680 train_time:146747ms step_avg:88.19ms +step:1665/1680 train_time:146836ms step_avg:88.19ms +step:1666/1680 train_time:146925ms step_avg:88.19ms +step:1667/1680 train_time:147014ms step_avg:88.19ms +step:1668/1680 train_time:147103ms step_avg:88.19ms +step:1669/1680 train_time:147192ms step_avg:88.19ms +step:1670/1680 train_time:147280ms step_avg:88.19ms +step:1671/1680 train_time:147369ms step_avg:88.19ms +step:1672/1680 train_time:147459ms step_avg:88.19ms +step:1673/1680 train_time:147549ms step_avg:88.19ms +step:1674/1680 train_time:147638ms step_avg:88.19ms +step:1675/1680 train_time:147728ms step_avg:88.20ms +step:1676/1680 train_time:147816ms step_avg:88.20ms +step:1677/1680 train_time:147905ms step_avg:88.20ms +step:1678/1680 train_time:147994ms step_avg:88.20ms +step:1679/1680 train_time:148083ms step_avg:88.20ms +step:1680/1680 train_time:148172ms step_avg:88.20ms +step:1680/1680 val_loss:3.2828 train_time:148262ms step_avg:88.25ms +peak memory allocated: 30760 MiB reserved: 45794 MiB diff --git a/records/092725_BF16CE/5c44ff06-998a-4310-af6e-f0f5441452f4.txt b/records/092725_BF16CE/5c44ff06-998a-4310-af6e-f0f5441452f4.txt new file mode 100644 index 000000000..334291fd8 --- /dev/null +++ b/records/092725_BF16CE/5c44ff06-998a-4310-af6e-f0f5441452f4.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:32:31 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 176192 C /usr/bin/python 0MiB | +| 0 N/A N/A 176193 C /usr/bin/python 0MiB | +| 0 N/A N/A 176194 C /usr/bin/python 0MiB | +| 0 N/A N/A 176195 C /usr/bin/python 0MiB | +| 0 N/A N/A 176196 C /usr/bin/python 0MiB | +| 0 N/A N/A 176197 C /usr/bin/python 0MiB | +| 0 N/A N/A 176198 C /usr/bin/python 0MiB | +| 0 N/A N/A 176199 C /usr/bin/python 0MiB | +| 1 N/A N/A 176193 C /usr/bin/python 0MiB | +| 2 N/A N/A 176194 C /usr/bin/python 0MiB | +| 3 N/A N/A 176195 C /usr/bin/python 0MiB | +| 4 N/A N/A 176196 C /usr/bin/python 0MiB | +| 5 N/A N/A 176197 C /usr/bin/python 0MiB | +| 6 N/A N/A 176198 C /usr/bin/python 0MiB | +| 7 N/A N/A 176199 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:146ms step_avg:145.82ms +step:2/1680 train_time:165ms step_avg:82.52ms +step:3/1680 train_time:230ms step_avg:76.64ms +step:4/1680 train_time:315ms step_avg:78.70ms +step:5/1680 train_time:401ms step_avg:80.25ms +step:6/1680 train_time:487ms step_avg:81.18ms +step:7/1680 train_time:574ms step_avg:81.94ms +step:8/1680 train_time:660ms step_avg:82.46ms +step:9/1680 train_time:746ms step_avg:82.86ms +step:10/1680 train_time:832ms step_avg:83.17ms +step:11/1680 train_time:918ms step_avg:83.42ms +step:12/1680 train_time:1005ms step_avg:83.76ms +step:13/1680 train_time:1094ms step_avg:84.18ms +step:14/1680 train_time:1185ms step_avg:84.61ms +step:15/1680 train_time:1272ms step_avg:84.83ms +step:16/1680 train_time:1360ms step_avg:85.02ms +step:17/1680 train_time:1447ms step_avg:85.13ms +step:18/1680 train_time:1535ms step_avg:85.26ms +step:19/1680 train_time:1621ms step_avg:85.34ms +step:20/1680 train_time:1708ms step_avg:85.39ms +step:21/1680 train_time:1794ms step_avg:85.42ms +step:22/1680 train_time:1880ms step_avg:85.47ms +step:23/1680 train_time:1967ms step_avg:85.52ms +step:24/1680 train_time:2055ms step_avg:85.62ms +step:25/1680 train_time:2144ms step_avg:85.76ms +step:26/1680 train_time:2232ms step_avg:85.86ms +step:27/1680 train_time:2320ms step_avg:85.92ms +step:28/1680 train_time:2408ms step_avg:86.00ms +step:29/1680 train_time:2495ms step_avg:86.04ms +step:30/1680 train_time:2583ms step_avg:86.09ms +step:31/1680 train_time:2670ms step_avg:86.11ms +step:32/1680 train_time:2756ms step_avg:86.11ms +step:33/1680 train_time:2842ms step_avg:86.12ms +step:34/1680 train_time:2928ms step_avg:86.13ms +step:35/1680 train_time:3015ms step_avg:86.15ms +step:36/1680 train_time:3103ms step_avg:86.20ms +step:37/1680 train_time:3192ms step_avg:86.26ms +step:38/1680 train_time:3280ms step_avg:86.32ms +step:39/1680 train_time:3368ms step_avg:86.36ms +step:40/1680 train_time:3455ms step_avg:86.38ms +step:41/1680 train_time:3543ms step_avg:86.41ms +step:42/1680 train_time:3629ms step_avg:86.41ms +step:43/1680 train_time:3715ms step_avg:86.41ms +step:44/1680 train_time:3802ms step_avg:86.42ms +step:45/1680 train_time:3889ms step_avg:86.43ms +step:46/1680 train_time:3975ms step_avg:86.42ms +step:47/1680 train_time:4063ms step_avg:86.44ms +step:48/1680 train_time:4150ms step_avg:86.47ms +step:49/1680 train_time:4238ms step_avg:86.48ms +step:50/1680 train_time:4325ms step_avg:86.50ms +step:51/1680 train_time:4413ms step_avg:86.52ms +step:52/1680 train_time:4500ms step_avg:86.54ms +step:53/1680 train_time:4588ms step_avg:86.57ms +step:54/1680 train_time:4675ms step_avg:86.58ms +step:55/1680 train_time:4762ms step_avg:86.59ms +step:56/1680 train_time:4849ms step_avg:86.58ms +step:57/1680 train_time:4936ms step_avg:86.59ms +step:58/1680 train_time:5023ms step_avg:86.61ms +step:59/1680 train_time:5110ms step_avg:86.61ms +step:60/1680 train_time:5198ms step_avg:86.63ms +step:61/1680 train_time:5285ms step_avg:86.64ms +step:62/1680 train_time:5372ms step_avg:86.65ms +step:63/1680 train_time:5460ms step_avg:86.67ms +step:64/1680 train_time:5548ms step_avg:86.68ms +step:65/1680 train_time:5634ms step_avg:86.68ms +step:66/1680 train_time:5721ms step_avg:86.69ms +step:67/1680 train_time:5808ms step_avg:86.69ms +step:68/1680 train_time:5895ms step_avg:86.69ms +step:69/1680 train_time:5982ms step_avg:86.69ms +step:70/1680 train_time:6069ms step_avg:86.70ms +step:71/1680 train_time:6157ms step_avg:86.71ms +step:72/1680 train_time:6244ms step_avg:86.73ms +step:73/1680 train_time:6331ms step_avg:86.73ms +step:74/1680 train_time:6419ms step_avg:86.74ms +step:75/1680 train_time:6505ms step_avg:86.74ms +step:76/1680 train_time:6592ms step_avg:86.74ms +step:77/1680 train_time:6680ms step_avg:86.76ms +step:78/1680 train_time:6767ms step_avg:86.76ms +step:79/1680 train_time:6854ms step_avg:86.76ms +step:80/1680 train_time:6941ms step_avg:86.77ms +step:81/1680 train_time:7028ms step_avg:86.76ms +step:82/1680 train_time:7115ms step_avg:86.77ms +step:83/1680 train_time:7203ms step_avg:86.78ms +step:84/1680 train_time:7290ms step_avg:86.78ms +step:85/1680 train_time:7377ms step_avg:86.79ms +step:86/1680 train_time:7465ms step_avg:86.80ms +step:87/1680 train_time:7552ms step_avg:86.80ms +step:88/1680 train_time:7640ms step_avg:86.82ms +step:89/1680 train_time:7727ms step_avg:86.82ms +step:90/1680 train_time:7814ms step_avg:86.82ms +step:91/1680 train_time:7901ms step_avg:86.83ms +step:92/1680 train_time:7988ms step_avg:86.83ms +step:93/1680 train_time:8075ms step_avg:86.83ms +step:94/1680 train_time:8162ms step_avg:86.83ms +step:95/1680 train_time:8250ms step_avg:86.84ms +step:96/1680 train_time:8337ms step_avg:86.85ms +step:97/1680 train_time:8425ms step_avg:86.85ms +step:98/1680 train_time:8512ms step_avg:86.85ms +step:99/1680 train_time:8600ms step_avg:86.86ms +step:100/1680 train_time:8686ms step_avg:86.86ms +step:101/1680 train_time:8774ms step_avg:86.87ms +step:102/1680 train_time:8861ms step_avg:86.87ms +step:103/1680 train_time:8947ms step_avg:86.87ms +step:104/1680 train_time:9034ms step_avg:86.87ms +step:105/1680 train_time:9122ms step_avg:86.87ms +step:106/1680 train_time:9208ms step_avg:86.87ms +step:107/1680 train_time:9295ms step_avg:86.87ms +step:108/1680 train_time:9383ms step_avg:86.88ms +step:109/1680 train_time:9470ms step_avg:86.88ms +step:110/1680 train_time:9556ms step_avg:86.88ms +step:111/1680 train_time:9644ms step_avg:86.88ms +step:112/1680 train_time:9730ms step_avg:86.88ms +step:113/1680 train_time:9818ms step_avg:86.89ms +step:114/1680 train_time:9905ms step_avg:86.89ms +step:115/1680 train_time:9992ms step_avg:86.89ms +step:116/1680 train_time:10079ms step_avg:86.89ms +step:117/1680 train_time:10166ms step_avg:86.89ms +step:118/1680 train_time:10253ms step_avg:86.89ms +step:119/1680 train_time:10340ms step_avg:86.89ms +step:120/1680 train_time:10426ms step_avg:86.88ms +step:121/1680 train_time:10513ms step_avg:86.89ms +step:122/1680 train_time:10600ms step_avg:86.89ms +step:123/1680 train_time:10687ms step_avg:86.89ms +step:124/1680 train_time:10774ms step_avg:86.89ms +step:125/1680 train_time:10862ms step_avg:86.89ms +step:125/1680 val_loss:4.3222 train_time:10950ms step_avg:87.60ms +step:126/1680 train_time:10969ms step_avg:87.06ms +step:127/1680 train_time:11038ms step_avg:86.91ms +step:128/1680 train_time:11133ms step_avg:86.97ms +step:129/1680 train_time:11225ms step_avg:87.02ms +step:130/1680 train_time:11312ms step_avg:87.01ms +step:131/1680 train_time:11397ms step_avg:87.00ms +step:132/1680 train_time:11483ms step_avg:86.99ms +step:133/1680 train_time:11569ms step_avg:86.98ms +step:134/1680 train_time:11655ms step_avg:86.98ms +step:135/1680 train_time:11740ms step_avg:86.96ms +step:136/1680 train_time:11826ms step_avg:86.96ms +step:137/1680 train_time:11913ms step_avg:86.95ms +step:138/1680 train_time:12000ms step_avg:86.95ms +step:139/1680 train_time:12089ms step_avg:86.97ms +step:140/1680 train_time:12178ms step_avg:86.98ms +step:141/1680 train_time:12267ms step_avg:87.00ms +step:142/1680 train_time:12354ms step_avg:87.00ms +step:143/1680 train_time:12441ms step_avg:87.00ms +step:144/1680 train_time:12528ms step_avg:87.00ms +step:145/1680 train_time:12614ms step_avg:86.99ms +step:146/1680 train_time:12700ms step_avg:86.99ms +step:147/1680 train_time:12787ms step_avg:86.98ms +step:148/1680 train_time:12873ms step_avg:86.98ms +step:149/1680 train_time:12959ms step_avg:86.98ms +step:150/1680 train_time:13047ms step_avg:86.98ms +step:151/1680 train_time:13135ms step_avg:86.98ms +step:152/1680 train_time:13222ms step_avg:86.99ms +step:153/1680 train_time:13310ms step_avg:86.99ms +step:154/1680 train_time:13397ms step_avg:86.99ms +step:155/1680 train_time:13484ms step_avg:86.99ms +step:156/1680 train_time:13571ms step_avg:87.00ms +step:157/1680 train_time:13658ms step_avg:86.99ms +step:158/1680 train_time:13745ms step_avg:86.99ms +step:159/1680 train_time:13832ms step_avg:86.99ms +step:160/1680 train_time:13919ms step_avg:86.99ms +step:161/1680 train_time:14006ms step_avg:86.99ms +step:162/1680 train_time:14094ms step_avg:87.00ms +step:163/1680 train_time:14180ms step_avg:87.00ms +step:164/1680 train_time:14268ms step_avg:87.00ms +step:165/1680 train_time:14355ms step_avg:87.00ms +step:166/1680 train_time:14442ms step_avg:87.00ms +step:167/1680 train_time:14529ms step_avg:87.00ms +step:168/1680 train_time:14616ms step_avg:87.00ms +step:169/1680 train_time:14703ms step_avg:87.00ms +step:170/1680 train_time:14791ms step_avg:87.00ms +step:171/1680 train_time:14877ms step_avg:87.00ms +step:172/1680 train_time:14964ms step_avg:87.00ms +step:173/1680 train_time:15052ms step_avg:87.01ms +step:174/1680 train_time:15139ms step_avg:87.00ms +step:175/1680 train_time:15226ms step_avg:87.01ms +step:176/1680 train_time:15313ms step_avg:87.01ms +step:177/1680 train_time:15400ms step_avg:87.01ms +step:178/1680 train_time:15487ms step_avg:87.01ms +step:179/1680 train_time:15574ms step_avg:87.01ms +step:180/1680 train_time:15661ms step_avg:87.01ms +step:181/1680 train_time:15747ms step_avg:87.00ms +step:182/1680 train_time:15834ms step_avg:87.00ms +step:183/1680 train_time:15920ms step_avg:87.00ms +step:184/1680 train_time:16007ms step_avg:87.00ms +step:185/1680 train_time:16094ms step_avg:87.00ms +step:186/1680 train_time:16181ms step_avg:86.99ms +step:187/1680 train_time:16268ms step_avg:87.00ms +step:188/1680 train_time:16355ms step_avg:87.00ms +step:189/1680 train_time:16442ms step_avg:87.00ms +step:190/1680 train_time:16529ms step_avg:87.00ms +step:191/1680 train_time:16616ms step_avg:86.99ms +step:192/1680 train_time:16703ms step_avg:87.00ms +step:193/1680 train_time:16791ms step_avg:87.00ms +step:194/1680 train_time:16877ms step_avg:86.99ms +step:195/1680 train_time:16964ms step_avg:87.00ms +step:196/1680 train_time:17051ms step_avg:87.00ms +step:197/1680 train_time:17138ms step_avg:87.00ms +step:198/1680 train_time:17226ms step_avg:87.00ms +step:199/1680 train_time:17313ms step_avg:87.00ms +step:200/1680 train_time:17400ms step_avg:87.00ms +step:201/1680 train_time:17487ms step_avg:87.00ms +step:202/1680 train_time:17574ms step_avg:87.00ms +step:203/1680 train_time:17661ms step_avg:87.00ms +step:204/1680 train_time:17748ms step_avg:87.00ms +step:205/1680 train_time:17835ms step_avg:87.00ms +step:206/1680 train_time:17922ms step_avg:87.00ms +step:207/1680 train_time:18009ms step_avg:87.00ms +step:208/1680 train_time:18096ms step_avg:87.00ms +step:209/1680 train_time:18183ms step_avg:87.00ms +step:210/1680 train_time:18271ms step_avg:87.01ms +step:211/1680 train_time:18358ms step_avg:87.00ms +step:212/1680 train_time:18445ms step_avg:87.00ms +step:213/1680 train_time:18532ms step_avg:87.01ms +step:214/1680 train_time:18619ms step_avg:87.00ms +step:215/1680 train_time:18706ms step_avg:87.01ms +step:216/1680 train_time:18794ms step_avg:87.01ms +step:217/1680 train_time:18880ms step_avg:87.00ms +step:218/1680 train_time:18967ms step_avg:87.00ms +step:219/1680 train_time:19054ms step_avg:87.01ms +step:220/1680 train_time:19141ms step_avg:87.00ms +step:221/1680 train_time:19228ms step_avg:87.00ms +step:222/1680 train_time:19315ms step_avg:87.01ms +step:223/1680 train_time:19402ms step_avg:87.01ms +step:224/1680 train_time:19490ms step_avg:87.01ms +step:225/1680 train_time:19576ms step_avg:87.01ms +step:226/1680 train_time:19663ms step_avg:87.01ms +step:227/1680 train_time:19751ms step_avg:87.01ms +step:228/1680 train_time:19837ms step_avg:87.01ms +step:229/1680 train_time:19925ms step_avg:87.01ms +step:230/1680 train_time:20012ms step_avg:87.01ms +step:231/1680 train_time:20099ms step_avg:87.01ms +step:232/1680 train_time:20185ms step_avg:87.01ms +step:233/1680 train_time:20272ms step_avg:87.00ms +step:234/1680 train_time:20358ms step_avg:87.00ms +step:235/1680 train_time:20446ms step_avg:87.00ms +step:236/1680 train_time:20534ms step_avg:87.01ms +step:237/1680 train_time:20620ms step_avg:87.01ms +step:238/1680 train_time:20708ms step_avg:87.01ms +step:239/1680 train_time:20794ms step_avg:87.01ms +step:240/1680 train_time:20881ms step_avg:87.00ms +step:241/1680 train_time:20968ms step_avg:87.00ms +step:242/1680 train_time:21054ms step_avg:87.00ms +step:243/1680 train_time:21141ms step_avg:87.00ms +step:244/1680 train_time:21227ms step_avg:87.00ms +step:245/1680 train_time:21314ms step_avg:87.00ms +step:246/1680 train_time:21401ms step_avg:87.00ms +step:247/1680 train_time:21489ms step_avg:87.00ms +step:248/1680 train_time:21576ms step_avg:87.00ms +step:249/1680 train_time:21663ms step_avg:87.00ms +step:250/1680 train_time:21751ms step_avg:87.00ms +step:250/1680 val_loss:3.9704 train_time:21839ms step_avg:87.35ms +step:251/1680 train_time:21857ms step_avg:87.08ms +step:252/1680 train_time:21929ms step_avg:87.02ms +step:253/1680 train_time:22019ms step_avg:87.03ms +step:254/1680 train_time:22106ms step_avg:87.03ms +step:255/1680 train_time:22194ms step_avg:87.03ms +step:256/1680 train_time:22279ms step_avg:87.03ms +step:257/1680 train_time:22365ms step_avg:87.02ms +step:258/1680 train_time:22452ms step_avg:87.02ms +step:259/1680 train_time:22538ms step_avg:87.02ms +step:260/1680 train_time:22625ms step_avg:87.02ms +step:261/1680 train_time:22711ms step_avg:87.02ms +step:262/1680 train_time:22798ms step_avg:87.02ms +step:263/1680 train_time:22887ms step_avg:87.02ms +step:264/1680 train_time:22976ms step_avg:87.03ms +step:265/1680 train_time:23064ms step_avg:87.03ms +step:266/1680 train_time:23151ms step_avg:87.04ms +step:267/1680 train_time:23238ms step_avg:87.03ms +step:268/1680 train_time:23325ms step_avg:87.03ms +step:269/1680 train_time:23411ms step_avg:87.03ms +step:270/1680 train_time:23497ms step_avg:87.03ms +step:271/1680 train_time:23584ms step_avg:87.02ms +step:272/1680 train_time:23671ms step_avg:87.02ms +step:273/1680 train_time:23758ms step_avg:87.02ms +step:274/1680 train_time:23845ms step_avg:87.03ms +step:275/1680 train_time:23932ms step_avg:87.03ms +step:276/1680 train_time:24020ms step_avg:87.03ms +step:277/1680 train_time:24107ms step_avg:87.03ms +step:278/1680 train_time:24195ms step_avg:87.03ms +step:279/1680 train_time:24281ms step_avg:87.03ms +step:280/1680 train_time:24368ms step_avg:87.03ms +step:281/1680 train_time:24454ms step_avg:87.03ms +step:282/1680 train_time:24540ms step_avg:87.02ms +step:283/1680 train_time:24627ms step_avg:87.02ms +step:284/1680 train_time:24713ms step_avg:87.02ms +step:285/1680 train_time:24800ms step_avg:87.02ms +step:286/1680 train_time:24888ms step_avg:87.02ms +step:287/1680 train_time:24975ms step_avg:87.02ms +step:288/1680 train_time:25063ms step_avg:87.02ms +step:289/1680 train_time:25150ms step_avg:87.03ms +step:290/1680 train_time:25237ms step_avg:87.02ms +step:291/1680 train_time:25324ms step_avg:87.03ms +step:292/1680 train_time:25411ms step_avg:87.03ms +step:293/1680 train_time:25498ms step_avg:87.02ms +step:294/1680 train_time:25585ms step_avg:87.02ms +step:295/1680 train_time:25673ms step_avg:87.03ms +step:296/1680 train_time:25759ms step_avg:87.02ms +step:297/1680 train_time:25847ms step_avg:87.03ms +step:298/1680 train_time:25934ms step_avg:87.03ms +step:299/1680 train_time:26021ms step_avg:87.03ms +step:300/1680 train_time:26108ms step_avg:87.03ms +step:301/1680 train_time:26195ms step_avg:87.03ms +step:302/1680 train_time:26282ms step_avg:87.03ms +step:303/1680 train_time:26369ms step_avg:87.03ms +step:304/1680 train_time:26456ms step_avg:87.02ms +step:305/1680 train_time:26542ms step_avg:87.02ms +step:306/1680 train_time:26628ms step_avg:87.02ms +step:307/1680 train_time:26715ms step_avg:87.02ms +step:308/1680 train_time:26802ms step_avg:87.02ms +step:309/1680 train_time:26889ms step_avg:87.02ms +step:310/1680 train_time:26976ms step_avg:87.02ms +step:311/1680 train_time:27063ms step_avg:87.02ms +step:312/1680 train_time:27151ms step_avg:87.02ms +step:313/1680 train_time:27238ms step_avg:87.02ms +step:314/1680 train_time:27325ms step_avg:87.02ms +step:315/1680 train_time:27412ms step_avg:87.02ms +step:316/1680 train_time:27499ms step_avg:87.02ms +step:317/1680 train_time:27585ms step_avg:87.02ms +step:318/1680 train_time:27672ms step_avg:87.02ms +step:319/1680 train_time:27759ms step_avg:87.02ms +step:320/1680 train_time:27846ms step_avg:87.02ms +step:321/1680 train_time:27933ms step_avg:87.02ms +step:322/1680 train_time:28019ms step_avg:87.02ms +step:323/1680 train_time:28107ms step_avg:87.02ms +step:324/1680 train_time:28194ms step_avg:87.02ms +step:325/1680 train_time:28281ms step_avg:87.02ms +step:326/1680 train_time:28368ms step_avg:87.02ms +step:327/1680 train_time:28455ms step_avg:87.02ms +step:328/1680 train_time:28541ms step_avg:87.01ms +step:329/1680 train_time:28628ms step_avg:87.01ms +step:330/1680 train_time:28715ms step_avg:87.02ms +step:331/1680 train_time:28802ms step_avg:87.01ms +step:332/1680 train_time:28889ms step_avg:87.01ms +step:333/1680 train_time:28976ms step_avg:87.01ms +step:334/1680 train_time:29063ms step_avg:87.01ms +step:335/1680 train_time:29150ms step_avg:87.01ms +step:336/1680 train_time:29237ms step_avg:87.01ms +step:337/1680 train_time:29324ms step_avg:87.01ms +step:338/1680 train_time:29410ms step_avg:87.01ms +step:339/1680 train_time:29497ms step_avg:87.01ms +step:340/1680 train_time:29584ms step_avg:87.01ms +step:341/1680 train_time:29672ms step_avg:87.01ms +step:342/1680 train_time:29758ms step_avg:87.01ms +step:343/1680 train_time:29845ms step_avg:87.01ms +step:344/1680 train_time:29933ms step_avg:87.01ms +step:345/1680 train_time:30020ms step_avg:87.01ms +step:346/1680 train_time:30107ms step_avg:87.01ms +step:347/1680 train_time:30194ms step_avg:87.01ms +step:348/1680 train_time:30281ms step_avg:87.01ms +step:349/1680 train_time:30368ms step_avg:87.01ms +step:350/1680 train_time:30455ms step_avg:87.01ms +step:351/1680 train_time:30541ms step_avg:87.01ms +step:352/1680 train_time:30628ms step_avg:87.01ms +step:353/1680 train_time:30715ms step_avg:87.01ms +step:354/1680 train_time:30802ms step_avg:87.01ms +step:355/1680 train_time:30889ms step_avg:87.01ms +step:356/1680 train_time:30977ms step_avg:87.01ms +step:357/1680 train_time:31064ms step_avg:87.01ms +step:358/1680 train_time:31152ms step_avg:87.02ms +step:359/1680 train_time:31238ms step_avg:87.01ms +step:360/1680 train_time:31325ms step_avg:87.01ms +step:361/1680 train_time:31412ms step_avg:87.01ms +step:362/1680 train_time:31499ms step_avg:87.01ms +step:363/1680 train_time:31586ms step_avg:87.01ms +step:364/1680 train_time:31674ms step_avg:87.02ms +step:365/1680 train_time:31760ms step_avg:87.01ms +step:366/1680 train_time:31847ms step_avg:87.01ms +step:367/1680 train_time:31933ms step_avg:87.01ms +step:368/1680 train_time:32020ms step_avg:87.01ms +step:369/1680 train_time:32107ms step_avg:87.01ms +step:370/1680 train_time:32195ms step_avg:87.01ms +step:371/1680 train_time:32281ms step_avg:87.01ms +step:372/1680 train_time:32368ms step_avg:87.01ms +step:373/1680 train_time:32455ms step_avg:87.01ms +step:374/1680 train_time:32541ms step_avg:87.01ms +step:375/1680 train_time:32628ms step_avg:87.01ms +step:375/1680 val_loss:3.8212 train_time:32716ms step_avg:87.24ms +step:376/1680 train_time:32736ms step_avg:87.06ms +step:377/1680 train_time:32805ms step_avg:87.02ms +step:378/1680 train_time:32895ms step_avg:87.02ms +step:379/1680 train_time:32983ms step_avg:87.03ms +step:380/1680 train_time:33070ms step_avg:87.03ms +step:381/1680 train_time:33157ms step_avg:87.03ms +step:382/1680 train_time:33245ms step_avg:87.03ms +step:383/1680 train_time:33331ms step_avg:87.03ms +step:384/1680 train_time:33417ms step_avg:87.02ms +step:385/1680 train_time:33504ms step_avg:87.02ms +step:386/1680 train_time:33590ms step_avg:87.02ms +step:387/1680 train_time:33677ms step_avg:87.02ms +step:388/1680 train_time:33764ms step_avg:87.02ms +step:389/1680 train_time:33853ms step_avg:87.03ms +step:390/1680 train_time:33941ms step_avg:87.03ms +step:391/1680 train_time:34028ms step_avg:87.03ms +step:392/1680 train_time:34115ms step_avg:87.03ms +step:393/1680 train_time:34202ms step_avg:87.03ms +step:394/1680 train_time:34289ms step_avg:87.03ms +step:395/1680 train_time:34375ms step_avg:87.03ms +step:396/1680 train_time:34462ms step_avg:87.03ms +step:397/1680 train_time:34548ms step_avg:87.02ms +step:398/1680 train_time:34635ms step_avg:87.02ms +step:399/1680 train_time:34722ms step_avg:87.02ms +step:400/1680 train_time:34811ms step_avg:87.03ms +step:401/1680 train_time:34899ms step_avg:87.03ms +step:402/1680 train_time:34987ms step_avg:87.03ms +step:403/1680 train_time:35074ms step_avg:87.03ms +step:404/1680 train_time:35161ms step_avg:87.03ms +step:405/1680 train_time:35249ms step_avg:87.03ms +step:406/1680 train_time:35336ms step_avg:87.03ms +step:407/1680 train_time:35423ms step_avg:87.03ms +step:408/1680 train_time:35509ms step_avg:87.03ms +step:409/1680 train_time:35596ms step_avg:87.03ms +step:410/1680 train_time:35682ms step_avg:87.03ms +step:411/1680 train_time:35769ms step_avg:87.03ms +step:412/1680 train_time:35857ms step_avg:87.03ms +step:413/1680 train_time:35944ms step_avg:87.03ms +step:414/1680 train_time:36031ms step_avg:87.03ms +step:415/1680 train_time:36118ms step_avg:87.03ms +step:416/1680 train_time:36206ms step_avg:87.03ms +step:417/1680 train_time:36293ms step_avg:87.03ms +step:418/1680 train_time:36380ms step_avg:87.03ms +step:419/1680 train_time:36466ms step_avg:87.03ms +step:420/1680 train_time:36553ms step_avg:87.03ms +step:421/1680 train_time:36639ms step_avg:87.03ms +step:422/1680 train_time:36726ms step_avg:87.03ms +step:423/1680 train_time:36814ms step_avg:87.03ms +step:424/1680 train_time:36900ms step_avg:87.03ms +step:425/1680 train_time:36988ms step_avg:87.03ms +step:426/1680 train_time:37075ms step_avg:87.03ms +step:427/1680 train_time:37162ms step_avg:87.03ms +step:428/1680 train_time:37250ms step_avg:87.03ms +step:429/1680 train_time:37336ms step_avg:87.03ms +step:430/1680 train_time:37423ms step_avg:87.03ms +step:431/1680 train_time:37510ms step_avg:87.03ms +step:432/1680 train_time:37597ms step_avg:87.03ms +step:433/1680 train_time:37684ms step_avg:87.03ms +step:434/1680 train_time:37770ms step_avg:87.03ms +step:435/1680 train_time:37857ms step_avg:87.03ms +step:436/1680 train_time:37945ms step_avg:87.03ms +step:437/1680 train_time:38033ms step_avg:87.03ms +step:438/1680 train_time:38120ms step_avg:87.03ms +step:439/1680 train_time:38207ms step_avg:87.03ms +step:440/1680 train_time:38294ms step_avg:87.03ms +step:441/1680 train_time:38381ms step_avg:87.03ms +step:442/1680 train_time:38468ms step_avg:87.03ms +step:443/1680 train_time:38556ms step_avg:87.03ms +step:444/1680 train_time:38643ms step_avg:87.03ms +step:445/1680 train_time:38729ms step_avg:87.03ms +step:446/1680 train_time:38815ms step_avg:87.03ms +step:447/1680 train_time:38902ms step_avg:87.03ms +step:448/1680 train_time:38990ms step_avg:87.03ms +step:449/1680 train_time:39078ms step_avg:87.03ms +step:450/1680 train_time:39165ms step_avg:87.03ms +step:451/1680 train_time:39252ms step_avg:87.03ms +step:452/1680 train_time:39338ms step_avg:87.03ms +step:453/1680 train_time:39425ms step_avg:87.03ms +step:454/1680 train_time:39513ms step_avg:87.03ms +step:455/1680 train_time:39599ms step_avg:87.03ms +step:456/1680 train_time:39686ms step_avg:87.03ms +step:457/1680 train_time:39773ms step_avg:87.03ms +step:458/1680 train_time:39860ms step_avg:87.03ms +step:459/1680 train_time:39948ms step_avg:87.03ms +step:460/1680 train_time:40036ms step_avg:87.03ms +step:461/1680 train_time:40123ms step_avg:87.03ms +step:462/1680 train_time:40209ms step_avg:87.03ms +step:463/1680 train_time:40295ms step_avg:87.03ms +step:464/1680 train_time:40382ms step_avg:87.03ms +step:465/1680 train_time:40470ms step_avg:87.03ms +step:466/1680 train_time:40557ms step_avg:87.03ms +step:467/1680 train_time:40643ms step_avg:87.03ms +step:468/1680 train_time:40730ms step_avg:87.03ms +step:469/1680 train_time:40817ms step_avg:87.03ms +step:470/1680 train_time:40904ms step_avg:87.03ms +step:471/1680 train_time:40990ms step_avg:87.03ms +step:472/1680 train_time:41077ms step_avg:87.03ms +step:473/1680 train_time:41165ms step_avg:87.03ms +step:474/1680 train_time:41252ms step_avg:87.03ms +step:475/1680 train_time:41339ms step_avg:87.03ms +step:476/1680 train_time:41426ms step_avg:87.03ms +step:477/1680 train_time:41514ms step_avg:87.03ms +step:478/1680 train_time:41601ms step_avg:87.03ms +step:479/1680 train_time:41688ms step_avg:87.03ms +step:480/1680 train_time:41776ms step_avg:87.03ms +step:481/1680 train_time:41862ms step_avg:87.03ms +step:482/1680 train_time:41949ms step_avg:87.03ms +step:483/1680 train_time:42036ms step_avg:87.03ms +step:484/1680 train_time:42123ms step_avg:87.03ms +step:485/1680 train_time:42211ms step_avg:87.03ms +step:486/1680 train_time:42298ms step_avg:87.03ms +step:487/1680 train_time:42385ms step_avg:87.03ms +step:488/1680 train_time:42472ms step_avg:87.03ms +step:489/1680 train_time:42559ms step_avg:87.03ms +step:490/1680 train_time:42646ms step_avg:87.03ms +step:491/1680 train_time:42733ms step_avg:87.03ms +step:492/1680 train_time:42819ms step_avg:87.03ms +step:493/1680 train_time:42906ms step_avg:87.03ms +step:494/1680 train_time:42993ms step_avg:87.03ms +step:495/1680 train_time:43080ms step_avg:87.03ms +step:496/1680 train_time:43167ms step_avg:87.03ms +step:497/1680 train_time:43255ms step_avg:87.03ms +step:498/1680 train_time:43341ms step_avg:87.03ms +step:499/1680 train_time:43429ms step_avg:87.03ms +step:500/1680 train_time:43516ms step_avg:87.03ms +step:500/1680 val_loss:3.7171 train_time:43604ms step_avg:87.21ms +step:501/1680 train_time:43622ms step_avg:87.07ms +step:502/1680 train_time:43694ms step_avg:87.04ms +step:503/1680 train_time:43784ms step_avg:87.04ms +step:504/1680 train_time:43871ms step_avg:87.05ms +step:505/1680 train_time:43958ms step_avg:87.05ms +step:506/1680 train_time:44043ms step_avg:87.04ms +step:507/1680 train_time:44130ms step_avg:87.04ms +step:508/1680 train_time:44216ms step_avg:87.04ms +step:509/1680 train_time:44302ms step_avg:87.04ms +step:510/1680 train_time:44389ms step_avg:87.04ms +step:511/1680 train_time:44475ms step_avg:87.04ms +step:512/1680 train_time:44562ms step_avg:87.04ms +step:513/1680 train_time:44651ms step_avg:87.04ms +step:514/1680 train_time:44739ms step_avg:87.04ms +step:515/1680 train_time:44828ms step_avg:87.04ms +step:516/1680 train_time:44916ms step_avg:87.05ms +step:517/1680 train_time:45002ms step_avg:87.04ms +step:518/1680 train_time:45089ms step_avg:87.04ms +step:519/1680 train_time:45175ms step_avg:87.04ms +step:520/1680 train_time:45261ms step_avg:87.04ms +step:521/1680 train_time:45347ms step_avg:87.04ms +step:522/1680 train_time:45434ms step_avg:87.04ms +step:523/1680 train_time:45521ms step_avg:87.04ms +step:524/1680 train_time:45608ms step_avg:87.04ms +step:525/1680 train_time:45696ms step_avg:87.04ms +step:526/1680 train_time:45784ms step_avg:87.04ms +step:527/1680 train_time:45871ms step_avg:87.04ms +step:528/1680 train_time:45958ms step_avg:87.04ms +step:529/1680 train_time:46045ms step_avg:87.04ms +step:530/1680 train_time:46132ms step_avg:87.04ms +step:531/1680 train_time:46219ms step_avg:87.04ms +step:532/1680 train_time:46306ms step_avg:87.04ms +step:533/1680 train_time:46393ms step_avg:87.04ms +step:534/1680 train_time:46480ms step_avg:87.04ms +step:535/1680 train_time:46567ms step_avg:87.04ms +step:536/1680 train_time:46654ms step_avg:87.04ms +step:537/1680 train_time:46742ms step_avg:87.04ms +step:538/1680 train_time:46830ms step_avg:87.04ms +step:539/1680 train_time:46917ms step_avg:87.05ms +step:540/1680 train_time:47004ms step_avg:87.04ms +step:541/1680 train_time:47091ms step_avg:87.04ms +step:542/1680 train_time:47178ms step_avg:87.04ms +step:543/1680 train_time:47264ms step_avg:87.04ms +step:544/1680 train_time:47351ms step_avg:87.04ms +step:545/1680 train_time:47438ms step_avg:87.04ms +step:546/1680 train_time:47525ms step_avg:87.04ms +step:547/1680 train_time:47612ms step_avg:87.04ms +step:548/1680 train_time:47699ms step_avg:87.04ms +step:549/1680 train_time:47788ms step_avg:87.04ms +step:550/1680 train_time:47877ms step_avg:87.05ms +step:551/1680 train_time:47965ms step_avg:87.05ms +step:552/1680 train_time:48053ms step_avg:87.05ms +step:553/1680 train_time:48141ms step_avg:87.05ms +step:554/1680 train_time:48229ms step_avg:87.06ms +step:555/1680 train_time:48318ms step_avg:87.06ms +step:556/1680 train_time:48405ms step_avg:87.06ms +step:557/1680 train_time:48493ms step_avg:87.06ms +step:558/1680 train_time:48581ms step_avg:87.06ms +step:559/1680 train_time:48670ms step_avg:87.07ms +step:560/1680 train_time:48759ms step_avg:87.07ms +step:561/1680 train_time:48847ms step_avg:87.07ms +step:562/1680 train_time:48937ms step_avg:87.08ms +step:563/1680 train_time:49025ms step_avg:87.08ms +step:564/1680 train_time:49113ms step_avg:87.08ms +step:565/1680 train_time:49200ms step_avg:87.08ms +step:566/1680 train_time:49288ms step_avg:87.08ms +step:567/1680 train_time:49377ms step_avg:87.08ms +step:568/1680 train_time:49464ms step_avg:87.08ms +step:569/1680 train_time:49552ms step_avg:87.09ms +step:570/1680 train_time:49640ms step_avg:87.09ms +step:571/1680 train_time:49729ms step_avg:87.09ms +step:572/1680 train_time:49817ms step_avg:87.09ms +step:573/1680 train_time:49905ms step_avg:87.09ms +step:574/1680 train_time:49993ms step_avg:87.10ms +step:575/1680 train_time:50082ms step_avg:87.10ms +step:576/1680 train_time:50170ms step_avg:87.10ms +step:577/1680 train_time:50258ms step_avg:87.10ms +step:578/1680 train_time:50346ms step_avg:87.10ms +step:579/1680 train_time:50434ms step_avg:87.11ms +step:580/1680 train_time:50522ms step_avg:87.11ms +step:581/1680 train_time:50611ms step_avg:87.11ms +step:582/1680 train_time:50700ms step_avg:87.11ms +step:583/1680 train_time:50788ms step_avg:87.11ms +step:584/1680 train_time:50877ms step_avg:87.12ms +step:585/1680 train_time:50964ms step_avg:87.12ms +step:586/1680 train_time:51052ms step_avg:87.12ms +step:587/1680 train_time:51140ms step_avg:87.12ms +step:588/1680 train_time:51228ms step_avg:87.12ms +step:589/1680 train_time:51316ms step_avg:87.12ms +step:590/1680 train_time:51405ms step_avg:87.13ms +step:591/1680 train_time:51493ms step_avg:87.13ms +step:592/1680 train_time:51581ms step_avg:87.13ms +step:593/1680 train_time:51669ms step_avg:87.13ms +step:594/1680 train_time:51757ms step_avg:87.13ms +step:595/1680 train_time:51845ms step_avg:87.13ms +step:596/1680 train_time:51933ms step_avg:87.14ms +step:597/1680 train_time:52021ms step_avg:87.14ms +step:598/1680 train_time:52109ms step_avg:87.14ms +step:599/1680 train_time:52198ms step_avg:87.14ms +step:600/1680 train_time:52286ms step_avg:87.14ms +step:601/1680 train_time:52374ms step_avg:87.15ms +step:602/1680 train_time:52462ms step_avg:87.15ms +step:603/1680 train_time:52551ms step_avg:87.15ms +step:604/1680 train_time:52639ms step_avg:87.15ms +step:605/1680 train_time:52727ms step_avg:87.15ms +step:606/1680 train_time:52815ms step_avg:87.15ms +step:607/1680 train_time:52903ms step_avg:87.16ms +step:608/1680 train_time:52991ms step_avg:87.16ms +step:609/1680 train_time:53079ms step_avg:87.16ms +step:610/1680 train_time:53167ms step_avg:87.16ms +step:611/1680 train_time:53255ms step_avg:87.16ms +step:612/1680 train_time:53344ms step_avg:87.16ms +step:613/1680 train_time:53432ms step_avg:87.16ms +step:614/1680 train_time:53520ms step_avg:87.17ms +step:615/1680 train_time:53608ms step_avg:87.17ms +step:616/1680 train_time:53696ms step_avg:87.17ms +step:617/1680 train_time:53784ms step_avg:87.17ms +step:618/1680 train_time:53872ms step_avg:87.17ms +step:619/1680 train_time:53960ms step_avg:87.17ms +step:620/1680 train_time:54047ms step_avg:87.17ms +step:621/1680 train_time:54135ms step_avg:87.17ms +step:622/1680 train_time:54223ms step_avg:87.18ms +step:623/1680 train_time:54312ms step_avg:87.18ms +step:624/1680 train_time:54399ms step_avg:87.18ms +step:625/1680 train_time:54488ms step_avg:87.18ms +step:625/1680 val_loss:3.6172 train_time:54578ms step_avg:87.33ms +step:626/1680 train_time:54598ms step_avg:87.22ms +step:627/1680 train_time:54667ms step_avg:87.19ms +step:628/1680 train_time:54757ms step_avg:87.19ms +step:629/1680 train_time:54848ms step_avg:87.20ms +step:630/1680 train_time:54937ms step_avg:87.20ms +step:631/1680 train_time:55025ms step_avg:87.20ms +step:632/1680 train_time:55112ms step_avg:87.20ms +step:633/1680 train_time:55198ms step_avg:87.20ms +step:634/1680 train_time:55285ms step_avg:87.20ms +step:635/1680 train_time:55373ms step_avg:87.20ms +step:636/1680 train_time:55462ms step_avg:87.20ms +step:637/1680 train_time:55554ms step_avg:87.21ms +step:638/1680 train_time:55644ms step_avg:87.22ms +step:639/1680 train_time:55733ms step_avg:87.22ms +step:640/1680 train_time:55823ms step_avg:87.22ms +step:641/1680 train_time:55910ms step_avg:87.22ms +step:642/1680 train_time:55998ms step_avg:87.22ms +step:643/1680 train_time:56086ms step_avg:87.23ms +step:644/1680 train_time:56173ms step_avg:87.23ms +step:645/1680 train_time:56261ms step_avg:87.23ms +step:646/1680 train_time:56348ms step_avg:87.23ms +step:647/1680 train_time:56435ms step_avg:87.23ms +step:648/1680 train_time:56524ms step_avg:87.23ms +step:649/1680 train_time:56612ms step_avg:87.23ms +step:650/1680 train_time:56701ms step_avg:87.23ms +step:651/1680 train_time:56789ms step_avg:87.23ms +step:652/1680 train_time:56878ms step_avg:87.24ms +step:653/1680 train_time:56966ms step_avg:87.24ms +step:654/1680 train_time:57055ms step_avg:87.24ms +step:655/1680 train_time:57143ms step_avg:87.24ms +step:656/1680 train_time:57230ms step_avg:87.24ms +step:657/1680 train_time:57317ms step_avg:87.24ms +step:658/1680 train_time:57405ms step_avg:87.24ms +step:659/1680 train_time:57493ms step_avg:87.24ms +step:660/1680 train_time:57581ms step_avg:87.24ms +step:661/1680 train_time:57670ms step_avg:87.25ms +step:662/1680 train_time:57759ms step_avg:87.25ms +step:663/1680 train_time:57848ms step_avg:87.25ms +step:664/1680 train_time:57936ms step_avg:87.25ms +step:665/1680 train_time:58026ms step_avg:87.26ms +step:666/1680 train_time:58113ms step_avg:87.26ms +step:667/1680 train_time:58201ms step_avg:87.26ms +step:668/1680 train_time:58289ms step_avg:87.26ms +step:669/1680 train_time:58376ms step_avg:87.26ms +step:670/1680 train_time:58464ms step_avg:87.26ms +step:671/1680 train_time:58552ms step_avg:87.26ms +step:672/1680 train_time:58640ms step_avg:87.26ms +step:673/1680 train_time:58729ms step_avg:87.26ms +step:674/1680 train_time:58817ms step_avg:87.27ms +step:675/1680 train_time:58905ms step_avg:87.27ms +step:676/1680 train_time:58993ms step_avg:87.27ms +step:677/1680 train_time:59082ms step_avg:87.27ms +step:678/1680 train_time:59170ms step_avg:87.27ms +step:679/1680 train_time:59257ms step_avg:87.27ms +step:680/1680 train_time:59345ms step_avg:87.27ms +step:681/1680 train_time:59433ms step_avg:87.27ms +step:682/1680 train_time:59522ms step_avg:87.27ms +step:683/1680 train_time:59609ms step_avg:87.28ms +step:684/1680 train_time:59697ms step_avg:87.28ms +step:685/1680 train_time:59786ms step_avg:87.28ms +step:686/1680 train_time:59874ms step_avg:87.28ms +step:687/1680 train_time:59964ms step_avg:87.28ms +step:688/1680 train_time:60052ms step_avg:87.28ms +step:689/1680 train_time:60139ms step_avg:87.28ms +step:690/1680 train_time:60227ms step_avg:87.29ms +step:691/1680 train_time:60314ms step_avg:87.29ms +step:692/1680 train_time:60403ms step_avg:87.29ms +step:693/1680 train_time:60491ms step_avg:87.29ms +step:694/1680 train_time:60579ms step_avg:87.29ms +step:695/1680 train_time:60667ms step_avg:87.29ms +step:696/1680 train_time:60755ms step_avg:87.29ms +step:697/1680 train_time:60843ms step_avg:87.29ms +step:698/1680 train_time:60932ms step_avg:87.29ms +step:699/1680 train_time:61020ms step_avg:87.30ms +step:700/1680 train_time:61108ms step_avg:87.30ms +step:701/1680 train_time:61196ms step_avg:87.30ms +step:702/1680 train_time:61284ms step_avg:87.30ms +step:703/1680 train_time:61372ms step_avg:87.30ms +step:704/1680 train_time:61460ms step_avg:87.30ms +step:705/1680 train_time:61548ms step_avg:87.30ms +step:706/1680 train_time:61636ms step_avg:87.30ms +step:707/1680 train_time:61725ms step_avg:87.30ms +step:708/1680 train_time:61813ms step_avg:87.31ms +step:709/1680 train_time:61901ms step_avg:87.31ms +step:710/1680 train_time:61989ms step_avg:87.31ms +step:711/1680 train_time:62078ms step_avg:87.31ms +step:712/1680 train_time:62167ms step_avg:87.31ms +step:713/1680 train_time:62255ms step_avg:87.31ms +step:714/1680 train_time:62344ms step_avg:87.32ms +step:715/1680 train_time:62431ms step_avg:87.32ms +step:716/1680 train_time:62519ms step_avg:87.32ms +step:717/1680 train_time:62606ms step_avg:87.32ms +step:718/1680 train_time:62694ms step_avg:87.32ms +step:719/1680 train_time:62782ms step_avg:87.32ms +step:720/1680 train_time:62871ms step_avg:87.32ms +step:721/1680 train_time:62959ms step_avg:87.32ms +step:722/1680 train_time:63047ms step_avg:87.32ms +step:723/1680 train_time:63135ms step_avg:87.32ms +step:724/1680 train_time:63223ms step_avg:87.32ms +step:725/1680 train_time:63311ms step_avg:87.33ms +step:726/1680 train_time:63399ms step_avg:87.33ms +step:727/1680 train_time:63487ms step_avg:87.33ms +step:728/1680 train_time:63575ms step_avg:87.33ms +step:729/1680 train_time:63664ms step_avg:87.33ms +step:730/1680 train_time:63752ms step_avg:87.33ms +step:731/1680 train_time:63840ms step_avg:87.33ms +step:732/1680 train_time:63928ms step_avg:87.33ms +step:733/1680 train_time:64017ms step_avg:87.34ms +step:734/1680 train_time:64105ms step_avg:87.34ms +step:735/1680 train_time:64193ms step_avg:87.34ms +step:736/1680 train_time:64280ms step_avg:87.34ms +step:737/1680 train_time:64368ms step_avg:87.34ms +step:738/1680 train_time:64456ms step_avg:87.34ms +step:739/1680 train_time:64545ms step_avg:87.34ms +step:740/1680 train_time:64632ms step_avg:87.34ms +step:741/1680 train_time:64721ms step_avg:87.34ms +step:742/1680 train_time:64808ms step_avg:87.34ms +step:743/1680 train_time:64896ms step_avg:87.34ms +step:744/1680 train_time:64985ms step_avg:87.35ms +step:745/1680 train_time:65074ms step_avg:87.35ms +step:746/1680 train_time:65163ms step_avg:87.35ms +step:747/1680 train_time:65251ms step_avg:87.35ms +step:748/1680 train_time:65339ms step_avg:87.35ms +step:749/1680 train_time:65427ms step_avg:87.35ms +step:750/1680 train_time:65515ms step_avg:87.35ms +step:750/1680 val_loss:3.5643 train_time:65605ms step_avg:87.47ms +step:751/1680 train_time:65623ms step_avg:87.38ms +step:752/1680 train_time:65695ms step_avg:87.36ms +step:753/1680 train_time:65789ms step_avg:87.37ms +step:754/1680 train_time:65878ms step_avg:87.37ms +step:755/1680 train_time:65966ms step_avg:87.37ms +step:756/1680 train_time:66053ms step_avg:87.37ms +step:757/1680 train_time:66141ms step_avg:87.37ms +step:758/1680 train_time:66228ms step_avg:87.37ms +step:759/1680 train_time:66315ms step_avg:87.37ms +step:760/1680 train_time:66402ms step_avg:87.37ms +step:761/1680 train_time:66490ms step_avg:87.37ms +step:762/1680 train_time:66578ms step_avg:87.37ms +step:763/1680 train_time:66668ms step_avg:87.38ms +step:764/1680 train_time:66759ms step_avg:87.38ms +step:765/1680 train_time:66849ms step_avg:87.38ms +step:766/1680 train_time:66937ms step_avg:87.39ms +step:767/1680 train_time:67025ms step_avg:87.39ms +step:768/1680 train_time:67113ms step_avg:87.39ms +step:769/1680 train_time:67200ms step_avg:87.39ms +step:770/1680 train_time:67287ms step_avg:87.39ms +step:771/1680 train_time:67374ms step_avg:87.39ms +step:772/1680 train_time:67462ms step_avg:87.39ms +step:773/1680 train_time:67550ms step_avg:87.39ms +step:774/1680 train_time:67638ms step_avg:87.39ms +step:775/1680 train_time:67727ms step_avg:87.39ms +step:776/1680 train_time:67817ms step_avg:87.39ms +step:777/1680 train_time:67906ms step_avg:87.40ms +step:778/1680 train_time:67994ms step_avg:87.40ms +step:779/1680 train_time:68082ms step_avg:87.40ms +step:780/1680 train_time:68170ms step_avg:87.40ms +step:781/1680 train_time:68257ms step_avg:87.40ms +step:782/1680 train_time:68345ms step_avg:87.40ms +step:783/1680 train_time:68432ms step_avg:87.40ms +step:784/1680 train_time:68520ms step_avg:87.40ms +step:785/1680 train_time:68608ms step_avg:87.40ms +step:786/1680 train_time:68696ms step_avg:87.40ms +step:787/1680 train_time:68786ms step_avg:87.40ms +step:788/1680 train_time:68876ms step_avg:87.41ms +step:789/1680 train_time:68964ms step_avg:87.41ms +step:790/1680 train_time:69052ms step_avg:87.41ms +step:791/1680 train_time:69140ms step_avg:87.41ms +step:792/1680 train_time:69227ms step_avg:87.41ms +step:793/1680 train_time:69315ms step_avg:87.41ms +step:794/1680 train_time:69403ms step_avg:87.41ms +step:795/1680 train_time:69491ms step_avg:87.41ms +step:796/1680 train_time:69579ms step_avg:87.41ms +step:797/1680 train_time:69668ms step_avg:87.41ms +step:798/1680 train_time:69757ms step_avg:87.41ms +step:799/1680 train_time:69845ms step_avg:87.42ms +step:800/1680 train_time:69934ms step_avg:87.42ms +step:801/1680 train_time:70022ms step_avg:87.42ms +step:802/1680 train_time:70110ms step_avg:87.42ms +step:803/1680 train_time:70198ms step_avg:87.42ms +step:804/1680 train_time:70286ms step_avg:87.42ms +step:805/1680 train_time:70374ms step_avg:87.42ms +step:806/1680 train_time:70461ms step_avg:87.42ms +step:807/1680 train_time:70550ms step_avg:87.42ms +step:808/1680 train_time:70638ms step_avg:87.42ms +step:809/1680 train_time:70726ms step_avg:87.42ms +step:810/1680 train_time:70815ms step_avg:87.43ms +step:811/1680 train_time:70904ms step_avg:87.43ms +step:812/1680 train_time:70992ms step_avg:87.43ms +step:813/1680 train_time:71081ms step_avg:87.43ms +step:814/1680 train_time:71168ms step_avg:87.43ms +step:815/1680 train_time:71256ms step_avg:87.43ms +step:816/1680 train_time:71344ms step_avg:87.43ms +step:817/1680 train_time:71432ms step_avg:87.43ms +step:818/1680 train_time:71520ms step_avg:87.43ms +step:819/1680 train_time:71607ms step_avg:87.43ms +step:820/1680 train_time:71696ms step_avg:87.43ms +step:821/1680 train_time:71785ms step_avg:87.44ms +step:822/1680 train_time:71874ms step_avg:87.44ms +step:823/1680 train_time:71962ms step_avg:87.44ms +step:824/1680 train_time:72050ms step_avg:87.44ms +step:825/1680 train_time:72138ms step_avg:87.44ms +step:826/1680 train_time:72226ms step_avg:87.44ms +step:827/1680 train_time:72315ms step_avg:87.44ms +step:828/1680 train_time:72403ms step_avg:87.44ms +step:829/1680 train_time:72491ms step_avg:87.44ms +step:830/1680 train_time:72578ms step_avg:87.44ms +step:831/1680 train_time:72667ms step_avg:87.45ms +step:832/1680 train_time:72755ms step_avg:87.45ms +step:833/1680 train_time:72843ms step_avg:87.45ms +step:834/1680 train_time:72932ms step_avg:87.45ms +step:835/1680 train_time:73021ms step_avg:87.45ms +step:836/1680 train_time:73109ms step_avg:87.45ms +step:837/1680 train_time:73197ms step_avg:87.45ms +step:838/1680 train_time:73285ms step_avg:87.45ms +step:839/1680 train_time:73374ms step_avg:87.45ms +step:840/1680 train_time:73461ms step_avg:87.45ms +step:841/1680 train_time:73549ms step_avg:87.45ms +step:842/1680 train_time:73637ms step_avg:87.45ms +step:843/1680 train_time:73724ms step_avg:87.45ms +step:844/1680 train_time:73812ms step_avg:87.45ms +step:845/1680 train_time:73900ms step_avg:87.46ms +step:846/1680 train_time:73988ms step_avg:87.46ms +step:847/1680 train_time:74077ms step_avg:87.46ms +step:848/1680 train_time:74165ms step_avg:87.46ms +step:849/1680 train_time:74253ms step_avg:87.46ms +step:850/1680 train_time:74341ms step_avg:87.46ms +step:851/1680 train_time:74428ms step_avg:87.46ms +step:852/1680 train_time:74517ms step_avg:87.46ms +step:853/1680 train_time:74605ms step_avg:87.46ms +step:854/1680 train_time:74693ms step_avg:87.46ms +step:855/1680 train_time:74781ms step_avg:87.46ms +step:856/1680 train_time:74869ms step_avg:87.46ms +step:857/1680 train_time:74957ms step_avg:87.46ms +step:858/1680 train_time:75046ms step_avg:87.47ms +step:859/1680 train_time:75135ms step_avg:87.47ms +step:860/1680 train_time:75223ms step_avg:87.47ms +step:861/1680 train_time:75312ms step_avg:87.47ms +step:862/1680 train_time:75400ms step_avg:87.47ms +step:863/1680 train_time:75488ms step_avg:87.47ms +step:864/1680 train_time:75576ms step_avg:87.47ms +step:865/1680 train_time:75664ms step_avg:87.47ms +step:866/1680 train_time:75753ms step_avg:87.47ms +step:867/1680 train_time:75841ms step_avg:87.47ms +step:868/1680 train_time:75928ms step_avg:87.48ms +step:869/1680 train_time:76016ms step_avg:87.48ms +step:870/1680 train_time:76105ms step_avg:87.48ms +step:871/1680 train_time:76194ms step_avg:87.48ms +step:872/1680 train_time:76282ms step_avg:87.48ms +step:873/1680 train_time:76370ms step_avg:87.48ms +step:874/1680 train_time:76458ms step_avg:87.48ms +step:875/1680 train_time:76546ms step_avg:87.48ms +step:875/1680 val_loss:3.5179 train_time:76635ms step_avg:87.58ms +step:876/1680 train_time:76654ms step_avg:87.50ms +step:877/1680 train_time:76725ms step_avg:87.49ms +step:878/1680 train_time:76816ms step_avg:87.49ms +step:879/1680 train_time:76904ms step_avg:87.49ms +step:880/1680 train_time:76991ms step_avg:87.49ms +step:881/1680 train_time:77078ms step_avg:87.49ms +step:882/1680 train_time:77165ms step_avg:87.49ms +step:883/1680 train_time:77252ms step_avg:87.49ms +step:884/1680 train_time:77340ms step_avg:87.49ms +step:885/1680 train_time:77429ms step_avg:87.49ms +step:886/1680 train_time:77516ms step_avg:87.49ms +step:887/1680 train_time:77605ms step_avg:87.49ms +step:888/1680 train_time:77695ms step_avg:87.49ms +step:889/1680 train_time:77785ms step_avg:87.50ms +step:890/1680 train_time:77873ms step_avg:87.50ms +step:891/1680 train_time:77961ms step_avg:87.50ms +step:892/1680 train_time:78049ms step_avg:87.50ms +step:893/1680 train_time:78136ms step_avg:87.50ms +step:894/1680 train_time:78223ms step_avg:87.50ms +step:895/1680 train_time:78311ms step_avg:87.50ms +step:896/1680 train_time:78399ms step_avg:87.50ms +step:897/1680 train_time:78487ms step_avg:87.50ms +step:898/1680 train_time:78575ms step_avg:87.50ms +step:899/1680 train_time:78664ms step_avg:87.50ms +step:900/1680 train_time:78753ms step_avg:87.50ms +step:901/1680 train_time:78842ms step_avg:87.50ms +step:902/1680 train_time:78930ms step_avg:87.51ms +step:903/1680 train_time:79018ms step_avg:87.51ms +step:904/1680 train_time:79106ms step_avg:87.51ms +step:905/1680 train_time:79193ms step_avg:87.51ms +step:906/1680 train_time:79281ms step_avg:87.51ms +step:907/1680 train_time:79368ms step_avg:87.51ms +step:908/1680 train_time:79456ms step_avg:87.51ms +step:909/1680 train_time:79544ms step_avg:87.51ms +step:910/1680 train_time:79633ms step_avg:87.51ms +step:911/1680 train_time:79722ms step_avg:87.51ms +step:912/1680 train_time:79810ms step_avg:87.51ms +step:913/1680 train_time:79899ms step_avg:87.51ms +step:914/1680 train_time:79987ms step_avg:87.51ms +step:915/1680 train_time:80075ms step_avg:87.51ms +step:916/1680 train_time:80162ms step_avg:87.51ms +step:917/1680 train_time:80251ms step_avg:87.51ms +step:918/1680 train_time:80339ms step_avg:87.51ms +step:919/1680 train_time:80428ms step_avg:87.52ms +step:920/1680 train_time:80516ms step_avg:87.52ms +step:921/1680 train_time:80604ms step_avg:87.52ms +step:922/1680 train_time:80693ms step_avg:87.52ms +step:923/1680 train_time:80782ms step_avg:87.52ms +step:924/1680 train_time:80870ms step_avg:87.52ms +step:925/1680 train_time:80959ms step_avg:87.52ms +step:926/1680 train_time:81047ms step_avg:87.52ms +step:927/1680 train_time:81134ms step_avg:87.52ms +step:928/1680 train_time:81222ms step_avg:87.52ms +step:929/1680 train_time:81310ms step_avg:87.52ms +step:930/1680 train_time:81398ms step_avg:87.52ms +step:931/1680 train_time:81486ms step_avg:87.53ms +step:932/1680 train_time:81574ms step_avg:87.53ms +step:933/1680 train_time:81663ms step_avg:87.53ms +step:934/1680 train_time:81752ms step_avg:87.53ms +step:935/1680 train_time:81840ms step_avg:87.53ms +step:936/1680 train_time:81928ms step_avg:87.53ms +step:937/1680 train_time:82017ms step_avg:87.53ms +step:938/1680 train_time:82104ms step_avg:87.53ms +step:939/1680 train_time:82192ms step_avg:87.53ms +step:940/1680 train_time:82280ms step_avg:87.53ms +step:941/1680 train_time:82369ms step_avg:87.53ms +step:942/1680 train_time:82457ms step_avg:87.53ms +step:943/1680 train_time:82545ms step_avg:87.53ms +step:944/1680 train_time:82632ms step_avg:87.53ms +step:945/1680 train_time:82721ms step_avg:87.54ms +step:946/1680 train_time:82809ms step_avg:87.54ms +step:947/1680 train_time:82898ms step_avg:87.54ms +step:948/1680 train_time:82986ms step_avg:87.54ms +step:949/1680 train_time:83075ms step_avg:87.54ms +step:950/1680 train_time:83163ms step_avg:87.54ms +step:951/1680 train_time:83252ms step_avg:87.54ms +step:952/1680 train_time:83340ms step_avg:87.54ms +step:953/1680 train_time:83428ms step_avg:87.54ms +step:954/1680 train_time:83516ms step_avg:87.54ms +step:955/1680 train_time:83604ms step_avg:87.54ms +step:956/1680 train_time:83691ms step_avg:87.54ms +step:957/1680 train_time:83779ms step_avg:87.54ms +step:958/1680 train_time:83868ms step_avg:87.54ms +step:959/1680 train_time:83956ms step_avg:87.55ms +step:960/1680 train_time:84045ms step_avg:87.55ms +step:961/1680 train_time:84133ms step_avg:87.55ms +step:962/1680 train_time:84221ms step_avg:87.55ms +step:963/1680 train_time:84309ms step_avg:87.55ms +step:964/1680 train_time:84397ms step_avg:87.55ms +step:965/1680 train_time:84486ms step_avg:87.55ms +step:966/1680 train_time:84573ms step_avg:87.55ms +step:967/1680 train_time:84662ms step_avg:87.55ms +step:968/1680 train_time:84750ms step_avg:87.55ms +step:969/1680 train_time:84838ms step_avg:87.55ms +step:970/1680 train_time:84926ms step_avg:87.55ms +step:971/1680 train_time:85014ms step_avg:87.55ms +step:972/1680 train_time:85102ms step_avg:87.55ms +step:973/1680 train_time:85191ms step_avg:87.55ms +step:974/1680 train_time:85279ms step_avg:87.56ms +step:975/1680 train_time:85367ms step_avg:87.56ms +step:976/1680 train_time:85455ms step_avg:87.56ms +step:977/1680 train_time:85543ms step_avg:87.56ms +step:978/1680 train_time:85631ms step_avg:87.56ms +step:979/1680 train_time:85720ms step_avg:87.56ms +step:980/1680 train_time:85808ms step_avg:87.56ms +step:981/1680 train_time:85896ms step_avg:87.56ms +step:982/1680 train_time:85984ms step_avg:87.56ms +step:983/1680 train_time:86073ms step_avg:87.56ms +step:984/1680 train_time:86161ms step_avg:87.56ms +step:985/1680 train_time:86250ms step_avg:87.56ms +step:986/1680 train_time:86338ms step_avg:87.56ms +step:987/1680 train_time:86427ms step_avg:87.57ms +step:988/1680 train_time:86515ms step_avg:87.57ms +step:989/1680 train_time:86603ms step_avg:87.57ms +step:990/1680 train_time:86691ms step_avg:87.57ms +step:991/1680 train_time:86779ms step_avg:87.57ms +step:992/1680 train_time:86866ms step_avg:87.57ms +step:993/1680 train_time:86955ms step_avg:87.57ms +step:994/1680 train_time:87043ms step_avg:87.57ms +step:995/1680 train_time:87132ms step_avg:87.57ms +step:996/1680 train_time:87220ms step_avg:87.57ms +step:997/1680 train_time:87307ms step_avg:87.57ms +step:998/1680 train_time:87395ms step_avg:87.57ms +step:999/1680 train_time:87483ms step_avg:87.57ms +step:1000/1680 train_time:87571ms step_avg:87.57ms +step:1000/1680 val_loss:3.4696 train_time:87661ms step_avg:87.66ms +step:1001/1680 train_time:87679ms step_avg:87.59ms +step:1002/1680 train_time:87752ms step_avg:87.58ms +step:1003/1680 train_time:87845ms step_avg:87.58ms +step:1004/1680 train_time:87936ms step_avg:87.59ms +step:1005/1680 train_time:88024ms step_avg:87.59ms +step:1006/1680 train_time:88111ms step_avg:87.59ms +step:1007/1680 train_time:88198ms step_avg:87.59ms +step:1008/1680 train_time:88285ms step_avg:87.58ms +step:1009/1680 train_time:88372ms step_avg:87.58ms +step:1010/1680 train_time:88459ms step_avg:87.58ms +step:1011/1680 train_time:88546ms step_avg:87.58ms +step:1012/1680 train_time:88635ms step_avg:87.58ms +step:1013/1680 train_time:88725ms step_avg:87.59ms +step:1014/1680 train_time:88814ms step_avg:87.59ms +step:1015/1680 train_time:88904ms step_avg:87.59ms +step:1016/1680 train_time:88993ms step_avg:87.59ms +step:1017/1680 train_time:89080ms step_avg:87.59ms +step:1018/1680 train_time:89167ms step_avg:87.59ms +step:1019/1680 train_time:89255ms step_avg:87.59ms +step:1020/1680 train_time:89342ms step_avg:87.59ms +step:1021/1680 train_time:89429ms step_avg:87.59ms +step:1022/1680 train_time:89517ms step_avg:87.59ms +step:1023/1680 train_time:89605ms step_avg:87.59ms +step:1024/1680 train_time:89694ms step_avg:87.59ms +step:1025/1680 train_time:89783ms step_avg:87.59ms +step:1026/1680 train_time:89872ms step_avg:87.59ms +step:1027/1680 train_time:89961ms step_avg:87.60ms +step:1028/1680 train_time:90049ms step_avg:87.60ms +step:1029/1680 train_time:90137ms step_avg:87.60ms +step:1030/1680 train_time:90224ms step_avg:87.60ms +step:1031/1680 train_time:90312ms step_avg:87.60ms +step:1032/1680 train_time:90399ms step_avg:87.60ms +step:1033/1680 train_time:90487ms step_avg:87.60ms +step:1034/1680 train_time:90575ms step_avg:87.60ms +step:1035/1680 train_time:90665ms step_avg:87.60ms +step:1036/1680 train_time:90753ms step_avg:87.60ms +step:1037/1680 train_time:90842ms step_avg:87.60ms +step:1038/1680 train_time:90930ms step_avg:87.60ms +step:1039/1680 train_time:91020ms step_avg:87.60ms +step:1040/1680 train_time:91108ms step_avg:87.60ms +step:1041/1680 train_time:91195ms step_avg:87.60ms +step:1042/1680 train_time:91284ms step_avg:87.60ms +step:1043/1680 train_time:91372ms step_avg:87.60ms +step:1044/1680 train_time:91459ms step_avg:87.60ms +step:1045/1680 train_time:91546ms step_avg:87.60ms +step:1046/1680 train_time:91635ms step_avg:87.60ms +step:1047/1680 train_time:91724ms step_avg:87.61ms +step:1048/1680 train_time:91813ms step_avg:87.61ms +step:1049/1680 train_time:91902ms step_avg:87.61ms +step:1050/1680 train_time:91990ms step_avg:87.61ms +step:1051/1680 train_time:92079ms step_avg:87.61ms +step:1052/1680 train_time:92167ms step_avg:87.61ms +step:1053/1680 train_time:92255ms step_avg:87.61ms +step:1054/1680 train_time:92343ms step_avg:87.61ms +step:1055/1680 train_time:92430ms step_avg:87.61ms +step:1056/1680 train_time:92518ms step_avg:87.61ms +step:1057/1680 train_time:92607ms step_avg:87.61ms +step:1058/1680 train_time:92695ms step_avg:87.61ms +step:1059/1680 train_time:92784ms step_avg:87.61ms +step:1060/1680 train_time:92872ms step_avg:87.62ms +step:1061/1680 train_time:92961ms step_avg:87.62ms +step:1062/1680 train_time:93049ms step_avg:87.62ms +step:1063/1680 train_time:93137ms step_avg:87.62ms +step:1064/1680 train_time:93225ms step_avg:87.62ms +step:1065/1680 train_time:93313ms step_avg:87.62ms +step:1066/1680 train_time:93401ms step_avg:87.62ms +step:1067/1680 train_time:93489ms step_avg:87.62ms +step:1068/1680 train_time:93577ms step_avg:87.62ms +step:1069/1680 train_time:93665ms step_avg:87.62ms +step:1070/1680 train_time:93753ms step_avg:87.62ms +step:1071/1680 train_time:93841ms step_avg:87.62ms +step:1072/1680 train_time:93929ms step_avg:87.62ms +step:1073/1680 train_time:94017ms step_avg:87.62ms +step:1074/1680 train_time:94105ms step_avg:87.62ms +step:1075/1680 train_time:94193ms step_avg:87.62ms +step:1076/1680 train_time:94282ms step_avg:87.62ms +step:1077/1680 train_time:94370ms step_avg:87.62ms +step:1078/1680 train_time:94458ms step_avg:87.62ms +step:1079/1680 train_time:94546ms step_avg:87.62ms +step:1080/1680 train_time:94633ms step_avg:87.62ms +step:1081/1680 train_time:94721ms step_avg:87.62ms +step:1082/1680 train_time:94810ms step_avg:87.63ms +step:1083/1680 train_time:94899ms step_avg:87.63ms +step:1084/1680 train_time:94987ms step_avg:87.63ms +step:1085/1680 train_time:95076ms step_avg:87.63ms +step:1086/1680 train_time:95164ms step_avg:87.63ms +step:1087/1680 train_time:95252ms step_avg:87.63ms +step:1088/1680 train_time:95340ms step_avg:87.63ms +step:1089/1680 train_time:95428ms step_avg:87.63ms +step:1090/1680 train_time:95516ms step_avg:87.63ms +step:1091/1680 train_time:95604ms step_avg:87.63ms +step:1092/1680 train_time:95693ms step_avg:87.63ms +step:1093/1680 train_time:95781ms step_avg:87.63ms +step:1094/1680 train_time:95869ms step_avg:87.63ms +step:1095/1680 train_time:95958ms step_avg:87.63ms +step:1096/1680 train_time:96046ms step_avg:87.63ms +step:1097/1680 train_time:96134ms step_avg:87.63ms +step:1098/1680 train_time:96223ms step_avg:87.63ms +step:1099/1680 train_time:96311ms step_avg:87.64ms +step:1100/1680 train_time:96400ms step_avg:87.64ms +step:1101/1680 train_time:96489ms step_avg:87.64ms +step:1102/1680 train_time:96578ms step_avg:87.64ms +step:1103/1680 train_time:96666ms step_avg:87.64ms +step:1104/1680 train_time:96755ms step_avg:87.64ms +step:1105/1680 train_time:96844ms step_avg:87.64ms +step:1106/1680 train_time:96934ms step_avg:87.64ms +step:1107/1680 train_time:97023ms step_avg:87.64ms +step:1108/1680 train_time:97111ms step_avg:87.65ms +step:1109/1680 train_time:97200ms step_avg:87.65ms +step:1110/1680 train_time:97290ms step_avg:87.65ms +step:1111/1680 train_time:97379ms step_avg:87.65ms +step:1112/1680 train_time:97468ms step_avg:87.65ms +step:1113/1680 train_time:97556ms step_avg:87.65ms +step:1114/1680 train_time:97645ms step_avg:87.65ms +step:1115/1680 train_time:97735ms step_avg:87.65ms +step:1116/1680 train_time:97824ms step_avg:87.66ms +step:1117/1680 train_time:97913ms step_avg:87.66ms +step:1118/1680 train_time:98002ms step_avg:87.66ms +step:1119/1680 train_time:98092ms step_avg:87.66ms +step:1120/1680 train_time:98181ms step_avg:87.66ms +step:1121/1680 train_time:98270ms step_avg:87.66ms +step:1122/1680 train_time:98359ms step_avg:87.66ms +step:1123/1680 train_time:98448ms step_avg:87.67ms +step:1124/1680 train_time:98537ms step_avg:87.67ms +step:1125/1680 train_time:98626ms step_avg:87.67ms +step:1125/1680 val_loss:3.4152 train_time:98716ms step_avg:87.75ms +step:1126/1680 train_time:98735ms step_avg:87.69ms +step:1127/1680 train_time:98807ms step_avg:87.67ms +step:1128/1680 train_time:98899ms step_avg:87.68ms +step:1129/1680 train_time:98989ms step_avg:87.68ms +step:1130/1680 train_time:99078ms step_avg:87.68ms +step:1131/1680 train_time:99166ms step_avg:87.68ms +step:1132/1680 train_time:99254ms step_avg:87.68ms +step:1133/1680 train_time:99342ms step_avg:87.68ms +step:1134/1680 train_time:99430ms step_avg:87.68ms +step:1135/1680 train_time:99518ms step_avg:87.68ms +step:1136/1680 train_time:99605ms step_avg:87.68ms +step:1137/1680 train_time:99695ms step_avg:87.68ms +step:1138/1680 train_time:99785ms step_avg:87.68ms +step:1139/1680 train_time:99876ms step_avg:87.69ms +step:1140/1680 train_time:99967ms step_avg:87.69ms +step:1141/1680 train_time:100056ms step_avg:87.69ms +step:1142/1680 train_time:100144ms step_avg:87.69ms +step:1143/1680 train_time:100233ms step_avg:87.69ms +step:1144/1680 train_time:100321ms step_avg:87.69ms +step:1145/1680 train_time:100409ms step_avg:87.69ms +step:1146/1680 train_time:100497ms step_avg:87.69ms +step:1147/1680 train_time:100586ms step_avg:87.69ms +step:1148/1680 train_time:100674ms step_avg:87.70ms +step:1149/1680 train_time:100764ms step_avg:87.70ms +step:1150/1680 train_time:100853ms step_avg:87.70ms +step:1151/1680 train_time:100943ms step_avg:87.70ms +step:1152/1680 train_time:101032ms step_avg:87.70ms +step:1153/1680 train_time:101122ms step_avg:87.70ms +step:1154/1680 train_time:101210ms step_avg:87.70ms +step:1155/1680 train_time:101300ms step_avg:87.71ms +step:1156/1680 train_time:101388ms step_avg:87.71ms +step:1157/1680 train_time:101476ms step_avg:87.71ms +step:1158/1680 train_time:101565ms step_avg:87.71ms +step:1159/1680 train_time:101654ms step_avg:87.71ms +step:1160/1680 train_time:101743ms step_avg:87.71ms +step:1161/1680 train_time:101832ms step_avg:87.71ms +step:1162/1680 train_time:101922ms step_avg:87.71ms +step:1163/1680 train_time:102011ms step_avg:87.71ms +step:1164/1680 train_time:102100ms step_avg:87.71ms +step:1165/1680 train_time:102190ms step_avg:87.72ms +step:1166/1680 train_time:102278ms step_avg:87.72ms +step:1167/1680 train_time:102367ms step_avg:87.72ms +step:1168/1680 train_time:102456ms step_avg:87.72ms +step:1169/1680 train_time:102544ms step_avg:87.72ms +step:1170/1680 train_time:102633ms step_avg:87.72ms +step:1171/1680 train_time:102721ms step_avg:87.72ms +step:1172/1680 train_time:102811ms step_avg:87.72ms +step:1173/1680 train_time:102902ms step_avg:87.73ms +step:1174/1680 train_time:102990ms step_avg:87.73ms +step:1175/1680 train_time:103080ms step_avg:87.73ms +step:1176/1680 train_time:103169ms step_avg:87.73ms +step:1177/1680 train_time:103258ms step_avg:87.73ms +step:1178/1680 train_time:103347ms step_avg:87.73ms +step:1179/1680 train_time:103436ms step_avg:87.73ms +step:1180/1680 train_time:103525ms step_avg:87.73ms +step:1181/1680 train_time:103613ms step_avg:87.73ms +step:1182/1680 train_time:103702ms step_avg:87.73ms +step:1183/1680 train_time:103790ms step_avg:87.73ms +step:1184/1680 train_time:103879ms step_avg:87.74ms +step:1185/1680 train_time:103969ms step_avg:87.74ms +step:1186/1680 train_time:104058ms step_avg:87.74ms +step:1187/1680 train_time:104148ms step_avg:87.74ms +step:1188/1680 train_time:104237ms step_avg:87.74ms +step:1189/1680 train_time:104326ms step_avg:87.74ms +step:1190/1680 train_time:104415ms step_avg:87.74ms +step:1191/1680 train_time:104504ms step_avg:87.74ms +step:1192/1680 train_time:104593ms step_avg:87.75ms +step:1193/1680 train_time:104681ms step_avg:87.75ms +step:1194/1680 train_time:104771ms step_avg:87.75ms +step:1195/1680 train_time:104859ms step_avg:87.75ms +step:1196/1680 train_time:104948ms step_avg:87.75ms +step:1197/1680 train_time:105037ms step_avg:87.75ms +step:1198/1680 train_time:105127ms step_avg:87.75ms +step:1199/1680 train_time:105215ms step_avg:87.75ms +step:1200/1680 train_time:105303ms step_avg:87.75ms +step:1201/1680 train_time:105392ms step_avg:87.75ms +step:1202/1680 train_time:105482ms step_avg:87.76ms +step:1203/1680 train_time:105570ms step_avg:87.76ms +step:1204/1680 train_time:105659ms step_avg:87.76ms +step:1205/1680 train_time:105748ms step_avg:87.76ms +step:1206/1680 train_time:105836ms step_avg:87.76ms +step:1207/1680 train_time:105926ms step_avg:87.76ms +step:1208/1680 train_time:106015ms step_avg:87.76ms +step:1209/1680 train_time:106104ms step_avg:87.76ms +step:1210/1680 train_time:106193ms step_avg:87.76ms +step:1211/1680 train_time:106282ms step_avg:87.76ms +step:1212/1680 train_time:106372ms step_avg:87.77ms +step:1213/1680 train_time:106461ms step_avg:87.77ms +step:1214/1680 train_time:106550ms step_avg:87.77ms +step:1215/1680 train_time:106639ms step_avg:87.77ms +step:1216/1680 train_time:106729ms step_avg:87.77ms +step:1217/1680 train_time:106818ms step_avg:87.77ms +step:1218/1680 train_time:106908ms step_avg:87.77ms +step:1219/1680 train_time:106998ms step_avg:87.78ms +step:1220/1680 train_time:107088ms step_avg:87.78ms +step:1221/1680 train_time:107177ms step_avg:87.78ms +step:1222/1680 train_time:107266ms step_avg:87.78ms +step:1223/1680 train_time:107354ms step_avg:87.78ms +step:1224/1680 train_time:107445ms step_avg:87.78ms +step:1225/1680 train_time:107533ms step_avg:87.78ms +step:1226/1680 train_time:107622ms step_avg:87.78ms +step:1227/1680 train_time:107710ms step_avg:87.78ms +step:1228/1680 train_time:107799ms step_avg:87.78ms +step:1229/1680 train_time:107887ms step_avg:87.78ms +step:1230/1680 train_time:107976ms step_avg:87.79ms +step:1231/1680 train_time:108064ms step_avg:87.79ms +step:1232/1680 train_time:108153ms step_avg:87.79ms +step:1233/1680 train_time:108242ms step_avg:87.79ms +step:1234/1680 train_time:108331ms step_avg:87.79ms +step:1235/1680 train_time:108419ms step_avg:87.79ms +step:1236/1680 train_time:108508ms step_avg:87.79ms +step:1237/1680 train_time:108597ms step_avg:87.79ms +step:1238/1680 train_time:108686ms step_avg:87.79ms +step:1239/1680 train_time:108775ms step_avg:87.79ms +step:1240/1680 train_time:108865ms step_avg:87.79ms +step:1241/1680 train_time:108954ms step_avg:87.80ms +step:1242/1680 train_time:109043ms step_avg:87.80ms +step:1243/1680 train_time:109132ms step_avg:87.80ms +step:1244/1680 train_time:109222ms step_avg:87.80ms +step:1245/1680 train_time:109310ms step_avg:87.80ms +step:1246/1680 train_time:109399ms step_avg:87.80ms +step:1247/1680 train_time:109487ms step_avg:87.80ms +step:1248/1680 train_time:109577ms step_avg:87.80ms +step:1249/1680 train_time:109665ms step_avg:87.80ms +step:1250/1680 train_time:109755ms step_avg:87.80ms +step:1250/1680 val_loss:3.3772 train_time:109845ms step_avg:87.88ms +step:1251/1680 train_time:109864ms step_avg:87.82ms +step:1252/1680 train_time:109937ms step_avg:87.81ms +step:1253/1680 train_time:110031ms step_avg:87.81ms +step:1254/1680 train_time:110121ms step_avg:87.82ms +step:1255/1680 train_time:110209ms step_avg:87.82ms +step:1256/1680 train_time:110298ms step_avg:87.82ms +step:1257/1680 train_time:110386ms step_avg:87.82ms +step:1258/1680 train_time:110474ms step_avg:87.82ms +step:1259/1680 train_time:110561ms step_avg:87.82ms +step:1260/1680 train_time:110650ms step_avg:87.82ms +step:1261/1680 train_time:110738ms step_avg:87.82ms +step:1262/1680 train_time:110828ms step_avg:87.82ms +step:1263/1680 train_time:110919ms step_avg:87.82ms +step:1264/1680 train_time:111010ms step_avg:87.82ms +step:1265/1680 train_time:111101ms step_avg:87.83ms +step:1266/1680 train_time:111190ms step_avg:87.83ms +step:1267/1680 train_time:111279ms step_avg:87.83ms +step:1268/1680 train_time:111368ms step_avg:87.83ms +step:1269/1680 train_time:111455ms step_avg:87.83ms +step:1270/1680 train_time:111543ms step_avg:87.83ms +step:1271/1680 train_time:111631ms step_avg:87.83ms +step:1272/1680 train_time:111719ms step_avg:87.83ms +step:1273/1680 train_time:111808ms step_avg:87.83ms +step:1274/1680 train_time:111898ms step_avg:87.83ms +step:1275/1680 train_time:111988ms step_avg:87.83ms +step:1276/1680 train_time:112077ms step_avg:87.84ms +step:1277/1680 train_time:112168ms step_avg:87.84ms +step:1278/1680 train_time:112256ms step_avg:87.84ms +step:1279/1680 train_time:112345ms step_avg:87.84ms +step:1280/1680 train_time:112433ms step_avg:87.84ms +step:1281/1680 train_time:112522ms step_avg:87.84ms +step:1282/1680 train_time:112611ms step_avg:87.84ms +step:1283/1680 train_time:112699ms step_avg:87.84ms +step:1284/1680 train_time:112788ms step_avg:87.84ms +step:1285/1680 train_time:112878ms step_avg:87.84ms +step:1286/1680 train_time:112969ms step_avg:87.85ms +step:1287/1680 train_time:113059ms step_avg:87.85ms +step:1288/1680 train_time:113149ms step_avg:87.85ms +step:1289/1680 train_time:113238ms step_avg:87.85ms +step:1290/1680 train_time:113327ms step_avg:87.85ms +step:1291/1680 train_time:113415ms step_avg:87.85ms +step:1292/1680 train_time:113504ms step_avg:87.85ms +step:1293/1680 train_time:113592ms step_avg:87.85ms +step:1294/1680 train_time:113681ms step_avg:87.85ms +step:1295/1680 train_time:113770ms step_avg:87.85ms +step:1296/1680 train_time:113859ms step_avg:87.85ms +step:1297/1680 train_time:113949ms step_avg:87.86ms +step:1298/1680 train_time:114038ms step_avg:87.86ms +step:1299/1680 train_time:114129ms step_avg:87.86ms +step:1300/1680 train_time:114218ms step_avg:87.86ms +step:1301/1680 train_time:114307ms step_avg:87.86ms +step:1302/1680 train_time:114395ms step_avg:87.86ms +step:1303/1680 train_time:114484ms step_avg:87.86ms +step:1304/1680 train_time:114573ms step_avg:87.86ms +step:1305/1680 train_time:114662ms step_avg:87.86ms +step:1306/1680 train_time:114752ms step_avg:87.87ms +step:1307/1680 train_time:114842ms step_avg:87.87ms +step:1308/1680 train_time:114931ms step_avg:87.87ms +step:1309/1680 train_time:115020ms step_avg:87.87ms +step:1310/1680 train_time:115110ms step_avg:87.87ms +step:1311/1680 train_time:115199ms step_avg:87.87ms +step:1312/1680 train_time:115288ms step_avg:87.87ms +step:1313/1680 train_time:115379ms step_avg:87.87ms +step:1314/1680 train_time:115467ms step_avg:87.87ms +step:1315/1680 train_time:115556ms step_avg:87.87ms +step:1316/1680 train_time:115644ms step_avg:87.88ms +step:1317/1680 train_time:115733ms step_avg:87.88ms +step:1318/1680 train_time:115823ms step_avg:87.88ms +step:1319/1680 train_time:115913ms step_avg:87.88ms +step:1320/1680 train_time:116001ms step_avg:87.88ms +step:1321/1680 train_time:116091ms step_avg:87.88ms +step:1322/1680 train_time:116181ms step_avg:87.88ms +step:1323/1680 train_time:116271ms step_avg:87.88ms +step:1324/1680 train_time:116361ms step_avg:87.89ms +step:1325/1680 train_time:116450ms step_avg:87.89ms +step:1326/1680 train_time:116539ms step_avg:87.89ms +step:1327/1680 train_time:116627ms step_avg:87.89ms +step:1328/1680 train_time:116716ms step_avg:87.89ms +step:1329/1680 train_time:116806ms step_avg:87.89ms +step:1330/1680 train_time:116895ms step_avg:87.89ms +step:1331/1680 train_time:116985ms step_avg:87.89ms +step:1332/1680 train_time:117075ms step_avg:87.89ms +step:1333/1680 train_time:117164ms step_avg:87.89ms +step:1334/1680 train_time:117252ms step_avg:87.90ms +step:1335/1680 train_time:117342ms step_avg:87.90ms +step:1336/1680 train_time:117432ms step_avg:87.90ms +step:1337/1680 train_time:117522ms step_avg:87.90ms +step:1338/1680 train_time:117611ms step_avg:87.90ms +step:1339/1680 train_time:117700ms step_avg:87.90ms +step:1340/1680 train_time:117789ms step_avg:87.90ms +step:1341/1680 train_time:117878ms step_avg:87.90ms +step:1342/1680 train_time:117968ms step_avg:87.90ms +step:1343/1680 train_time:118057ms step_avg:87.91ms +step:1344/1680 train_time:118146ms step_avg:87.91ms +step:1345/1680 train_time:118235ms step_avg:87.91ms +step:1346/1680 train_time:118324ms step_avg:87.91ms +step:1347/1680 train_time:118413ms step_avg:87.91ms +step:1348/1680 train_time:118502ms step_avg:87.91ms +step:1349/1680 train_time:118591ms step_avg:87.91ms +step:1350/1680 train_time:118681ms step_avg:87.91ms +step:1351/1680 train_time:118770ms step_avg:87.91ms +step:1352/1680 train_time:118859ms step_avg:87.91ms +step:1353/1680 train_time:118949ms step_avg:87.91ms +step:1354/1680 train_time:119037ms step_avg:87.92ms +step:1355/1680 train_time:119127ms step_avg:87.92ms +step:1356/1680 train_time:119215ms step_avg:87.92ms +step:1357/1680 train_time:119305ms step_avg:87.92ms +step:1358/1680 train_time:119394ms step_avg:87.92ms +step:1359/1680 train_time:119483ms step_avg:87.92ms +step:1360/1680 train_time:119572ms step_avg:87.92ms +step:1361/1680 train_time:119661ms step_avg:87.92ms +step:1362/1680 train_time:119750ms step_avg:87.92ms +step:1363/1680 train_time:119839ms step_avg:87.92ms +step:1364/1680 train_time:119929ms step_avg:87.92ms +step:1365/1680 train_time:120017ms step_avg:87.92ms +step:1366/1680 train_time:120107ms step_avg:87.93ms +step:1367/1680 train_time:120197ms step_avg:87.93ms +step:1368/1680 train_time:120287ms step_avg:87.93ms +step:1369/1680 train_time:120375ms step_avg:87.93ms +step:1370/1680 train_time:120464ms step_avg:87.93ms +step:1371/1680 train_time:120553ms step_avg:87.93ms +step:1372/1680 train_time:120642ms step_avg:87.93ms +step:1373/1680 train_time:120731ms step_avg:87.93ms +step:1374/1680 train_time:120821ms step_avg:87.93ms +step:1375/1680 train_time:120910ms step_avg:87.93ms +step:1375/1680 val_loss:3.3431 train_time:121000ms step_avg:88.00ms +step:1376/1680 train_time:121019ms step_avg:87.95ms +step:1377/1680 train_time:121093ms step_avg:87.94ms +step:1378/1680 train_time:121189ms step_avg:87.95ms +step:1379/1680 train_time:121280ms step_avg:87.95ms +step:1380/1680 train_time:121368ms step_avg:87.95ms +step:1381/1680 train_time:121456ms step_avg:87.95ms +step:1382/1680 train_time:121544ms step_avg:87.95ms +step:1383/1680 train_time:121631ms step_avg:87.95ms +step:1384/1680 train_time:121719ms step_avg:87.95ms +step:1385/1680 train_time:121807ms step_avg:87.95ms +step:1386/1680 train_time:121895ms step_avg:87.95ms +step:1387/1680 train_time:121985ms step_avg:87.95ms +step:1388/1680 train_time:122076ms step_avg:87.95ms +step:1389/1680 train_time:122168ms step_avg:87.95ms +step:1390/1680 train_time:122258ms step_avg:87.96ms +step:1391/1680 train_time:122348ms step_avg:87.96ms +step:1392/1680 train_time:122437ms step_avg:87.96ms +step:1393/1680 train_time:122525ms step_avg:87.96ms +step:1394/1680 train_time:122613ms step_avg:87.96ms +step:1395/1680 train_time:122701ms step_avg:87.96ms +step:1396/1680 train_time:122790ms step_avg:87.96ms +step:1397/1680 train_time:122877ms step_avg:87.96ms +step:1398/1680 train_time:122966ms step_avg:87.96ms +step:1399/1680 train_time:123055ms step_avg:87.96ms +step:1400/1680 train_time:123146ms step_avg:87.96ms +step:1401/1680 train_time:123236ms step_avg:87.96ms +step:1402/1680 train_time:123325ms step_avg:87.96ms +step:1403/1680 train_time:123415ms step_avg:87.96ms +step:1404/1680 train_time:123503ms step_avg:87.97ms +step:1405/1680 train_time:123592ms step_avg:87.97ms +step:1406/1680 train_time:123680ms step_avg:87.97ms +step:1407/1680 train_time:123768ms step_avg:87.97ms +step:1408/1680 train_time:123857ms step_avg:87.97ms +step:1409/1680 train_time:123946ms step_avg:87.97ms +step:1410/1680 train_time:124035ms step_avg:87.97ms +step:1411/1680 train_time:124124ms step_avg:87.97ms +step:1412/1680 train_time:124214ms step_avg:87.97ms +step:1413/1680 train_time:124303ms step_avg:87.97ms +step:1414/1680 train_time:124393ms step_avg:87.97ms +step:1415/1680 train_time:124482ms step_avg:87.97ms +step:1416/1680 train_time:124570ms step_avg:87.97ms +step:1417/1680 train_time:124658ms step_avg:87.97ms +step:1418/1680 train_time:124746ms step_avg:87.97ms +step:1419/1680 train_time:124836ms step_avg:87.97ms +step:1420/1680 train_time:124925ms step_avg:87.98ms +step:1421/1680 train_time:125014ms step_avg:87.98ms +step:1422/1680 train_time:125103ms step_avg:87.98ms +step:1423/1680 train_time:125193ms step_avg:87.98ms +step:1424/1680 train_time:125283ms step_avg:87.98ms +step:1425/1680 train_time:125372ms step_avg:87.98ms +step:1426/1680 train_time:125462ms step_avg:87.98ms +step:1427/1680 train_time:125551ms step_avg:87.98ms +step:1428/1680 train_time:125640ms step_avg:87.98ms +step:1429/1680 train_time:125729ms step_avg:87.98ms +step:1430/1680 train_time:125817ms step_avg:87.98ms +step:1431/1680 train_time:125906ms step_avg:87.98ms +step:1432/1680 train_time:125995ms step_avg:87.99ms +step:1433/1680 train_time:126084ms step_avg:87.99ms +step:1434/1680 train_time:126173ms step_avg:87.99ms +step:1435/1680 train_time:126263ms step_avg:87.99ms +step:1436/1680 train_time:126352ms step_avg:87.99ms +step:1437/1680 train_time:126441ms step_avg:87.99ms +step:1438/1680 train_time:126530ms step_avg:87.99ms +step:1439/1680 train_time:126618ms step_avg:87.99ms +step:1440/1680 train_time:126706ms step_avg:87.99ms +step:1441/1680 train_time:126795ms step_avg:87.99ms +step:1442/1680 train_time:126884ms step_avg:87.99ms +step:1443/1680 train_time:126973ms step_avg:87.99ms +step:1444/1680 train_time:127062ms step_avg:87.99ms +step:1445/1680 train_time:127151ms step_avg:87.99ms +step:1446/1680 train_time:127240ms step_avg:87.99ms +step:1447/1680 train_time:127331ms step_avg:88.00ms +step:1448/1680 train_time:127421ms step_avg:88.00ms +step:1449/1680 train_time:127511ms step_avg:88.00ms +step:1450/1680 train_time:127601ms step_avg:88.00ms +step:1451/1680 train_time:127689ms step_avg:88.00ms +step:1452/1680 train_time:127778ms step_avg:88.00ms +step:1453/1680 train_time:127866ms step_avg:88.00ms +step:1454/1680 train_time:127955ms step_avg:88.00ms +step:1455/1680 train_time:128044ms step_avg:88.00ms +step:1456/1680 train_time:128133ms step_avg:88.00ms +step:1457/1680 train_time:128222ms step_avg:88.00ms +step:1458/1680 train_time:128311ms step_avg:88.01ms +step:1459/1680 train_time:128400ms step_avg:88.01ms +step:1460/1680 train_time:128489ms step_avg:88.01ms +step:1461/1680 train_time:128578ms step_avg:88.01ms +step:1462/1680 train_time:128667ms step_avg:88.01ms +step:1463/1680 train_time:128756ms step_avg:88.01ms +step:1464/1680 train_time:128844ms step_avg:88.01ms +step:1465/1680 train_time:128933ms step_avg:88.01ms +step:1466/1680 train_time:129022ms step_avg:88.01ms +step:1467/1680 train_time:129111ms step_avg:88.01ms +step:1468/1680 train_time:129201ms step_avg:88.01ms +step:1469/1680 train_time:129290ms step_avg:88.01ms +step:1470/1680 train_time:129378ms step_avg:88.01ms +step:1471/1680 train_time:129467ms step_avg:88.01ms +step:1472/1680 train_time:129558ms step_avg:88.01ms +step:1473/1680 train_time:129647ms step_avg:88.02ms +step:1474/1680 train_time:129736ms step_avg:88.02ms +step:1475/1680 train_time:129826ms step_avg:88.02ms +step:1476/1680 train_time:129915ms step_avg:88.02ms +step:1477/1680 train_time:130005ms step_avg:88.02ms +step:1478/1680 train_time:130093ms step_avg:88.02ms +step:1479/1680 train_time:130183ms step_avg:88.02ms +step:1480/1680 train_time:130272ms step_avg:88.02ms +step:1481/1680 train_time:130361ms step_avg:88.02ms +step:1482/1680 train_time:130450ms step_avg:88.02ms +step:1483/1680 train_time:130539ms step_avg:88.02ms +step:1484/1680 train_time:130628ms step_avg:88.02ms +step:1485/1680 train_time:130718ms step_avg:88.03ms +step:1486/1680 train_time:130807ms step_avg:88.03ms +step:1487/1680 train_time:130895ms step_avg:88.03ms +step:1488/1680 train_time:130984ms step_avg:88.03ms +step:1489/1680 train_time:131073ms step_avg:88.03ms +step:1490/1680 train_time:131161ms step_avg:88.03ms +step:1491/1680 train_time:131251ms step_avg:88.03ms +step:1492/1680 train_time:131340ms step_avg:88.03ms +step:1493/1680 train_time:131430ms step_avg:88.03ms +step:1494/1680 train_time:131519ms step_avg:88.03ms +step:1495/1680 train_time:131609ms step_avg:88.03ms +step:1496/1680 train_time:131697ms step_avg:88.03ms +step:1497/1680 train_time:131786ms step_avg:88.03ms +step:1498/1680 train_time:131875ms step_avg:88.03ms +step:1499/1680 train_time:131964ms step_avg:88.03ms +step:1500/1680 train_time:132052ms step_avg:88.03ms +step:1500/1680 val_loss:3.3132 train_time:132143ms step_avg:88.10ms +step:1501/1680 train_time:132163ms step_avg:88.05ms +step:1502/1680 train_time:132235ms step_avg:88.04ms +step:1503/1680 train_time:132330ms step_avg:88.04ms +step:1504/1680 train_time:132418ms step_avg:88.04ms +step:1505/1680 train_time:132506ms step_avg:88.04ms +step:1506/1680 train_time:132594ms step_avg:88.04ms +step:1507/1680 train_time:132682ms step_avg:88.04ms +step:1508/1680 train_time:132770ms step_avg:88.04ms +step:1509/1680 train_time:132858ms step_avg:88.04ms +step:1510/1680 train_time:132947ms step_avg:88.04ms +step:1511/1680 train_time:133034ms step_avg:88.04ms +step:1512/1680 train_time:133124ms step_avg:88.05ms +step:1513/1680 train_time:133214ms step_avg:88.05ms +step:1514/1680 train_time:133305ms step_avg:88.05ms +step:1515/1680 train_time:133395ms step_avg:88.05ms +step:1516/1680 train_time:133484ms step_avg:88.05ms +step:1517/1680 train_time:133573ms step_avg:88.05ms +step:1518/1680 train_time:133661ms step_avg:88.05ms +step:1519/1680 train_time:133750ms step_avg:88.05ms +step:1520/1680 train_time:133838ms step_avg:88.05ms +step:1521/1680 train_time:133926ms step_avg:88.05ms +step:1522/1680 train_time:134015ms step_avg:88.05ms +step:1523/1680 train_time:134104ms step_avg:88.05ms +step:1524/1680 train_time:134194ms step_avg:88.05ms +step:1525/1680 train_time:134285ms step_avg:88.06ms +step:1526/1680 train_time:134374ms step_avg:88.06ms +step:1527/1680 train_time:134463ms step_avg:88.06ms +step:1528/1680 train_time:134552ms step_avg:88.06ms +step:1529/1680 train_time:134641ms step_avg:88.06ms +step:1530/1680 train_time:134729ms step_avg:88.06ms +step:1531/1680 train_time:134817ms step_avg:88.06ms +step:1532/1680 train_time:134907ms step_avg:88.06ms +step:1533/1680 train_time:134996ms step_avg:88.06ms +step:1534/1680 train_time:135085ms step_avg:88.06ms +step:1535/1680 train_time:135174ms step_avg:88.06ms +step:1536/1680 train_time:135264ms step_avg:88.06ms +step:1537/1680 train_time:135353ms step_avg:88.06ms +step:1538/1680 train_time:135443ms step_avg:88.06ms +step:1539/1680 train_time:135532ms step_avg:88.06ms +step:1540/1680 train_time:135621ms step_avg:88.07ms +step:1541/1680 train_time:135709ms step_avg:88.07ms +step:1542/1680 train_time:135798ms step_avg:88.07ms +step:1543/1680 train_time:135887ms step_avg:88.07ms +step:1544/1680 train_time:135975ms step_avg:88.07ms +step:1545/1680 train_time:136064ms step_avg:88.07ms +step:1546/1680 train_time:136154ms step_avg:88.07ms +step:1547/1680 train_time:136242ms step_avg:88.07ms +step:1548/1680 train_time:136332ms step_avg:88.07ms +step:1549/1680 train_time:136421ms step_avg:88.07ms +step:1550/1680 train_time:136509ms step_avg:88.07ms +step:1551/1680 train_time:136598ms step_avg:88.07ms +step:1552/1680 train_time:136686ms step_avg:88.07ms +step:1553/1680 train_time:136775ms step_avg:88.07ms +step:1554/1680 train_time:136864ms step_avg:88.07ms +step:1555/1680 train_time:136953ms step_avg:88.07ms +step:1556/1680 train_time:137043ms step_avg:88.07ms +step:1557/1680 train_time:137132ms step_avg:88.07ms +step:1558/1680 train_time:137221ms step_avg:88.07ms +step:1559/1680 train_time:137310ms step_avg:88.08ms +step:1560/1680 train_time:137399ms step_avg:88.08ms +step:1561/1680 train_time:137490ms step_avg:88.08ms +step:1562/1680 train_time:137578ms step_avg:88.08ms +step:1563/1680 train_time:137668ms step_avg:88.08ms +step:1564/1680 train_time:137758ms step_avg:88.08ms +step:1565/1680 train_time:137847ms step_avg:88.08ms +step:1566/1680 train_time:137935ms step_avg:88.08ms +step:1567/1680 train_time:138023ms step_avg:88.08ms +step:1568/1680 train_time:138112ms step_avg:88.08ms +step:1569/1680 train_time:138200ms step_avg:88.08ms +step:1570/1680 train_time:138290ms step_avg:88.08ms +step:1571/1680 train_time:138380ms step_avg:88.08ms +step:1572/1680 train_time:138470ms step_avg:88.09ms +step:1573/1680 train_time:138559ms step_avg:88.09ms +step:1574/1680 train_time:138649ms step_avg:88.09ms +step:1575/1680 train_time:138738ms step_avg:88.09ms +step:1576/1680 train_time:138826ms step_avg:88.09ms +step:1577/1680 train_time:138914ms step_avg:88.09ms +step:1578/1680 train_time:139004ms step_avg:88.09ms +step:1579/1680 train_time:139093ms step_avg:88.09ms +step:1580/1680 train_time:139182ms step_avg:88.09ms +step:1581/1680 train_time:139271ms step_avg:88.09ms +step:1582/1680 train_time:139360ms step_avg:88.09ms +step:1583/1680 train_time:139449ms step_avg:88.09ms +step:1584/1680 train_time:139538ms step_avg:88.09ms +step:1585/1680 train_time:139627ms step_avg:88.09ms +step:1586/1680 train_time:139717ms step_avg:88.09ms +step:1587/1680 train_time:139805ms step_avg:88.09ms +step:1588/1680 train_time:139894ms step_avg:88.09ms +step:1589/1680 train_time:139983ms step_avg:88.10ms +step:1590/1680 train_time:140072ms step_avg:88.10ms +step:1591/1680 train_time:140161ms step_avg:88.10ms +step:1592/1680 train_time:140250ms step_avg:88.10ms +step:1593/1680 train_time:140339ms step_avg:88.10ms +step:1594/1680 train_time:140428ms step_avg:88.10ms +step:1595/1680 train_time:140517ms step_avg:88.10ms +step:1596/1680 train_time:140607ms step_avg:88.10ms +step:1597/1680 train_time:140696ms step_avg:88.10ms +step:1598/1680 train_time:140785ms step_avg:88.10ms +step:1599/1680 train_time:140874ms step_avg:88.10ms +step:1600/1680 train_time:140963ms step_avg:88.10ms +step:1601/1680 train_time:141052ms step_avg:88.10ms +step:1602/1680 train_time:141140ms step_avg:88.10ms +step:1603/1680 train_time:141229ms step_avg:88.10ms +step:1604/1680 train_time:141318ms step_avg:88.10ms +step:1605/1680 train_time:141406ms step_avg:88.10ms +step:1606/1680 train_time:141494ms step_avg:88.10ms +step:1607/1680 train_time:141584ms step_avg:88.10ms +step:1608/1680 train_time:141673ms step_avg:88.11ms +step:1609/1680 train_time:141762ms step_avg:88.11ms +step:1610/1680 train_time:141853ms step_avg:88.11ms +step:1611/1680 train_time:141943ms step_avg:88.11ms +step:1612/1680 train_time:142032ms step_avg:88.11ms +step:1613/1680 train_time:142121ms step_avg:88.11ms +step:1614/1680 train_time:142210ms step_avg:88.11ms +step:1615/1680 train_time:142299ms step_avg:88.11ms +step:1616/1680 train_time:142389ms step_avg:88.11ms +step:1617/1680 train_time:142478ms step_avg:88.11ms +step:1618/1680 train_time:142568ms step_avg:88.11ms +step:1619/1680 train_time:142658ms step_avg:88.12ms +step:1620/1680 train_time:142747ms step_avg:88.12ms +step:1621/1680 train_time:142836ms step_avg:88.12ms +step:1622/1680 train_time:142925ms step_avg:88.12ms +step:1623/1680 train_time:143015ms step_avg:88.12ms +step:1624/1680 train_time:143104ms step_avg:88.12ms +step:1625/1680 train_time:143192ms step_avg:88.12ms +step:1625/1680 val_loss:3.2893 train_time:143282ms step_avg:88.17ms +step:1626/1680 train_time:143300ms step_avg:88.13ms +step:1627/1680 train_time:143374ms step_avg:88.12ms +step:1628/1680 train_time:143471ms step_avg:88.13ms +step:1629/1680 train_time:143560ms step_avg:88.13ms +step:1630/1680 train_time:143648ms step_avg:88.13ms +step:1631/1680 train_time:143737ms step_avg:88.13ms +step:1632/1680 train_time:143825ms step_avg:88.13ms +step:1633/1680 train_time:143913ms step_avg:88.13ms +step:1634/1680 train_time:144001ms step_avg:88.13ms +step:1635/1680 train_time:144089ms step_avg:88.13ms +step:1636/1680 train_time:144179ms step_avg:88.13ms +step:1637/1680 train_time:144270ms step_avg:88.13ms +step:1638/1680 train_time:144363ms step_avg:88.13ms +step:1639/1680 train_time:144455ms step_avg:88.14ms +step:1640/1680 train_time:144546ms step_avg:88.14ms +step:1641/1680 train_time:144635ms step_avg:88.14ms +step:1642/1680 train_time:144723ms step_avg:88.14ms +step:1643/1680 train_time:144811ms step_avg:88.14ms +step:1644/1680 train_time:144899ms step_avg:88.14ms +step:1645/1680 train_time:144987ms step_avg:88.14ms +step:1646/1680 train_time:145075ms step_avg:88.14ms +step:1647/1680 train_time:145164ms step_avg:88.14ms +step:1648/1680 train_time:145253ms step_avg:88.14ms +step:1649/1680 train_time:145345ms step_avg:88.14ms +step:1650/1680 train_time:145435ms step_avg:88.14ms +step:1651/1680 train_time:145525ms step_avg:88.14ms +step:1652/1680 train_time:145614ms step_avg:88.14ms +step:1653/1680 train_time:145703ms step_avg:88.14ms +step:1654/1680 train_time:145791ms step_avg:88.14ms +step:1655/1680 train_time:145880ms step_avg:88.15ms +step:1656/1680 train_time:145969ms step_avg:88.15ms +step:1657/1680 train_time:146057ms step_avg:88.15ms +step:1658/1680 train_time:146146ms step_avg:88.15ms +step:1659/1680 train_time:146234ms step_avg:88.15ms +step:1660/1680 train_time:146325ms step_avg:88.15ms +step:1661/1680 train_time:146414ms step_avg:88.15ms +step:1662/1680 train_time:146503ms step_avg:88.15ms +step:1663/1680 train_time:146593ms step_avg:88.15ms +step:1664/1680 train_time:146683ms step_avg:88.15ms +step:1665/1680 train_time:146771ms step_avg:88.15ms +step:1666/1680 train_time:146860ms step_avg:88.15ms +step:1667/1680 train_time:146948ms step_avg:88.15ms +step:1668/1680 train_time:147037ms step_avg:88.15ms +step:1669/1680 train_time:147125ms step_avg:88.15ms +step:1670/1680 train_time:147214ms step_avg:88.15ms +step:1671/1680 train_time:147303ms step_avg:88.15ms +step:1672/1680 train_time:147393ms step_avg:88.15ms +step:1673/1680 train_time:147483ms step_avg:88.15ms +step:1674/1680 train_time:147572ms step_avg:88.16ms +step:1675/1680 train_time:147660ms step_avg:88.16ms +step:1676/1680 train_time:147749ms step_avg:88.16ms +step:1677/1680 train_time:147838ms step_avg:88.16ms +step:1678/1680 train_time:147927ms step_avg:88.16ms +step:1679/1680 train_time:148016ms step_avg:88.16ms +step:1680/1680 train_time:148105ms step_avg:88.16ms +step:1680/1680 val_loss:3.2783 train_time:148196ms step_avg:88.21ms +peak memory allocated: 30760 MiB reserved: 45914 MiB diff --git a/records/092725_BF16CE/68d6605a-f386-4e4f-84c0-2582dc6989d8.txt b/records/092725_BF16CE/68d6605a-f386-4e4f-84c0-2582dc6989d8.txt new file mode 100644 index 000000000..079c044bb --- /dev/null +++ b/records/092725_BF16CE/68d6605a-f386-4e4f-84c0-2582dc6989d8.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:25:28 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 25C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 30C P0 123W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 28C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 156805 C /usr/bin/python 0MiB | +| 0 N/A N/A 156806 C /usr/bin/python 0MiB | +| 0 N/A N/A 156807 C /usr/bin/python 0MiB | +| 0 N/A N/A 156808 C /usr/bin/python 0MiB | +| 0 N/A N/A 156809 C /usr/bin/python 0MiB | +| 0 N/A N/A 156810 C /usr/bin/python 0MiB | +| 0 N/A N/A 156811 C /usr/bin/python 0MiB | +| 0 N/A N/A 156812 C /usr/bin/python 0MiB | +| 1 N/A N/A 156806 C /usr/bin/python 0MiB | +| 2 N/A N/A 156807 C /usr/bin/python 0MiB | +| 3 N/A N/A 156808 C /usr/bin/python 0MiB | +| 4 N/A N/A 156809 C /usr/bin/python 0MiB | +| 5 N/A N/A 156810 C /usr/bin/python 0MiB | +| 6 N/A N/A 156811 C /usr/bin/python 0MiB | +| 7 N/A N/A 156812 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:151ms step_avg:151.29ms +step:2/1680 train_time:172ms step_avg:85.76ms +step:3/1680 train_time:235ms step_avg:78.29ms +step:4/1680 train_time:320ms step_avg:80.05ms +step:5/1680 train_time:406ms step_avg:81.23ms +step:6/1680 train_time:493ms step_avg:82.17ms +step:7/1680 train_time:580ms step_avg:82.80ms +step:8/1680 train_time:666ms step_avg:83.28ms +step:9/1680 train_time:752ms step_avg:83.61ms +step:10/1680 train_time:839ms step_avg:83.88ms +step:11/1680 train_time:926ms step_avg:84.19ms +step:12/1680 train_time:1012ms step_avg:84.34ms +step:13/1680 train_time:1101ms step_avg:84.70ms +step:14/1680 train_time:1193ms step_avg:85.22ms +step:15/1680 train_time:1282ms step_avg:85.49ms +step:16/1680 train_time:1370ms step_avg:85.62ms +step:17/1680 train_time:1457ms step_avg:85.71ms +step:18/1680 train_time:1545ms step_avg:85.81ms +step:19/1680 train_time:1632ms step_avg:85.87ms +step:20/1680 train_time:1718ms step_avg:85.90ms +step:21/1680 train_time:1804ms step_avg:85.91ms +step:22/1680 train_time:1891ms step_avg:85.95ms +step:23/1680 train_time:1978ms step_avg:86.00ms +step:24/1680 train_time:2065ms step_avg:86.05ms +step:25/1680 train_time:2154ms step_avg:86.16ms +step:26/1680 train_time:2243ms step_avg:86.26ms +step:27/1680 train_time:2331ms step_avg:86.32ms +step:28/1680 train_time:2418ms step_avg:86.35ms +step:29/1680 train_time:2505ms step_avg:86.38ms +step:30/1680 train_time:2592ms step_avg:86.41ms +step:31/1680 train_time:2680ms step_avg:86.44ms +step:32/1680 train_time:2767ms step_avg:86.45ms +step:33/1680 train_time:2853ms step_avg:86.46ms +step:34/1680 train_time:2940ms step_avg:86.47ms +step:35/1680 train_time:3027ms step_avg:86.50ms +step:36/1680 train_time:3115ms step_avg:86.53ms +step:37/1680 train_time:3203ms step_avg:86.57ms +step:38/1680 train_time:3292ms step_avg:86.62ms +step:39/1680 train_time:3379ms step_avg:86.65ms +step:40/1680 train_time:3467ms step_avg:86.68ms +step:41/1680 train_time:3555ms step_avg:86.70ms +step:42/1680 train_time:3643ms step_avg:86.73ms +step:43/1680 train_time:3729ms step_avg:86.73ms +step:44/1680 train_time:3817ms step_avg:86.74ms +step:45/1680 train_time:3904ms step_avg:86.75ms +step:46/1680 train_time:3991ms step_avg:86.76ms +step:47/1680 train_time:4078ms step_avg:86.76ms +step:48/1680 train_time:4166ms step_avg:86.79ms +step:49/1680 train_time:4254ms step_avg:86.82ms +step:50/1680 train_time:4342ms step_avg:86.85ms +step:51/1680 train_time:4430ms step_avg:86.87ms +step:52/1680 train_time:4517ms step_avg:86.87ms +step:53/1680 train_time:4604ms step_avg:86.88ms +step:54/1680 train_time:4691ms step_avg:86.88ms +step:55/1680 train_time:4778ms step_avg:86.88ms +step:56/1680 train_time:4866ms step_avg:86.88ms +step:57/1680 train_time:4953ms step_avg:86.90ms +step:58/1680 train_time:5040ms step_avg:86.90ms +step:59/1680 train_time:5127ms step_avg:86.90ms +step:60/1680 train_time:5214ms step_avg:86.90ms +step:61/1680 train_time:5302ms step_avg:86.93ms +step:62/1680 train_time:5391ms step_avg:86.95ms +step:63/1680 train_time:5478ms step_avg:86.95ms +step:64/1680 train_time:5566ms step_avg:86.96ms +step:65/1680 train_time:5653ms step_avg:86.97ms +step:66/1680 train_time:5741ms step_avg:86.99ms +step:67/1680 train_time:5828ms step_avg:86.99ms +step:68/1680 train_time:5915ms step_avg:86.98ms +step:69/1680 train_time:6002ms step_avg:86.99ms +step:70/1680 train_time:6090ms step_avg:87.00ms +step:71/1680 train_time:6177ms step_avg:87.00ms +step:72/1680 train_time:6264ms step_avg:87.00ms +step:73/1680 train_time:6351ms step_avg:87.01ms +step:74/1680 train_time:6439ms step_avg:87.01ms +step:75/1680 train_time:6527ms step_avg:87.02ms +step:76/1680 train_time:6614ms step_avg:87.02ms +step:77/1680 train_time:6701ms step_avg:87.02ms +step:78/1680 train_time:6788ms step_avg:87.03ms +step:79/1680 train_time:6875ms step_avg:87.03ms +step:80/1680 train_time:6962ms step_avg:87.03ms +step:81/1680 train_time:7050ms step_avg:87.04ms +step:82/1680 train_time:7137ms step_avg:87.04ms +step:83/1680 train_time:7225ms step_avg:87.05ms +step:84/1680 train_time:7312ms step_avg:87.05ms +step:85/1680 train_time:7399ms step_avg:87.05ms +step:86/1680 train_time:7487ms step_avg:87.06ms +step:87/1680 train_time:7574ms step_avg:87.06ms +step:88/1680 train_time:7661ms step_avg:87.06ms +step:89/1680 train_time:7749ms step_avg:87.06ms +step:90/1680 train_time:7836ms step_avg:87.06ms +step:91/1680 train_time:7923ms step_avg:87.07ms +step:92/1680 train_time:8011ms step_avg:87.07ms +step:93/1680 train_time:8098ms step_avg:87.07ms +step:94/1680 train_time:8186ms step_avg:87.09ms +step:95/1680 train_time:8273ms step_avg:87.09ms +step:96/1680 train_time:8361ms step_avg:87.09ms +step:97/1680 train_time:8448ms step_avg:87.09ms +step:98/1680 train_time:8535ms step_avg:87.09ms +step:99/1680 train_time:8622ms step_avg:87.09ms +step:100/1680 train_time:8709ms step_avg:87.09ms +step:101/1680 train_time:8797ms step_avg:87.10ms +step:102/1680 train_time:8885ms step_avg:87.11ms +step:103/1680 train_time:8972ms step_avg:87.11ms +step:104/1680 train_time:9059ms step_avg:87.11ms +step:105/1680 train_time:9147ms step_avg:87.11ms +step:106/1680 train_time:9234ms step_avg:87.11ms +step:107/1680 train_time:9322ms step_avg:87.12ms +step:108/1680 train_time:9409ms step_avg:87.12ms +step:109/1680 train_time:9496ms step_avg:87.12ms +step:110/1680 train_time:9584ms step_avg:87.12ms +step:111/1680 train_time:9671ms step_avg:87.13ms +step:112/1680 train_time:9758ms step_avg:87.13ms +step:113/1680 train_time:9846ms step_avg:87.13ms +step:114/1680 train_time:9933ms step_avg:87.13ms +step:115/1680 train_time:10021ms step_avg:87.14ms +step:116/1680 train_time:10109ms step_avg:87.15ms +step:117/1680 train_time:10196ms step_avg:87.14ms +step:118/1680 train_time:10284ms step_avg:87.15ms +step:119/1680 train_time:10371ms step_avg:87.15ms +step:120/1680 train_time:10460ms step_avg:87.16ms +step:121/1680 train_time:10546ms step_avg:87.16ms +step:122/1680 train_time:10633ms step_avg:87.16ms +step:123/1680 train_time:10721ms step_avg:87.16ms +step:124/1680 train_time:10808ms step_avg:87.16ms +step:125/1680 train_time:10895ms step_avg:87.16ms +step:125/1680 val_loss:4.3079 train_time:10983ms step_avg:87.86ms +step:126/1680 train_time:11013ms step_avg:87.40ms +step:127/1680 train_time:11072ms step_avg:87.18ms +step:128/1680 train_time:11167ms step_avg:87.25ms +step:129/1680 train_time:11256ms step_avg:87.26ms +step:130/1680 train_time:11344ms step_avg:87.26ms +step:131/1680 train_time:11430ms step_avg:87.25ms +step:132/1680 train_time:11516ms step_avg:87.24ms +step:133/1680 train_time:11602ms step_avg:87.24ms +step:134/1680 train_time:11688ms step_avg:87.23ms +step:135/1680 train_time:11774ms step_avg:87.22ms +step:136/1680 train_time:11860ms step_avg:87.21ms +step:137/1680 train_time:11947ms step_avg:87.20ms +step:138/1680 train_time:12034ms step_avg:87.20ms +step:139/1680 train_time:12124ms step_avg:87.22ms +step:140/1680 train_time:12214ms step_avg:87.24ms +step:141/1680 train_time:12303ms step_avg:87.25ms +step:142/1680 train_time:12390ms step_avg:87.26ms +step:143/1680 train_time:12477ms step_avg:87.25ms +step:144/1680 train_time:12564ms step_avg:87.25ms +step:145/1680 train_time:12650ms step_avg:87.24ms +step:146/1680 train_time:12736ms step_avg:87.24ms +step:147/1680 train_time:12822ms step_avg:87.23ms +step:148/1680 train_time:12909ms step_avg:87.22ms +step:149/1680 train_time:12996ms step_avg:87.22ms +step:150/1680 train_time:13085ms step_avg:87.23ms +step:151/1680 train_time:13173ms step_avg:87.24ms +step:152/1680 train_time:13262ms step_avg:87.25ms +step:153/1680 train_time:13350ms step_avg:87.26ms +step:154/1680 train_time:13437ms step_avg:87.25ms +step:155/1680 train_time:13524ms step_avg:87.25ms +step:156/1680 train_time:13610ms step_avg:87.25ms +step:157/1680 train_time:13697ms step_avg:87.24ms +step:158/1680 train_time:13784ms step_avg:87.24ms +step:159/1680 train_time:13871ms step_avg:87.24ms +step:160/1680 train_time:13957ms step_avg:87.23ms +step:161/1680 train_time:14045ms step_avg:87.24ms +step:162/1680 train_time:14133ms step_avg:87.24ms +step:163/1680 train_time:14221ms step_avg:87.24ms +step:164/1680 train_time:14309ms step_avg:87.25ms +step:165/1680 train_time:14396ms step_avg:87.25ms +step:166/1680 train_time:14484ms step_avg:87.25ms +step:167/1680 train_time:14571ms step_avg:87.25ms +step:168/1680 train_time:14658ms step_avg:87.25ms +step:169/1680 train_time:14745ms step_avg:87.25ms +step:170/1680 train_time:14831ms step_avg:87.24ms +step:171/1680 train_time:14918ms step_avg:87.24ms +step:172/1680 train_time:15005ms step_avg:87.24ms +step:173/1680 train_time:15092ms step_avg:87.24ms +step:174/1680 train_time:15180ms step_avg:87.24ms +step:175/1680 train_time:15269ms step_avg:87.25ms +step:176/1680 train_time:15356ms step_avg:87.25ms +step:177/1680 train_time:15444ms step_avg:87.25ms +step:178/1680 train_time:15531ms step_avg:87.25ms +step:179/1680 train_time:15618ms step_avg:87.25ms +step:180/1680 train_time:15705ms step_avg:87.25ms +step:181/1680 train_time:15791ms step_avg:87.24ms +step:182/1680 train_time:15879ms step_avg:87.25ms +step:183/1680 train_time:15965ms step_avg:87.24ms +step:184/1680 train_time:16051ms step_avg:87.23ms +step:185/1680 train_time:16139ms step_avg:87.24ms +step:186/1680 train_time:16226ms step_avg:87.24ms +step:187/1680 train_time:16314ms step_avg:87.24ms +step:188/1680 train_time:16401ms step_avg:87.24ms +step:189/1680 train_time:16489ms step_avg:87.25ms +step:190/1680 train_time:16576ms step_avg:87.24ms +step:191/1680 train_time:16663ms step_avg:87.24ms +step:192/1680 train_time:16751ms step_avg:87.24ms +step:193/1680 train_time:16838ms step_avg:87.24ms +step:194/1680 train_time:16925ms step_avg:87.24ms +step:195/1680 train_time:17012ms step_avg:87.24ms +step:196/1680 train_time:17099ms step_avg:87.24ms +step:197/1680 train_time:17187ms step_avg:87.24ms +step:198/1680 train_time:17274ms step_avg:87.24ms +step:199/1680 train_time:17362ms step_avg:87.25ms +step:200/1680 train_time:17449ms step_avg:87.25ms +step:201/1680 train_time:17537ms step_avg:87.25ms +step:202/1680 train_time:17624ms step_avg:87.25ms +step:203/1680 train_time:17710ms step_avg:87.24ms +step:204/1680 train_time:17797ms step_avg:87.24ms +step:205/1680 train_time:17884ms step_avg:87.24ms +step:206/1680 train_time:17972ms step_avg:87.24ms +step:207/1680 train_time:18058ms step_avg:87.24ms +step:208/1680 train_time:18146ms step_avg:87.24ms +step:209/1680 train_time:18233ms step_avg:87.24ms +step:210/1680 train_time:18321ms step_avg:87.24ms +step:211/1680 train_time:18408ms step_avg:87.24ms +step:212/1680 train_time:18496ms step_avg:87.24ms +step:213/1680 train_time:18583ms step_avg:87.24ms +step:214/1680 train_time:18670ms step_avg:87.24ms +step:215/1680 train_time:18757ms step_avg:87.24ms +step:216/1680 train_time:18844ms step_avg:87.24ms +step:217/1680 train_time:18931ms step_avg:87.24ms +step:218/1680 train_time:19017ms step_avg:87.24ms +step:219/1680 train_time:19105ms step_avg:87.24ms +step:220/1680 train_time:19192ms step_avg:87.24ms +step:221/1680 train_time:19279ms step_avg:87.24ms +step:222/1680 train_time:19367ms step_avg:87.24ms +step:223/1680 train_time:19454ms step_avg:87.24ms +step:224/1680 train_time:19541ms step_avg:87.24ms +step:225/1680 train_time:19629ms step_avg:87.24ms +step:226/1680 train_time:19716ms step_avg:87.24ms +step:227/1680 train_time:19803ms step_avg:87.24ms +step:228/1680 train_time:19891ms step_avg:87.24ms +step:229/1680 train_time:19978ms step_avg:87.24ms +step:230/1680 train_time:20065ms step_avg:87.24ms +step:231/1680 train_time:20152ms step_avg:87.24ms +step:232/1680 train_time:20239ms step_avg:87.24ms +step:233/1680 train_time:20326ms step_avg:87.24ms +step:234/1680 train_time:20413ms step_avg:87.23ms +step:235/1680 train_time:20500ms step_avg:87.23ms +step:236/1680 train_time:20588ms step_avg:87.24ms +step:237/1680 train_time:20675ms step_avg:87.24ms +step:238/1680 train_time:20763ms step_avg:87.24ms +step:239/1680 train_time:20850ms step_avg:87.24ms +step:240/1680 train_time:20937ms step_avg:87.24ms +step:241/1680 train_time:21024ms step_avg:87.24ms +step:242/1680 train_time:21111ms step_avg:87.23ms +step:243/1680 train_time:21197ms step_avg:87.23ms +step:244/1680 train_time:21284ms step_avg:87.23ms +step:245/1680 train_time:21371ms step_avg:87.23ms +step:246/1680 train_time:21458ms step_avg:87.23ms +step:247/1680 train_time:21546ms step_avg:87.23ms +step:248/1680 train_time:21633ms step_avg:87.23ms +step:249/1680 train_time:21720ms step_avg:87.23ms +step:250/1680 train_time:21808ms step_avg:87.23ms +step:250/1680 val_loss:3.9612 train_time:21896ms step_avg:87.58ms +step:251/1680 train_time:21921ms step_avg:87.34ms +step:252/1680 train_time:21990ms step_avg:87.26ms +step:253/1680 train_time:22082ms step_avg:87.28ms +step:254/1680 train_time:22169ms step_avg:87.28ms +step:255/1680 train_time:22255ms step_avg:87.28ms +step:256/1680 train_time:22341ms step_avg:87.27ms +step:257/1680 train_time:22428ms step_avg:87.27ms +step:258/1680 train_time:22514ms step_avg:87.26ms +step:259/1680 train_time:22601ms step_avg:87.26ms +step:260/1680 train_time:22687ms step_avg:87.26ms +step:261/1680 train_time:22774ms step_avg:87.25ms +step:262/1680 train_time:22861ms step_avg:87.26ms +step:263/1680 train_time:22950ms step_avg:87.26ms +step:264/1680 train_time:23039ms step_avg:87.27ms +step:265/1680 train_time:23127ms step_avg:87.27ms +step:266/1680 train_time:23215ms step_avg:87.28ms +step:267/1680 train_time:23302ms step_avg:87.27ms +step:268/1680 train_time:23388ms step_avg:87.27ms +step:269/1680 train_time:23475ms step_avg:87.27ms +step:270/1680 train_time:23562ms step_avg:87.26ms +step:271/1680 train_time:23648ms step_avg:87.26ms +step:272/1680 train_time:23735ms step_avg:87.26ms +step:273/1680 train_time:23821ms step_avg:87.26ms +step:274/1680 train_time:23909ms step_avg:87.26ms +step:275/1680 train_time:23997ms step_avg:87.26ms +step:276/1680 train_time:24086ms step_avg:87.27ms +step:277/1680 train_time:24173ms step_avg:87.27ms +step:278/1680 train_time:24261ms step_avg:87.27ms +step:279/1680 train_time:24348ms step_avg:87.27ms +step:280/1680 train_time:24435ms step_avg:87.27ms +step:281/1680 train_time:24522ms step_avg:87.27ms +step:282/1680 train_time:24609ms step_avg:87.27ms +step:283/1680 train_time:24696ms step_avg:87.26ms +step:284/1680 train_time:24782ms step_avg:87.26ms +step:285/1680 train_time:24869ms step_avg:87.26ms +step:286/1680 train_time:24957ms step_avg:87.26ms +step:287/1680 train_time:25045ms step_avg:87.26ms +step:288/1680 train_time:25132ms step_avg:87.26ms +step:289/1680 train_time:25220ms step_avg:87.27ms +step:290/1680 train_time:25307ms step_avg:87.27ms +step:291/1680 train_time:25395ms step_avg:87.27ms +step:292/1680 train_time:25481ms step_avg:87.26ms +step:293/1680 train_time:25568ms step_avg:87.26ms +step:294/1680 train_time:25655ms step_avg:87.26ms +step:295/1680 train_time:25742ms step_avg:87.26ms +step:296/1680 train_time:25828ms step_avg:87.26ms +step:297/1680 train_time:25916ms step_avg:87.26ms +step:298/1680 train_time:26003ms step_avg:87.26ms +step:299/1680 train_time:26090ms step_avg:87.26ms +step:300/1680 train_time:26179ms step_avg:87.26ms +step:301/1680 train_time:26266ms step_avg:87.26ms +step:302/1680 train_time:26353ms step_avg:87.26ms +step:303/1680 train_time:26441ms step_avg:87.26ms +step:304/1680 train_time:26528ms step_avg:87.26ms +step:305/1680 train_time:26615ms step_avg:87.26ms +step:306/1680 train_time:26703ms step_avg:87.26ms +step:307/1680 train_time:26789ms step_avg:87.26ms +step:308/1680 train_time:26876ms step_avg:87.26ms +step:309/1680 train_time:26964ms step_avg:87.26ms +step:310/1680 train_time:27052ms step_avg:87.26ms +step:311/1680 train_time:27139ms step_avg:87.26ms +step:312/1680 train_time:27226ms step_avg:87.26ms +step:313/1680 train_time:27314ms step_avg:87.27ms +step:314/1680 train_time:27402ms step_avg:87.27ms +step:315/1680 train_time:27489ms step_avg:87.27ms +step:316/1680 train_time:27577ms step_avg:87.27ms +step:317/1680 train_time:27664ms step_avg:87.27ms +step:318/1680 train_time:27751ms step_avg:87.27ms +step:319/1680 train_time:27837ms step_avg:87.26ms +step:320/1680 train_time:27925ms step_avg:87.27ms +step:321/1680 train_time:28012ms step_avg:87.27ms +step:322/1680 train_time:28100ms step_avg:87.27ms +step:323/1680 train_time:28187ms step_avg:87.27ms +step:324/1680 train_time:28274ms step_avg:87.27ms +step:325/1680 train_time:28361ms step_avg:87.27ms +step:326/1680 train_time:28449ms step_avg:87.27ms +step:327/1680 train_time:28537ms step_avg:87.27ms +step:328/1680 train_time:28624ms step_avg:87.27ms +step:329/1680 train_time:28711ms step_avg:87.27ms +step:330/1680 train_time:28798ms step_avg:87.27ms +step:331/1680 train_time:28886ms step_avg:87.27ms +step:332/1680 train_time:28972ms step_avg:87.27ms +step:333/1680 train_time:29059ms step_avg:87.27ms +step:334/1680 train_time:29147ms step_avg:87.27ms +step:335/1680 train_time:29234ms step_avg:87.26ms +step:336/1680 train_time:29321ms step_avg:87.27ms +step:337/1680 train_time:29408ms step_avg:87.27ms +step:338/1680 train_time:29496ms step_avg:87.27ms +step:339/1680 train_time:29583ms step_avg:87.27ms +step:340/1680 train_time:29670ms step_avg:87.26ms +step:341/1680 train_time:29757ms step_avg:87.26ms +step:342/1680 train_time:29844ms step_avg:87.26ms +step:343/1680 train_time:29931ms step_avg:87.26ms +step:344/1680 train_time:30019ms step_avg:87.26ms +step:345/1680 train_time:30106ms step_avg:87.26ms +step:346/1680 train_time:30193ms step_avg:87.26ms +step:347/1680 train_time:30281ms step_avg:87.26ms +step:348/1680 train_time:30367ms step_avg:87.26ms +step:349/1680 train_time:30455ms step_avg:87.26ms +step:350/1680 train_time:30543ms step_avg:87.26ms +step:351/1680 train_time:30630ms step_avg:87.26ms +step:352/1680 train_time:30717ms step_avg:87.26ms +step:353/1680 train_time:30804ms step_avg:87.26ms +step:354/1680 train_time:30891ms step_avg:87.26ms +step:355/1680 train_time:30979ms step_avg:87.26ms +step:356/1680 train_time:31066ms step_avg:87.26ms +step:357/1680 train_time:31153ms step_avg:87.26ms +step:358/1680 train_time:31240ms step_avg:87.26ms +step:359/1680 train_time:31327ms step_avg:87.26ms +step:360/1680 train_time:31414ms step_avg:87.26ms +step:361/1680 train_time:31503ms step_avg:87.27ms +step:362/1680 train_time:31590ms step_avg:87.26ms +step:363/1680 train_time:31678ms step_avg:87.27ms +step:364/1680 train_time:31765ms step_avg:87.27ms +step:365/1680 train_time:31852ms step_avg:87.26ms +step:366/1680 train_time:31939ms step_avg:87.26ms +step:367/1680 train_time:32026ms step_avg:87.26ms +step:368/1680 train_time:32113ms step_avg:87.26ms +step:369/1680 train_time:32201ms step_avg:87.27ms +step:370/1680 train_time:32288ms step_avg:87.26ms +step:371/1680 train_time:32375ms step_avg:87.26ms +step:372/1680 train_time:32463ms step_avg:87.27ms +step:373/1680 train_time:32550ms step_avg:87.27ms +step:374/1680 train_time:32637ms step_avg:87.27ms +step:375/1680 train_time:32725ms step_avg:87.27ms +step:375/1680 val_loss:3.8140 train_time:32813ms step_avg:87.50ms +step:376/1680 train_time:32837ms step_avg:87.33ms +step:377/1680 train_time:32903ms step_avg:87.28ms +step:378/1680 train_time:32993ms step_avg:87.28ms +step:379/1680 train_time:33083ms step_avg:87.29ms +step:380/1680 train_time:33170ms step_avg:87.29ms +step:381/1680 train_time:33256ms step_avg:87.29ms +step:382/1680 train_time:33342ms step_avg:87.28ms +step:383/1680 train_time:33428ms step_avg:87.28ms +step:384/1680 train_time:33514ms step_avg:87.28ms +step:385/1680 train_time:33601ms step_avg:87.28ms +step:386/1680 train_time:33688ms step_avg:87.27ms +step:387/1680 train_time:33775ms step_avg:87.27ms +step:388/1680 train_time:33863ms step_avg:87.28ms +step:389/1680 train_time:33951ms step_avg:87.28ms +step:390/1680 train_time:34040ms step_avg:87.28ms +step:391/1680 train_time:34128ms step_avg:87.28ms +step:392/1680 train_time:34216ms step_avg:87.28ms +step:393/1680 train_time:34302ms step_avg:87.28ms +step:394/1680 train_time:34389ms step_avg:87.28ms +step:395/1680 train_time:34475ms step_avg:87.28ms +step:396/1680 train_time:34561ms step_avg:87.28ms +step:397/1680 train_time:34648ms step_avg:87.27ms +step:398/1680 train_time:34734ms step_avg:87.27ms +step:399/1680 train_time:34821ms step_avg:87.27ms +step:400/1680 train_time:34909ms step_avg:87.27ms +step:401/1680 train_time:34997ms step_avg:87.27ms +step:402/1680 train_time:35085ms step_avg:87.28ms +step:403/1680 train_time:35173ms step_avg:87.28ms +step:404/1680 train_time:35260ms step_avg:87.28ms +step:405/1680 train_time:35347ms step_avg:87.28ms +step:406/1680 train_time:35434ms step_avg:87.27ms +step:407/1680 train_time:35520ms step_avg:87.27ms +step:408/1680 train_time:35607ms step_avg:87.27ms +step:409/1680 train_time:35693ms step_avg:87.27ms +step:410/1680 train_time:35780ms step_avg:87.27ms +step:411/1680 train_time:35868ms step_avg:87.27ms +step:412/1680 train_time:35956ms step_avg:87.27ms +step:413/1680 train_time:36043ms step_avg:87.27ms +step:414/1680 train_time:36131ms step_avg:87.27ms +step:415/1680 train_time:36219ms step_avg:87.27ms +step:416/1680 train_time:36306ms step_avg:87.27ms +step:417/1680 train_time:36392ms step_avg:87.27ms +step:418/1680 train_time:36480ms step_avg:87.27ms +step:419/1680 train_time:36567ms step_avg:87.27ms +step:420/1680 train_time:36654ms step_avg:87.27ms +step:421/1680 train_time:36741ms step_avg:87.27ms +step:422/1680 train_time:36828ms step_avg:87.27ms +step:423/1680 train_time:36915ms step_avg:87.27ms +step:424/1680 train_time:37002ms step_avg:87.27ms +step:425/1680 train_time:37089ms step_avg:87.27ms +step:426/1680 train_time:37176ms step_avg:87.27ms +step:427/1680 train_time:37264ms step_avg:87.27ms +step:428/1680 train_time:37351ms step_avg:87.27ms +step:429/1680 train_time:37439ms step_avg:87.27ms +step:430/1680 train_time:37526ms step_avg:87.27ms +step:431/1680 train_time:37614ms step_avg:87.27ms +step:432/1680 train_time:37701ms step_avg:87.27ms +step:433/1680 train_time:37788ms step_avg:87.27ms +step:434/1680 train_time:37875ms step_avg:87.27ms +step:435/1680 train_time:37962ms step_avg:87.27ms +step:436/1680 train_time:38049ms step_avg:87.27ms +step:437/1680 train_time:38137ms step_avg:87.27ms +step:438/1680 train_time:38224ms step_avg:87.27ms +step:439/1680 train_time:38312ms step_avg:87.27ms +step:440/1680 train_time:38399ms step_avg:87.27ms +step:441/1680 train_time:38486ms step_avg:87.27ms +step:442/1680 train_time:38573ms step_avg:87.27ms +step:443/1680 train_time:38661ms step_avg:87.27ms +step:444/1680 train_time:38748ms step_avg:87.27ms +step:445/1680 train_time:38837ms step_avg:87.27ms +step:446/1680 train_time:38923ms step_avg:87.27ms +step:447/1680 train_time:39010ms step_avg:87.27ms +step:448/1680 train_time:39098ms step_avg:87.27ms +step:449/1680 train_time:39186ms step_avg:87.27ms +step:450/1680 train_time:39272ms step_avg:87.27ms +step:451/1680 train_time:39360ms step_avg:87.27ms +step:452/1680 train_time:39446ms step_avg:87.27ms +step:453/1680 train_time:39534ms step_avg:87.27ms +step:454/1680 train_time:39622ms step_avg:87.27ms +step:455/1680 train_time:39709ms step_avg:87.27ms +step:456/1680 train_time:39796ms step_avg:87.27ms +step:457/1680 train_time:39883ms step_avg:87.27ms +step:458/1680 train_time:39970ms step_avg:87.27ms +step:459/1680 train_time:40057ms step_avg:87.27ms +step:460/1680 train_time:40144ms step_avg:87.27ms +step:461/1680 train_time:40232ms step_avg:87.27ms +step:462/1680 train_time:40319ms step_avg:87.27ms +step:463/1680 train_time:40407ms step_avg:87.27ms +step:464/1680 train_time:40495ms step_avg:87.27ms +step:465/1680 train_time:40582ms step_avg:87.27ms +step:466/1680 train_time:40669ms step_avg:87.27ms +step:467/1680 train_time:40756ms step_avg:87.27ms +step:468/1680 train_time:40843ms step_avg:87.27ms +step:469/1680 train_time:40930ms step_avg:87.27ms +step:470/1680 train_time:41018ms step_avg:87.27ms +step:471/1680 train_time:41105ms step_avg:87.27ms +step:472/1680 train_time:41192ms step_avg:87.27ms +step:473/1680 train_time:41280ms step_avg:87.27ms +step:474/1680 train_time:41367ms step_avg:87.27ms +step:475/1680 train_time:41454ms step_avg:87.27ms +step:476/1680 train_time:41541ms step_avg:87.27ms +step:477/1680 train_time:41628ms step_avg:87.27ms +step:478/1680 train_time:41716ms step_avg:87.27ms +step:479/1680 train_time:41803ms step_avg:87.27ms +step:480/1680 train_time:41891ms step_avg:87.27ms +step:481/1680 train_time:41978ms step_avg:87.27ms +step:482/1680 train_time:42065ms step_avg:87.27ms +step:483/1680 train_time:42152ms step_avg:87.27ms +step:484/1680 train_time:42240ms step_avg:87.27ms +step:485/1680 train_time:42327ms step_avg:87.27ms +step:486/1680 train_time:42414ms step_avg:87.27ms +step:487/1680 train_time:42502ms step_avg:87.27ms +step:488/1680 train_time:42589ms step_avg:87.27ms +step:489/1680 train_time:42676ms step_avg:87.27ms +step:490/1680 train_time:42764ms step_avg:87.27ms +step:491/1680 train_time:42851ms step_avg:87.27ms +step:492/1680 train_time:42938ms step_avg:87.27ms +step:493/1680 train_time:43026ms step_avg:87.27ms +step:494/1680 train_time:43113ms step_avg:87.27ms +step:495/1680 train_time:43200ms step_avg:87.27ms +step:496/1680 train_time:43287ms step_avg:87.27ms +step:497/1680 train_time:43375ms step_avg:87.27ms +step:498/1680 train_time:43462ms step_avg:87.27ms +step:499/1680 train_time:43549ms step_avg:87.27ms +step:500/1680 train_time:43636ms step_avg:87.27ms +step:500/1680 val_loss:3.7139 train_time:43726ms step_avg:87.45ms +step:501/1680 train_time:43747ms step_avg:87.32ms +step:502/1680 train_time:43815ms step_avg:87.28ms +step:503/1680 train_time:43905ms step_avg:87.29ms +step:504/1680 train_time:43993ms step_avg:87.29ms +step:505/1680 train_time:44081ms step_avg:87.29ms +step:506/1680 train_time:44168ms step_avg:87.29ms +step:507/1680 train_time:44254ms step_avg:87.29ms +step:508/1680 train_time:44340ms step_avg:87.28ms +step:509/1680 train_time:44426ms step_avg:87.28ms +step:510/1680 train_time:44512ms step_avg:87.28ms +step:511/1680 train_time:44598ms step_avg:87.28ms +step:512/1680 train_time:44687ms step_avg:87.28ms +step:513/1680 train_time:44775ms step_avg:87.28ms +step:514/1680 train_time:44865ms step_avg:87.29ms +step:515/1680 train_time:44953ms step_avg:87.29ms +step:516/1680 train_time:45041ms step_avg:87.29ms +step:517/1680 train_time:45127ms step_avg:87.29ms +step:518/1680 train_time:45214ms step_avg:87.29ms +step:519/1680 train_time:45301ms step_avg:87.29ms +step:520/1680 train_time:45387ms step_avg:87.28ms +step:521/1680 train_time:45474ms step_avg:87.28ms +step:522/1680 train_time:45561ms step_avg:87.28ms +step:523/1680 train_time:45648ms step_avg:87.28ms +step:524/1680 train_time:45736ms step_avg:87.28ms +step:525/1680 train_time:45825ms step_avg:87.28ms +step:526/1680 train_time:45914ms step_avg:87.29ms +step:527/1680 train_time:46001ms step_avg:87.29ms +step:528/1680 train_time:46088ms step_avg:87.29ms +step:529/1680 train_time:46175ms step_avg:87.29ms +step:530/1680 train_time:46262ms step_avg:87.29ms +step:531/1680 train_time:46349ms step_avg:87.29ms +step:532/1680 train_time:46436ms step_avg:87.29ms +step:533/1680 train_time:46523ms step_avg:87.29ms +step:534/1680 train_time:46610ms step_avg:87.28ms +step:535/1680 train_time:46697ms step_avg:87.28ms +step:536/1680 train_time:46785ms step_avg:87.29ms +step:537/1680 train_time:46873ms step_avg:87.29ms +step:538/1680 train_time:46961ms step_avg:87.29ms +step:539/1680 train_time:47049ms step_avg:87.29ms +step:540/1680 train_time:47136ms step_avg:87.29ms +step:541/1680 train_time:47223ms step_avg:87.29ms +step:542/1680 train_time:47310ms step_avg:87.29ms +step:543/1680 train_time:47397ms step_avg:87.29ms +step:544/1680 train_time:47484ms step_avg:87.29ms +step:545/1680 train_time:47571ms step_avg:87.29ms +step:546/1680 train_time:47658ms step_avg:87.29ms +step:547/1680 train_time:47745ms step_avg:87.29ms +step:548/1680 train_time:47833ms step_avg:87.29ms +step:549/1680 train_time:47922ms step_avg:87.29ms +step:550/1680 train_time:48011ms step_avg:87.29ms +step:551/1680 train_time:48100ms step_avg:87.30ms +step:552/1680 train_time:48188ms step_avg:87.30ms +step:553/1680 train_time:48276ms step_avg:87.30ms +step:554/1680 train_time:48365ms step_avg:87.30ms +step:555/1680 train_time:48453ms step_avg:87.30ms +step:556/1680 train_time:48542ms step_avg:87.31ms +step:557/1680 train_time:48630ms step_avg:87.31ms +step:558/1680 train_time:48718ms step_avg:87.31ms +step:559/1680 train_time:48807ms step_avg:87.31ms +step:560/1680 train_time:48896ms step_avg:87.31ms +step:561/1680 train_time:48985ms step_avg:87.32ms +step:562/1680 train_time:49074ms step_avg:87.32ms +step:563/1680 train_time:49163ms step_avg:87.32ms +step:564/1680 train_time:49251ms step_avg:87.32ms +step:565/1680 train_time:49340ms step_avg:87.33ms +step:566/1680 train_time:49428ms step_avg:87.33ms +step:567/1680 train_time:49517ms step_avg:87.33ms +step:568/1680 train_time:49604ms step_avg:87.33ms +step:569/1680 train_time:49692ms step_avg:87.33ms +step:570/1680 train_time:49782ms step_avg:87.34ms +step:571/1680 train_time:49870ms step_avg:87.34ms +step:572/1680 train_time:49959ms step_avg:87.34ms +step:573/1680 train_time:50048ms step_avg:87.34ms +step:574/1680 train_time:50136ms step_avg:87.35ms +step:575/1680 train_time:50225ms step_avg:87.35ms +step:576/1680 train_time:50313ms step_avg:87.35ms +step:577/1680 train_time:50402ms step_avg:87.35ms +step:578/1680 train_time:50490ms step_avg:87.35ms +step:579/1680 train_time:50578ms step_avg:87.35ms +step:580/1680 train_time:50666ms step_avg:87.35ms +step:581/1680 train_time:50755ms step_avg:87.36ms +step:582/1680 train_time:50844ms step_avg:87.36ms +step:583/1680 train_time:50933ms step_avg:87.36ms +step:584/1680 train_time:51021ms step_avg:87.36ms +step:585/1680 train_time:51110ms step_avg:87.37ms +step:586/1680 train_time:51199ms step_avg:87.37ms +step:587/1680 train_time:51288ms step_avg:87.37ms +step:588/1680 train_time:51376ms step_avg:87.37ms +step:589/1680 train_time:51465ms step_avg:87.38ms +step:590/1680 train_time:51553ms step_avg:87.38ms +step:591/1680 train_time:51641ms step_avg:87.38ms +step:592/1680 train_time:51729ms step_avg:87.38ms +step:593/1680 train_time:51817ms step_avg:87.38ms +step:594/1680 train_time:51906ms step_avg:87.38ms +step:595/1680 train_time:51994ms step_avg:87.38ms +step:596/1680 train_time:52083ms step_avg:87.39ms +step:597/1680 train_time:52171ms step_avg:87.39ms +step:598/1680 train_time:52260ms step_avg:87.39ms +step:599/1680 train_time:52349ms step_avg:87.39ms +step:600/1680 train_time:52437ms step_avg:87.39ms +step:601/1680 train_time:52525ms step_avg:87.40ms +step:602/1680 train_time:52613ms step_avg:87.40ms +step:603/1680 train_time:52702ms step_avg:87.40ms +step:604/1680 train_time:52790ms step_avg:87.40ms +step:605/1680 train_time:52879ms step_avg:87.40ms +step:606/1680 train_time:52967ms step_avg:87.40ms +step:607/1680 train_time:53056ms step_avg:87.41ms +step:608/1680 train_time:53146ms step_avg:87.41ms +step:609/1680 train_time:53234ms step_avg:87.41ms +step:610/1680 train_time:53323ms step_avg:87.41ms +step:611/1680 train_time:53411ms step_avg:87.42ms +step:612/1680 train_time:53499ms step_avg:87.42ms +step:613/1680 train_time:53587ms step_avg:87.42ms +step:614/1680 train_time:53676ms step_avg:87.42ms +step:615/1680 train_time:53764ms step_avg:87.42ms +step:616/1680 train_time:53853ms step_avg:87.42ms +step:617/1680 train_time:53942ms step_avg:87.43ms +step:618/1680 train_time:54030ms step_avg:87.43ms +step:619/1680 train_time:54118ms step_avg:87.43ms +step:620/1680 train_time:54207ms step_avg:87.43ms +step:621/1680 train_time:54295ms step_avg:87.43ms +step:622/1680 train_time:54384ms step_avg:87.43ms +step:623/1680 train_time:54472ms step_avg:87.44ms +step:624/1680 train_time:54561ms step_avg:87.44ms +step:625/1680 train_time:54649ms step_avg:87.44ms +step:625/1680 val_loss:3.6126 train_time:54739ms step_avg:87.58ms +step:626/1680 train_time:54763ms step_avg:87.48ms +step:627/1680 train_time:54829ms step_avg:87.45ms +step:628/1680 train_time:54919ms step_avg:87.45ms +step:629/1680 train_time:55009ms step_avg:87.46ms +step:630/1680 train_time:55099ms step_avg:87.46ms +step:631/1680 train_time:55186ms step_avg:87.46ms +step:632/1680 train_time:55274ms step_avg:87.46ms +step:633/1680 train_time:55361ms step_avg:87.46ms +step:634/1680 train_time:55448ms step_avg:87.46ms +step:635/1680 train_time:55535ms step_avg:87.46ms +step:636/1680 train_time:55625ms step_avg:87.46ms +step:637/1680 train_time:55718ms step_avg:87.47ms +step:638/1680 train_time:55809ms step_avg:87.48ms +step:639/1680 train_time:55898ms step_avg:87.48ms +step:640/1680 train_time:55986ms step_avg:87.48ms +step:641/1680 train_time:56075ms step_avg:87.48ms +step:642/1680 train_time:56164ms step_avg:87.48ms +step:643/1680 train_time:56251ms step_avg:87.48ms +step:644/1680 train_time:56339ms step_avg:87.48ms +step:645/1680 train_time:56426ms step_avg:87.48ms +step:646/1680 train_time:56515ms step_avg:87.48ms +step:647/1680 train_time:56604ms step_avg:87.49ms +step:648/1680 train_time:56693ms step_avg:87.49ms +step:649/1680 train_time:56783ms step_avg:87.49ms +step:650/1680 train_time:56872ms step_avg:87.50ms +step:651/1680 train_time:56960ms step_avg:87.50ms +step:652/1680 train_time:57048ms step_avg:87.50ms +step:653/1680 train_time:57137ms step_avg:87.50ms +step:654/1680 train_time:57224ms step_avg:87.50ms +step:655/1680 train_time:57312ms step_avg:87.50ms +step:656/1680 train_time:57399ms step_avg:87.50ms +step:657/1680 train_time:57487ms step_avg:87.50ms +step:658/1680 train_time:57576ms step_avg:87.50ms +step:659/1680 train_time:57665ms step_avg:87.50ms +step:660/1680 train_time:57753ms step_avg:87.51ms +step:661/1680 train_time:57842ms step_avg:87.51ms +step:662/1680 train_time:57930ms step_avg:87.51ms +step:663/1680 train_time:58019ms step_avg:87.51ms +step:664/1680 train_time:58107ms step_avg:87.51ms +step:665/1680 train_time:58196ms step_avg:87.51ms +step:666/1680 train_time:58284ms step_avg:87.51ms +step:667/1680 train_time:58371ms step_avg:87.51ms +step:668/1680 train_time:58460ms step_avg:87.52ms +step:669/1680 train_time:58548ms step_avg:87.52ms +step:670/1680 train_time:58637ms step_avg:87.52ms +step:671/1680 train_time:58725ms step_avg:87.52ms +step:672/1680 train_time:58813ms step_avg:87.52ms +step:673/1680 train_time:58903ms step_avg:87.52ms +step:674/1680 train_time:58991ms step_avg:87.52ms +step:675/1680 train_time:59079ms step_avg:87.52ms +step:676/1680 train_time:59167ms step_avg:87.53ms +step:677/1680 train_time:59256ms step_avg:87.53ms +step:678/1680 train_time:59343ms step_avg:87.53ms +step:679/1680 train_time:59431ms step_avg:87.53ms +step:680/1680 train_time:59519ms step_avg:87.53ms +step:681/1680 train_time:59607ms step_avg:87.53ms +step:682/1680 train_time:59696ms step_avg:87.53ms +step:683/1680 train_time:59785ms step_avg:87.53ms +step:684/1680 train_time:59873ms step_avg:87.53ms +step:685/1680 train_time:59962ms step_avg:87.54ms +step:686/1680 train_time:60051ms step_avg:87.54ms +step:687/1680 train_time:60139ms step_avg:87.54ms +step:688/1680 train_time:60227ms step_avg:87.54ms +step:689/1680 train_time:60315ms step_avg:87.54ms +step:690/1680 train_time:60403ms step_avg:87.54ms +step:691/1680 train_time:60492ms step_avg:87.54ms +step:692/1680 train_time:60581ms step_avg:87.54ms +step:693/1680 train_time:60669ms step_avg:87.55ms +step:694/1680 train_time:60757ms step_avg:87.55ms +step:695/1680 train_time:60845ms step_avg:87.55ms +step:696/1680 train_time:60933ms step_avg:87.55ms +step:697/1680 train_time:61023ms step_avg:87.55ms +step:698/1680 train_time:61111ms step_avg:87.55ms +step:699/1680 train_time:61200ms step_avg:87.55ms +step:700/1680 train_time:61289ms step_avg:87.56ms +step:701/1680 train_time:61378ms step_avg:87.56ms +step:702/1680 train_time:61467ms step_avg:87.56ms +step:703/1680 train_time:61555ms step_avg:87.56ms +step:704/1680 train_time:61643ms step_avg:87.56ms +step:705/1680 train_time:61732ms step_avg:87.56ms +step:706/1680 train_time:61820ms step_avg:87.56ms +step:707/1680 train_time:61908ms step_avg:87.56ms +step:708/1680 train_time:61996ms step_avg:87.57ms +step:709/1680 train_time:62084ms step_avg:87.57ms +step:710/1680 train_time:62172ms step_avg:87.57ms +step:711/1680 train_time:62260ms step_avg:87.57ms +step:712/1680 train_time:62349ms step_avg:87.57ms +step:713/1680 train_time:62437ms step_avg:87.57ms +step:714/1680 train_time:62524ms step_avg:87.57ms +step:715/1680 train_time:62612ms step_avg:87.57ms +step:716/1680 train_time:62700ms step_avg:87.57ms +step:717/1680 train_time:62789ms step_avg:87.57ms +step:718/1680 train_time:62878ms step_avg:87.57ms +step:719/1680 train_time:62965ms step_avg:87.57ms +step:720/1680 train_time:63053ms step_avg:87.57ms +step:721/1680 train_time:63142ms step_avg:87.58ms +step:722/1680 train_time:63230ms step_avg:87.58ms +step:723/1680 train_time:63318ms step_avg:87.58ms +step:724/1680 train_time:63406ms step_avg:87.58ms +step:725/1680 train_time:63495ms step_avg:87.58ms +step:726/1680 train_time:63583ms step_avg:87.58ms +step:727/1680 train_time:63670ms step_avg:87.58ms +step:728/1680 train_time:63759ms step_avg:87.58ms +step:729/1680 train_time:63848ms step_avg:87.58ms +step:730/1680 train_time:63936ms step_avg:87.58ms +step:731/1680 train_time:64024ms step_avg:87.58ms +step:732/1680 train_time:64113ms step_avg:87.59ms +step:733/1680 train_time:64202ms step_avg:87.59ms +step:734/1680 train_time:64290ms step_avg:87.59ms +step:735/1680 train_time:64379ms step_avg:87.59ms +step:736/1680 train_time:64468ms step_avg:87.59ms +step:737/1680 train_time:64556ms step_avg:87.59ms +step:738/1680 train_time:64644ms step_avg:87.59ms +step:739/1680 train_time:64732ms step_avg:87.59ms +step:740/1680 train_time:64820ms step_avg:87.60ms +step:741/1680 train_time:64910ms step_avg:87.60ms +step:742/1680 train_time:64998ms step_avg:87.60ms +step:743/1680 train_time:65086ms step_avg:87.60ms +step:744/1680 train_time:65175ms step_avg:87.60ms +step:745/1680 train_time:65263ms step_avg:87.60ms +step:746/1680 train_time:65351ms step_avg:87.60ms +step:747/1680 train_time:65440ms step_avg:87.60ms +step:748/1680 train_time:65528ms step_avg:87.60ms +step:749/1680 train_time:65616ms step_avg:87.61ms +step:750/1680 train_time:65705ms step_avg:87.61ms +step:750/1680 val_loss:3.5614 train_time:65795ms step_avg:87.73ms +step:751/1680 train_time:65818ms step_avg:87.64ms +step:752/1680 train_time:65885ms step_avg:87.61ms +step:753/1680 train_time:65977ms step_avg:87.62ms +step:754/1680 train_time:66066ms step_avg:87.62ms +step:755/1680 train_time:66154ms step_avg:87.62ms +step:756/1680 train_time:66243ms step_avg:87.62ms +step:757/1680 train_time:66331ms step_avg:87.62ms +step:758/1680 train_time:66418ms step_avg:87.62ms +step:759/1680 train_time:66506ms step_avg:87.62ms +step:760/1680 train_time:66593ms step_avg:87.62ms +step:761/1680 train_time:66681ms step_avg:87.62ms +step:762/1680 train_time:66769ms step_avg:87.62ms +step:763/1680 train_time:66860ms step_avg:87.63ms +step:764/1680 train_time:66949ms step_avg:87.63ms +step:765/1680 train_time:67038ms step_avg:87.63ms +step:766/1680 train_time:67127ms step_avg:87.63ms +step:767/1680 train_time:67217ms step_avg:87.64ms +step:768/1680 train_time:67305ms step_avg:87.64ms +step:769/1680 train_time:67392ms step_avg:87.64ms +step:770/1680 train_time:67480ms step_avg:87.64ms +step:771/1680 train_time:67568ms step_avg:87.64ms +step:772/1680 train_time:67656ms step_avg:87.64ms +step:773/1680 train_time:67744ms step_avg:87.64ms +step:774/1680 train_time:67833ms step_avg:87.64ms +step:775/1680 train_time:67922ms step_avg:87.64ms +step:776/1680 train_time:68011ms step_avg:87.64ms +step:777/1680 train_time:68101ms step_avg:87.65ms +step:778/1680 train_time:68189ms step_avg:87.65ms +step:779/1680 train_time:68278ms step_avg:87.65ms +step:780/1680 train_time:68366ms step_avg:87.65ms +step:781/1680 train_time:68454ms step_avg:87.65ms +step:782/1680 train_time:68542ms step_avg:87.65ms +step:783/1680 train_time:68630ms step_avg:87.65ms +step:784/1680 train_time:68718ms step_avg:87.65ms +step:785/1680 train_time:68807ms step_avg:87.65ms +step:786/1680 train_time:68896ms step_avg:87.65ms +step:787/1680 train_time:68985ms step_avg:87.66ms +step:788/1680 train_time:69075ms step_avg:87.66ms +step:789/1680 train_time:69163ms step_avg:87.66ms +step:790/1680 train_time:69252ms step_avg:87.66ms +step:791/1680 train_time:69340ms step_avg:87.66ms +step:792/1680 train_time:69428ms step_avg:87.66ms +step:793/1680 train_time:69517ms step_avg:87.66ms +step:794/1680 train_time:69605ms step_avg:87.66ms +step:795/1680 train_time:69694ms step_avg:87.66ms +step:796/1680 train_time:69782ms step_avg:87.67ms +step:797/1680 train_time:69870ms step_avg:87.67ms +step:798/1680 train_time:69958ms step_avg:87.67ms +step:799/1680 train_time:70047ms step_avg:87.67ms +step:800/1680 train_time:70137ms step_avg:87.67ms +step:801/1680 train_time:70225ms step_avg:87.67ms +step:802/1680 train_time:70314ms step_avg:87.67ms +step:803/1680 train_time:70402ms step_avg:87.67ms +step:804/1680 train_time:70489ms step_avg:87.67ms +step:805/1680 train_time:70577ms step_avg:87.67ms +step:806/1680 train_time:70665ms step_avg:87.67ms +step:807/1680 train_time:70754ms step_avg:87.68ms +step:808/1680 train_time:70843ms step_avg:87.68ms +step:809/1680 train_time:70931ms step_avg:87.68ms +step:810/1680 train_time:71019ms step_avg:87.68ms +step:811/1680 train_time:71108ms step_avg:87.68ms +step:812/1680 train_time:71196ms step_avg:87.68ms +step:813/1680 train_time:71285ms step_avg:87.68ms +step:814/1680 train_time:71374ms step_avg:87.68ms +step:815/1680 train_time:71462ms step_avg:87.68ms +step:816/1680 train_time:71550ms step_avg:87.68ms +step:817/1680 train_time:71638ms step_avg:87.68ms +step:818/1680 train_time:71727ms step_avg:87.69ms +step:819/1680 train_time:71815ms step_avg:87.69ms +step:820/1680 train_time:71903ms step_avg:87.69ms +step:821/1680 train_time:71992ms step_avg:87.69ms +step:822/1680 train_time:72080ms step_avg:87.69ms +step:823/1680 train_time:72168ms step_avg:87.69ms +step:824/1680 train_time:72257ms step_avg:87.69ms +step:825/1680 train_time:72345ms step_avg:87.69ms +step:826/1680 train_time:72434ms step_avg:87.69ms +step:827/1680 train_time:72522ms step_avg:87.69ms +step:828/1680 train_time:72610ms step_avg:87.69ms +step:829/1680 train_time:72698ms step_avg:87.69ms +step:830/1680 train_time:72786ms step_avg:87.69ms +step:831/1680 train_time:72874ms step_avg:87.69ms +step:832/1680 train_time:72962ms step_avg:87.69ms +step:833/1680 train_time:73050ms step_avg:87.70ms +step:834/1680 train_time:73139ms step_avg:87.70ms +step:835/1680 train_time:73228ms step_avg:87.70ms +step:836/1680 train_time:73317ms step_avg:87.70ms +step:837/1680 train_time:73405ms step_avg:87.70ms +step:838/1680 train_time:73494ms step_avg:87.70ms +step:839/1680 train_time:73582ms step_avg:87.70ms +step:840/1680 train_time:73671ms step_avg:87.70ms +step:841/1680 train_time:73759ms step_avg:87.70ms +step:842/1680 train_time:73847ms step_avg:87.70ms +step:843/1680 train_time:73936ms step_avg:87.71ms +step:844/1680 train_time:74024ms step_avg:87.71ms +step:845/1680 train_time:74112ms step_avg:87.71ms +step:846/1680 train_time:74200ms step_avg:87.71ms +step:847/1680 train_time:74289ms step_avg:87.71ms +step:848/1680 train_time:74378ms step_avg:87.71ms +step:849/1680 train_time:74466ms step_avg:87.71ms +step:850/1680 train_time:74554ms step_avg:87.71ms +step:851/1680 train_time:74643ms step_avg:87.71ms +step:852/1680 train_time:74731ms step_avg:87.71ms +step:853/1680 train_time:74819ms step_avg:87.71ms +step:854/1680 train_time:74908ms step_avg:87.71ms +step:855/1680 train_time:74995ms step_avg:87.71ms +step:856/1680 train_time:75083ms step_avg:87.71ms +step:857/1680 train_time:75172ms step_avg:87.72ms +step:858/1680 train_time:75260ms step_avg:87.72ms +step:859/1680 train_time:75349ms step_avg:87.72ms +step:860/1680 train_time:75437ms step_avg:87.72ms +step:861/1680 train_time:75526ms step_avg:87.72ms +step:862/1680 train_time:75614ms step_avg:87.72ms +step:863/1680 train_time:75702ms step_avg:87.72ms +step:864/1680 train_time:75790ms step_avg:87.72ms +step:865/1680 train_time:75878ms step_avg:87.72ms +step:866/1680 train_time:75967ms step_avg:87.72ms +step:867/1680 train_time:76055ms step_avg:87.72ms +step:868/1680 train_time:76144ms step_avg:87.72ms +step:869/1680 train_time:76233ms step_avg:87.72ms +step:870/1680 train_time:76321ms step_avg:87.73ms +step:871/1680 train_time:76410ms step_avg:87.73ms +step:872/1680 train_time:76498ms step_avg:87.73ms +step:873/1680 train_time:76587ms step_avg:87.73ms +step:874/1680 train_time:76675ms step_avg:87.73ms +step:875/1680 train_time:76763ms step_avg:87.73ms +step:875/1680 val_loss:3.5152 train_time:76853ms step_avg:87.83ms +step:876/1680 train_time:76872ms step_avg:87.75ms +step:877/1680 train_time:76947ms step_avg:87.74ms +step:878/1680 train_time:77042ms step_avg:87.75ms +step:879/1680 train_time:77131ms step_avg:87.75ms +step:880/1680 train_time:77219ms step_avg:87.75ms +step:881/1680 train_time:77308ms step_avg:87.75ms +step:882/1680 train_time:77395ms step_avg:87.75ms +step:883/1680 train_time:77483ms step_avg:87.75ms +step:884/1680 train_time:77571ms step_avg:87.75ms +step:885/1680 train_time:77659ms step_avg:87.75ms +step:886/1680 train_time:77746ms step_avg:87.75ms +step:887/1680 train_time:77835ms step_avg:87.75ms +step:888/1680 train_time:77925ms step_avg:87.75ms +step:889/1680 train_time:78015ms step_avg:87.76ms +step:890/1680 train_time:78104ms step_avg:87.76ms +step:891/1680 train_time:78193ms step_avg:87.76ms +step:892/1680 train_time:78282ms step_avg:87.76ms +step:893/1680 train_time:78370ms step_avg:87.76ms +step:894/1680 train_time:78457ms step_avg:87.76ms +step:895/1680 train_time:78545ms step_avg:87.76ms +step:896/1680 train_time:78632ms step_avg:87.76ms +step:897/1680 train_time:78720ms step_avg:87.76ms +step:898/1680 train_time:78809ms step_avg:87.76ms +step:899/1680 train_time:78897ms step_avg:87.76ms +step:900/1680 train_time:78986ms step_avg:87.76ms +step:901/1680 train_time:79075ms step_avg:87.76ms +step:902/1680 train_time:79164ms step_avg:87.76ms +step:903/1680 train_time:79254ms step_avg:87.77ms +step:904/1680 train_time:79342ms step_avg:87.77ms +step:905/1680 train_time:79430ms step_avg:87.77ms +step:906/1680 train_time:79519ms step_avg:87.77ms +step:907/1680 train_time:79606ms step_avg:87.77ms +step:908/1680 train_time:79694ms step_avg:87.77ms +step:909/1680 train_time:79783ms step_avg:87.77ms +step:910/1680 train_time:79871ms step_avg:87.77ms +step:911/1680 train_time:79960ms step_avg:87.77ms +step:912/1680 train_time:80049ms step_avg:87.77ms +step:913/1680 train_time:80138ms step_avg:87.77ms +step:914/1680 train_time:80227ms step_avg:87.78ms +step:915/1680 train_time:80315ms step_avg:87.78ms +step:916/1680 train_time:80403ms step_avg:87.78ms +step:917/1680 train_time:80491ms step_avg:87.78ms +step:918/1680 train_time:80580ms step_avg:87.78ms +step:919/1680 train_time:80668ms step_avg:87.78ms +step:920/1680 train_time:80757ms step_avg:87.78ms +step:921/1680 train_time:80845ms step_avg:87.78ms +step:922/1680 train_time:80933ms step_avg:87.78ms +step:923/1680 train_time:81022ms step_avg:87.78ms +step:924/1680 train_time:81112ms step_avg:87.78ms +step:925/1680 train_time:81202ms step_avg:87.79ms +step:926/1680 train_time:81290ms step_avg:87.79ms +step:927/1680 train_time:81379ms step_avg:87.79ms +step:928/1680 train_time:81467ms step_avg:87.79ms +step:929/1680 train_time:81556ms step_avg:87.79ms +step:930/1680 train_time:81644ms step_avg:87.79ms +step:931/1680 train_time:81733ms step_avg:87.79ms +step:932/1680 train_time:81821ms step_avg:87.79ms +step:933/1680 train_time:81909ms step_avg:87.79ms +step:934/1680 train_time:81998ms step_avg:87.79ms +step:935/1680 train_time:82086ms step_avg:87.79ms +step:936/1680 train_time:82175ms step_avg:87.79ms +step:937/1680 train_time:82264ms step_avg:87.79ms +step:938/1680 train_time:82351ms step_avg:87.79ms +step:939/1680 train_time:82440ms step_avg:87.80ms +step:940/1680 train_time:82528ms step_avg:87.80ms +step:941/1680 train_time:82617ms step_avg:87.80ms +step:942/1680 train_time:82705ms step_avg:87.80ms +step:943/1680 train_time:82792ms step_avg:87.80ms +step:944/1680 train_time:82882ms step_avg:87.80ms +step:945/1680 train_time:82970ms step_avg:87.80ms +step:946/1680 train_time:83059ms step_avg:87.80ms +step:947/1680 train_time:83147ms step_avg:87.80ms +step:948/1680 train_time:83236ms step_avg:87.80ms +step:949/1680 train_time:83324ms step_avg:87.80ms +step:950/1680 train_time:83412ms step_avg:87.80ms +step:951/1680 train_time:83500ms step_avg:87.80ms +step:952/1680 train_time:83589ms step_avg:87.80ms +step:953/1680 train_time:83678ms step_avg:87.80ms +step:954/1680 train_time:83766ms step_avg:87.80ms +step:955/1680 train_time:83854ms step_avg:87.81ms +step:956/1680 train_time:83943ms step_avg:87.81ms +step:957/1680 train_time:84032ms step_avg:87.81ms +step:958/1680 train_time:84121ms step_avg:87.81ms +step:959/1680 train_time:84210ms step_avg:87.81ms +step:960/1680 train_time:84299ms step_avg:87.81ms +step:961/1680 train_time:84387ms step_avg:87.81ms +step:962/1680 train_time:84475ms step_avg:87.81ms +step:963/1680 train_time:84563ms step_avg:87.81ms +step:964/1680 train_time:84652ms step_avg:87.81ms +step:965/1680 train_time:84739ms step_avg:87.81ms +step:966/1680 train_time:84828ms step_avg:87.81ms +step:967/1680 train_time:84916ms step_avg:87.81ms +step:968/1680 train_time:85005ms step_avg:87.82ms +step:969/1680 train_time:85094ms step_avg:87.82ms +step:970/1680 train_time:85183ms step_avg:87.82ms +step:971/1680 train_time:85271ms step_avg:87.82ms +step:972/1680 train_time:85359ms step_avg:87.82ms +step:973/1680 train_time:85447ms step_avg:87.82ms +step:974/1680 train_time:85535ms step_avg:87.82ms +step:975/1680 train_time:85623ms step_avg:87.82ms +step:976/1680 train_time:85711ms step_avg:87.82ms +step:977/1680 train_time:85799ms step_avg:87.82ms +step:978/1680 train_time:85887ms step_avg:87.82ms +step:979/1680 train_time:85975ms step_avg:87.82ms +step:980/1680 train_time:86065ms step_avg:87.82ms +step:981/1680 train_time:86154ms step_avg:87.82ms +step:982/1680 train_time:86242ms step_avg:87.82ms +step:983/1680 train_time:86332ms step_avg:87.83ms +step:984/1680 train_time:86421ms step_avg:87.83ms +step:985/1680 train_time:86510ms step_avg:87.83ms +step:986/1680 train_time:86599ms step_avg:87.83ms +step:987/1680 train_time:86688ms step_avg:87.83ms +step:988/1680 train_time:86776ms step_avg:87.83ms +step:989/1680 train_time:86864ms step_avg:87.83ms +step:990/1680 train_time:86952ms step_avg:87.83ms +step:991/1680 train_time:87040ms step_avg:87.83ms +step:992/1680 train_time:87129ms step_avg:87.83ms +step:993/1680 train_time:87219ms step_avg:87.83ms +step:994/1680 train_time:87307ms step_avg:87.83ms +step:995/1680 train_time:87394ms step_avg:87.83ms +step:996/1680 train_time:87483ms step_avg:87.83ms +step:997/1680 train_time:87572ms step_avg:87.84ms +step:998/1680 train_time:87661ms step_avg:87.84ms +step:999/1680 train_time:87749ms step_avg:87.84ms +step:1000/1680 train_time:87837ms step_avg:87.84ms +step:1000/1680 val_loss:3.4656 train_time:87926ms step_avg:87.93ms +step:1001/1680 train_time:87947ms step_avg:87.86ms +step:1002/1680 train_time:88017ms step_avg:87.84ms +step:1003/1680 train_time:88109ms step_avg:87.85ms +step:1004/1680 train_time:88198ms step_avg:87.85ms +step:1005/1680 train_time:88285ms step_avg:87.85ms +step:1006/1680 train_time:88373ms step_avg:87.85ms +step:1007/1680 train_time:88460ms step_avg:87.85ms +step:1008/1680 train_time:88549ms step_avg:87.85ms +step:1009/1680 train_time:88637ms step_avg:87.85ms +step:1010/1680 train_time:88726ms step_avg:87.85ms +step:1011/1680 train_time:88814ms step_avg:87.85ms +step:1012/1680 train_time:88903ms step_avg:87.85ms +step:1013/1680 train_time:88992ms step_avg:87.85ms +step:1014/1680 train_time:89081ms step_avg:87.85ms +step:1015/1680 train_time:89170ms step_avg:87.85ms +step:1016/1680 train_time:89259ms step_avg:87.85ms +step:1017/1680 train_time:89348ms step_avg:87.85ms +step:1018/1680 train_time:89436ms step_avg:87.85ms +step:1019/1680 train_time:89524ms step_avg:87.85ms +step:1020/1680 train_time:89611ms step_avg:87.85ms +step:1021/1680 train_time:89699ms step_avg:87.85ms +step:1022/1680 train_time:89788ms step_avg:87.85ms +step:1023/1680 train_time:89876ms step_avg:87.85ms +step:1024/1680 train_time:89965ms step_avg:87.86ms +step:1025/1680 train_time:90054ms step_avg:87.86ms +step:1026/1680 train_time:90143ms step_avg:87.86ms +step:1027/1680 train_time:90231ms step_avg:87.86ms +step:1028/1680 train_time:90319ms step_avg:87.86ms +step:1029/1680 train_time:90407ms step_avg:87.86ms +step:1030/1680 train_time:90495ms step_avg:87.86ms +step:1031/1680 train_time:90582ms step_avg:87.86ms +step:1032/1680 train_time:90670ms step_avg:87.86ms +step:1033/1680 train_time:90758ms step_avg:87.86ms +step:1034/1680 train_time:90847ms step_avg:87.86ms +step:1035/1680 train_time:90935ms step_avg:87.86ms +step:1036/1680 train_time:91024ms step_avg:87.86ms +step:1037/1680 train_time:91112ms step_avg:87.86ms +step:1038/1680 train_time:91202ms step_avg:87.86ms +step:1039/1680 train_time:91290ms step_avg:87.86ms +step:1040/1680 train_time:91378ms step_avg:87.86ms +step:1041/1680 train_time:91467ms step_avg:87.86ms +step:1042/1680 train_time:91556ms step_avg:87.87ms +step:1043/1680 train_time:91645ms step_avg:87.87ms +step:1044/1680 train_time:91733ms step_avg:87.87ms +step:1045/1680 train_time:91821ms step_avg:87.87ms +step:1046/1680 train_time:91910ms step_avg:87.87ms +step:1047/1680 train_time:91999ms step_avg:87.87ms +step:1048/1680 train_time:92087ms step_avg:87.87ms +step:1049/1680 train_time:92175ms step_avg:87.87ms +step:1050/1680 train_time:92264ms step_avg:87.87ms +step:1051/1680 train_time:92353ms step_avg:87.87ms +step:1052/1680 train_time:92441ms step_avg:87.87ms +step:1053/1680 train_time:92529ms step_avg:87.87ms +step:1054/1680 train_time:92617ms step_avg:87.87ms +step:1055/1680 train_time:92705ms step_avg:87.87ms +step:1056/1680 train_time:92794ms step_avg:87.87ms +step:1057/1680 train_time:92882ms step_avg:87.87ms +step:1058/1680 train_time:92971ms step_avg:87.87ms +step:1059/1680 train_time:93059ms step_avg:87.87ms +step:1060/1680 train_time:93149ms step_avg:87.88ms +step:1061/1680 train_time:93238ms step_avg:87.88ms +step:1062/1680 train_time:93326ms step_avg:87.88ms +step:1063/1680 train_time:93415ms step_avg:87.88ms +step:1064/1680 train_time:93503ms step_avg:87.88ms +step:1065/1680 train_time:93591ms step_avg:87.88ms +step:1066/1680 train_time:93679ms step_avg:87.88ms +step:1067/1680 train_time:93768ms step_avg:87.88ms +step:1068/1680 train_time:93856ms step_avg:87.88ms +step:1069/1680 train_time:93945ms step_avg:87.88ms +step:1070/1680 train_time:94034ms step_avg:87.88ms +step:1071/1680 train_time:94121ms step_avg:87.88ms +step:1072/1680 train_time:94211ms step_avg:87.88ms +step:1073/1680 train_time:94300ms step_avg:87.88ms +step:1074/1680 train_time:94388ms step_avg:87.88ms +step:1075/1680 train_time:94475ms step_avg:87.88ms +step:1076/1680 train_time:94564ms step_avg:87.88ms +step:1077/1680 train_time:94653ms step_avg:87.89ms +step:1078/1680 train_time:94741ms step_avg:87.89ms +step:1079/1680 train_time:94829ms step_avg:87.89ms +step:1080/1680 train_time:94918ms step_avg:87.89ms +step:1081/1680 train_time:95007ms step_avg:87.89ms +step:1082/1680 train_time:95096ms step_avg:87.89ms +step:1083/1680 train_time:95184ms step_avg:87.89ms +step:1084/1680 train_time:95273ms step_avg:87.89ms +step:1085/1680 train_time:95361ms step_avg:87.89ms +step:1086/1680 train_time:95450ms step_avg:87.89ms +step:1087/1680 train_time:95539ms step_avg:87.89ms +step:1088/1680 train_time:95627ms step_avg:87.89ms +step:1089/1680 train_time:95715ms step_avg:87.89ms +step:1090/1680 train_time:95803ms step_avg:87.89ms +step:1091/1680 train_time:95891ms step_avg:87.89ms +step:1092/1680 train_time:95980ms step_avg:87.89ms +step:1093/1680 train_time:96069ms step_avg:87.89ms +step:1094/1680 train_time:96158ms step_avg:87.90ms +step:1095/1680 train_time:96247ms step_avg:87.90ms +step:1096/1680 train_time:96336ms step_avg:87.90ms +step:1097/1680 train_time:96425ms step_avg:87.90ms +step:1098/1680 train_time:96515ms step_avg:87.90ms +step:1099/1680 train_time:96604ms step_avg:87.90ms +step:1100/1680 train_time:96692ms step_avg:87.90ms +step:1101/1680 train_time:96781ms step_avg:87.90ms +step:1102/1680 train_time:96870ms step_avg:87.90ms +step:1103/1680 train_time:96960ms step_avg:87.91ms +step:1104/1680 train_time:97049ms step_avg:87.91ms +step:1105/1680 train_time:97138ms step_avg:87.91ms +step:1106/1680 train_time:97228ms step_avg:87.91ms +step:1107/1680 train_time:97317ms step_avg:87.91ms +step:1108/1680 train_time:97406ms step_avg:87.91ms +step:1109/1680 train_time:97495ms step_avg:87.91ms +step:1110/1680 train_time:97585ms step_avg:87.91ms +step:1111/1680 train_time:97674ms step_avg:87.92ms +step:1112/1680 train_time:97763ms step_avg:87.92ms +step:1113/1680 train_time:97852ms step_avg:87.92ms +step:1114/1680 train_time:97941ms step_avg:87.92ms +step:1115/1680 train_time:98031ms step_avg:87.92ms +step:1116/1680 train_time:98121ms step_avg:87.92ms +step:1117/1680 train_time:98211ms step_avg:87.92ms +step:1118/1680 train_time:98300ms step_avg:87.93ms +step:1119/1680 train_time:98390ms step_avg:87.93ms +step:1120/1680 train_time:98479ms step_avg:87.93ms +step:1121/1680 train_time:98567ms step_avg:87.93ms +step:1122/1680 train_time:98656ms step_avg:87.93ms +step:1123/1680 train_time:98745ms step_avg:87.93ms +step:1124/1680 train_time:98836ms step_avg:87.93ms +step:1125/1680 train_time:98925ms step_avg:87.93ms +step:1125/1680 val_loss:3.4121 train_time:99016ms step_avg:88.01ms +step:1126/1680 train_time:99036ms step_avg:87.95ms +step:1127/1680 train_time:99106ms step_avg:87.94ms +step:1128/1680 train_time:99197ms step_avg:87.94ms +step:1129/1680 train_time:99289ms step_avg:87.94ms +step:1130/1680 train_time:99378ms step_avg:87.94ms +step:1131/1680 train_time:99466ms step_avg:87.95ms +step:1132/1680 train_time:99555ms step_avg:87.95ms +step:1133/1680 train_time:99644ms step_avg:87.95ms +step:1134/1680 train_time:99732ms step_avg:87.95ms +step:1135/1680 train_time:99821ms step_avg:87.95ms +step:1136/1680 train_time:99910ms step_avg:87.95ms +step:1137/1680 train_time:100001ms step_avg:87.95ms +step:1138/1680 train_time:100092ms step_avg:87.95ms +step:1139/1680 train_time:100184ms step_avg:87.96ms +step:1140/1680 train_time:100274ms step_avg:87.96ms +step:1141/1680 train_time:100363ms step_avg:87.96ms +step:1142/1680 train_time:100452ms step_avg:87.96ms +step:1143/1680 train_time:100541ms step_avg:87.96ms +step:1144/1680 train_time:100629ms step_avg:87.96ms +step:1145/1680 train_time:100717ms step_avg:87.96ms +step:1146/1680 train_time:100806ms step_avg:87.96ms +step:1147/1680 train_time:100895ms step_avg:87.96ms +step:1148/1680 train_time:100985ms step_avg:87.97ms +step:1149/1680 train_time:101075ms step_avg:87.97ms +step:1150/1680 train_time:101164ms step_avg:87.97ms +step:1151/1680 train_time:101255ms step_avg:87.97ms +step:1152/1680 train_time:101345ms step_avg:87.97ms +step:1153/1680 train_time:101434ms step_avg:87.97ms +step:1154/1680 train_time:101523ms step_avg:87.97ms +step:1155/1680 train_time:101612ms step_avg:87.98ms +step:1156/1680 train_time:101700ms step_avg:87.98ms +step:1157/1680 train_time:101788ms step_avg:87.98ms +step:1158/1680 train_time:101878ms step_avg:87.98ms +step:1159/1680 train_time:101967ms step_avg:87.98ms +step:1160/1680 train_time:102056ms step_avg:87.98ms +step:1161/1680 train_time:102145ms step_avg:87.98ms +step:1162/1680 train_time:102234ms step_avg:87.98ms +step:1163/1680 train_time:102323ms step_avg:87.98ms +step:1164/1680 train_time:102414ms step_avg:87.98ms +step:1165/1680 train_time:102503ms step_avg:87.99ms +step:1166/1680 train_time:102592ms step_avg:87.99ms +step:1167/1680 train_time:102682ms step_avg:87.99ms +step:1168/1680 train_time:102770ms step_avg:87.99ms +step:1169/1680 train_time:102859ms step_avg:87.99ms +step:1170/1680 train_time:102948ms step_avg:87.99ms +step:1171/1680 train_time:103038ms step_avg:87.99ms +step:1172/1680 train_time:103126ms step_avg:87.99ms +step:1173/1680 train_time:103216ms step_avg:87.99ms +step:1174/1680 train_time:103305ms step_avg:87.99ms +step:1175/1680 train_time:103395ms step_avg:88.00ms +step:1176/1680 train_time:103484ms step_avg:88.00ms +step:1177/1680 train_time:103573ms step_avg:88.00ms +step:1178/1680 train_time:103663ms step_avg:88.00ms +step:1179/1680 train_time:103751ms step_avg:88.00ms +step:1180/1680 train_time:103841ms step_avg:88.00ms +step:1181/1680 train_time:103929ms step_avg:88.00ms +step:1182/1680 train_time:104019ms step_avg:88.00ms +step:1183/1680 train_time:104108ms step_avg:88.00ms +step:1184/1680 train_time:104198ms step_avg:88.01ms +step:1185/1680 train_time:104287ms step_avg:88.01ms +step:1186/1680 train_time:104377ms step_avg:88.01ms +step:1187/1680 train_time:104466ms step_avg:88.01ms +step:1188/1680 train_time:104555ms step_avg:88.01ms +step:1189/1680 train_time:104645ms step_avg:88.01ms +step:1190/1680 train_time:104734ms step_avg:88.01ms +step:1191/1680 train_time:104824ms step_avg:88.01ms +step:1192/1680 train_time:104913ms step_avg:88.01ms +step:1193/1680 train_time:105002ms step_avg:88.02ms +step:1194/1680 train_time:105092ms step_avg:88.02ms +step:1195/1680 train_time:105182ms step_avg:88.02ms +step:1196/1680 train_time:105271ms step_avg:88.02ms +step:1197/1680 train_time:105360ms step_avg:88.02ms +step:1198/1680 train_time:105449ms step_avg:88.02ms +step:1199/1680 train_time:105538ms step_avg:88.02ms +step:1200/1680 train_time:105627ms step_avg:88.02ms +step:1201/1680 train_time:105715ms step_avg:88.02ms +step:1202/1680 train_time:105805ms step_avg:88.02ms +step:1203/1680 train_time:105894ms step_avg:88.02ms +step:1204/1680 train_time:105984ms step_avg:88.03ms +step:1205/1680 train_time:106074ms step_avg:88.03ms +step:1206/1680 train_time:106163ms step_avg:88.03ms +step:1207/1680 train_time:106253ms step_avg:88.03ms +step:1208/1680 train_time:106342ms step_avg:88.03ms +step:1209/1680 train_time:106431ms step_avg:88.03ms +step:1210/1680 train_time:106521ms step_avg:88.03ms +step:1211/1680 train_time:106610ms step_avg:88.03ms +step:1212/1680 train_time:106699ms step_avg:88.04ms +step:1213/1680 train_time:106788ms step_avg:88.04ms +step:1214/1680 train_time:106878ms step_avg:88.04ms +step:1215/1680 train_time:106967ms step_avg:88.04ms +step:1216/1680 train_time:107055ms step_avg:88.04ms +step:1217/1680 train_time:107144ms step_avg:88.04ms +step:1218/1680 train_time:107233ms step_avg:88.04ms +step:1219/1680 train_time:107324ms step_avg:88.04ms +step:1220/1680 train_time:107414ms step_avg:88.04ms +step:1221/1680 train_time:107502ms step_avg:88.04ms +step:1222/1680 train_time:107591ms step_avg:88.05ms +step:1223/1680 train_time:107681ms step_avg:88.05ms +step:1224/1680 train_time:107769ms step_avg:88.05ms +step:1225/1680 train_time:107858ms step_avg:88.05ms +step:1226/1680 train_time:107948ms step_avg:88.05ms +step:1227/1680 train_time:108037ms step_avg:88.05ms +step:1228/1680 train_time:108126ms step_avg:88.05ms +step:1229/1680 train_time:108216ms step_avg:88.05ms +step:1230/1680 train_time:108305ms step_avg:88.05ms +step:1231/1680 train_time:108395ms step_avg:88.05ms +step:1232/1680 train_time:108484ms step_avg:88.06ms +step:1233/1680 train_time:108574ms step_avg:88.06ms +step:1234/1680 train_time:108662ms step_avg:88.06ms +step:1235/1680 train_time:108751ms step_avg:88.06ms +step:1236/1680 train_time:108840ms step_avg:88.06ms +step:1237/1680 train_time:108929ms step_avg:88.06ms +step:1238/1680 train_time:109018ms step_avg:88.06ms +step:1239/1680 train_time:109107ms step_avg:88.06ms +step:1240/1680 train_time:109196ms step_avg:88.06ms +step:1241/1680 train_time:109285ms step_avg:88.06ms +step:1242/1680 train_time:109374ms step_avg:88.06ms +step:1243/1680 train_time:109463ms step_avg:88.06ms +step:1244/1680 train_time:109552ms step_avg:88.06ms +step:1245/1680 train_time:109643ms step_avg:88.07ms +step:1246/1680 train_time:109733ms step_avg:88.07ms +step:1247/1680 train_time:109821ms step_avg:88.07ms +step:1248/1680 train_time:109910ms step_avg:88.07ms +step:1249/1680 train_time:109999ms step_avg:88.07ms +step:1250/1680 train_time:110089ms step_avg:88.07ms +step:1250/1680 val_loss:3.3741 train_time:110180ms step_avg:88.14ms +step:1251/1680 train_time:110199ms step_avg:88.09ms +step:1252/1680 train_time:110274ms step_avg:88.08ms +step:1253/1680 train_time:110368ms step_avg:88.08ms +step:1254/1680 train_time:110459ms step_avg:88.09ms +step:1255/1680 train_time:110547ms step_avg:88.09ms +step:1256/1680 train_time:110635ms step_avg:88.09ms +step:1257/1680 train_time:110723ms step_avg:88.09ms +step:1258/1680 train_time:110811ms step_avg:88.08ms +step:1259/1680 train_time:110899ms step_avg:88.08ms +step:1260/1680 train_time:110987ms step_avg:88.08ms +step:1261/1680 train_time:111075ms step_avg:88.08ms +step:1262/1680 train_time:111165ms step_avg:88.09ms +step:1263/1680 train_time:111255ms step_avg:88.09ms +step:1264/1680 train_time:111347ms step_avg:88.09ms +step:1265/1680 train_time:111437ms step_avg:88.09ms +step:1266/1680 train_time:111526ms step_avg:88.09ms +step:1267/1680 train_time:111615ms step_avg:88.09ms +step:1268/1680 train_time:111703ms step_avg:88.09ms +step:1269/1680 train_time:111792ms step_avg:88.09ms +step:1270/1680 train_time:111880ms step_avg:88.09ms +step:1271/1680 train_time:111969ms step_avg:88.09ms +step:1272/1680 train_time:112057ms step_avg:88.09ms +step:1273/1680 train_time:112146ms step_avg:88.10ms +step:1274/1680 train_time:112237ms step_avg:88.10ms +step:1275/1680 train_time:112329ms step_avg:88.10ms +step:1276/1680 train_time:112419ms step_avg:88.10ms +step:1277/1680 train_time:112508ms step_avg:88.10ms +step:1278/1680 train_time:112597ms step_avg:88.10ms +step:1279/1680 train_time:112686ms step_avg:88.10ms +step:1280/1680 train_time:112775ms step_avg:88.11ms +step:1281/1680 train_time:112863ms step_avg:88.11ms +step:1282/1680 train_time:112951ms step_avg:88.11ms +step:1283/1680 train_time:113041ms step_avg:88.11ms +step:1284/1680 train_time:113131ms step_avg:88.11ms +step:1285/1680 train_time:113221ms step_avg:88.11ms +step:1286/1680 train_time:113312ms step_avg:88.11ms +step:1287/1680 train_time:113403ms step_avg:88.11ms +step:1288/1680 train_time:113491ms step_avg:88.11ms +step:1289/1680 train_time:113581ms step_avg:88.12ms +step:1290/1680 train_time:113670ms step_avg:88.12ms +step:1291/1680 train_time:113759ms step_avg:88.12ms +step:1292/1680 train_time:113847ms step_avg:88.12ms +step:1293/1680 train_time:113936ms step_avg:88.12ms +step:1294/1680 train_time:114024ms step_avg:88.12ms +step:1295/1680 train_time:114113ms step_avg:88.12ms +step:1296/1680 train_time:114202ms step_avg:88.12ms +step:1297/1680 train_time:114292ms step_avg:88.12ms +step:1298/1680 train_time:114381ms step_avg:88.12ms +step:1299/1680 train_time:114470ms step_avg:88.12ms +step:1300/1680 train_time:114560ms step_avg:88.12ms +step:1301/1680 train_time:114649ms step_avg:88.12ms +step:1302/1680 train_time:114739ms step_avg:88.13ms +step:1303/1680 train_time:114829ms step_avg:88.13ms +step:1304/1680 train_time:114918ms step_avg:88.13ms +step:1305/1680 train_time:115006ms step_avg:88.13ms +step:1306/1680 train_time:115095ms step_avg:88.13ms +step:1307/1680 train_time:115184ms step_avg:88.13ms +step:1308/1680 train_time:115275ms step_avg:88.13ms +step:1309/1680 train_time:115364ms step_avg:88.13ms +step:1310/1680 train_time:115454ms step_avg:88.13ms +step:1311/1680 train_time:115543ms step_avg:88.13ms +step:1312/1680 train_time:115632ms step_avg:88.13ms +step:1313/1680 train_time:115721ms step_avg:88.13ms +step:1314/1680 train_time:115810ms step_avg:88.14ms +step:1315/1680 train_time:115899ms step_avg:88.14ms +step:1316/1680 train_time:115988ms step_avg:88.14ms +step:1317/1680 train_time:116078ms step_avg:88.14ms +step:1318/1680 train_time:116167ms step_avg:88.14ms +step:1319/1680 train_time:116256ms step_avg:88.14ms +step:1320/1680 train_time:116345ms step_avg:88.14ms +step:1321/1680 train_time:116435ms step_avg:88.14ms +step:1322/1680 train_time:116525ms step_avg:88.14ms +step:1323/1680 train_time:116616ms step_avg:88.14ms +step:1324/1680 train_time:116705ms step_avg:88.15ms +step:1325/1680 train_time:116795ms step_avg:88.15ms +step:1326/1680 train_time:116882ms step_avg:88.15ms +step:1327/1680 train_time:116971ms step_avg:88.15ms +step:1328/1680 train_time:117060ms step_avg:88.15ms +step:1329/1680 train_time:117150ms step_avg:88.15ms +step:1330/1680 train_time:117239ms step_avg:88.15ms +step:1331/1680 train_time:117329ms step_avg:88.15ms +step:1332/1680 train_time:117419ms step_avg:88.15ms +step:1333/1680 train_time:117508ms step_avg:88.15ms +step:1334/1680 train_time:117598ms step_avg:88.15ms +step:1335/1680 train_time:117687ms step_avg:88.16ms +step:1336/1680 train_time:117776ms step_avg:88.16ms +step:1337/1680 train_time:117865ms step_avg:88.16ms +step:1338/1680 train_time:117955ms step_avg:88.16ms +step:1339/1680 train_time:118044ms step_avg:88.16ms +step:1340/1680 train_time:118133ms step_avg:88.16ms +step:1341/1680 train_time:118222ms step_avg:88.16ms +step:1342/1680 train_time:118311ms step_avg:88.16ms +step:1343/1680 train_time:118400ms step_avg:88.16ms +step:1344/1680 train_time:118490ms step_avg:88.16ms +step:1345/1680 train_time:118580ms step_avg:88.16ms +step:1346/1680 train_time:118670ms step_avg:88.17ms +step:1347/1680 train_time:118760ms step_avg:88.17ms +step:1348/1680 train_time:118848ms step_avg:88.17ms +step:1349/1680 train_time:118938ms step_avg:88.17ms +step:1350/1680 train_time:119028ms step_avg:88.17ms +step:1351/1680 train_time:119119ms step_avg:88.17ms +step:1352/1680 train_time:119208ms step_avg:88.17ms +step:1353/1680 train_time:119296ms step_avg:88.17ms +step:1354/1680 train_time:119387ms step_avg:88.17ms +step:1355/1680 train_time:119475ms step_avg:88.17ms +step:1356/1680 train_time:119565ms step_avg:88.17ms +step:1357/1680 train_time:119655ms step_avg:88.18ms +step:1358/1680 train_time:119744ms step_avg:88.18ms +step:1359/1680 train_time:119834ms step_avg:88.18ms +step:1360/1680 train_time:119923ms step_avg:88.18ms +step:1361/1680 train_time:120012ms step_avg:88.18ms +step:1362/1680 train_time:120101ms step_avg:88.18ms +step:1363/1680 train_time:120190ms step_avg:88.18ms +step:1364/1680 train_time:120279ms step_avg:88.18ms +step:1365/1680 train_time:120368ms step_avg:88.18ms +step:1366/1680 train_time:120457ms step_avg:88.18ms +step:1367/1680 train_time:120547ms step_avg:88.18ms +step:1368/1680 train_time:120636ms step_avg:88.18ms +step:1369/1680 train_time:120726ms step_avg:88.19ms +step:1370/1680 train_time:120815ms step_avg:88.19ms +step:1371/1680 train_time:120904ms step_avg:88.19ms +step:1372/1680 train_time:120993ms step_avg:88.19ms +step:1373/1680 train_time:121083ms step_avg:88.19ms +step:1374/1680 train_time:121172ms step_avg:88.19ms +step:1375/1680 train_time:121262ms step_avg:88.19ms +step:1375/1680 val_loss:3.3393 train_time:121352ms step_avg:88.26ms +step:1376/1680 train_time:121371ms step_avg:88.21ms +step:1377/1680 train_time:121443ms step_avg:88.19ms +step:1378/1680 train_time:121535ms step_avg:88.20ms +step:1379/1680 train_time:121624ms step_avg:88.20ms +step:1380/1680 train_time:121712ms step_avg:88.20ms +step:1381/1680 train_time:121801ms step_avg:88.20ms +step:1382/1680 train_time:121889ms step_avg:88.20ms +step:1383/1680 train_time:121977ms step_avg:88.20ms +step:1384/1680 train_time:122065ms step_avg:88.20ms +step:1385/1680 train_time:122156ms step_avg:88.20ms +step:1386/1680 train_time:122245ms step_avg:88.20ms +step:1387/1680 train_time:122335ms step_avg:88.20ms +step:1388/1680 train_time:122426ms step_avg:88.20ms +step:1389/1680 train_time:122516ms step_avg:88.20ms +step:1390/1680 train_time:122608ms step_avg:88.21ms +step:1391/1680 train_time:122697ms step_avg:88.21ms +step:1392/1680 train_time:122786ms step_avg:88.21ms +step:1393/1680 train_time:122874ms step_avg:88.21ms +step:1394/1680 train_time:122963ms step_avg:88.21ms +step:1395/1680 train_time:123051ms step_avg:88.21ms +step:1396/1680 train_time:123140ms step_avg:88.21ms +step:1397/1680 train_time:123230ms step_avg:88.21ms +step:1398/1680 train_time:123319ms step_avg:88.21ms +step:1399/1680 train_time:123408ms step_avg:88.21ms +step:1400/1680 train_time:123498ms step_avg:88.21ms +step:1401/1680 train_time:123589ms step_avg:88.21ms +step:1402/1680 train_time:123678ms step_avg:88.22ms +step:1403/1680 train_time:123768ms step_avg:88.22ms +step:1404/1680 train_time:123856ms step_avg:88.22ms +step:1405/1680 train_time:123944ms step_avg:88.22ms +step:1406/1680 train_time:124033ms step_avg:88.22ms +step:1407/1680 train_time:124121ms step_avg:88.22ms +step:1408/1680 train_time:124210ms step_avg:88.22ms +step:1409/1680 train_time:124299ms step_avg:88.22ms +step:1410/1680 train_time:124390ms step_avg:88.22ms +step:1411/1680 train_time:124481ms step_avg:88.22ms +step:1412/1680 train_time:124573ms step_avg:88.22ms +step:1413/1680 train_time:124663ms step_avg:88.23ms +step:1414/1680 train_time:124753ms step_avg:88.23ms +step:1415/1680 train_time:124842ms step_avg:88.23ms +step:1416/1680 train_time:124931ms step_avg:88.23ms +step:1417/1680 train_time:125020ms step_avg:88.23ms +step:1418/1680 train_time:125109ms step_avg:88.23ms +step:1419/1680 train_time:125198ms step_avg:88.23ms +step:1420/1680 train_time:125288ms step_avg:88.23ms +step:1421/1680 train_time:125378ms step_avg:88.23ms +step:1422/1680 train_time:125468ms step_avg:88.23ms +step:1423/1680 train_time:125557ms step_avg:88.23ms +step:1424/1680 train_time:125647ms step_avg:88.24ms +step:1425/1680 train_time:125736ms step_avg:88.24ms +step:1426/1680 train_time:125825ms step_avg:88.24ms +step:1427/1680 train_time:125914ms step_avg:88.24ms +step:1428/1680 train_time:126004ms step_avg:88.24ms +step:1429/1680 train_time:126094ms step_avg:88.24ms +step:1430/1680 train_time:126183ms step_avg:88.24ms +step:1431/1680 train_time:126273ms step_avg:88.24ms +step:1432/1680 train_time:126362ms step_avg:88.24ms +step:1433/1680 train_time:126451ms step_avg:88.24ms +step:1434/1680 train_time:126540ms step_avg:88.24ms +step:1435/1680 train_time:126630ms step_avg:88.24ms +step:1436/1680 train_time:126720ms step_avg:88.24ms +step:1437/1680 train_time:126809ms step_avg:88.25ms +step:1438/1680 train_time:126898ms step_avg:88.25ms +step:1439/1680 train_time:126988ms step_avg:88.25ms +step:1440/1680 train_time:127077ms step_avg:88.25ms +step:1441/1680 train_time:127167ms step_avg:88.25ms +step:1442/1680 train_time:127256ms step_avg:88.25ms +step:1443/1680 train_time:127344ms step_avg:88.25ms +step:1444/1680 train_time:127434ms step_avg:88.25ms +step:1445/1680 train_time:127523ms step_avg:88.25ms +step:1446/1680 train_time:127613ms step_avg:88.25ms +step:1447/1680 train_time:127703ms step_avg:88.25ms +step:1448/1680 train_time:127793ms step_avg:88.26ms +step:1449/1680 train_time:127883ms step_avg:88.26ms +step:1450/1680 train_time:127972ms step_avg:88.26ms +step:1451/1680 train_time:128060ms step_avg:88.26ms +step:1452/1680 train_time:128150ms step_avg:88.26ms +step:1453/1680 train_time:128238ms step_avg:88.26ms +step:1454/1680 train_time:128327ms step_avg:88.26ms +step:1455/1680 train_time:128417ms step_avg:88.26ms +step:1456/1680 train_time:128506ms step_avg:88.26ms +step:1457/1680 train_time:128595ms step_avg:88.26ms +step:1458/1680 train_time:128685ms step_avg:88.26ms +step:1459/1680 train_time:128775ms step_avg:88.26ms +step:1460/1680 train_time:128864ms step_avg:88.26ms +step:1461/1680 train_time:128954ms step_avg:88.26ms +step:1462/1680 train_time:129044ms step_avg:88.27ms +step:1463/1680 train_time:129133ms step_avg:88.27ms +step:1464/1680 train_time:129222ms step_avg:88.27ms +step:1465/1680 train_time:129312ms step_avg:88.27ms +step:1466/1680 train_time:129401ms step_avg:88.27ms +step:1467/1680 train_time:129490ms step_avg:88.27ms +step:1468/1680 train_time:129580ms step_avg:88.27ms +step:1469/1680 train_time:129670ms step_avg:88.27ms +step:1470/1680 train_time:129758ms step_avg:88.27ms +step:1471/1680 train_time:129847ms step_avg:88.27ms +step:1472/1680 train_time:129936ms step_avg:88.27ms +step:1473/1680 train_time:130025ms step_avg:88.27ms +step:1474/1680 train_time:130114ms step_avg:88.27ms +step:1475/1680 train_time:130204ms step_avg:88.27ms +step:1476/1680 train_time:130293ms step_avg:88.27ms +step:1477/1680 train_time:130383ms step_avg:88.28ms +step:1478/1680 train_time:130472ms step_avg:88.28ms +step:1479/1680 train_time:130562ms step_avg:88.28ms +step:1480/1680 train_time:130651ms step_avg:88.28ms +step:1481/1680 train_time:130740ms step_avg:88.28ms +step:1482/1680 train_time:130828ms step_avg:88.28ms +step:1483/1680 train_time:130918ms step_avg:88.28ms +step:1484/1680 train_time:131007ms step_avg:88.28ms +step:1485/1680 train_time:131096ms step_avg:88.28ms +step:1486/1680 train_time:131185ms step_avg:88.28ms +step:1487/1680 train_time:131274ms step_avg:88.28ms +step:1488/1680 train_time:131364ms step_avg:88.28ms +step:1489/1680 train_time:131454ms step_avg:88.28ms +step:1490/1680 train_time:131543ms step_avg:88.28ms +step:1491/1680 train_time:131633ms step_avg:88.29ms +step:1492/1680 train_time:131723ms step_avg:88.29ms +step:1493/1680 train_time:131813ms step_avg:88.29ms +step:1494/1680 train_time:131903ms step_avg:88.29ms +step:1495/1680 train_time:131992ms step_avg:88.29ms +step:1496/1680 train_time:132080ms step_avg:88.29ms +step:1497/1680 train_time:132169ms step_avg:88.29ms +step:1498/1680 train_time:132258ms step_avg:88.29ms +step:1499/1680 train_time:132348ms step_avg:88.29ms +step:1500/1680 train_time:132437ms step_avg:88.29ms +step:1500/1680 val_loss:3.3097 train_time:132529ms step_avg:88.35ms +step:1501/1680 train_time:132548ms step_avg:88.31ms +step:1502/1680 train_time:132623ms step_avg:88.30ms +step:1503/1680 train_time:132717ms step_avg:88.30ms +step:1504/1680 train_time:132807ms step_avg:88.30ms +step:1505/1680 train_time:132895ms step_avg:88.30ms +step:1506/1680 train_time:132983ms step_avg:88.30ms +step:1507/1680 train_time:133071ms step_avg:88.30ms +step:1508/1680 train_time:133159ms step_avg:88.30ms +step:1509/1680 train_time:133247ms step_avg:88.30ms +step:1510/1680 train_time:133335ms step_avg:88.30ms +step:1511/1680 train_time:133424ms step_avg:88.30ms +step:1512/1680 train_time:133515ms step_avg:88.30ms +step:1513/1680 train_time:133605ms step_avg:88.30ms +step:1514/1680 train_time:133698ms step_avg:88.31ms +step:1515/1680 train_time:133788ms step_avg:88.31ms +step:1516/1680 train_time:133877ms step_avg:88.31ms +step:1517/1680 train_time:133966ms step_avg:88.31ms +step:1518/1680 train_time:134055ms step_avg:88.31ms +step:1519/1680 train_time:134143ms step_avg:88.31ms +step:1520/1680 train_time:134231ms step_avg:88.31ms +step:1521/1680 train_time:134319ms step_avg:88.31ms +step:1522/1680 train_time:134408ms step_avg:88.31ms +step:1523/1680 train_time:134498ms step_avg:88.31ms +step:1524/1680 train_time:134588ms step_avg:88.31ms +step:1525/1680 train_time:134680ms step_avg:88.31ms +step:1526/1680 train_time:134770ms step_avg:88.32ms +step:1527/1680 train_time:134859ms step_avg:88.32ms +step:1528/1680 train_time:134948ms step_avg:88.32ms +step:1529/1680 train_time:135037ms step_avg:88.32ms +step:1530/1680 train_time:135125ms step_avg:88.32ms +step:1531/1680 train_time:135214ms step_avg:88.32ms +step:1532/1680 train_time:135302ms step_avg:88.32ms +step:1533/1680 train_time:135391ms step_avg:88.32ms +step:1534/1680 train_time:135481ms step_avg:88.32ms +step:1535/1680 train_time:135571ms step_avg:88.32ms +step:1536/1680 train_time:135661ms step_avg:88.32ms +step:1537/1680 train_time:135752ms step_avg:88.32ms +step:1538/1680 train_time:135842ms step_avg:88.32ms +step:1539/1680 train_time:135932ms step_avg:88.32ms +step:1540/1680 train_time:136021ms step_avg:88.33ms +step:1541/1680 train_time:136110ms step_avg:88.33ms +step:1542/1680 train_time:136198ms step_avg:88.33ms +step:1543/1680 train_time:136286ms step_avg:88.33ms +step:1544/1680 train_time:136375ms step_avg:88.33ms +step:1545/1680 train_time:136464ms step_avg:88.33ms +step:1546/1680 train_time:136554ms step_avg:88.33ms +step:1547/1680 train_time:136643ms step_avg:88.33ms +step:1548/1680 train_time:136733ms step_avg:88.33ms +step:1549/1680 train_time:136822ms step_avg:88.33ms +step:1550/1680 train_time:136911ms step_avg:88.33ms +step:1551/1680 train_time:137000ms step_avg:88.33ms +step:1552/1680 train_time:137088ms step_avg:88.33ms +step:1553/1680 train_time:137177ms step_avg:88.33ms +step:1554/1680 train_time:137265ms step_avg:88.33ms +step:1555/1680 train_time:137354ms step_avg:88.33ms +step:1556/1680 train_time:137443ms step_avg:88.33ms +step:1557/1680 train_time:137533ms step_avg:88.33ms +step:1558/1680 train_time:137623ms step_avg:88.33ms +step:1559/1680 train_time:137712ms step_avg:88.33ms +step:1560/1680 train_time:137801ms step_avg:88.33ms +step:1561/1680 train_time:137892ms step_avg:88.34ms +step:1562/1680 train_time:137981ms step_avg:88.34ms +step:1563/1680 train_time:138069ms step_avg:88.34ms +step:1564/1680 train_time:138158ms step_avg:88.34ms +step:1565/1680 train_time:138247ms step_avg:88.34ms +step:1566/1680 train_time:138336ms step_avg:88.34ms +step:1567/1680 train_time:138425ms step_avg:88.34ms +step:1568/1680 train_time:138514ms step_avg:88.34ms +step:1569/1680 train_time:138603ms step_avg:88.34ms +step:1570/1680 train_time:138692ms step_avg:88.34ms +step:1571/1680 train_time:138782ms step_avg:88.34ms +step:1572/1680 train_time:138870ms step_avg:88.34ms +step:1573/1680 train_time:138960ms step_avg:88.34ms +step:1574/1680 train_time:139049ms step_avg:88.34ms +step:1575/1680 train_time:139139ms step_avg:88.34ms +step:1576/1680 train_time:139227ms step_avg:88.34ms +step:1577/1680 train_time:139317ms step_avg:88.34ms +step:1578/1680 train_time:139405ms step_avg:88.34ms +step:1579/1680 train_time:139494ms step_avg:88.34ms +step:1580/1680 train_time:139584ms step_avg:88.34ms +step:1581/1680 train_time:139674ms step_avg:88.35ms +step:1582/1680 train_time:139763ms step_avg:88.35ms +step:1583/1680 train_time:139852ms step_avg:88.35ms +step:1584/1680 train_time:139942ms step_avg:88.35ms +step:1585/1680 train_time:140032ms step_avg:88.35ms +step:1586/1680 train_time:140121ms step_avg:88.35ms +step:1587/1680 train_time:140210ms step_avg:88.35ms +step:1588/1680 train_time:140299ms step_avg:88.35ms +step:1589/1680 train_time:140388ms step_avg:88.35ms +step:1590/1680 train_time:140477ms step_avg:88.35ms +step:1591/1680 train_time:140567ms step_avg:88.35ms +step:1592/1680 train_time:140657ms step_avg:88.35ms +step:1593/1680 train_time:140746ms step_avg:88.35ms +step:1594/1680 train_time:140836ms step_avg:88.35ms +step:1595/1680 train_time:140926ms step_avg:88.35ms +step:1596/1680 train_time:141015ms step_avg:88.35ms +step:1597/1680 train_time:141104ms step_avg:88.36ms +step:1598/1680 train_time:141193ms step_avg:88.36ms +step:1599/1680 train_time:141282ms step_avg:88.36ms +step:1600/1680 train_time:141371ms step_avg:88.36ms +step:1601/1680 train_time:141461ms step_avg:88.36ms +step:1602/1680 train_time:141550ms step_avg:88.36ms +step:1603/1680 train_time:141640ms step_avg:88.36ms +step:1604/1680 train_time:141729ms step_avg:88.36ms +step:1605/1680 train_time:141819ms step_avg:88.36ms +step:1606/1680 train_time:141909ms step_avg:88.36ms +step:1607/1680 train_time:141998ms step_avg:88.36ms +step:1608/1680 train_time:142087ms step_avg:88.36ms +step:1609/1680 train_time:142177ms step_avg:88.36ms +step:1610/1680 train_time:142266ms step_avg:88.36ms +step:1611/1680 train_time:142355ms step_avg:88.36ms +step:1612/1680 train_time:142444ms step_avg:88.36ms +step:1613/1680 train_time:142534ms step_avg:88.37ms +step:1614/1680 train_time:142624ms step_avg:88.37ms +step:1615/1680 train_time:142714ms step_avg:88.37ms +step:1616/1680 train_time:142802ms step_avg:88.37ms +step:1617/1680 train_time:142892ms step_avg:88.37ms +step:1618/1680 train_time:142983ms step_avg:88.37ms +step:1619/1680 train_time:143072ms step_avg:88.37ms +step:1620/1680 train_time:143161ms step_avg:88.37ms +step:1621/1680 train_time:143251ms step_avg:88.37ms +step:1622/1680 train_time:143340ms step_avg:88.37ms +step:1623/1680 train_time:143430ms step_avg:88.37ms +step:1624/1680 train_time:143520ms step_avg:88.37ms +step:1625/1680 train_time:143609ms step_avg:88.38ms +step:1625/1680 val_loss:3.2859 train_time:143700ms step_avg:88.43ms +step:1626/1680 train_time:143719ms step_avg:88.39ms +step:1627/1680 train_time:143792ms step_avg:88.38ms +step:1628/1680 train_time:143884ms step_avg:88.38ms +step:1629/1680 train_time:143975ms step_avg:88.38ms +step:1630/1680 train_time:144064ms step_avg:88.38ms +step:1631/1680 train_time:144152ms step_avg:88.38ms +step:1632/1680 train_time:144240ms step_avg:88.38ms +step:1633/1680 train_time:144328ms step_avg:88.38ms +step:1634/1680 train_time:144416ms step_avg:88.38ms +step:1635/1680 train_time:144504ms step_avg:88.38ms +step:1636/1680 train_time:144593ms step_avg:88.38ms +step:1637/1680 train_time:144684ms step_avg:88.38ms +step:1638/1680 train_time:144776ms step_avg:88.39ms +step:1639/1680 train_time:144867ms step_avg:88.39ms +step:1640/1680 train_time:144958ms step_avg:88.39ms +step:1641/1680 train_time:145048ms step_avg:88.39ms +step:1642/1680 train_time:145137ms step_avg:88.39ms +step:1643/1680 train_time:145226ms step_avg:88.39ms +step:1644/1680 train_time:145314ms step_avg:88.39ms +step:1645/1680 train_time:145402ms step_avg:88.39ms +step:1646/1680 train_time:145491ms step_avg:88.39ms +step:1647/1680 train_time:145580ms step_avg:88.39ms +step:1648/1680 train_time:145670ms step_avg:88.39ms +step:1649/1680 train_time:145760ms step_avg:88.39ms +step:1650/1680 train_time:145850ms step_avg:88.39ms +step:1651/1680 train_time:145942ms step_avg:88.40ms +step:1652/1680 train_time:146031ms step_avg:88.40ms +step:1653/1680 train_time:146120ms step_avg:88.40ms +step:1654/1680 train_time:146208ms step_avg:88.40ms +step:1655/1680 train_time:146297ms step_avg:88.40ms +step:1656/1680 train_time:146386ms step_avg:88.40ms +step:1657/1680 train_time:146476ms step_avg:88.40ms +step:1658/1680 train_time:146565ms step_avg:88.40ms +step:1659/1680 train_time:146655ms step_avg:88.40ms +step:1660/1680 train_time:146744ms step_avg:88.40ms +step:1661/1680 train_time:146834ms step_avg:88.40ms +step:1662/1680 train_time:146924ms step_avg:88.40ms +step:1663/1680 train_time:147014ms step_avg:88.40ms +step:1664/1680 train_time:147103ms step_avg:88.40ms +step:1665/1680 train_time:147192ms step_avg:88.40ms +step:1666/1680 train_time:147281ms step_avg:88.40ms +step:1667/1680 train_time:147369ms step_avg:88.40ms +step:1668/1680 train_time:147458ms step_avg:88.40ms +step:1669/1680 train_time:147547ms step_avg:88.40ms +step:1670/1680 train_time:147636ms step_avg:88.40ms +step:1671/1680 train_time:147726ms step_avg:88.41ms +step:1672/1680 train_time:147816ms step_avg:88.41ms +step:1673/1680 train_time:147906ms step_avg:88.41ms +step:1674/1680 train_time:147996ms step_avg:88.41ms +step:1675/1680 train_time:148086ms step_avg:88.41ms +step:1676/1680 train_time:148176ms step_avg:88.41ms +step:1677/1680 train_time:148265ms step_avg:88.41ms +step:1678/1680 train_time:148354ms step_avg:88.41ms +step:1679/1680 train_time:148442ms step_avg:88.41ms +step:1680/1680 train_time:148530ms step_avg:88.41ms +step:1680/1680 val_loss:3.2752 train_time:148621ms step_avg:88.47ms +peak memory allocated: 30760 MiB reserved: 45994 MiB diff --git a/records/092725_BF16CE/7fa6fb13-cac4-46c4-bf34-83c8290e17f0.txt b/records/092725_BF16CE/7fa6fb13-cac4-46c4-bf34-83c8290e17f0.txt new file mode 100644 index 000000000..426cd55dc --- /dev/null +++ b/records/092725_BF16CE/7fa6fb13-cac4-46c4-bf34-83c8290e17f0.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:24:12 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 173763 C /usr/bin/python 0MiB | +| 0 N/A N/A 173764 C /usr/bin/python 0MiB | +| 0 N/A N/A 173765 C /usr/bin/python 0MiB | +| 0 N/A N/A 173766 C /usr/bin/python 0MiB | +| 0 N/A N/A 173767 C /usr/bin/python 0MiB | +| 0 N/A N/A 173768 C /usr/bin/python 0MiB | +| 0 N/A N/A 173769 C /usr/bin/python 0MiB | +| 0 N/A N/A 173770 C /usr/bin/python 0MiB | +| 1 N/A N/A 173764 C /usr/bin/python 0MiB | +| 2 N/A N/A 173765 C /usr/bin/python 0MiB | +| 3 N/A N/A 173766 C /usr/bin/python 0MiB | +| 4 N/A N/A 173767 C /usr/bin/python 0MiB | +| 5 N/A N/A 173768 C /usr/bin/python 0MiB | +| 6 N/A N/A 173769 C /usr/bin/python 0MiB | +| 7 N/A N/A 173770 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.06ms +step:1/1680 train_time:139ms step_avg:139.00ms +step:2/1680 train_time:159ms step_avg:79.65ms +step:3/1680 train_time:224ms step_avg:74.60ms +step:4/1680 train_time:309ms step_avg:77.16ms +step:5/1680 train_time:395ms step_avg:78.99ms +step:6/1680 train_time:481ms step_avg:80.15ms +step:7/1680 train_time:567ms step_avg:81.01ms +step:8/1680 train_time:654ms step_avg:81.75ms +step:9/1680 train_time:740ms step_avg:82.25ms +step:10/1680 train_time:827ms step_avg:82.67ms +step:11/1680 train_time:913ms step_avg:82.98ms +step:12/1680 train_time:1000ms step_avg:83.33ms +step:13/1680 train_time:1091ms step_avg:83.90ms +step:14/1680 train_time:1181ms step_avg:84.32ms +step:15/1680 train_time:1269ms step_avg:84.61ms +step:16/1680 train_time:1357ms step_avg:84.78ms +step:17/1680 train_time:1443ms step_avg:84.90ms +step:18/1680 train_time:1530ms step_avg:84.97ms +step:19/1680 train_time:1617ms step_avg:85.10ms +step:20/1680 train_time:1703ms step_avg:85.17ms +step:21/1680 train_time:1790ms step_avg:85.23ms +step:22/1680 train_time:1877ms step_avg:85.30ms +step:23/1680 train_time:1964ms step_avg:85.37ms +step:24/1680 train_time:2052ms step_avg:85.50ms +step:25/1680 train_time:2140ms step_avg:85.61ms +step:26/1680 train_time:2228ms step_avg:85.71ms +step:27/1680 train_time:2316ms step_avg:85.78ms +step:28/1680 train_time:2404ms step_avg:85.85ms +step:29/1680 train_time:2491ms step_avg:85.91ms +step:30/1680 train_time:2578ms step_avg:85.94ms +step:31/1680 train_time:2665ms step_avg:85.95ms +step:32/1680 train_time:2751ms step_avg:85.98ms +step:33/1680 train_time:2838ms step_avg:85.99ms +step:34/1680 train_time:2925ms step_avg:86.04ms +step:35/1680 train_time:3013ms step_avg:86.08ms +step:36/1680 train_time:3100ms step_avg:86.12ms +step:37/1680 train_time:3189ms step_avg:86.18ms +step:38/1680 train_time:3276ms step_avg:86.22ms +step:39/1680 train_time:3364ms step_avg:86.25ms +step:40/1680 train_time:3451ms step_avg:86.29ms +step:41/1680 train_time:3539ms step_avg:86.32ms +step:42/1680 train_time:3626ms step_avg:86.33ms +step:43/1680 train_time:3712ms step_avg:86.33ms +step:44/1680 train_time:3799ms step_avg:86.34ms +step:45/1680 train_time:3887ms step_avg:86.37ms +step:46/1680 train_time:3974ms step_avg:86.39ms +step:47/1680 train_time:4062ms step_avg:86.42ms +step:48/1680 train_time:4151ms step_avg:86.48ms +step:49/1680 train_time:4238ms step_avg:86.49ms +step:50/1680 train_time:4327ms step_avg:86.53ms +step:51/1680 train_time:4414ms step_avg:86.55ms +step:52/1680 train_time:4501ms step_avg:86.56ms +step:53/1680 train_time:4589ms step_avg:86.58ms +step:54/1680 train_time:4675ms step_avg:86.58ms +step:55/1680 train_time:4763ms step_avg:86.59ms +step:56/1680 train_time:4850ms step_avg:86.61ms +step:57/1680 train_time:4937ms step_avg:86.61ms +step:58/1680 train_time:5024ms step_avg:86.62ms +step:59/1680 train_time:5111ms step_avg:86.63ms +step:60/1680 train_time:5198ms step_avg:86.64ms +step:61/1680 train_time:5286ms step_avg:86.66ms +step:62/1680 train_time:5374ms step_avg:86.68ms +step:63/1680 train_time:5462ms step_avg:86.69ms +step:64/1680 train_time:5548ms step_avg:86.69ms +step:65/1680 train_time:5636ms step_avg:86.70ms +step:66/1680 train_time:5722ms step_avg:86.70ms +step:67/1680 train_time:5809ms step_avg:86.71ms +step:68/1680 train_time:5897ms step_avg:86.71ms +step:69/1680 train_time:5984ms step_avg:86.72ms +step:70/1680 train_time:6071ms step_avg:86.73ms +step:71/1680 train_time:6158ms step_avg:86.73ms +step:72/1680 train_time:6246ms step_avg:86.75ms +step:73/1680 train_time:6333ms step_avg:86.75ms +step:74/1680 train_time:6420ms step_avg:86.76ms +step:75/1680 train_time:6508ms step_avg:86.77ms +step:76/1680 train_time:6595ms step_avg:86.78ms +step:77/1680 train_time:6682ms step_avg:86.78ms +step:78/1680 train_time:6770ms step_avg:86.79ms +step:79/1680 train_time:6857ms step_avg:86.80ms +step:80/1680 train_time:6944ms step_avg:86.81ms +step:81/1680 train_time:7031ms step_avg:86.80ms +step:82/1680 train_time:7118ms step_avg:86.81ms +step:83/1680 train_time:7206ms step_avg:86.81ms +step:84/1680 train_time:7293ms step_avg:86.82ms +step:85/1680 train_time:7380ms step_avg:86.83ms +step:86/1680 train_time:7467ms step_avg:86.82ms +step:87/1680 train_time:7554ms step_avg:86.83ms +step:88/1680 train_time:7640ms step_avg:86.82ms +step:89/1680 train_time:7727ms step_avg:86.82ms +step:90/1680 train_time:7815ms step_avg:86.83ms +step:91/1680 train_time:7902ms step_avg:86.84ms +step:92/1680 train_time:7989ms step_avg:86.84ms +step:93/1680 train_time:8076ms step_avg:86.84ms +step:94/1680 train_time:8163ms step_avg:86.84ms +step:95/1680 train_time:8251ms step_avg:86.85ms +step:96/1680 train_time:8337ms step_avg:86.84ms +step:97/1680 train_time:8425ms step_avg:86.85ms +step:98/1680 train_time:8512ms step_avg:86.86ms +step:99/1680 train_time:8599ms step_avg:86.86ms +step:100/1680 train_time:8686ms step_avg:86.86ms +step:101/1680 train_time:8773ms step_avg:86.86ms +step:102/1680 train_time:8860ms step_avg:86.86ms +step:103/1680 train_time:8947ms step_avg:86.86ms +step:104/1680 train_time:9033ms step_avg:86.86ms +step:105/1680 train_time:9120ms step_avg:86.86ms +step:106/1680 train_time:9207ms step_avg:86.86ms +step:107/1680 train_time:9295ms step_avg:86.87ms +step:108/1680 train_time:9382ms step_avg:86.87ms +step:109/1680 train_time:9469ms step_avg:86.87ms +step:110/1680 train_time:9556ms step_avg:86.87ms +step:111/1680 train_time:9643ms step_avg:86.87ms +step:112/1680 train_time:9731ms step_avg:86.88ms +step:113/1680 train_time:9817ms step_avg:86.88ms +step:114/1680 train_time:9904ms step_avg:86.88ms +step:115/1680 train_time:9992ms step_avg:86.89ms +step:116/1680 train_time:10079ms step_avg:86.89ms +step:117/1680 train_time:10166ms step_avg:86.89ms +step:118/1680 train_time:10254ms step_avg:86.90ms +step:119/1680 train_time:10340ms step_avg:86.89ms +step:120/1680 train_time:10427ms step_avg:86.89ms +step:121/1680 train_time:10514ms step_avg:86.90ms +step:122/1680 train_time:10602ms step_avg:86.90ms +step:123/1680 train_time:10690ms step_avg:86.91ms +step:124/1680 train_time:10777ms step_avg:86.91ms +step:125/1680 train_time:10864ms step_avg:86.91ms +step:125/1680 val_loss:4.3283 train_time:10952ms step_avg:87.62ms +step:126/1680 train_time:10975ms step_avg:87.10ms +step:127/1680 train_time:11043ms step_avg:86.95ms +step:128/1680 train_time:11139ms step_avg:87.03ms +step:129/1680 train_time:11231ms step_avg:87.06ms +step:130/1680 train_time:11319ms step_avg:87.07ms +step:131/1680 train_time:11405ms step_avg:87.06ms +step:132/1680 train_time:11491ms step_avg:87.05ms +step:133/1680 train_time:11576ms step_avg:87.04ms +step:134/1680 train_time:11662ms step_avg:87.03ms +step:135/1680 train_time:11748ms step_avg:87.02ms +step:136/1680 train_time:11833ms step_avg:87.01ms +step:137/1680 train_time:11919ms step_avg:87.00ms +step:138/1680 train_time:12007ms step_avg:87.01ms +step:139/1680 train_time:12095ms step_avg:87.02ms +step:140/1680 train_time:12184ms step_avg:87.03ms +step:141/1680 train_time:12272ms step_avg:87.04ms +step:142/1680 train_time:12361ms step_avg:87.05ms +step:143/1680 train_time:12448ms step_avg:87.05ms +step:144/1680 train_time:12535ms step_avg:87.05ms +step:145/1680 train_time:12621ms step_avg:87.04ms +step:146/1680 train_time:12707ms step_avg:87.04ms +step:147/1680 train_time:12793ms step_avg:87.03ms +step:148/1680 train_time:12879ms step_avg:87.02ms +step:149/1680 train_time:12966ms step_avg:87.02ms +step:150/1680 train_time:13053ms step_avg:87.02ms +step:151/1680 train_time:13142ms step_avg:87.04ms +step:152/1680 train_time:13231ms step_avg:87.05ms +step:153/1680 train_time:13319ms step_avg:87.05ms +step:154/1680 train_time:13407ms step_avg:87.06ms +step:155/1680 train_time:13493ms step_avg:87.05ms +step:156/1680 train_time:13580ms step_avg:87.05ms +step:157/1680 train_time:13667ms step_avg:87.05ms +step:158/1680 train_time:13753ms step_avg:87.05ms +step:159/1680 train_time:13839ms step_avg:87.04ms +step:160/1680 train_time:13925ms step_avg:87.03ms +step:161/1680 train_time:14012ms step_avg:87.03ms +step:162/1680 train_time:14099ms step_avg:87.03ms +step:163/1680 train_time:14187ms step_avg:87.04ms +step:164/1680 train_time:14275ms step_avg:87.04ms +step:165/1680 train_time:14362ms step_avg:87.04ms +step:166/1680 train_time:14450ms step_avg:87.05ms +step:167/1680 train_time:14537ms step_avg:87.05ms +step:168/1680 train_time:14623ms step_avg:87.04ms +step:169/1680 train_time:14710ms step_avg:87.04ms +step:170/1680 train_time:14797ms step_avg:87.04ms +step:171/1680 train_time:14883ms step_avg:87.04ms +step:172/1680 train_time:14970ms step_avg:87.04ms +step:173/1680 train_time:15058ms step_avg:87.04ms +step:174/1680 train_time:15146ms step_avg:87.04ms +step:175/1680 train_time:15233ms step_avg:87.05ms +step:176/1680 train_time:15321ms step_avg:87.05ms +step:177/1680 train_time:15408ms step_avg:87.05ms +step:178/1680 train_time:15495ms step_avg:87.05ms +step:179/1680 train_time:15582ms step_avg:87.05ms +step:180/1680 train_time:15668ms step_avg:87.05ms +step:181/1680 train_time:15755ms step_avg:87.04ms +step:182/1680 train_time:15842ms step_avg:87.05ms +step:183/1680 train_time:15929ms step_avg:87.04ms +step:184/1680 train_time:16017ms step_avg:87.05ms +step:185/1680 train_time:16103ms step_avg:87.04ms +step:186/1680 train_time:16191ms step_avg:87.05ms +step:187/1680 train_time:16278ms step_avg:87.05ms +step:188/1680 train_time:16366ms step_avg:87.05ms +step:189/1680 train_time:16452ms step_avg:87.05ms +step:190/1680 train_time:16539ms step_avg:87.05ms +step:191/1680 train_time:16626ms step_avg:87.05ms +step:192/1680 train_time:16713ms step_avg:87.05ms +step:193/1680 train_time:16800ms step_avg:87.05ms +step:194/1680 train_time:16886ms step_avg:87.04ms +step:195/1680 train_time:16973ms step_avg:87.04ms +step:196/1680 train_time:17060ms step_avg:87.04ms +step:197/1680 train_time:17147ms step_avg:87.04ms +step:198/1680 train_time:17235ms step_avg:87.04ms +step:199/1680 train_time:17321ms step_avg:87.04ms +step:200/1680 train_time:17409ms step_avg:87.04ms +step:201/1680 train_time:17496ms step_avg:87.04ms +step:202/1680 train_time:17582ms step_avg:87.04ms +step:203/1680 train_time:17669ms step_avg:87.04ms +step:204/1680 train_time:17756ms step_avg:87.04ms +step:205/1680 train_time:17842ms step_avg:87.04ms +step:206/1680 train_time:17930ms step_avg:87.04ms +step:207/1680 train_time:18017ms step_avg:87.04ms +step:208/1680 train_time:18105ms step_avg:87.04ms +step:209/1680 train_time:18191ms step_avg:87.04ms +step:210/1680 train_time:18279ms step_avg:87.04ms +step:211/1680 train_time:18366ms step_avg:87.04ms +step:212/1680 train_time:18453ms step_avg:87.04ms +step:213/1680 train_time:18540ms step_avg:87.04ms +step:214/1680 train_time:18627ms step_avg:87.04ms +step:215/1680 train_time:18714ms step_avg:87.04ms +step:216/1680 train_time:18801ms step_avg:87.04ms +step:217/1680 train_time:18888ms step_avg:87.04ms +step:218/1680 train_time:18975ms step_avg:87.04ms +step:219/1680 train_time:19062ms step_avg:87.04ms +step:220/1680 train_time:19149ms step_avg:87.04ms +step:221/1680 train_time:19236ms step_avg:87.04ms +step:222/1680 train_time:19323ms step_avg:87.04ms +step:223/1680 train_time:19410ms step_avg:87.04ms +step:224/1680 train_time:19497ms step_avg:87.04ms +step:225/1680 train_time:19584ms step_avg:87.04ms +step:226/1680 train_time:19670ms step_avg:87.04ms +step:227/1680 train_time:19757ms step_avg:87.04ms +step:228/1680 train_time:19845ms step_avg:87.04ms +step:229/1680 train_time:19932ms step_avg:87.04ms +step:230/1680 train_time:20019ms step_avg:87.04ms +step:231/1680 train_time:20107ms step_avg:87.04ms +step:232/1680 train_time:20193ms step_avg:87.04ms +step:233/1680 train_time:20281ms step_avg:87.04ms +step:234/1680 train_time:20368ms step_avg:87.04ms +step:235/1680 train_time:20454ms step_avg:87.04ms +step:236/1680 train_time:20541ms step_avg:87.04ms +step:237/1680 train_time:20629ms step_avg:87.04ms +step:238/1680 train_time:20716ms step_avg:87.04ms +step:239/1680 train_time:20803ms step_avg:87.04ms +step:240/1680 train_time:20890ms step_avg:87.04ms +step:241/1680 train_time:20977ms step_avg:87.04ms +step:242/1680 train_time:21064ms step_avg:87.04ms +step:243/1680 train_time:21151ms step_avg:87.04ms +step:244/1680 train_time:21238ms step_avg:87.04ms +step:245/1680 train_time:21325ms step_avg:87.04ms +step:246/1680 train_time:21411ms step_avg:87.04ms +step:247/1680 train_time:21498ms step_avg:87.04ms +step:248/1680 train_time:21585ms step_avg:87.04ms +step:249/1680 train_time:21672ms step_avg:87.04ms +step:250/1680 train_time:21759ms step_avg:87.04ms +step:250/1680 val_loss:3.9807 train_time:21848ms step_avg:87.39ms +step:251/1680 train_time:21868ms step_avg:87.12ms +step:252/1680 train_time:21939ms step_avg:87.06ms +step:253/1680 train_time:22033ms step_avg:87.09ms +step:254/1680 train_time:22123ms step_avg:87.10ms +step:255/1680 train_time:22210ms step_avg:87.10ms +step:256/1680 train_time:22296ms step_avg:87.09ms +step:257/1680 train_time:22382ms step_avg:87.09ms +step:258/1680 train_time:22468ms step_avg:87.09ms +step:259/1680 train_time:22554ms step_avg:87.08ms +step:260/1680 train_time:22640ms step_avg:87.08ms +step:261/1680 train_time:22726ms step_avg:87.07ms +step:262/1680 train_time:22813ms step_avg:87.07ms +step:263/1680 train_time:22901ms step_avg:87.08ms +step:264/1680 train_time:22990ms step_avg:87.08ms +step:265/1680 train_time:23079ms step_avg:87.09ms +step:266/1680 train_time:23167ms step_avg:87.09ms +step:267/1680 train_time:23254ms step_avg:87.09ms +step:268/1680 train_time:23341ms step_avg:87.09ms +step:269/1680 train_time:23428ms step_avg:87.09ms +step:270/1680 train_time:23514ms step_avg:87.09ms +step:271/1680 train_time:23601ms step_avg:87.09ms +step:272/1680 train_time:23687ms step_avg:87.09ms +step:273/1680 train_time:23774ms step_avg:87.09ms +step:274/1680 train_time:23861ms step_avg:87.08ms +step:275/1680 train_time:23949ms step_avg:87.09ms +step:276/1680 train_time:24037ms step_avg:87.09ms +step:277/1680 train_time:24125ms step_avg:87.09ms +step:278/1680 train_time:24213ms step_avg:87.10ms +step:279/1680 train_time:24300ms step_avg:87.10ms +step:280/1680 train_time:24387ms step_avg:87.09ms +step:281/1680 train_time:24474ms step_avg:87.10ms +step:282/1680 train_time:24561ms step_avg:87.09ms +step:283/1680 train_time:24648ms step_avg:87.09ms +step:284/1680 train_time:24734ms step_avg:87.09ms +step:285/1680 train_time:24820ms step_avg:87.09ms +step:286/1680 train_time:24908ms step_avg:87.09ms +step:287/1680 train_time:24995ms step_avg:87.09ms +step:288/1680 train_time:25083ms step_avg:87.10ms +step:289/1680 train_time:25171ms step_avg:87.10ms +step:290/1680 train_time:25259ms step_avg:87.10ms +step:291/1680 train_time:25346ms step_avg:87.10ms +step:292/1680 train_time:25433ms step_avg:87.10ms +step:293/1680 train_time:25519ms step_avg:87.10ms +step:294/1680 train_time:25606ms step_avg:87.09ms +step:295/1680 train_time:25692ms step_avg:87.09ms +step:296/1680 train_time:25779ms step_avg:87.09ms +step:297/1680 train_time:25866ms step_avg:87.09ms +step:298/1680 train_time:25953ms step_avg:87.09ms +step:299/1680 train_time:26040ms step_avg:87.09ms +step:300/1680 train_time:26128ms step_avg:87.09ms +step:301/1680 train_time:26215ms step_avg:87.09ms +step:302/1680 train_time:26303ms step_avg:87.10ms +step:303/1680 train_time:26390ms step_avg:87.10ms +step:304/1680 train_time:26477ms step_avg:87.09ms +step:305/1680 train_time:26564ms step_avg:87.10ms +step:306/1680 train_time:26650ms step_avg:87.09ms +step:307/1680 train_time:26737ms step_avg:87.09ms +step:308/1680 train_time:26823ms step_avg:87.09ms +step:309/1680 train_time:26910ms step_avg:87.09ms +step:310/1680 train_time:26996ms step_avg:87.08ms +step:311/1680 train_time:27083ms step_avg:87.08ms +step:312/1680 train_time:27170ms step_avg:87.08ms +step:313/1680 train_time:27256ms step_avg:87.08ms +step:314/1680 train_time:27343ms step_avg:87.08ms +step:315/1680 train_time:27430ms step_avg:87.08ms +step:316/1680 train_time:27517ms step_avg:87.08ms +step:317/1680 train_time:27604ms step_avg:87.08ms +step:318/1680 train_time:27691ms step_avg:87.08ms +step:319/1680 train_time:27778ms step_avg:87.08ms +step:320/1680 train_time:27866ms step_avg:87.08ms +step:321/1680 train_time:27952ms step_avg:87.08ms +step:322/1680 train_time:28039ms step_avg:87.08ms +step:323/1680 train_time:28125ms step_avg:87.08ms +step:324/1680 train_time:28213ms step_avg:87.08ms +step:325/1680 train_time:28300ms step_avg:87.08ms +step:326/1680 train_time:28388ms step_avg:87.08ms +step:327/1680 train_time:28475ms step_avg:87.08ms +step:328/1680 train_time:28563ms step_avg:87.08ms +step:329/1680 train_time:28649ms step_avg:87.08ms +step:330/1680 train_time:28736ms step_avg:87.08ms +step:331/1680 train_time:28824ms step_avg:87.08ms +step:332/1680 train_time:28910ms step_avg:87.08ms +step:333/1680 train_time:28998ms step_avg:87.08ms +step:334/1680 train_time:29085ms step_avg:87.08ms +step:335/1680 train_time:29172ms step_avg:87.08ms +step:336/1680 train_time:29259ms step_avg:87.08ms +step:337/1680 train_time:29346ms step_avg:87.08ms +step:338/1680 train_time:29433ms step_avg:87.08ms +step:339/1680 train_time:29520ms step_avg:87.08ms +step:340/1680 train_time:29607ms step_avg:87.08ms +step:341/1680 train_time:29694ms step_avg:87.08ms +step:342/1680 train_time:29781ms step_avg:87.08ms +step:343/1680 train_time:29868ms step_avg:87.08ms +step:344/1680 train_time:29954ms step_avg:87.08ms +step:345/1680 train_time:30042ms step_avg:87.08ms +step:346/1680 train_time:30129ms step_avg:87.08ms +step:347/1680 train_time:30216ms step_avg:87.08ms +step:348/1680 train_time:30304ms step_avg:87.08ms +step:349/1680 train_time:30391ms step_avg:87.08ms +step:350/1680 train_time:30478ms step_avg:87.08ms +step:351/1680 train_time:30566ms step_avg:87.08ms +step:352/1680 train_time:30653ms step_avg:87.08ms +step:353/1680 train_time:30740ms step_avg:87.08ms +step:354/1680 train_time:30826ms step_avg:87.08ms +step:355/1680 train_time:30913ms step_avg:87.08ms +step:356/1680 train_time:31000ms step_avg:87.08ms +step:357/1680 train_time:31087ms step_avg:87.08ms +step:358/1680 train_time:31174ms step_avg:87.08ms +step:359/1680 train_time:31260ms step_avg:87.08ms +step:360/1680 train_time:31347ms step_avg:87.08ms +step:361/1680 train_time:31435ms step_avg:87.08ms +step:362/1680 train_time:31522ms step_avg:87.08ms +step:363/1680 train_time:31608ms step_avg:87.08ms +step:364/1680 train_time:31696ms step_avg:87.08ms +step:365/1680 train_time:31783ms step_avg:87.08ms +step:366/1680 train_time:31870ms step_avg:87.08ms +step:367/1680 train_time:31957ms step_avg:87.08ms +step:368/1680 train_time:32044ms step_avg:87.08ms +step:369/1680 train_time:32131ms step_avg:87.08ms +step:370/1680 train_time:32218ms step_avg:87.08ms +step:371/1680 train_time:32306ms step_avg:87.08ms +step:372/1680 train_time:32393ms step_avg:87.08ms +step:373/1680 train_time:32481ms step_avg:87.08ms +step:374/1680 train_time:32568ms step_avg:87.08ms +step:375/1680 train_time:32655ms step_avg:87.08ms +step:375/1680 val_loss:3.8209 train_time:32744ms step_avg:87.32ms +step:376/1680 train_time:32764ms step_avg:87.14ms +step:377/1680 train_time:32834ms step_avg:87.09ms +step:378/1680 train_time:32923ms step_avg:87.10ms +step:379/1680 train_time:33011ms step_avg:87.10ms +step:380/1680 train_time:33099ms step_avg:87.10ms +step:381/1680 train_time:33185ms step_avg:87.10ms +step:382/1680 train_time:33271ms step_avg:87.10ms +step:383/1680 train_time:33357ms step_avg:87.10ms +step:384/1680 train_time:33444ms step_avg:87.09ms +step:385/1680 train_time:33530ms step_avg:87.09ms +step:386/1680 train_time:33617ms step_avg:87.09ms +step:387/1680 train_time:33704ms step_avg:87.09ms +step:388/1680 train_time:33792ms step_avg:87.09ms +step:389/1680 train_time:33882ms step_avg:87.10ms +step:390/1680 train_time:33969ms step_avg:87.10ms +step:391/1680 train_time:34057ms step_avg:87.10ms +step:392/1680 train_time:34144ms step_avg:87.10ms +step:393/1680 train_time:34231ms step_avg:87.10ms +step:394/1680 train_time:34318ms step_avg:87.10ms +step:395/1680 train_time:34404ms step_avg:87.10ms +step:396/1680 train_time:34491ms step_avg:87.10ms +step:397/1680 train_time:34577ms step_avg:87.09ms +step:398/1680 train_time:34665ms step_avg:87.10ms +step:399/1680 train_time:34752ms step_avg:87.10ms +step:400/1680 train_time:34840ms step_avg:87.10ms +step:401/1680 train_time:34928ms step_avg:87.10ms +step:402/1680 train_time:35015ms step_avg:87.10ms +step:403/1680 train_time:35102ms step_avg:87.10ms +step:404/1680 train_time:35189ms step_avg:87.10ms +step:405/1680 train_time:35275ms step_avg:87.10ms +step:406/1680 train_time:35363ms step_avg:87.10ms +step:407/1680 train_time:35449ms step_avg:87.10ms +step:408/1680 train_time:35535ms step_avg:87.10ms +step:409/1680 train_time:35622ms step_avg:87.10ms +step:410/1680 train_time:35709ms step_avg:87.09ms +step:411/1680 train_time:35795ms step_avg:87.09ms +step:412/1680 train_time:35885ms step_avg:87.10ms +step:413/1680 train_time:35973ms step_avg:87.10ms +step:414/1680 train_time:36060ms step_avg:87.10ms +step:415/1680 train_time:36148ms step_avg:87.10ms +step:416/1680 train_time:36235ms step_avg:87.10ms +step:417/1680 train_time:36322ms step_avg:87.10ms +step:418/1680 train_time:36408ms step_avg:87.10ms +step:419/1680 train_time:36495ms step_avg:87.10ms +step:420/1680 train_time:36582ms step_avg:87.10ms +step:421/1680 train_time:36668ms step_avg:87.10ms +step:422/1680 train_time:36756ms step_avg:87.10ms +step:423/1680 train_time:36845ms step_avg:87.10ms +step:424/1680 train_time:36932ms step_avg:87.10ms +step:425/1680 train_time:37019ms step_avg:87.10ms +step:426/1680 train_time:37107ms step_avg:87.11ms +step:427/1680 train_time:37194ms step_avg:87.11ms +step:428/1680 train_time:37282ms step_avg:87.11ms +step:429/1680 train_time:37368ms step_avg:87.10ms +step:430/1680 train_time:37454ms step_avg:87.10ms +step:431/1680 train_time:37541ms step_avg:87.10ms +step:432/1680 train_time:37628ms step_avg:87.10ms +step:433/1680 train_time:37715ms step_avg:87.10ms +step:434/1680 train_time:37802ms step_avg:87.10ms +step:435/1680 train_time:37889ms step_avg:87.10ms +step:436/1680 train_time:37976ms step_avg:87.10ms +step:437/1680 train_time:38064ms step_avg:87.10ms +step:438/1680 train_time:38151ms step_avg:87.10ms +step:439/1680 train_time:38239ms step_avg:87.10ms +step:440/1680 train_time:38326ms step_avg:87.10ms +step:441/1680 train_time:38412ms step_avg:87.10ms +step:442/1680 train_time:38499ms step_avg:87.10ms +step:443/1680 train_time:38586ms step_avg:87.10ms +step:444/1680 train_time:38673ms step_avg:87.10ms +step:445/1680 train_time:38760ms step_avg:87.10ms +step:446/1680 train_time:38848ms step_avg:87.10ms +step:447/1680 train_time:38935ms step_avg:87.10ms +step:448/1680 train_time:39022ms step_avg:87.10ms +step:449/1680 train_time:39109ms step_avg:87.10ms +step:450/1680 train_time:39195ms step_avg:87.10ms +step:451/1680 train_time:39283ms step_avg:87.10ms +step:452/1680 train_time:39369ms step_avg:87.10ms +step:453/1680 train_time:39457ms step_avg:87.10ms +step:454/1680 train_time:39544ms step_avg:87.10ms +step:455/1680 train_time:39631ms step_avg:87.10ms +step:456/1680 train_time:39718ms step_avg:87.10ms +step:457/1680 train_time:39805ms step_avg:87.10ms +step:458/1680 train_time:39893ms step_avg:87.10ms +step:459/1680 train_time:39980ms step_avg:87.10ms +step:460/1680 train_time:40067ms step_avg:87.10ms +step:461/1680 train_time:40154ms step_avg:87.10ms +step:462/1680 train_time:40241ms step_avg:87.10ms +step:463/1680 train_time:40328ms step_avg:87.10ms +step:464/1680 train_time:40415ms step_avg:87.10ms +step:465/1680 train_time:40502ms step_avg:87.10ms +step:466/1680 train_time:40589ms step_avg:87.10ms +step:467/1680 train_time:40676ms step_avg:87.10ms +step:468/1680 train_time:40763ms step_avg:87.10ms +step:469/1680 train_time:40850ms step_avg:87.10ms +step:470/1680 train_time:40937ms step_avg:87.10ms +step:471/1680 train_time:41025ms step_avg:87.10ms +step:472/1680 train_time:41112ms step_avg:87.10ms +step:473/1680 train_time:41199ms step_avg:87.10ms +step:474/1680 train_time:41286ms step_avg:87.10ms +step:475/1680 train_time:41372ms step_avg:87.10ms +step:476/1680 train_time:41460ms step_avg:87.10ms +step:477/1680 train_time:41547ms step_avg:87.10ms +step:478/1680 train_time:41634ms step_avg:87.10ms +step:479/1680 train_time:41722ms step_avg:87.10ms +step:480/1680 train_time:41809ms step_avg:87.10ms +step:481/1680 train_time:41896ms step_avg:87.10ms +step:482/1680 train_time:41983ms step_avg:87.10ms +step:483/1680 train_time:42070ms step_avg:87.10ms +step:484/1680 train_time:42157ms step_avg:87.10ms +step:485/1680 train_time:42244ms step_avg:87.10ms +step:486/1680 train_time:42331ms step_avg:87.10ms +step:487/1680 train_time:42418ms step_avg:87.10ms +step:488/1680 train_time:42506ms step_avg:87.10ms +step:489/1680 train_time:42592ms step_avg:87.10ms +step:490/1680 train_time:42679ms step_avg:87.10ms +step:491/1680 train_time:42766ms step_avg:87.10ms +step:492/1680 train_time:42853ms step_avg:87.10ms +step:493/1680 train_time:42940ms step_avg:87.10ms +step:494/1680 train_time:43027ms step_avg:87.10ms +step:495/1680 train_time:43115ms step_avg:87.10ms +step:496/1680 train_time:43202ms step_avg:87.10ms +step:497/1680 train_time:43289ms step_avg:87.10ms +step:498/1680 train_time:43376ms step_avg:87.10ms +step:499/1680 train_time:43464ms step_avg:87.10ms +step:500/1680 train_time:43551ms step_avg:87.10ms +step:500/1680 val_loss:3.7200 train_time:43639ms step_avg:87.28ms +step:501/1680 train_time:43660ms step_avg:87.15ms +step:502/1680 train_time:43732ms step_avg:87.12ms +step:503/1680 train_time:43823ms step_avg:87.12ms +step:504/1680 train_time:43910ms step_avg:87.12ms +step:505/1680 train_time:43996ms step_avg:87.12ms +step:506/1680 train_time:44083ms step_avg:87.12ms +step:507/1680 train_time:44169ms step_avg:87.12ms +step:508/1680 train_time:44255ms step_avg:87.12ms +step:509/1680 train_time:44341ms step_avg:87.11ms +step:510/1680 train_time:44427ms step_avg:87.11ms +step:511/1680 train_time:44513ms step_avg:87.11ms +step:512/1680 train_time:44600ms step_avg:87.11ms +step:513/1680 train_time:44689ms step_avg:87.11ms +step:514/1680 train_time:44777ms step_avg:87.12ms +step:515/1680 train_time:44866ms step_avg:87.12ms +step:516/1680 train_time:44953ms step_avg:87.12ms +step:517/1680 train_time:45040ms step_avg:87.12ms +step:518/1680 train_time:45127ms step_avg:87.12ms +step:519/1680 train_time:45214ms step_avg:87.12ms +step:520/1680 train_time:45300ms step_avg:87.12ms +step:521/1680 train_time:45386ms step_avg:87.11ms +step:522/1680 train_time:45472ms step_avg:87.11ms +step:523/1680 train_time:45559ms step_avg:87.11ms +step:524/1680 train_time:45646ms step_avg:87.11ms +step:525/1680 train_time:45734ms step_avg:87.11ms +step:526/1680 train_time:45823ms step_avg:87.12ms +step:527/1680 train_time:45910ms step_avg:87.12ms +step:528/1680 train_time:45998ms step_avg:87.12ms +step:529/1680 train_time:46085ms step_avg:87.12ms +step:530/1680 train_time:46171ms step_avg:87.12ms +step:531/1680 train_time:46258ms step_avg:87.11ms +step:532/1680 train_time:46345ms step_avg:87.11ms +step:533/1680 train_time:46431ms step_avg:87.11ms +step:534/1680 train_time:46518ms step_avg:87.11ms +step:535/1680 train_time:46605ms step_avg:87.11ms +step:536/1680 train_time:46692ms step_avg:87.11ms +step:537/1680 train_time:46780ms step_avg:87.11ms +step:538/1680 train_time:46867ms step_avg:87.11ms +step:539/1680 train_time:46955ms step_avg:87.11ms +step:540/1680 train_time:47042ms step_avg:87.11ms +step:541/1680 train_time:47129ms step_avg:87.11ms +step:542/1680 train_time:47215ms step_avg:87.11ms +step:543/1680 train_time:47302ms step_avg:87.11ms +step:544/1680 train_time:47388ms step_avg:87.11ms +step:545/1680 train_time:47474ms step_avg:87.11ms +step:546/1680 train_time:47561ms step_avg:87.11ms +step:547/1680 train_time:47649ms step_avg:87.11ms +step:548/1680 train_time:47737ms step_avg:87.11ms +step:549/1680 train_time:47826ms step_avg:87.11ms +step:550/1680 train_time:47914ms step_avg:87.12ms +step:551/1680 train_time:48003ms step_avg:87.12ms +step:552/1680 train_time:48091ms step_avg:87.12ms +step:553/1680 train_time:48179ms step_avg:87.12ms +step:554/1680 train_time:48267ms step_avg:87.12ms +step:555/1680 train_time:48355ms step_avg:87.13ms +step:556/1680 train_time:48443ms step_avg:87.13ms +step:557/1680 train_time:48531ms step_avg:87.13ms +step:558/1680 train_time:48620ms step_avg:87.13ms +step:559/1680 train_time:48708ms step_avg:87.13ms +step:560/1680 train_time:48796ms step_avg:87.14ms +step:561/1680 train_time:48884ms step_avg:87.14ms +step:562/1680 train_time:48972ms step_avg:87.14ms +step:563/1680 train_time:49061ms step_avg:87.14ms +step:564/1680 train_time:49149ms step_avg:87.14ms +step:565/1680 train_time:49237ms step_avg:87.14ms +step:566/1680 train_time:49325ms step_avg:87.15ms +step:567/1680 train_time:49412ms step_avg:87.15ms +step:568/1680 train_time:49501ms step_avg:87.15ms +step:569/1680 train_time:49589ms step_avg:87.15ms +step:570/1680 train_time:49678ms step_avg:87.15ms +step:571/1680 train_time:49766ms step_avg:87.16ms +step:572/1680 train_time:49854ms step_avg:87.16ms +step:573/1680 train_time:49942ms step_avg:87.16ms +step:574/1680 train_time:50030ms step_avg:87.16ms +step:575/1680 train_time:50119ms step_avg:87.16ms +step:576/1680 train_time:50207ms step_avg:87.16ms +step:577/1680 train_time:50295ms step_avg:87.17ms +step:578/1680 train_time:50383ms step_avg:87.17ms +step:579/1680 train_time:50470ms step_avg:87.17ms +step:580/1680 train_time:50558ms step_avg:87.17ms +step:581/1680 train_time:50646ms step_avg:87.17ms +step:582/1680 train_time:50735ms step_avg:87.17ms +step:583/1680 train_time:50825ms step_avg:87.18ms +step:584/1680 train_time:50913ms step_avg:87.18ms +step:585/1680 train_time:51001ms step_avg:87.18ms +step:586/1680 train_time:51090ms step_avg:87.18ms +step:587/1680 train_time:51178ms step_avg:87.19ms +step:588/1680 train_time:51266ms step_avg:87.19ms +step:589/1680 train_time:51354ms step_avg:87.19ms +step:590/1680 train_time:51441ms step_avg:87.19ms +step:591/1680 train_time:51530ms step_avg:87.19ms +step:592/1680 train_time:51618ms step_avg:87.19ms +step:593/1680 train_time:51707ms step_avg:87.19ms +step:594/1680 train_time:51795ms step_avg:87.20ms +step:595/1680 train_time:51884ms step_avg:87.20ms +step:596/1680 train_time:51972ms step_avg:87.20ms +step:597/1680 train_time:52060ms step_avg:87.20ms +step:598/1680 train_time:52148ms step_avg:87.20ms +step:599/1680 train_time:52236ms step_avg:87.21ms +step:600/1680 train_time:52325ms step_avg:87.21ms +step:601/1680 train_time:52412ms step_avg:87.21ms +step:602/1680 train_time:52500ms step_avg:87.21ms +step:603/1680 train_time:52589ms step_avg:87.21ms +step:604/1680 train_time:52676ms step_avg:87.21ms +step:605/1680 train_time:52765ms step_avg:87.21ms +step:606/1680 train_time:52852ms step_avg:87.22ms +step:607/1680 train_time:52941ms step_avg:87.22ms +step:608/1680 train_time:53028ms step_avg:87.22ms +step:609/1680 train_time:53116ms step_avg:87.22ms +step:610/1680 train_time:53204ms step_avg:87.22ms +step:611/1680 train_time:53292ms step_avg:87.22ms +step:612/1680 train_time:53380ms step_avg:87.22ms +step:613/1680 train_time:53468ms step_avg:87.22ms +step:614/1680 train_time:53556ms step_avg:87.22ms +step:615/1680 train_time:53644ms step_avg:87.23ms +step:616/1680 train_time:53732ms step_avg:87.23ms +step:617/1680 train_time:53820ms step_avg:87.23ms +step:618/1680 train_time:53908ms step_avg:87.23ms +step:619/1680 train_time:53996ms step_avg:87.23ms +step:620/1680 train_time:54085ms step_avg:87.23ms +step:621/1680 train_time:54172ms step_avg:87.23ms +step:622/1680 train_time:54261ms step_avg:87.24ms +step:623/1680 train_time:54349ms step_avg:87.24ms +step:624/1680 train_time:54437ms step_avg:87.24ms +step:625/1680 train_time:54525ms step_avg:87.24ms +step:625/1680 val_loss:3.6190 train_time:54615ms step_avg:87.38ms +step:626/1680 train_time:54636ms step_avg:87.28ms +step:627/1680 train_time:54703ms step_avg:87.25ms +step:628/1680 train_time:54791ms step_avg:87.25ms +step:629/1680 train_time:54883ms step_avg:87.25ms +step:630/1680 train_time:54971ms step_avg:87.26ms +step:631/1680 train_time:55057ms step_avg:87.25ms +step:632/1680 train_time:55144ms step_avg:87.25ms +step:633/1680 train_time:55231ms step_avg:87.25ms +step:634/1680 train_time:55318ms step_avg:87.25ms +step:635/1680 train_time:55405ms step_avg:87.25ms +step:636/1680 train_time:55492ms step_avg:87.25ms +step:637/1680 train_time:55584ms step_avg:87.26ms +step:638/1680 train_time:55675ms step_avg:87.27ms +step:639/1680 train_time:55763ms step_avg:87.27ms +step:640/1680 train_time:55852ms step_avg:87.27ms +step:641/1680 train_time:55940ms step_avg:87.27ms +step:642/1680 train_time:56027ms step_avg:87.27ms +step:643/1680 train_time:56115ms step_avg:87.27ms +step:644/1680 train_time:56202ms step_avg:87.27ms +step:645/1680 train_time:56290ms step_avg:87.27ms +step:646/1680 train_time:56377ms step_avg:87.27ms +step:647/1680 train_time:56464ms step_avg:87.27ms +step:648/1680 train_time:56554ms step_avg:87.28ms +step:649/1680 train_time:56644ms step_avg:87.28ms +step:650/1680 train_time:56733ms step_avg:87.28ms +step:651/1680 train_time:56822ms step_avg:87.28ms +step:652/1680 train_time:56910ms step_avg:87.29ms +step:653/1680 train_time:56998ms step_avg:87.29ms +step:654/1680 train_time:57086ms step_avg:87.29ms +step:655/1680 train_time:57173ms step_avg:87.29ms +step:656/1680 train_time:57261ms step_avg:87.29ms +step:657/1680 train_time:57348ms step_avg:87.29ms +step:658/1680 train_time:57436ms step_avg:87.29ms +step:659/1680 train_time:57525ms step_avg:87.29ms +step:660/1680 train_time:57614ms step_avg:87.29ms +step:661/1680 train_time:57702ms step_avg:87.30ms +step:662/1680 train_time:57791ms step_avg:87.30ms +step:663/1680 train_time:57880ms step_avg:87.30ms +step:664/1680 train_time:57967ms step_avg:87.30ms +step:665/1680 train_time:58055ms step_avg:87.30ms +step:666/1680 train_time:58143ms step_avg:87.30ms +step:667/1680 train_time:58231ms step_avg:87.30ms +step:668/1680 train_time:58318ms step_avg:87.30ms +step:669/1680 train_time:58406ms step_avg:87.30ms +step:670/1680 train_time:58494ms step_avg:87.30ms +step:671/1680 train_time:58582ms step_avg:87.31ms +step:672/1680 train_time:58670ms step_avg:87.31ms +step:673/1680 train_time:58760ms step_avg:87.31ms +step:674/1680 train_time:58848ms step_avg:87.31ms +step:675/1680 train_time:58936ms step_avg:87.31ms +step:676/1680 train_time:59023ms step_avg:87.31ms +step:677/1680 train_time:59111ms step_avg:87.31ms +step:678/1680 train_time:59198ms step_avg:87.31ms +step:679/1680 train_time:59286ms step_avg:87.31ms +step:680/1680 train_time:59374ms step_avg:87.31ms +step:681/1680 train_time:59462ms step_avg:87.32ms +step:682/1680 train_time:59551ms step_avg:87.32ms +step:683/1680 train_time:59639ms step_avg:87.32ms +step:684/1680 train_time:59728ms step_avg:87.32ms +step:685/1680 train_time:59818ms step_avg:87.33ms +step:686/1680 train_time:59906ms step_avg:87.33ms +step:687/1680 train_time:59994ms step_avg:87.33ms +step:688/1680 train_time:60082ms step_avg:87.33ms +step:689/1680 train_time:60169ms step_avg:87.33ms +step:690/1680 train_time:60257ms step_avg:87.33ms +step:691/1680 train_time:60345ms step_avg:87.33ms +step:692/1680 train_time:60433ms step_avg:87.33ms +step:693/1680 train_time:60521ms step_avg:87.33ms +step:694/1680 train_time:60608ms step_avg:87.33ms +step:695/1680 train_time:60697ms step_avg:87.33ms +step:696/1680 train_time:60785ms step_avg:87.34ms +step:697/1680 train_time:60874ms step_avg:87.34ms +step:698/1680 train_time:60962ms step_avg:87.34ms +step:699/1680 train_time:61050ms step_avg:87.34ms +step:700/1680 train_time:61139ms step_avg:87.34ms +step:701/1680 train_time:61227ms step_avg:87.34ms +step:702/1680 train_time:61317ms step_avg:87.35ms +step:703/1680 train_time:61404ms step_avg:87.35ms +step:704/1680 train_time:61492ms step_avg:87.35ms +step:705/1680 train_time:61579ms step_avg:87.35ms +step:706/1680 train_time:61667ms step_avg:87.35ms +step:707/1680 train_time:61756ms step_avg:87.35ms +step:708/1680 train_time:61844ms step_avg:87.35ms +step:709/1680 train_time:61932ms step_avg:87.35ms +step:710/1680 train_time:62020ms step_avg:87.35ms +step:711/1680 train_time:62108ms step_avg:87.35ms +step:712/1680 train_time:62196ms step_avg:87.35ms +step:713/1680 train_time:62284ms step_avg:87.35ms +step:714/1680 train_time:62372ms step_avg:87.36ms +step:715/1680 train_time:62460ms step_avg:87.36ms +step:716/1680 train_time:62548ms step_avg:87.36ms +step:717/1680 train_time:62636ms step_avg:87.36ms +step:718/1680 train_time:62724ms step_avg:87.36ms +step:719/1680 train_time:62812ms step_avg:87.36ms +step:720/1680 train_time:62900ms step_avg:87.36ms +step:721/1680 train_time:62988ms step_avg:87.36ms +step:722/1680 train_time:63076ms step_avg:87.36ms +step:723/1680 train_time:63164ms step_avg:87.36ms +step:724/1680 train_time:63252ms step_avg:87.36ms +step:725/1680 train_time:63340ms step_avg:87.37ms +step:726/1680 train_time:63428ms step_avg:87.37ms +step:727/1680 train_time:63517ms step_avg:87.37ms +step:728/1680 train_time:63604ms step_avg:87.37ms +step:729/1680 train_time:63692ms step_avg:87.37ms +step:730/1680 train_time:63780ms step_avg:87.37ms +step:731/1680 train_time:63868ms step_avg:87.37ms +step:732/1680 train_time:63956ms step_avg:87.37ms +step:733/1680 train_time:64045ms step_avg:87.37ms +step:734/1680 train_time:64133ms step_avg:87.37ms +step:735/1680 train_time:64221ms step_avg:87.38ms +step:736/1680 train_time:64309ms step_avg:87.38ms +step:737/1680 train_time:64397ms step_avg:87.38ms +step:738/1680 train_time:64485ms step_avg:87.38ms +step:739/1680 train_time:64573ms step_avg:87.38ms +step:740/1680 train_time:64661ms step_avg:87.38ms +step:741/1680 train_time:64749ms step_avg:87.38ms +step:742/1680 train_time:64838ms step_avg:87.38ms +step:743/1680 train_time:64926ms step_avg:87.38ms +step:744/1680 train_time:65014ms step_avg:87.38ms +step:745/1680 train_time:65102ms step_avg:87.39ms +step:746/1680 train_time:65190ms step_avg:87.39ms +step:747/1680 train_time:65278ms step_avg:87.39ms +step:748/1680 train_time:65367ms step_avg:87.39ms +step:749/1680 train_time:65455ms step_avg:87.39ms +step:750/1680 train_time:65543ms step_avg:87.39ms +step:750/1680 val_loss:3.5670 train_time:65633ms step_avg:87.51ms +step:751/1680 train_time:65653ms step_avg:87.42ms +step:752/1680 train_time:65722ms step_avg:87.40ms +step:753/1680 train_time:65818ms step_avg:87.41ms +step:754/1680 train_time:65909ms step_avg:87.41ms +step:755/1680 train_time:65997ms step_avg:87.41ms +step:756/1680 train_time:66084ms step_avg:87.41ms +step:757/1680 train_time:66172ms step_avg:87.41ms +step:758/1680 train_time:66258ms step_avg:87.41ms +step:759/1680 train_time:66345ms step_avg:87.41ms +step:760/1680 train_time:66432ms step_avg:87.41ms +step:761/1680 train_time:66519ms step_avg:87.41ms +step:762/1680 train_time:66608ms step_avg:87.41ms +step:763/1680 train_time:66698ms step_avg:87.42ms +step:764/1680 train_time:66789ms step_avg:87.42ms +step:765/1680 train_time:66877ms step_avg:87.42ms +step:766/1680 train_time:66966ms step_avg:87.42ms +step:767/1680 train_time:67054ms step_avg:87.42ms +step:768/1680 train_time:67141ms step_avg:87.42ms +step:769/1680 train_time:67229ms step_avg:87.42ms +step:770/1680 train_time:67316ms step_avg:87.42ms +step:771/1680 train_time:67404ms step_avg:87.42ms +step:772/1680 train_time:67491ms step_avg:87.42ms +step:773/1680 train_time:67579ms step_avg:87.42ms +step:774/1680 train_time:67668ms step_avg:87.43ms +step:775/1680 train_time:67757ms step_avg:87.43ms +step:776/1680 train_time:67847ms step_avg:87.43ms +step:777/1680 train_time:67935ms step_avg:87.43ms +step:778/1680 train_time:68024ms step_avg:87.43ms +step:779/1680 train_time:68112ms step_avg:87.44ms +step:780/1680 train_time:68200ms step_avg:87.44ms +step:781/1680 train_time:68287ms step_avg:87.44ms +step:782/1680 train_time:68374ms step_avg:87.44ms +step:783/1680 train_time:68462ms step_avg:87.44ms +step:784/1680 train_time:68550ms step_avg:87.44ms +step:785/1680 train_time:68638ms step_avg:87.44ms +step:786/1680 train_time:68726ms step_avg:87.44ms +step:787/1680 train_time:68815ms step_avg:87.44ms +step:788/1680 train_time:68903ms step_avg:87.44ms +step:789/1680 train_time:68993ms step_avg:87.44ms +step:790/1680 train_time:69081ms step_avg:87.44ms +step:791/1680 train_time:69169ms step_avg:87.45ms +step:792/1680 train_time:69257ms step_avg:87.45ms +step:793/1680 train_time:69345ms step_avg:87.45ms +step:794/1680 train_time:69433ms step_avg:87.45ms +step:795/1680 train_time:69520ms step_avg:87.45ms +step:796/1680 train_time:69609ms step_avg:87.45ms +step:797/1680 train_time:69696ms step_avg:87.45ms +step:798/1680 train_time:69785ms step_avg:87.45ms +step:799/1680 train_time:69873ms step_avg:87.45ms +step:800/1680 train_time:69962ms step_avg:87.45ms +step:801/1680 train_time:70051ms step_avg:87.45ms +step:802/1680 train_time:70139ms step_avg:87.46ms +step:803/1680 train_time:70227ms step_avg:87.46ms +step:804/1680 train_time:70314ms step_avg:87.46ms +step:805/1680 train_time:70402ms step_avg:87.46ms +step:806/1680 train_time:70490ms step_avg:87.46ms +step:807/1680 train_time:70579ms step_avg:87.46ms +step:808/1680 train_time:70667ms step_avg:87.46ms +step:809/1680 train_time:70755ms step_avg:87.46ms +step:810/1680 train_time:70843ms step_avg:87.46ms +step:811/1680 train_time:70932ms step_avg:87.46ms +step:812/1680 train_time:71021ms step_avg:87.46ms +step:813/1680 train_time:71110ms step_avg:87.47ms +step:814/1680 train_time:71198ms step_avg:87.47ms +step:815/1680 train_time:71285ms step_avg:87.47ms +step:816/1680 train_time:71373ms step_avg:87.47ms +step:817/1680 train_time:71461ms step_avg:87.47ms +step:818/1680 train_time:71549ms step_avg:87.47ms +step:819/1680 train_time:71637ms step_avg:87.47ms +step:820/1680 train_time:71725ms step_avg:87.47ms +step:821/1680 train_time:71813ms step_avg:87.47ms +step:822/1680 train_time:71901ms step_avg:87.47ms +step:823/1680 train_time:71989ms step_avg:87.47ms +step:824/1680 train_time:72078ms step_avg:87.47ms +step:825/1680 train_time:72166ms step_avg:87.47ms +step:826/1680 train_time:72253ms step_avg:87.47ms +step:827/1680 train_time:72341ms step_avg:87.47ms +step:828/1680 train_time:72429ms step_avg:87.47ms +step:829/1680 train_time:72517ms step_avg:87.48ms +step:830/1680 train_time:72605ms step_avg:87.48ms +step:831/1680 train_time:72694ms step_avg:87.48ms +step:832/1680 train_time:72782ms step_avg:87.48ms +step:833/1680 train_time:72871ms step_avg:87.48ms +step:834/1680 train_time:72959ms step_avg:87.48ms +step:835/1680 train_time:73047ms step_avg:87.48ms +step:836/1680 train_time:73135ms step_avg:87.48ms +step:837/1680 train_time:73223ms step_avg:87.48ms +step:838/1680 train_time:73311ms step_avg:87.48ms +step:839/1680 train_time:73399ms step_avg:87.48ms +step:840/1680 train_time:73486ms step_avg:87.48ms +step:841/1680 train_time:73575ms step_avg:87.48ms +step:842/1680 train_time:73663ms step_avg:87.49ms +step:843/1680 train_time:73751ms step_avg:87.49ms +step:844/1680 train_time:73838ms step_avg:87.49ms +step:845/1680 train_time:73926ms step_avg:87.49ms +step:846/1680 train_time:74014ms step_avg:87.49ms +step:847/1680 train_time:74103ms step_avg:87.49ms +step:848/1680 train_time:74191ms step_avg:87.49ms +step:849/1680 train_time:74279ms step_avg:87.49ms +step:850/1680 train_time:74367ms step_avg:87.49ms +step:851/1680 train_time:74455ms step_avg:87.49ms +step:852/1680 train_time:74543ms step_avg:87.49ms +step:853/1680 train_time:74630ms step_avg:87.49ms +step:854/1680 train_time:74719ms step_avg:87.49ms +step:855/1680 train_time:74807ms step_avg:87.49ms +step:856/1680 train_time:74895ms step_avg:87.49ms +step:857/1680 train_time:74983ms step_avg:87.49ms +step:858/1680 train_time:75072ms step_avg:87.50ms +step:859/1680 train_time:75160ms step_avg:87.50ms +step:860/1680 train_time:75248ms step_avg:87.50ms +step:861/1680 train_time:75336ms step_avg:87.50ms +step:862/1680 train_time:75424ms step_avg:87.50ms +step:863/1680 train_time:75512ms step_avg:87.50ms +step:864/1680 train_time:75600ms step_avg:87.50ms +step:865/1680 train_time:75689ms step_avg:87.50ms +step:866/1680 train_time:75777ms step_avg:87.50ms +step:867/1680 train_time:75865ms step_avg:87.50ms +step:868/1680 train_time:75953ms step_avg:87.50ms +step:869/1680 train_time:76042ms step_avg:87.51ms +step:870/1680 train_time:76130ms step_avg:87.51ms +step:871/1680 train_time:76218ms step_avg:87.51ms +step:872/1680 train_time:76307ms step_avg:87.51ms +step:873/1680 train_time:76394ms step_avg:87.51ms +step:874/1680 train_time:76482ms step_avg:87.51ms +step:875/1680 train_time:76571ms step_avg:87.51ms +step:875/1680 val_loss:3.5182 train_time:76660ms step_avg:87.61ms +step:876/1680 train_time:76680ms step_avg:87.53ms +step:877/1680 train_time:76752ms step_avg:87.52ms +step:878/1680 train_time:76843ms step_avg:87.52ms +step:879/1680 train_time:76932ms step_avg:87.52ms +step:880/1680 train_time:77020ms step_avg:87.52ms +step:881/1680 train_time:77108ms step_avg:87.52ms +step:882/1680 train_time:77195ms step_avg:87.52ms +step:883/1680 train_time:77282ms step_avg:87.52ms +step:884/1680 train_time:77369ms step_avg:87.52ms +step:885/1680 train_time:77456ms step_avg:87.52ms +step:886/1680 train_time:77543ms step_avg:87.52ms +step:887/1680 train_time:77632ms step_avg:87.52ms +step:888/1680 train_time:77722ms step_avg:87.52ms +step:889/1680 train_time:77811ms step_avg:87.53ms +step:890/1680 train_time:77902ms step_avg:87.53ms +step:891/1680 train_time:77991ms step_avg:87.53ms +step:892/1680 train_time:78079ms step_avg:87.53ms +step:893/1680 train_time:78167ms step_avg:87.53ms +step:894/1680 train_time:78254ms step_avg:87.53ms +step:895/1680 train_time:78341ms step_avg:87.53ms +step:896/1680 train_time:78429ms step_avg:87.53ms +step:897/1680 train_time:78516ms step_avg:87.53ms +step:898/1680 train_time:78605ms step_avg:87.53ms +step:899/1680 train_time:78694ms step_avg:87.53ms +step:900/1680 train_time:78784ms step_avg:87.54ms +step:901/1680 train_time:78873ms step_avg:87.54ms +step:902/1680 train_time:78962ms step_avg:87.54ms +step:903/1680 train_time:79050ms step_avg:87.54ms +step:904/1680 train_time:79137ms step_avg:87.54ms +step:905/1680 train_time:79225ms step_avg:87.54ms +step:906/1680 train_time:79313ms step_avg:87.54ms +step:907/1680 train_time:79400ms step_avg:87.54ms +step:908/1680 train_time:79487ms step_avg:87.54ms +step:909/1680 train_time:79575ms step_avg:87.54ms +step:910/1680 train_time:79664ms step_avg:87.54ms +step:911/1680 train_time:79752ms step_avg:87.54ms +step:912/1680 train_time:79841ms step_avg:87.54ms +step:913/1680 train_time:79929ms step_avg:87.55ms +step:914/1680 train_time:80018ms step_avg:87.55ms +step:915/1680 train_time:80106ms step_avg:87.55ms +step:916/1680 train_time:80194ms step_avg:87.55ms +step:917/1680 train_time:80282ms step_avg:87.55ms +step:918/1680 train_time:80369ms step_avg:87.55ms +step:919/1680 train_time:80457ms step_avg:87.55ms +step:920/1680 train_time:80545ms step_avg:87.55ms +step:921/1680 train_time:80633ms step_avg:87.55ms +step:922/1680 train_time:80721ms step_avg:87.55ms +step:923/1680 train_time:80809ms step_avg:87.55ms +step:924/1680 train_time:80897ms step_avg:87.55ms +step:925/1680 train_time:80985ms step_avg:87.55ms +step:926/1680 train_time:81074ms step_avg:87.55ms +step:927/1680 train_time:81163ms step_avg:87.55ms +step:928/1680 train_time:81250ms step_avg:87.55ms +step:929/1680 train_time:81338ms step_avg:87.55ms +step:930/1680 train_time:81426ms step_avg:87.55ms +step:931/1680 train_time:81515ms step_avg:87.56ms +step:932/1680 train_time:81603ms step_avg:87.56ms +step:933/1680 train_time:81691ms step_avg:87.56ms +step:934/1680 train_time:81780ms step_avg:87.56ms +step:935/1680 train_time:81868ms step_avg:87.56ms +step:936/1680 train_time:81957ms step_avg:87.56ms +step:937/1680 train_time:82045ms step_avg:87.56ms +step:938/1680 train_time:82133ms step_avg:87.56ms +step:939/1680 train_time:82222ms step_avg:87.56ms +step:940/1680 train_time:82309ms step_avg:87.56ms +step:941/1680 train_time:82396ms step_avg:87.56ms +step:942/1680 train_time:82484ms step_avg:87.56ms +step:943/1680 train_time:82572ms step_avg:87.56ms +step:944/1680 train_time:82660ms step_avg:87.56ms +step:945/1680 train_time:82748ms step_avg:87.56ms +step:946/1680 train_time:82836ms step_avg:87.56ms +step:947/1680 train_time:82926ms step_avg:87.57ms +step:948/1680 train_time:83014ms step_avg:87.57ms +step:949/1680 train_time:83103ms step_avg:87.57ms +step:950/1680 train_time:83190ms step_avg:87.57ms +step:951/1680 train_time:83278ms step_avg:87.57ms +step:952/1680 train_time:83366ms step_avg:87.57ms +step:953/1680 train_time:83454ms step_avg:87.57ms +step:954/1680 train_time:83541ms step_avg:87.57ms +step:955/1680 train_time:83629ms step_avg:87.57ms +step:956/1680 train_time:83718ms step_avg:87.57ms +step:957/1680 train_time:83807ms step_avg:87.57ms +step:958/1680 train_time:83896ms step_avg:87.57ms +step:959/1680 train_time:83984ms step_avg:87.57ms +step:960/1680 train_time:84073ms step_avg:87.58ms +step:961/1680 train_time:84161ms step_avg:87.58ms +step:962/1680 train_time:84249ms step_avg:87.58ms +step:963/1680 train_time:84337ms step_avg:87.58ms +step:964/1680 train_time:84425ms step_avg:87.58ms +step:965/1680 train_time:84512ms step_avg:87.58ms +step:966/1680 train_time:84600ms step_avg:87.58ms +step:967/1680 train_time:84688ms step_avg:87.58ms +step:968/1680 train_time:84777ms step_avg:87.58ms +step:969/1680 train_time:84866ms step_avg:87.58ms +step:970/1680 train_time:84954ms step_avg:87.58ms +step:971/1680 train_time:85043ms step_avg:87.58ms +step:972/1680 train_time:85131ms step_avg:87.58ms +step:973/1680 train_time:85219ms step_avg:87.58ms +step:974/1680 train_time:85306ms step_avg:87.58ms +step:975/1680 train_time:85395ms step_avg:87.58ms +step:976/1680 train_time:85482ms step_avg:87.58ms +step:977/1680 train_time:85570ms step_avg:87.58ms +step:978/1680 train_time:85658ms step_avg:87.59ms +step:979/1680 train_time:85747ms step_avg:87.59ms +step:980/1680 train_time:85835ms step_avg:87.59ms +step:981/1680 train_time:85924ms step_avg:87.59ms +step:982/1680 train_time:86012ms step_avg:87.59ms +step:983/1680 train_time:86100ms step_avg:87.59ms +step:984/1680 train_time:86187ms step_avg:87.59ms +step:985/1680 train_time:86275ms step_avg:87.59ms +step:986/1680 train_time:86364ms step_avg:87.59ms +step:987/1680 train_time:86451ms step_avg:87.59ms +step:988/1680 train_time:86539ms step_avg:87.59ms +step:989/1680 train_time:86627ms step_avg:87.59ms +step:990/1680 train_time:86715ms step_avg:87.59ms +step:991/1680 train_time:86803ms step_avg:87.59ms +step:992/1680 train_time:86892ms step_avg:87.59ms +step:993/1680 train_time:86981ms step_avg:87.59ms +step:994/1680 train_time:87069ms step_avg:87.59ms +step:995/1680 train_time:87157ms step_avg:87.60ms +step:996/1680 train_time:87246ms step_avg:87.60ms +step:997/1680 train_time:87334ms step_avg:87.60ms +step:998/1680 train_time:87422ms step_avg:87.60ms +step:999/1680 train_time:87509ms step_avg:87.60ms +step:1000/1680 train_time:87596ms step_avg:87.60ms +step:1000/1680 val_loss:3.4693 train_time:87686ms step_avg:87.69ms +step:1001/1680 train_time:87705ms step_avg:87.62ms +step:1002/1680 train_time:87778ms step_avg:87.60ms +step:1003/1680 train_time:87873ms step_avg:87.61ms +step:1004/1680 train_time:87963ms step_avg:87.61ms +step:1005/1680 train_time:88052ms step_avg:87.61ms +step:1006/1680 train_time:88139ms step_avg:87.61ms +step:1007/1680 train_time:88226ms step_avg:87.61ms +step:1008/1680 train_time:88313ms step_avg:87.61ms +step:1009/1680 train_time:88400ms step_avg:87.61ms +step:1010/1680 train_time:88487ms step_avg:87.61ms +step:1011/1680 train_time:88575ms step_avg:87.61ms +step:1012/1680 train_time:88663ms step_avg:87.61ms +step:1013/1680 train_time:88752ms step_avg:87.61ms +step:1014/1680 train_time:88843ms step_avg:87.62ms +step:1015/1680 train_time:88934ms step_avg:87.62ms +step:1016/1680 train_time:89023ms step_avg:87.62ms +step:1017/1680 train_time:89111ms step_avg:87.62ms +step:1018/1680 train_time:89199ms step_avg:87.62ms +step:1019/1680 train_time:89287ms step_avg:87.62ms +step:1020/1680 train_time:89374ms step_avg:87.62ms +step:1021/1680 train_time:89461ms step_avg:87.62ms +step:1022/1680 train_time:89548ms step_avg:87.62ms +step:1023/1680 train_time:89636ms step_avg:87.62ms +step:1024/1680 train_time:89724ms step_avg:87.62ms +step:1025/1680 train_time:89813ms step_avg:87.62ms +step:1026/1680 train_time:89902ms step_avg:87.62ms +step:1027/1680 train_time:89991ms step_avg:87.63ms +step:1028/1680 train_time:90080ms step_avg:87.63ms +step:1029/1680 train_time:90168ms step_avg:87.63ms +step:1030/1680 train_time:90256ms step_avg:87.63ms +step:1031/1680 train_time:90343ms step_avg:87.63ms +step:1032/1680 train_time:90430ms step_avg:87.63ms +step:1033/1680 train_time:90518ms step_avg:87.63ms +step:1034/1680 train_time:90606ms step_avg:87.63ms +step:1035/1680 train_time:90693ms step_avg:87.63ms +step:1036/1680 train_time:90782ms step_avg:87.63ms +step:1037/1680 train_time:90871ms step_avg:87.63ms +step:1038/1680 train_time:90961ms step_avg:87.63ms +step:1039/1680 train_time:91049ms step_avg:87.63ms +step:1040/1680 train_time:91138ms step_avg:87.63ms +step:1041/1680 train_time:91226ms step_avg:87.63ms +step:1042/1680 train_time:91313ms step_avg:87.63ms +step:1043/1680 train_time:91402ms step_avg:87.63ms +step:1044/1680 train_time:91490ms step_avg:87.63ms +step:1045/1680 train_time:91577ms step_avg:87.63ms +step:1046/1680 train_time:91665ms step_avg:87.63ms +step:1047/1680 train_time:91753ms step_avg:87.63ms +step:1048/1680 train_time:91842ms step_avg:87.64ms +step:1049/1680 train_time:91930ms step_avg:87.64ms +step:1050/1680 train_time:92019ms step_avg:87.64ms +step:1051/1680 train_time:92108ms step_avg:87.64ms +step:1052/1680 train_time:92196ms step_avg:87.64ms +step:1053/1680 train_time:92284ms step_avg:87.64ms +step:1054/1680 train_time:92371ms step_avg:87.64ms +step:1055/1680 train_time:92459ms step_avg:87.64ms +step:1056/1680 train_time:92546ms step_avg:87.64ms +step:1057/1680 train_time:92634ms step_avg:87.64ms +step:1058/1680 train_time:92721ms step_avg:87.64ms +step:1059/1680 train_time:92810ms step_avg:87.64ms +step:1060/1680 train_time:92899ms step_avg:87.64ms +step:1061/1680 train_time:92987ms step_avg:87.64ms +step:1062/1680 train_time:93076ms step_avg:87.64ms +step:1063/1680 train_time:93164ms step_avg:87.64ms +step:1064/1680 train_time:93252ms step_avg:87.64ms +step:1065/1680 train_time:93341ms step_avg:87.64ms +step:1066/1680 train_time:93428ms step_avg:87.64ms +step:1067/1680 train_time:93516ms step_avg:87.64ms +step:1068/1680 train_time:93604ms step_avg:87.64ms +step:1069/1680 train_time:93693ms step_avg:87.65ms +step:1070/1680 train_time:93781ms step_avg:87.65ms +step:1071/1680 train_time:93870ms step_avg:87.65ms +step:1072/1680 train_time:93959ms step_avg:87.65ms +step:1073/1680 train_time:94047ms step_avg:87.65ms +step:1074/1680 train_time:94136ms step_avg:87.65ms +step:1075/1680 train_time:94224ms step_avg:87.65ms +step:1076/1680 train_time:94312ms step_avg:87.65ms +step:1077/1680 train_time:94401ms step_avg:87.65ms +step:1078/1680 train_time:94489ms step_avg:87.65ms +step:1079/1680 train_time:94577ms step_avg:87.65ms +step:1080/1680 train_time:94666ms step_avg:87.65ms +step:1081/1680 train_time:94753ms step_avg:87.65ms +step:1082/1680 train_time:94841ms step_avg:87.65ms +step:1083/1680 train_time:94929ms step_avg:87.65ms +step:1084/1680 train_time:95018ms step_avg:87.65ms +step:1085/1680 train_time:95106ms step_avg:87.66ms +step:1086/1680 train_time:95194ms step_avg:87.66ms +step:1087/1680 train_time:95282ms step_avg:87.66ms +step:1088/1680 train_time:95371ms step_avg:87.66ms +step:1089/1680 train_time:95459ms step_avg:87.66ms +step:1090/1680 train_time:95548ms step_avg:87.66ms +step:1091/1680 train_time:95635ms step_avg:87.66ms +step:1092/1680 train_time:95723ms step_avg:87.66ms +step:1093/1680 train_time:95811ms step_avg:87.66ms +step:1094/1680 train_time:95899ms step_avg:87.66ms +step:1095/1680 train_time:95988ms step_avg:87.66ms +step:1096/1680 train_time:96077ms step_avg:87.66ms +step:1097/1680 train_time:96165ms step_avg:87.66ms +step:1098/1680 train_time:96253ms step_avg:87.66ms +step:1099/1680 train_time:96342ms step_avg:87.66ms +step:1100/1680 train_time:96431ms step_avg:87.66ms +step:1101/1680 train_time:96520ms step_avg:87.67ms +step:1102/1680 train_time:96608ms step_avg:87.67ms +step:1103/1680 train_time:96697ms step_avg:87.67ms +step:1104/1680 train_time:96785ms step_avg:87.67ms +step:1105/1680 train_time:96875ms step_avg:87.67ms +step:1106/1680 train_time:96963ms step_avg:87.67ms +step:1107/1680 train_time:97052ms step_avg:87.67ms +step:1108/1680 train_time:97141ms step_avg:87.67ms +step:1109/1680 train_time:97230ms step_avg:87.67ms +step:1110/1680 train_time:97320ms step_avg:87.68ms +step:1111/1680 train_time:97409ms step_avg:87.68ms +step:1112/1680 train_time:97498ms step_avg:87.68ms +step:1113/1680 train_time:97587ms step_avg:87.68ms +step:1114/1680 train_time:97676ms step_avg:87.68ms +step:1115/1680 train_time:97765ms step_avg:87.68ms +step:1116/1680 train_time:97853ms step_avg:87.68ms +step:1117/1680 train_time:97942ms step_avg:87.68ms +step:1118/1680 train_time:98030ms step_avg:87.68ms +step:1119/1680 train_time:98118ms step_avg:87.68ms +step:1120/1680 train_time:98207ms step_avg:87.69ms +step:1121/1680 train_time:98296ms step_avg:87.69ms +step:1122/1680 train_time:98386ms step_avg:87.69ms +step:1123/1680 train_time:98475ms step_avg:87.69ms +step:1124/1680 train_time:98564ms step_avg:87.69ms +step:1125/1680 train_time:98652ms step_avg:87.69ms +step:1125/1680 val_loss:3.4159 train_time:98743ms step_avg:87.77ms +step:1126/1680 train_time:98762ms step_avg:87.71ms +step:1127/1680 train_time:98833ms step_avg:87.70ms +step:1128/1680 train_time:98925ms step_avg:87.70ms +step:1129/1680 train_time:99019ms step_avg:87.71ms +step:1130/1680 train_time:99108ms step_avg:87.71ms +step:1131/1680 train_time:99197ms step_avg:87.71ms +step:1132/1680 train_time:99284ms step_avg:87.71ms +step:1133/1680 train_time:99373ms step_avg:87.71ms +step:1134/1680 train_time:99460ms step_avg:87.71ms +step:1135/1680 train_time:99549ms step_avg:87.71ms +step:1136/1680 train_time:99637ms step_avg:87.71ms +step:1137/1680 train_time:99726ms step_avg:87.71ms +step:1138/1680 train_time:99816ms step_avg:87.71ms +step:1139/1680 train_time:99906ms step_avg:87.71ms +step:1140/1680 train_time:99996ms step_avg:87.72ms +step:1141/1680 train_time:100086ms step_avg:87.72ms +step:1142/1680 train_time:100175ms step_avg:87.72ms +step:1143/1680 train_time:100264ms step_avg:87.72ms +step:1144/1680 train_time:100352ms step_avg:87.72ms +step:1145/1680 train_time:100440ms step_avg:87.72ms +step:1146/1680 train_time:100528ms step_avg:87.72ms +step:1147/1680 train_time:100616ms step_avg:87.72ms +step:1148/1680 train_time:100705ms step_avg:87.72ms +step:1149/1680 train_time:100794ms step_avg:87.72ms +step:1150/1680 train_time:100883ms step_avg:87.72ms +step:1151/1680 train_time:100973ms step_avg:87.73ms +step:1152/1680 train_time:101063ms step_avg:87.73ms +step:1153/1680 train_time:101152ms step_avg:87.73ms +step:1154/1680 train_time:101242ms step_avg:87.73ms +step:1155/1680 train_time:101331ms step_avg:87.73ms +step:1156/1680 train_time:101419ms step_avg:87.73ms +step:1157/1680 train_time:101508ms step_avg:87.73ms +step:1158/1680 train_time:101597ms step_avg:87.73ms +step:1159/1680 train_time:101686ms step_avg:87.74ms +step:1160/1680 train_time:101776ms step_avg:87.74ms +step:1161/1680 train_time:101866ms step_avg:87.74ms +step:1162/1680 train_time:101955ms step_avg:87.74ms +step:1163/1680 train_time:102045ms step_avg:87.74ms +step:1164/1680 train_time:102134ms step_avg:87.74ms +step:1165/1680 train_time:102222ms step_avg:87.74ms +step:1166/1680 train_time:102311ms step_avg:87.75ms +step:1167/1680 train_time:102399ms step_avg:87.75ms +step:1168/1680 train_time:102487ms step_avg:87.75ms +step:1169/1680 train_time:102576ms step_avg:87.75ms +step:1170/1680 train_time:102664ms step_avg:87.75ms +step:1171/1680 train_time:102753ms step_avg:87.75ms +step:1172/1680 train_time:102842ms step_avg:87.75ms +step:1173/1680 train_time:102932ms step_avg:87.75ms +step:1174/1680 train_time:103020ms step_avg:87.75ms +step:1175/1680 train_time:103110ms step_avg:87.75ms +step:1176/1680 train_time:103198ms step_avg:87.75ms +step:1177/1680 train_time:103288ms step_avg:87.75ms +step:1178/1680 train_time:103376ms step_avg:87.76ms +step:1179/1680 train_time:103464ms step_avg:87.76ms +step:1180/1680 train_time:103554ms step_avg:87.76ms +step:1181/1680 train_time:103642ms step_avg:87.76ms +step:1182/1680 train_time:103730ms step_avg:87.76ms +step:1183/1680 train_time:103819ms step_avg:87.76ms +step:1184/1680 train_time:103908ms step_avg:87.76ms +step:1185/1680 train_time:103997ms step_avg:87.76ms +step:1186/1680 train_time:104087ms step_avg:87.76ms +step:1187/1680 train_time:104175ms step_avg:87.76ms +step:1188/1680 train_time:104265ms step_avg:87.77ms +step:1189/1680 train_time:104354ms step_avg:87.77ms +step:1190/1680 train_time:104443ms step_avg:87.77ms +step:1191/1680 train_time:104533ms step_avg:87.77ms +step:1192/1680 train_time:104622ms step_avg:87.77ms +step:1193/1680 train_time:104711ms step_avg:87.77ms +step:1194/1680 train_time:104799ms step_avg:87.77ms +step:1195/1680 train_time:104888ms step_avg:87.77ms +step:1196/1680 train_time:104977ms step_avg:87.77ms +step:1197/1680 train_time:105066ms step_avg:87.77ms +step:1198/1680 train_time:105155ms step_avg:87.78ms +step:1199/1680 train_time:105245ms step_avg:87.78ms +step:1200/1680 train_time:105334ms step_avg:87.78ms +step:1201/1680 train_time:105423ms step_avg:87.78ms +step:1202/1680 train_time:105511ms step_avg:87.78ms +step:1203/1680 train_time:105600ms step_avg:87.78ms +step:1204/1680 train_time:105689ms step_avg:87.78ms +step:1205/1680 train_time:105777ms step_avg:87.78ms +step:1206/1680 train_time:105866ms step_avg:87.78ms +step:1207/1680 train_time:105955ms step_avg:87.78ms +step:1208/1680 train_time:106044ms step_avg:87.78ms +step:1209/1680 train_time:106132ms step_avg:87.78ms +step:1210/1680 train_time:106222ms step_avg:87.79ms +step:1211/1680 train_time:106311ms step_avg:87.79ms +step:1212/1680 train_time:106400ms step_avg:87.79ms +step:1213/1680 train_time:106489ms step_avg:87.79ms +step:1214/1680 train_time:106579ms step_avg:87.79ms +step:1215/1680 train_time:106668ms step_avg:87.79ms +step:1216/1680 train_time:106757ms step_avg:87.79ms +step:1217/1680 train_time:106846ms step_avg:87.79ms +step:1218/1680 train_time:106936ms step_avg:87.80ms +step:1219/1680 train_time:107025ms step_avg:87.80ms +step:1220/1680 train_time:107114ms step_avg:87.80ms +step:1221/1680 train_time:107204ms step_avg:87.80ms +step:1222/1680 train_time:107292ms step_avg:87.80ms +step:1223/1680 train_time:107381ms step_avg:87.80ms +step:1224/1680 train_time:107469ms step_avg:87.80ms +step:1225/1680 train_time:107558ms step_avg:87.80ms +step:1226/1680 train_time:107648ms step_avg:87.80ms +step:1227/1680 train_time:107736ms step_avg:87.80ms +step:1228/1680 train_time:107826ms step_avg:87.81ms +step:1229/1680 train_time:107915ms step_avg:87.81ms +step:1230/1680 train_time:108004ms step_avg:87.81ms +step:1231/1680 train_time:108093ms step_avg:87.81ms +step:1232/1680 train_time:108182ms step_avg:87.81ms +step:1233/1680 train_time:108271ms step_avg:87.81ms +step:1234/1680 train_time:108360ms step_avg:87.81ms +step:1235/1680 train_time:108449ms step_avg:87.81ms +step:1236/1680 train_time:108539ms step_avg:87.81ms +step:1237/1680 train_time:108628ms step_avg:87.82ms +step:1238/1680 train_time:108717ms step_avg:87.82ms +step:1239/1680 train_time:108806ms step_avg:87.82ms +step:1240/1680 train_time:108895ms step_avg:87.82ms +step:1241/1680 train_time:108984ms step_avg:87.82ms +step:1242/1680 train_time:109073ms step_avg:87.82ms +step:1243/1680 train_time:109161ms step_avg:87.82ms +step:1244/1680 train_time:109250ms step_avg:87.82ms +step:1245/1680 train_time:109339ms step_avg:87.82ms +step:1246/1680 train_time:109429ms step_avg:87.82ms +step:1247/1680 train_time:109518ms step_avg:87.82ms +step:1248/1680 train_time:109607ms step_avg:87.83ms +step:1249/1680 train_time:109696ms step_avg:87.83ms +step:1250/1680 train_time:109784ms step_avg:87.83ms +step:1250/1680 val_loss:3.3770 train_time:109875ms step_avg:87.90ms +step:1251/1680 train_time:109894ms step_avg:87.84ms +step:1252/1680 train_time:109968ms step_avg:87.83ms +step:1253/1680 train_time:110059ms step_avg:87.84ms +step:1254/1680 train_time:110148ms step_avg:87.84ms +step:1255/1680 train_time:110236ms step_avg:87.84ms +step:1256/1680 train_time:110324ms step_avg:87.84ms +step:1257/1680 train_time:110412ms step_avg:87.84ms +step:1258/1680 train_time:110500ms step_avg:87.84ms +step:1259/1680 train_time:110588ms step_avg:87.84ms +step:1260/1680 train_time:110676ms step_avg:87.84ms +step:1261/1680 train_time:110765ms step_avg:87.84ms +step:1262/1680 train_time:110855ms step_avg:87.84ms +step:1263/1680 train_time:110946ms step_avg:87.84ms +step:1264/1680 train_time:111036ms step_avg:87.85ms +step:1265/1680 train_time:111126ms step_avg:87.85ms +step:1266/1680 train_time:111214ms step_avg:87.85ms +step:1267/1680 train_time:111303ms step_avg:87.85ms +step:1268/1680 train_time:111391ms step_avg:87.85ms +step:1269/1680 train_time:111479ms step_avg:87.85ms +step:1270/1680 train_time:111568ms step_avg:87.85ms +step:1271/1680 train_time:111655ms step_avg:87.85ms +step:1272/1680 train_time:111744ms step_avg:87.85ms +step:1273/1680 train_time:111833ms step_avg:87.85ms +step:1274/1680 train_time:111923ms step_avg:87.85ms +step:1275/1680 train_time:112013ms step_avg:87.85ms +step:1276/1680 train_time:112103ms step_avg:87.85ms +step:1277/1680 train_time:112192ms step_avg:87.86ms +step:1278/1680 train_time:112280ms step_avg:87.86ms +step:1279/1680 train_time:112369ms step_avg:87.86ms +step:1280/1680 train_time:112457ms step_avg:87.86ms +step:1281/1680 train_time:112545ms step_avg:87.86ms +step:1282/1680 train_time:112633ms step_avg:87.86ms +step:1283/1680 train_time:112722ms step_avg:87.86ms +step:1284/1680 train_time:112811ms step_avg:87.86ms +step:1285/1680 train_time:112901ms step_avg:87.86ms +step:1286/1680 train_time:112991ms step_avg:87.86ms +step:1287/1680 train_time:113081ms step_avg:87.86ms +step:1288/1680 train_time:113170ms step_avg:87.86ms +step:1289/1680 train_time:113259ms step_avg:87.87ms +step:1290/1680 train_time:113348ms step_avg:87.87ms +step:1291/1680 train_time:113438ms step_avg:87.87ms +step:1292/1680 train_time:113525ms step_avg:87.87ms +step:1293/1680 train_time:113613ms step_avg:87.87ms +step:1294/1680 train_time:113703ms step_avg:87.87ms +step:1295/1680 train_time:113791ms step_avg:87.87ms +step:1296/1680 train_time:113881ms step_avg:87.87ms +step:1297/1680 train_time:113970ms step_avg:87.87ms +step:1298/1680 train_time:114059ms step_avg:87.87ms +step:1299/1680 train_time:114148ms step_avg:87.87ms +step:1300/1680 train_time:114237ms step_avg:87.87ms +step:1301/1680 train_time:114326ms step_avg:87.88ms +step:1302/1680 train_time:114415ms step_avg:87.88ms +step:1303/1680 train_time:114503ms step_avg:87.88ms +step:1304/1680 train_time:114592ms step_avg:87.88ms +step:1305/1680 train_time:114682ms step_avg:87.88ms +step:1306/1680 train_time:114771ms step_avg:87.88ms +step:1307/1680 train_time:114859ms step_avg:87.88ms +step:1308/1680 train_time:114949ms step_avg:87.88ms +step:1309/1680 train_time:115038ms step_avg:87.88ms +step:1310/1680 train_time:115127ms step_avg:87.88ms +step:1311/1680 train_time:115217ms step_avg:87.89ms +step:1312/1680 train_time:115307ms step_avg:87.89ms +step:1313/1680 train_time:115397ms step_avg:87.89ms +step:1314/1680 train_time:115485ms step_avg:87.89ms +step:1315/1680 train_time:115574ms step_avg:87.89ms +step:1316/1680 train_time:115664ms step_avg:87.89ms +step:1317/1680 train_time:115752ms step_avg:87.89ms +step:1318/1680 train_time:115841ms step_avg:87.89ms +step:1319/1680 train_time:115930ms step_avg:87.89ms +step:1320/1680 train_time:116018ms step_avg:87.89ms +step:1321/1680 train_time:116108ms step_avg:87.89ms +step:1322/1680 train_time:116198ms step_avg:87.90ms +step:1323/1680 train_time:116288ms step_avg:87.90ms +step:1324/1680 train_time:116377ms step_avg:87.90ms +step:1325/1680 train_time:116465ms step_avg:87.90ms +step:1326/1680 train_time:116554ms step_avg:87.90ms +step:1327/1680 train_time:116642ms step_avg:87.90ms +step:1328/1680 train_time:116731ms step_avg:87.90ms +step:1329/1680 train_time:116820ms step_avg:87.90ms +step:1330/1680 train_time:116910ms step_avg:87.90ms +step:1331/1680 train_time:116999ms step_avg:87.90ms +step:1332/1680 train_time:117089ms step_avg:87.90ms +step:1333/1680 train_time:117178ms step_avg:87.91ms +step:1334/1680 train_time:117267ms step_avg:87.91ms +step:1335/1680 train_time:117356ms step_avg:87.91ms +step:1336/1680 train_time:117445ms step_avg:87.91ms +step:1337/1680 train_time:117534ms step_avg:87.91ms +step:1338/1680 train_time:117622ms step_avg:87.91ms +step:1339/1680 train_time:117711ms step_avg:87.91ms +step:1340/1680 train_time:117800ms step_avg:87.91ms +step:1341/1680 train_time:117889ms step_avg:87.91ms +step:1342/1680 train_time:117978ms step_avg:87.91ms +step:1343/1680 train_time:118069ms step_avg:87.91ms +step:1344/1680 train_time:118157ms step_avg:87.91ms +step:1345/1680 train_time:118246ms step_avg:87.92ms +step:1346/1680 train_time:118335ms step_avg:87.92ms +step:1347/1680 train_time:118424ms step_avg:87.92ms +step:1348/1680 train_time:118513ms step_avg:87.92ms +step:1349/1680 train_time:118603ms step_avg:87.92ms +step:1350/1680 train_time:118691ms step_avg:87.92ms +step:1351/1680 train_time:118780ms step_avg:87.92ms +step:1352/1680 train_time:118869ms step_avg:87.92ms +step:1353/1680 train_time:118957ms step_avg:87.92ms +step:1354/1680 train_time:119047ms step_avg:87.92ms +step:1355/1680 train_time:119136ms step_avg:87.92ms +step:1356/1680 train_time:119225ms step_avg:87.92ms +step:1357/1680 train_time:119313ms step_avg:87.92ms +step:1358/1680 train_time:119402ms step_avg:87.93ms +step:1359/1680 train_time:119491ms step_avg:87.93ms +step:1360/1680 train_time:119581ms step_avg:87.93ms +step:1361/1680 train_time:119669ms step_avg:87.93ms +step:1362/1680 train_time:119758ms step_avg:87.93ms +step:1363/1680 train_time:119847ms step_avg:87.93ms +step:1364/1680 train_time:119935ms step_avg:87.93ms +step:1365/1680 train_time:120025ms step_avg:87.93ms +step:1366/1680 train_time:120113ms step_avg:87.93ms +step:1367/1680 train_time:120202ms step_avg:87.93ms +step:1368/1680 train_time:120292ms step_avg:87.93ms +step:1369/1680 train_time:120381ms step_avg:87.93ms +step:1370/1680 train_time:120470ms step_avg:87.93ms +step:1371/1680 train_time:120559ms step_avg:87.94ms +step:1372/1680 train_time:120647ms step_avg:87.94ms +step:1373/1680 train_time:120736ms step_avg:87.94ms +step:1374/1680 train_time:120825ms step_avg:87.94ms +step:1375/1680 train_time:120914ms step_avg:87.94ms +step:1375/1680 val_loss:3.3426 train_time:121005ms step_avg:88.00ms +step:1376/1680 train_time:121023ms step_avg:87.95ms +step:1377/1680 train_time:121096ms step_avg:87.94ms +step:1378/1680 train_time:121191ms step_avg:87.95ms +step:1379/1680 train_time:121280ms step_avg:87.95ms +step:1380/1680 train_time:121368ms step_avg:87.95ms +step:1381/1680 train_time:121456ms step_avg:87.95ms +step:1382/1680 train_time:121544ms step_avg:87.95ms +step:1383/1680 train_time:121632ms step_avg:87.95ms +step:1384/1680 train_time:121720ms step_avg:87.95ms +step:1385/1680 train_time:121808ms step_avg:87.95ms +step:1386/1680 train_time:121895ms step_avg:87.95ms +step:1387/1680 train_time:121985ms step_avg:87.95ms +step:1388/1680 train_time:122076ms step_avg:87.95ms +step:1389/1680 train_time:122167ms step_avg:87.95ms +step:1390/1680 train_time:122257ms step_avg:87.95ms +step:1391/1680 train_time:122346ms step_avg:87.96ms +step:1392/1680 train_time:122435ms step_avg:87.96ms +step:1393/1680 train_time:122524ms step_avg:87.96ms +step:1394/1680 train_time:122612ms step_avg:87.96ms +step:1395/1680 train_time:122700ms step_avg:87.96ms +step:1396/1680 train_time:122788ms step_avg:87.96ms +step:1397/1680 train_time:122876ms step_avg:87.96ms +step:1398/1680 train_time:122964ms step_avg:87.96ms +step:1399/1680 train_time:123054ms step_avg:87.96ms +step:1400/1680 train_time:123144ms step_avg:87.96ms +step:1401/1680 train_time:123233ms step_avg:87.96ms +step:1402/1680 train_time:123323ms step_avg:87.96ms +step:1403/1680 train_time:123412ms step_avg:87.96ms +step:1404/1680 train_time:123500ms step_avg:87.96ms +step:1405/1680 train_time:123589ms step_avg:87.96ms +step:1406/1680 train_time:123677ms step_avg:87.96ms +step:1407/1680 train_time:123765ms step_avg:87.96ms +step:1408/1680 train_time:123854ms step_avg:87.96ms +step:1409/1680 train_time:123943ms step_avg:87.97ms +step:1410/1680 train_time:124032ms step_avg:87.97ms +step:1411/1680 train_time:124121ms step_avg:87.97ms +step:1412/1680 train_time:124211ms step_avg:87.97ms +step:1413/1680 train_time:124302ms step_avg:87.97ms +step:1414/1680 train_time:124391ms step_avg:87.97ms +step:1415/1680 train_time:124481ms step_avg:87.97ms +step:1416/1680 train_time:124570ms step_avg:87.97ms +step:1417/1680 train_time:124658ms step_avg:87.97ms +step:1418/1680 train_time:124747ms step_avg:87.97ms +step:1419/1680 train_time:124835ms step_avg:87.97ms +step:1420/1680 train_time:124925ms step_avg:87.98ms +step:1421/1680 train_time:125014ms step_avg:87.98ms +step:1422/1680 train_time:125103ms step_avg:87.98ms +step:1423/1680 train_time:125193ms step_avg:87.98ms +step:1424/1680 train_time:125282ms step_avg:87.98ms +step:1425/1680 train_time:125372ms step_avg:87.98ms +step:1426/1680 train_time:125461ms step_avg:87.98ms +step:1427/1680 train_time:125549ms step_avg:87.98ms +step:1428/1680 train_time:125638ms step_avg:87.98ms +step:1429/1680 train_time:125727ms step_avg:87.98ms +step:1430/1680 train_time:125816ms step_avg:87.98ms +step:1431/1680 train_time:125906ms step_avg:87.98ms +step:1432/1680 train_time:125994ms step_avg:87.98ms +step:1433/1680 train_time:126083ms step_avg:87.99ms +step:1434/1680 train_time:126172ms step_avg:87.99ms +step:1435/1680 train_time:126261ms step_avg:87.99ms +step:1436/1680 train_time:126351ms step_avg:87.99ms +step:1437/1680 train_time:126441ms step_avg:87.99ms +step:1438/1680 train_time:126530ms step_avg:87.99ms +step:1439/1680 train_time:126619ms step_avg:87.99ms +step:1440/1680 train_time:126709ms step_avg:87.99ms +step:1441/1680 train_time:126797ms step_avg:87.99ms +step:1442/1680 train_time:126886ms step_avg:87.99ms +step:1443/1680 train_time:126974ms step_avg:87.99ms +step:1444/1680 train_time:127063ms step_avg:87.99ms +step:1445/1680 train_time:127152ms step_avg:87.99ms +step:1446/1680 train_time:127241ms step_avg:87.99ms +step:1447/1680 train_time:127331ms step_avg:88.00ms +step:1448/1680 train_time:127420ms step_avg:88.00ms +step:1449/1680 train_time:127510ms step_avg:88.00ms +step:1450/1680 train_time:127599ms step_avg:88.00ms +step:1451/1680 train_time:127689ms step_avg:88.00ms +step:1452/1680 train_time:127778ms step_avg:88.00ms +step:1453/1680 train_time:127867ms step_avg:88.00ms +step:1454/1680 train_time:127955ms step_avg:88.00ms +step:1455/1680 train_time:128044ms step_avg:88.00ms +step:1456/1680 train_time:128133ms step_avg:88.00ms +step:1457/1680 train_time:128222ms step_avg:88.00ms +step:1458/1680 train_time:128311ms step_avg:88.00ms +step:1459/1680 train_time:128400ms step_avg:88.01ms +step:1460/1680 train_time:128489ms step_avg:88.01ms +step:1461/1680 train_time:128577ms step_avg:88.01ms +step:1462/1680 train_time:128667ms step_avg:88.01ms +step:1463/1680 train_time:128755ms step_avg:88.01ms +step:1464/1680 train_time:128844ms step_avg:88.01ms +step:1465/1680 train_time:128933ms step_avg:88.01ms +step:1466/1680 train_time:129022ms step_avg:88.01ms +step:1467/1680 train_time:129111ms step_avg:88.01ms +step:1468/1680 train_time:129200ms step_avg:88.01ms +step:1469/1680 train_time:129290ms step_avg:88.01ms +step:1470/1680 train_time:129378ms step_avg:88.01ms +step:1471/1680 train_time:129467ms step_avg:88.01ms +step:1472/1680 train_time:129556ms step_avg:88.01ms +step:1473/1680 train_time:129645ms step_avg:88.01ms +step:1474/1680 train_time:129734ms step_avg:88.02ms +step:1475/1680 train_time:129824ms step_avg:88.02ms +step:1476/1680 train_time:129913ms step_avg:88.02ms +step:1477/1680 train_time:130002ms step_avg:88.02ms +step:1478/1680 train_time:130091ms step_avg:88.02ms +step:1479/1680 train_time:130180ms step_avg:88.02ms +step:1480/1680 train_time:130269ms step_avg:88.02ms +step:1481/1680 train_time:130359ms step_avg:88.02ms +step:1482/1680 train_time:130448ms step_avg:88.02ms +step:1483/1680 train_time:130536ms step_avg:88.02ms +step:1484/1680 train_time:130626ms step_avg:88.02ms +step:1485/1680 train_time:130715ms step_avg:88.02ms +step:1486/1680 train_time:130804ms step_avg:88.02ms +step:1487/1680 train_time:130892ms step_avg:88.02ms +step:1488/1680 train_time:130981ms step_avg:88.03ms +step:1489/1680 train_time:131070ms step_avg:88.03ms +step:1490/1680 train_time:131159ms step_avg:88.03ms +step:1491/1680 train_time:131249ms step_avg:88.03ms +step:1492/1680 train_time:131338ms step_avg:88.03ms +step:1493/1680 train_time:131427ms step_avg:88.03ms +step:1494/1680 train_time:131515ms step_avg:88.03ms +step:1495/1680 train_time:131604ms step_avg:88.03ms +step:1496/1680 train_time:131693ms step_avg:88.03ms +step:1497/1680 train_time:131782ms step_avg:88.03ms +step:1498/1680 train_time:131871ms step_avg:88.03ms +step:1499/1680 train_time:131961ms step_avg:88.03ms +step:1500/1680 train_time:132051ms step_avg:88.03ms +step:1500/1680 val_loss:3.3129 train_time:132142ms step_avg:88.09ms +step:1501/1680 train_time:132160ms step_avg:88.05ms +step:1502/1680 train_time:132234ms step_avg:88.04ms +step:1503/1680 train_time:132327ms step_avg:88.04ms +step:1504/1680 train_time:132417ms step_avg:88.04ms +step:1505/1680 train_time:132505ms step_avg:88.04ms +step:1506/1680 train_time:132593ms step_avg:88.04ms +step:1507/1680 train_time:132681ms step_avg:88.04ms +step:1508/1680 train_time:132768ms step_avg:88.04ms +step:1509/1680 train_time:132856ms step_avg:88.04ms +step:1510/1680 train_time:132945ms step_avg:88.04ms +step:1511/1680 train_time:133033ms step_avg:88.04ms +step:1512/1680 train_time:133124ms step_avg:88.05ms +step:1513/1680 train_time:133215ms step_avg:88.05ms +step:1514/1680 train_time:133306ms step_avg:88.05ms +step:1515/1680 train_time:133396ms step_avg:88.05ms +step:1516/1680 train_time:133485ms step_avg:88.05ms +step:1517/1680 train_time:133573ms step_avg:88.05ms +step:1518/1680 train_time:133663ms step_avg:88.05ms +step:1519/1680 train_time:133751ms step_avg:88.05ms +step:1520/1680 train_time:133838ms step_avg:88.05ms +step:1521/1680 train_time:133927ms step_avg:88.05ms +step:1522/1680 train_time:134015ms step_avg:88.05ms +step:1523/1680 train_time:134106ms step_avg:88.05ms +step:1524/1680 train_time:134195ms step_avg:88.05ms +step:1525/1680 train_time:134285ms step_avg:88.06ms +step:1526/1680 train_time:134375ms step_avg:88.06ms +step:1527/1680 train_time:134464ms step_avg:88.06ms +step:1528/1680 train_time:134553ms step_avg:88.06ms +step:1529/1680 train_time:134642ms step_avg:88.06ms +step:1530/1680 train_time:134730ms step_avg:88.06ms +step:1531/1680 train_time:134818ms step_avg:88.06ms +step:1532/1680 train_time:134906ms step_avg:88.06ms +step:1533/1680 train_time:134995ms step_avg:88.06ms +step:1534/1680 train_time:135085ms step_avg:88.06ms +step:1535/1680 train_time:135174ms step_avg:88.06ms +step:1536/1680 train_time:135264ms step_avg:88.06ms +step:1537/1680 train_time:135354ms step_avg:88.06ms +step:1538/1680 train_time:135443ms step_avg:88.06ms +step:1539/1680 train_time:135533ms step_avg:88.07ms +step:1540/1680 train_time:135622ms step_avg:88.07ms +step:1541/1680 train_time:135711ms step_avg:88.07ms +step:1542/1680 train_time:135800ms step_avg:88.07ms +step:1543/1680 train_time:135889ms step_avg:88.07ms +step:1544/1680 train_time:135977ms step_avg:88.07ms +step:1545/1680 train_time:136067ms step_avg:88.07ms +step:1546/1680 train_time:136156ms step_avg:88.07ms +step:1547/1680 train_time:136246ms step_avg:88.07ms +step:1548/1680 train_time:136335ms step_avg:88.07ms +step:1549/1680 train_time:136425ms step_avg:88.07ms +step:1550/1680 train_time:136514ms step_avg:88.07ms +step:1551/1680 train_time:136603ms step_avg:88.07ms +step:1552/1680 train_time:136691ms step_avg:88.07ms +step:1553/1680 train_time:136779ms step_avg:88.07ms +step:1554/1680 train_time:136868ms step_avg:88.07ms +step:1555/1680 train_time:136957ms step_avg:88.08ms +step:1556/1680 train_time:137046ms step_avg:88.08ms +step:1557/1680 train_time:137134ms step_avg:88.08ms +step:1558/1680 train_time:137224ms step_avg:88.08ms +step:1559/1680 train_time:137313ms step_avg:88.08ms +step:1560/1680 train_time:137403ms step_avg:88.08ms +step:1561/1680 train_time:137491ms step_avg:88.08ms +step:1562/1680 train_time:137580ms step_avg:88.08ms +step:1563/1680 train_time:137669ms step_avg:88.08ms +step:1564/1680 train_time:137757ms step_avg:88.08ms +step:1565/1680 train_time:137846ms step_avg:88.08ms +step:1566/1680 train_time:137935ms step_avg:88.08ms +step:1567/1680 train_time:138023ms step_avg:88.08ms +step:1568/1680 train_time:138112ms step_avg:88.08ms +step:1569/1680 train_time:138201ms step_avg:88.08ms +step:1570/1680 train_time:138290ms step_avg:88.08ms +step:1571/1680 train_time:138379ms step_avg:88.08ms +step:1572/1680 train_time:138469ms step_avg:88.08ms +step:1573/1680 train_time:138558ms step_avg:88.09ms +step:1574/1680 train_time:138648ms step_avg:88.09ms +step:1575/1680 train_time:138735ms step_avg:88.09ms +step:1576/1680 train_time:138823ms step_avg:88.09ms +step:1577/1680 train_time:138911ms step_avg:88.09ms +step:1578/1680 train_time:139000ms step_avg:88.09ms +step:1579/1680 train_time:139089ms step_avg:88.09ms +step:1580/1680 train_time:139178ms step_avg:88.09ms +step:1581/1680 train_time:139269ms step_avg:88.09ms +step:1582/1680 train_time:139359ms step_avg:88.09ms +step:1583/1680 train_time:139449ms step_avg:88.09ms +step:1584/1680 train_time:139539ms step_avg:88.09ms +step:1585/1680 train_time:139628ms step_avg:88.09ms +step:1586/1680 train_time:139717ms step_avg:88.09ms +step:1587/1680 train_time:139806ms step_avg:88.09ms +step:1588/1680 train_time:139895ms step_avg:88.09ms +step:1589/1680 train_time:139985ms step_avg:88.10ms +step:1590/1680 train_time:140074ms step_avg:88.10ms +step:1591/1680 train_time:140162ms step_avg:88.10ms +step:1592/1680 train_time:140251ms step_avg:88.10ms +step:1593/1680 train_time:140340ms step_avg:88.10ms +step:1594/1680 train_time:140429ms step_avg:88.10ms +step:1595/1680 train_time:140519ms step_avg:88.10ms +step:1596/1680 train_time:140607ms step_avg:88.10ms +step:1597/1680 train_time:140696ms step_avg:88.10ms +step:1598/1680 train_time:140784ms step_avg:88.10ms +step:1599/1680 train_time:140873ms step_avg:88.10ms +step:1600/1680 train_time:140962ms step_avg:88.10ms +step:1601/1680 train_time:141051ms step_avg:88.10ms +step:1602/1680 train_time:141140ms step_avg:88.10ms +step:1603/1680 train_time:141229ms step_avg:88.10ms +step:1604/1680 train_time:141318ms step_avg:88.10ms +step:1605/1680 train_time:141408ms step_avg:88.10ms +step:1606/1680 train_time:141498ms step_avg:88.11ms +step:1607/1680 train_time:141588ms step_avg:88.11ms +step:1608/1680 train_time:141677ms step_avg:88.11ms +step:1609/1680 train_time:141767ms step_avg:88.11ms +step:1610/1680 train_time:141856ms step_avg:88.11ms +step:1611/1680 train_time:141947ms step_avg:88.11ms +step:1612/1680 train_time:142035ms step_avg:88.11ms +step:1613/1680 train_time:142124ms step_avg:88.11ms +step:1614/1680 train_time:142212ms step_avg:88.11ms +step:1615/1680 train_time:142301ms step_avg:88.11ms +step:1616/1680 train_time:142390ms step_avg:88.11ms +step:1617/1680 train_time:142479ms step_avg:88.11ms +step:1618/1680 train_time:142569ms step_avg:88.11ms +step:1619/1680 train_time:142658ms step_avg:88.11ms +step:1620/1680 train_time:142747ms step_avg:88.12ms +step:1621/1680 train_time:142836ms step_avg:88.12ms +step:1622/1680 train_time:142925ms step_avg:88.12ms +step:1623/1680 train_time:143014ms step_avg:88.12ms +step:1624/1680 train_time:143102ms step_avg:88.12ms +step:1625/1680 train_time:143191ms step_avg:88.12ms +step:1625/1680 val_loss:3.2890 train_time:143282ms step_avg:88.17ms +step:1626/1680 train_time:143300ms step_avg:88.13ms +step:1627/1680 train_time:143377ms step_avg:88.12ms +step:1628/1680 train_time:143469ms step_avg:88.13ms +step:1629/1680 train_time:143559ms step_avg:88.13ms +step:1630/1680 train_time:143647ms step_avg:88.13ms +step:1631/1680 train_time:143735ms step_avg:88.13ms +step:1632/1680 train_time:143823ms step_avg:88.13ms +step:1633/1680 train_time:143911ms step_avg:88.13ms +step:1634/1680 train_time:143999ms step_avg:88.13ms +step:1635/1680 train_time:144086ms step_avg:88.13ms +step:1636/1680 train_time:144174ms step_avg:88.13ms +step:1637/1680 train_time:144264ms step_avg:88.13ms +step:1638/1680 train_time:144356ms step_avg:88.13ms +step:1639/1680 train_time:144447ms step_avg:88.13ms +step:1640/1680 train_time:144537ms step_avg:88.13ms +step:1641/1680 train_time:144626ms step_avg:88.13ms +step:1642/1680 train_time:144715ms step_avg:88.13ms +step:1643/1680 train_time:144804ms step_avg:88.13ms +step:1644/1680 train_time:144893ms step_avg:88.13ms +step:1645/1680 train_time:144981ms step_avg:88.13ms +step:1646/1680 train_time:145069ms step_avg:88.13ms +step:1647/1680 train_time:145156ms step_avg:88.13ms +step:1648/1680 train_time:145246ms step_avg:88.13ms +step:1649/1680 train_time:145336ms step_avg:88.14ms +step:1650/1680 train_time:145427ms step_avg:88.14ms +step:1651/1680 train_time:145517ms step_avg:88.14ms +step:1652/1680 train_time:145606ms step_avg:88.14ms +step:1653/1680 train_time:145696ms step_avg:88.14ms +step:1654/1680 train_time:145785ms step_avg:88.14ms +step:1655/1680 train_time:145873ms step_avg:88.14ms +step:1656/1680 train_time:145961ms step_avg:88.14ms +step:1657/1680 train_time:146049ms step_avg:88.14ms +step:1658/1680 train_time:146136ms step_avg:88.14ms +step:1659/1680 train_time:146226ms step_avg:88.14ms +step:1660/1680 train_time:146315ms step_avg:88.14ms +step:1661/1680 train_time:146406ms step_avg:88.14ms +step:1662/1680 train_time:146496ms step_avg:88.14ms +step:1663/1680 train_time:146586ms step_avg:88.15ms +step:1664/1680 train_time:146674ms step_avg:88.15ms +step:1665/1680 train_time:146764ms step_avg:88.15ms +step:1666/1680 train_time:146853ms step_avg:88.15ms +step:1667/1680 train_time:146942ms step_avg:88.15ms +step:1668/1680 train_time:147030ms step_avg:88.15ms +step:1669/1680 train_time:147119ms step_avg:88.15ms +step:1670/1680 train_time:147207ms step_avg:88.15ms +step:1671/1680 train_time:147296ms step_avg:88.15ms +step:1672/1680 train_time:147386ms step_avg:88.15ms +step:1673/1680 train_time:147475ms step_avg:88.15ms +step:1674/1680 train_time:147565ms step_avg:88.15ms +step:1675/1680 train_time:147655ms step_avg:88.15ms +step:1676/1680 train_time:147744ms step_avg:88.15ms +step:1677/1680 train_time:147832ms step_avg:88.15ms +step:1678/1680 train_time:147921ms step_avg:88.15ms +step:1679/1680 train_time:148010ms step_avg:88.15ms +step:1680/1680 train_time:148099ms step_avg:88.15ms +step:1680/1680 val_loss:3.2780 train_time:148189ms step_avg:88.21ms +peak memory allocated: 30760 MiB reserved: 45994 MiB diff --git a/records/092725_BF16CE/8237bd60-bbc4-4ad6-8f7f-9a2a654a1c5a.txt b/records/092725_BF16CE/8237bd60-bbc4-4ad6-8f7f-9a2a654a1c5a.txt new file mode 100644 index 000000000..9ef413249 --- /dev/null +++ b/records/092725_BF16CE/8237bd60-bbc4-4ad6-8f7f-9a2a654a1c5a.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:21:15 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 24C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 28C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 29C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 27C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 26C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 155604 C /usr/bin/python 0MiB | +| 0 N/A N/A 155605 C /usr/bin/python 0MiB | +| 0 N/A N/A 155606 C /usr/bin/python 0MiB | +| 0 N/A N/A 155607 C /usr/bin/python 0MiB | +| 0 N/A N/A 155608 C /usr/bin/python 0MiB | +| 0 N/A N/A 155609 C /usr/bin/python 0MiB | +| 0 N/A N/A 155610 C /usr/bin/python 0MiB | +| 0 N/A N/A 155611 C /usr/bin/python 0MiB | +| 1 N/A N/A 155605 C /usr/bin/python 0MiB | +| 2 N/A N/A 155606 C /usr/bin/python 0MiB | +| 3 N/A N/A 155607 C /usr/bin/python 0MiB | +| 4 N/A N/A 155608 C /usr/bin/python 0MiB | +| 5 N/A N/A 155609 C /usr/bin/python 0MiB | +| 6 N/A N/A 155610 C /usr/bin/python 0MiB | +| 7 N/A N/A 155611 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:149ms step_avg:148.61ms +step:2/1680 train_time:169ms step_avg:84.53ms +step:3/1680 train_time:232ms step_avg:77.33ms +step:4/1680 train_time:317ms step_avg:79.30ms +step:5/1680 train_time:404ms step_avg:80.85ms +step:6/1680 train_time:490ms step_avg:81.67ms +step:7/1680 train_time:576ms step_avg:82.27ms +step:8/1680 train_time:662ms step_avg:82.80ms +step:9/1680 train_time:749ms step_avg:83.20ms +step:10/1680 train_time:836ms step_avg:83.56ms +step:11/1680 train_time:922ms step_avg:83.81ms +step:12/1680 train_time:1010ms step_avg:84.15ms +step:13/1680 train_time:1102ms step_avg:84.79ms +step:14/1680 train_time:1192ms step_avg:85.12ms +step:15/1680 train_time:1279ms step_avg:85.29ms +step:16/1680 train_time:1367ms step_avg:85.45ms +step:17/1680 train_time:1455ms step_avg:85.57ms +step:18/1680 train_time:1542ms step_avg:85.67ms +step:19/1680 train_time:1628ms step_avg:85.69ms +step:20/1680 train_time:1714ms step_avg:85.72ms +step:21/1680 train_time:1801ms step_avg:85.75ms +step:22/1680 train_time:1887ms step_avg:85.79ms +step:23/1680 train_time:1975ms step_avg:85.85ms +step:24/1680 train_time:2063ms step_avg:85.97ms +step:25/1680 train_time:2152ms step_avg:86.06ms +step:26/1680 train_time:2240ms step_avg:86.16ms +step:27/1680 train_time:2328ms step_avg:86.22ms +step:28/1680 train_time:2416ms step_avg:86.29ms +step:29/1680 train_time:2504ms step_avg:86.34ms +step:30/1680 train_time:2590ms step_avg:86.35ms +step:31/1680 train_time:2677ms step_avg:86.35ms +step:32/1680 train_time:2763ms step_avg:86.35ms +step:33/1680 train_time:2850ms step_avg:86.36ms +step:34/1680 train_time:2937ms step_avg:86.39ms +step:35/1680 train_time:3026ms step_avg:86.44ms +step:36/1680 train_time:3114ms step_avg:86.51ms +step:37/1680 train_time:3203ms step_avg:86.58ms +step:38/1680 train_time:3291ms step_avg:86.60ms +step:39/1680 train_time:3378ms step_avg:86.62ms +step:40/1680 train_time:3466ms step_avg:86.64ms +step:41/1680 train_time:3552ms step_avg:86.65ms +step:42/1680 train_time:3639ms step_avg:86.65ms +step:43/1680 train_time:3726ms step_avg:86.65ms +step:44/1680 train_time:3812ms step_avg:86.64ms +step:45/1680 train_time:3899ms step_avg:86.65ms +step:46/1680 train_time:3987ms step_avg:86.67ms +step:47/1680 train_time:4075ms step_avg:86.70ms +step:48/1680 train_time:4163ms step_avg:86.73ms +step:49/1680 train_time:4250ms step_avg:86.74ms +step:50/1680 train_time:4338ms step_avg:86.76ms +step:51/1680 train_time:4425ms step_avg:86.77ms +step:52/1680 train_time:4513ms step_avg:86.79ms +step:53/1680 train_time:4600ms step_avg:86.79ms +step:54/1680 train_time:4687ms step_avg:86.79ms +step:55/1680 train_time:4774ms step_avg:86.79ms +step:56/1680 train_time:4861ms step_avg:86.80ms +step:57/1680 train_time:4947ms step_avg:86.79ms +step:58/1680 train_time:5035ms step_avg:86.82ms +step:59/1680 train_time:5123ms step_avg:86.83ms +step:60/1680 train_time:5210ms step_avg:86.84ms +step:61/1680 train_time:5298ms step_avg:86.85ms +step:62/1680 train_time:5385ms step_avg:86.86ms +step:63/1680 train_time:5473ms step_avg:86.87ms +step:64/1680 train_time:5559ms step_avg:86.86ms +step:65/1680 train_time:5646ms step_avg:86.87ms +step:66/1680 train_time:5734ms step_avg:86.88ms +step:67/1680 train_time:5821ms step_avg:86.88ms +step:68/1680 train_time:5908ms step_avg:86.88ms +step:69/1680 train_time:5995ms step_avg:86.89ms +step:70/1680 train_time:6083ms step_avg:86.90ms +step:71/1680 train_time:6170ms step_avg:86.90ms +step:72/1680 train_time:6258ms step_avg:86.91ms +step:73/1680 train_time:6346ms step_avg:86.93ms +step:74/1680 train_time:6433ms step_avg:86.94ms +step:75/1680 train_time:6520ms step_avg:86.94ms +step:76/1680 train_time:6607ms step_avg:86.94ms +step:77/1680 train_time:6694ms step_avg:86.94ms +step:78/1680 train_time:6782ms step_avg:86.94ms +step:79/1680 train_time:6868ms step_avg:86.94ms +step:80/1680 train_time:6955ms step_avg:86.94ms +step:81/1680 train_time:7044ms step_avg:86.96ms +step:82/1680 train_time:7131ms step_avg:86.96ms +step:83/1680 train_time:7218ms step_avg:86.96ms +step:84/1680 train_time:7305ms step_avg:86.97ms +step:85/1680 train_time:7393ms step_avg:86.97ms +step:86/1680 train_time:7481ms step_avg:86.99ms +step:87/1680 train_time:7568ms step_avg:86.99ms +step:88/1680 train_time:7655ms step_avg:86.99ms +step:89/1680 train_time:7742ms step_avg:86.99ms +step:90/1680 train_time:7829ms step_avg:86.99ms +step:91/1680 train_time:7916ms step_avg:86.99ms +step:92/1680 train_time:8003ms step_avg:86.99ms +step:93/1680 train_time:8091ms step_avg:87.00ms +step:94/1680 train_time:8178ms step_avg:87.00ms +step:95/1680 train_time:8265ms step_avg:87.00ms +step:96/1680 train_time:8353ms step_avg:87.01ms +step:97/1680 train_time:8440ms step_avg:87.01ms +step:98/1680 train_time:8527ms step_avg:87.01ms +step:99/1680 train_time:8614ms step_avg:87.01ms +step:100/1680 train_time:8701ms step_avg:87.01ms +step:101/1680 train_time:8788ms step_avg:87.01ms +step:102/1680 train_time:8875ms step_avg:87.01ms +step:103/1680 train_time:8962ms step_avg:87.01ms +step:104/1680 train_time:9050ms step_avg:87.02ms +step:105/1680 train_time:9137ms step_avg:87.02ms +step:106/1680 train_time:9224ms step_avg:87.02ms +step:107/1680 train_time:9311ms step_avg:87.02ms +step:108/1680 train_time:9399ms step_avg:87.03ms +step:109/1680 train_time:9486ms step_avg:87.03ms +step:110/1680 train_time:9574ms step_avg:87.03ms +step:111/1680 train_time:9662ms step_avg:87.04ms +step:112/1680 train_time:9749ms step_avg:87.04ms +step:113/1680 train_time:9836ms step_avg:87.04ms +step:114/1680 train_time:9922ms step_avg:87.04ms +step:115/1680 train_time:10009ms step_avg:87.04ms +step:116/1680 train_time:10097ms step_avg:87.04ms +step:117/1680 train_time:10184ms step_avg:87.04ms +step:118/1680 train_time:10270ms step_avg:87.04ms +step:119/1680 train_time:10358ms step_avg:87.05ms +step:120/1680 train_time:10445ms step_avg:87.04ms +step:121/1680 train_time:10533ms step_avg:87.05ms +step:122/1680 train_time:10621ms step_avg:87.05ms +step:123/1680 train_time:10708ms step_avg:87.06ms +step:124/1680 train_time:10795ms step_avg:87.06ms +step:125/1680 train_time:10882ms step_avg:87.05ms +step:125/1680 val_loss:4.2925 train_time:10970ms step_avg:87.76ms +step:126/1680 train_time:10990ms step_avg:87.22ms +step:127/1680 train_time:11058ms step_avg:87.07ms +step:128/1680 train_time:11153ms step_avg:87.13ms +step:129/1680 train_time:11245ms step_avg:87.17ms +step:130/1680 train_time:11334ms step_avg:87.18ms +step:131/1680 train_time:11421ms step_avg:87.19ms +step:132/1680 train_time:11508ms step_avg:87.18ms +step:133/1680 train_time:11593ms step_avg:87.17ms +step:134/1680 train_time:11679ms step_avg:87.16ms +step:135/1680 train_time:11766ms step_avg:87.15ms +step:136/1680 train_time:11852ms step_avg:87.14ms +step:137/1680 train_time:11937ms step_avg:87.13ms +step:138/1680 train_time:12024ms step_avg:87.13ms +step:139/1680 train_time:12113ms step_avg:87.14ms +step:140/1680 train_time:12203ms step_avg:87.17ms +step:141/1680 train_time:12292ms step_avg:87.18ms +step:142/1680 train_time:12381ms step_avg:87.19ms +step:143/1680 train_time:12468ms step_avg:87.19ms +step:144/1680 train_time:12554ms step_avg:87.18ms +step:145/1680 train_time:12641ms step_avg:87.18ms +step:146/1680 train_time:12728ms step_avg:87.18ms +step:147/1680 train_time:12814ms step_avg:87.17ms +step:148/1680 train_time:12900ms step_avg:87.16ms +step:149/1680 train_time:12986ms step_avg:87.16ms +step:150/1680 train_time:13074ms step_avg:87.16ms +step:151/1680 train_time:13163ms step_avg:87.17ms +step:152/1680 train_time:13251ms step_avg:87.18ms +step:153/1680 train_time:13339ms step_avg:87.18ms +step:154/1680 train_time:13427ms step_avg:87.19ms +step:155/1680 train_time:13514ms step_avg:87.19ms +step:156/1680 train_time:13602ms step_avg:87.19ms +step:157/1680 train_time:13688ms step_avg:87.19ms +step:158/1680 train_time:13775ms step_avg:87.18ms +step:159/1680 train_time:13861ms step_avg:87.18ms +step:160/1680 train_time:13948ms step_avg:87.17ms +step:161/1680 train_time:14035ms step_avg:87.17ms +step:162/1680 train_time:14123ms step_avg:87.18ms +step:163/1680 train_time:14210ms step_avg:87.18ms +step:164/1680 train_time:14297ms step_avg:87.18ms +step:165/1680 train_time:14385ms step_avg:87.18ms +step:166/1680 train_time:14473ms step_avg:87.19ms +step:167/1680 train_time:14561ms step_avg:87.19ms +step:168/1680 train_time:14648ms step_avg:87.19ms +step:169/1680 train_time:14734ms step_avg:87.18ms +step:170/1680 train_time:14821ms step_avg:87.18ms +step:171/1680 train_time:14907ms step_avg:87.18ms +step:172/1680 train_time:14994ms step_avg:87.18ms +step:173/1680 train_time:15082ms step_avg:87.18ms +step:174/1680 train_time:15169ms step_avg:87.18ms +step:175/1680 train_time:15257ms step_avg:87.18ms +step:176/1680 train_time:15343ms step_avg:87.18ms +step:177/1680 train_time:15431ms step_avg:87.18ms +step:178/1680 train_time:15518ms step_avg:87.18ms +step:179/1680 train_time:15605ms step_avg:87.18ms +step:180/1680 train_time:15692ms step_avg:87.18ms +step:181/1680 train_time:15778ms step_avg:87.17ms +step:182/1680 train_time:15865ms step_avg:87.17ms +step:183/1680 train_time:15952ms step_avg:87.17ms +step:184/1680 train_time:16040ms step_avg:87.17ms +step:185/1680 train_time:16127ms step_avg:87.17ms +step:186/1680 train_time:16215ms step_avg:87.18ms +step:187/1680 train_time:16302ms step_avg:87.18ms +step:188/1680 train_time:16389ms step_avg:87.18ms +step:189/1680 train_time:16477ms step_avg:87.18ms +step:190/1680 train_time:16563ms step_avg:87.17ms +step:191/1680 train_time:16650ms step_avg:87.17ms +step:192/1680 train_time:16737ms step_avg:87.17ms +step:193/1680 train_time:16823ms step_avg:87.17ms +step:194/1680 train_time:16910ms step_avg:87.16ms +step:195/1680 train_time:16997ms step_avg:87.16ms +step:196/1680 train_time:17083ms step_avg:87.16ms +step:197/1680 train_time:17171ms step_avg:87.16ms +step:198/1680 train_time:17258ms step_avg:87.16ms +step:199/1680 train_time:17345ms step_avg:87.16ms +step:200/1680 train_time:17433ms step_avg:87.16ms +step:201/1680 train_time:17520ms step_avg:87.16ms +step:202/1680 train_time:17607ms step_avg:87.16ms +step:203/1680 train_time:17695ms step_avg:87.17ms +step:204/1680 train_time:17782ms step_avg:87.17ms +step:205/1680 train_time:17868ms step_avg:87.16ms +step:206/1680 train_time:17956ms step_avg:87.16ms +step:207/1680 train_time:18043ms step_avg:87.16ms +step:208/1680 train_time:18130ms step_avg:87.16ms +step:209/1680 train_time:18217ms step_avg:87.16ms +step:210/1680 train_time:18304ms step_avg:87.16ms +step:211/1680 train_time:18392ms step_avg:87.16ms +step:212/1680 train_time:18479ms step_avg:87.16ms +step:213/1680 train_time:18566ms step_avg:87.16ms +step:214/1680 train_time:18653ms step_avg:87.16ms +step:215/1680 train_time:18740ms step_avg:87.16ms +step:216/1680 train_time:18827ms step_avg:87.16ms +step:217/1680 train_time:18913ms step_avg:87.16ms +step:218/1680 train_time:19001ms step_avg:87.16ms +step:219/1680 train_time:19088ms step_avg:87.16ms +step:220/1680 train_time:19176ms step_avg:87.16ms +step:221/1680 train_time:19264ms step_avg:87.17ms +step:222/1680 train_time:19350ms step_avg:87.16ms +step:223/1680 train_time:19438ms step_avg:87.16ms +step:224/1680 train_time:19524ms step_avg:87.16ms +step:225/1680 train_time:19612ms step_avg:87.16ms +step:226/1680 train_time:19699ms step_avg:87.16ms +step:227/1680 train_time:19785ms step_avg:87.16ms +step:228/1680 train_time:19872ms step_avg:87.16ms +step:229/1680 train_time:19959ms step_avg:87.16ms +step:230/1680 train_time:20046ms step_avg:87.16ms +step:231/1680 train_time:20133ms step_avg:87.16ms +step:232/1680 train_time:20221ms step_avg:87.16ms +step:233/1680 train_time:20308ms step_avg:87.16ms +step:234/1680 train_time:20396ms step_avg:87.16ms +step:235/1680 train_time:20483ms step_avg:87.16ms +step:236/1680 train_time:20570ms step_avg:87.16ms +step:237/1680 train_time:20657ms step_avg:87.16ms +step:238/1680 train_time:20743ms step_avg:87.16ms +step:239/1680 train_time:20830ms step_avg:87.16ms +step:240/1680 train_time:20918ms step_avg:87.16ms +step:241/1680 train_time:21005ms step_avg:87.16ms +step:242/1680 train_time:21092ms step_avg:87.16ms +step:243/1680 train_time:21179ms step_avg:87.16ms +step:244/1680 train_time:21266ms step_avg:87.16ms +step:245/1680 train_time:21353ms step_avg:87.15ms +step:246/1680 train_time:21440ms step_avg:87.16ms +step:247/1680 train_time:21527ms step_avg:87.15ms +step:248/1680 train_time:21615ms step_avg:87.16ms +step:249/1680 train_time:21702ms step_avg:87.16ms +step:250/1680 train_time:21789ms step_avg:87.15ms +step:250/1680 val_loss:3.9621 train_time:21877ms step_avg:87.51ms +step:251/1680 train_time:21897ms step_avg:87.24ms +step:252/1680 train_time:21968ms step_avg:87.17ms +step:253/1680 train_time:22058ms step_avg:87.19ms +step:254/1680 train_time:22148ms step_avg:87.20ms +step:255/1680 train_time:22236ms step_avg:87.20ms +step:256/1680 train_time:22323ms step_avg:87.20ms +step:257/1680 train_time:22409ms step_avg:87.19ms +step:258/1680 train_time:22495ms step_avg:87.19ms +step:259/1680 train_time:22582ms step_avg:87.19ms +step:260/1680 train_time:22668ms step_avg:87.18ms +step:261/1680 train_time:22755ms step_avg:87.18ms +step:262/1680 train_time:22842ms step_avg:87.18ms +step:263/1680 train_time:22930ms step_avg:87.19ms +step:264/1680 train_time:23018ms step_avg:87.19ms +step:265/1680 train_time:23106ms step_avg:87.19ms +step:266/1680 train_time:23194ms step_avg:87.20ms +step:267/1680 train_time:23282ms step_avg:87.20ms +step:268/1680 train_time:23369ms step_avg:87.20ms +step:269/1680 train_time:23456ms step_avg:87.20ms +step:270/1680 train_time:23543ms step_avg:87.20ms +step:271/1680 train_time:23629ms step_avg:87.19ms +step:272/1680 train_time:23716ms step_avg:87.19ms +step:273/1680 train_time:23803ms step_avg:87.19ms +step:274/1680 train_time:23891ms step_avg:87.19ms +step:275/1680 train_time:23978ms step_avg:87.19ms +step:276/1680 train_time:24067ms step_avg:87.20ms +step:277/1680 train_time:24155ms step_avg:87.20ms +step:278/1680 train_time:24243ms step_avg:87.20ms +step:279/1680 train_time:24330ms step_avg:87.20ms +step:280/1680 train_time:24417ms step_avg:87.20ms +step:281/1680 train_time:24504ms step_avg:87.20ms +step:282/1680 train_time:24590ms step_avg:87.20ms +step:283/1680 train_time:24677ms step_avg:87.20ms +step:284/1680 train_time:24763ms step_avg:87.19ms +step:285/1680 train_time:24850ms step_avg:87.19ms +step:286/1680 train_time:24938ms step_avg:87.20ms +step:287/1680 train_time:25026ms step_avg:87.20ms +step:288/1680 train_time:25114ms step_avg:87.20ms +step:289/1680 train_time:25201ms step_avg:87.20ms +step:290/1680 train_time:25289ms step_avg:87.20ms +step:291/1680 train_time:25376ms step_avg:87.20ms +step:292/1680 train_time:25464ms step_avg:87.21ms +step:293/1680 train_time:25550ms step_avg:87.20ms +step:294/1680 train_time:25637ms step_avg:87.20ms +step:295/1680 train_time:25724ms step_avg:87.20ms +step:296/1680 train_time:25811ms step_avg:87.20ms +step:297/1680 train_time:25898ms step_avg:87.20ms +step:298/1680 train_time:25986ms step_avg:87.20ms +step:299/1680 train_time:26074ms step_avg:87.20ms +step:300/1680 train_time:26161ms step_avg:87.20ms +step:301/1680 train_time:26248ms step_avg:87.20ms +step:302/1680 train_time:26335ms step_avg:87.20ms +step:303/1680 train_time:26422ms step_avg:87.20ms +step:304/1680 train_time:26510ms step_avg:87.20ms +step:305/1680 train_time:26596ms step_avg:87.20ms +step:306/1680 train_time:26683ms step_avg:87.20ms +step:307/1680 train_time:26770ms step_avg:87.20ms +step:308/1680 train_time:26856ms step_avg:87.20ms +step:309/1680 train_time:26944ms step_avg:87.20ms +step:310/1680 train_time:27031ms step_avg:87.20ms +step:311/1680 train_time:27119ms step_avg:87.20ms +step:312/1680 train_time:27206ms step_avg:87.20ms +step:313/1680 train_time:27294ms step_avg:87.20ms +step:314/1680 train_time:27381ms step_avg:87.20ms +step:315/1680 train_time:27469ms step_avg:87.20ms +step:316/1680 train_time:27556ms step_avg:87.20ms +step:317/1680 train_time:27643ms step_avg:87.20ms +step:318/1680 train_time:27730ms step_avg:87.20ms +step:319/1680 train_time:27816ms step_avg:87.20ms +step:320/1680 train_time:27903ms step_avg:87.20ms +step:321/1680 train_time:27990ms step_avg:87.20ms +step:322/1680 train_time:28077ms step_avg:87.20ms +step:323/1680 train_time:28164ms step_avg:87.20ms +step:324/1680 train_time:28252ms step_avg:87.20ms +step:325/1680 train_time:28339ms step_avg:87.20ms +step:326/1680 train_time:28427ms step_avg:87.20ms +step:327/1680 train_time:28514ms step_avg:87.20ms +step:328/1680 train_time:28601ms step_avg:87.20ms +step:329/1680 train_time:28688ms step_avg:87.20ms +step:330/1680 train_time:28775ms step_avg:87.20ms +step:331/1680 train_time:28862ms step_avg:87.20ms +step:332/1680 train_time:28949ms step_avg:87.20ms +step:333/1680 train_time:29037ms step_avg:87.20ms +step:334/1680 train_time:29124ms step_avg:87.20ms +step:335/1680 train_time:29211ms step_avg:87.20ms +step:336/1680 train_time:29299ms step_avg:87.20ms +step:337/1680 train_time:29387ms step_avg:87.20ms +step:338/1680 train_time:29473ms step_avg:87.20ms +step:339/1680 train_time:29560ms step_avg:87.20ms +step:340/1680 train_time:29647ms step_avg:87.20ms +step:341/1680 train_time:29735ms step_avg:87.20ms +step:342/1680 train_time:29821ms step_avg:87.20ms +step:343/1680 train_time:29908ms step_avg:87.20ms +step:344/1680 train_time:29996ms step_avg:87.20ms +step:345/1680 train_time:30084ms step_avg:87.20ms +step:346/1680 train_time:30171ms step_avg:87.20ms +step:347/1680 train_time:30259ms step_avg:87.20ms +step:348/1680 train_time:30346ms step_avg:87.20ms +step:349/1680 train_time:30433ms step_avg:87.20ms +step:350/1680 train_time:30520ms step_avg:87.20ms +step:351/1680 train_time:30607ms step_avg:87.20ms +step:352/1680 train_time:30694ms step_avg:87.20ms +step:353/1680 train_time:30781ms step_avg:87.20ms +step:354/1680 train_time:30868ms step_avg:87.20ms +step:355/1680 train_time:30955ms step_avg:87.20ms +step:356/1680 train_time:31043ms step_avg:87.20ms +step:357/1680 train_time:31130ms step_avg:87.20ms +step:358/1680 train_time:31217ms step_avg:87.20ms +step:359/1680 train_time:31305ms step_avg:87.20ms +step:360/1680 train_time:31392ms step_avg:87.20ms +step:361/1680 train_time:31479ms step_avg:87.20ms +step:362/1680 train_time:31567ms step_avg:87.20ms +step:363/1680 train_time:31654ms step_avg:87.20ms +step:364/1680 train_time:31741ms step_avg:87.20ms +step:365/1680 train_time:31828ms step_avg:87.20ms +step:366/1680 train_time:31915ms step_avg:87.20ms +step:367/1680 train_time:32002ms step_avg:87.20ms +step:368/1680 train_time:32089ms step_avg:87.20ms +step:369/1680 train_time:32176ms step_avg:87.20ms +step:370/1680 train_time:32264ms step_avg:87.20ms +step:371/1680 train_time:32351ms step_avg:87.20ms +step:372/1680 train_time:32439ms step_avg:87.20ms +step:373/1680 train_time:32527ms step_avg:87.20ms +step:374/1680 train_time:32614ms step_avg:87.20ms +step:375/1680 train_time:32701ms step_avg:87.20ms +step:375/1680 val_loss:3.8127 train_time:32789ms step_avg:87.44ms +step:376/1680 train_time:32808ms step_avg:87.25ms +step:377/1680 train_time:32877ms step_avg:87.21ms +step:378/1680 train_time:32970ms step_avg:87.22ms +step:379/1680 train_time:33059ms step_avg:87.23ms +step:380/1680 train_time:33146ms step_avg:87.23ms +step:381/1680 train_time:33232ms step_avg:87.22ms +step:382/1680 train_time:33319ms step_avg:87.22ms +step:383/1680 train_time:33405ms step_avg:87.22ms +step:384/1680 train_time:33491ms step_avg:87.22ms +step:385/1680 train_time:33578ms step_avg:87.22ms +step:386/1680 train_time:33664ms step_avg:87.21ms +step:387/1680 train_time:33751ms step_avg:87.21ms +step:388/1680 train_time:33839ms step_avg:87.22ms +step:389/1680 train_time:33928ms step_avg:87.22ms +step:390/1680 train_time:34016ms step_avg:87.22ms +step:391/1680 train_time:34104ms step_avg:87.22ms +step:392/1680 train_time:34191ms step_avg:87.22ms +step:393/1680 train_time:34279ms step_avg:87.22ms +step:394/1680 train_time:34365ms step_avg:87.22ms +step:395/1680 train_time:34452ms step_avg:87.22ms +step:396/1680 train_time:34538ms step_avg:87.22ms +step:397/1680 train_time:34625ms step_avg:87.22ms +step:398/1680 train_time:34711ms step_avg:87.21ms +step:399/1680 train_time:34799ms step_avg:87.22ms +step:400/1680 train_time:34887ms step_avg:87.22ms +step:401/1680 train_time:34975ms step_avg:87.22ms +step:402/1680 train_time:35063ms step_avg:87.22ms +step:403/1680 train_time:35150ms step_avg:87.22ms +step:404/1680 train_time:35238ms step_avg:87.22ms +step:405/1680 train_time:35325ms step_avg:87.22ms +step:406/1680 train_time:35411ms step_avg:87.22ms +step:407/1680 train_time:35498ms step_avg:87.22ms +step:408/1680 train_time:35585ms step_avg:87.22ms +step:409/1680 train_time:35671ms step_avg:87.22ms +step:410/1680 train_time:35758ms step_avg:87.22ms +step:411/1680 train_time:35846ms step_avg:87.22ms +step:412/1680 train_time:35933ms step_avg:87.22ms +step:413/1680 train_time:36021ms step_avg:87.22ms +step:414/1680 train_time:36108ms step_avg:87.22ms +step:415/1680 train_time:36195ms step_avg:87.22ms +step:416/1680 train_time:36283ms step_avg:87.22ms +step:417/1680 train_time:36370ms step_avg:87.22ms +step:418/1680 train_time:36456ms step_avg:87.22ms +step:419/1680 train_time:36543ms step_avg:87.22ms +step:420/1680 train_time:36630ms step_avg:87.21ms +step:421/1680 train_time:36717ms step_avg:87.21ms +step:422/1680 train_time:36804ms step_avg:87.21ms +step:423/1680 train_time:36891ms step_avg:87.21ms +step:424/1680 train_time:36979ms step_avg:87.22ms +step:425/1680 train_time:37067ms step_avg:87.22ms +step:426/1680 train_time:37154ms step_avg:87.22ms +step:427/1680 train_time:37241ms step_avg:87.22ms +step:428/1680 train_time:37329ms step_avg:87.22ms +step:429/1680 train_time:37416ms step_avg:87.22ms +step:430/1680 train_time:37503ms step_avg:87.22ms +step:431/1680 train_time:37589ms step_avg:87.21ms +step:432/1680 train_time:37677ms step_avg:87.21ms +step:433/1680 train_time:37764ms step_avg:87.21ms +step:434/1680 train_time:37851ms step_avg:87.21ms +step:435/1680 train_time:37938ms step_avg:87.21ms +step:436/1680 train_time:38026ms step_avg:87.21ms +step:437/1680 train_time:38112ms step_avg:87.21ms +step:438/1680 train_time:38199ms step_avg:87.21ms +step:439/1680 train_time:38286ms step_avg:87.21ms +step:440/1680 train_time:38374ms step_avg:87.21ms +step:441/1680 train_time:38461ms step_avg:87.21ms +step:442/1680 train_time:38547ms step_avg:87.21ms +step:443/1680 train_time:38634ms step_avg:87.21ms +step:444/1680 train_time:38721ms step_avg:87.21ms +step:445/1680 train_time:38808ms step_avg:87.21ms +step:446/1680 train_time:38895ms step_avg:87.21ms +step:447/1680 train_time:38983ms step_avg:87.21ms +step:448/1680 train_time:39071ms step_avg:87.21ms +step:449/1680 train_time:39158ms step_avg:87.21ms +step:450/1680 train_time:39245ms step_avg:87.21ms +step:451/1680 train_time:39332ms step_avg:87.21ms +step:452/1680 train_time:39420ms step_avg:87.21ms +step:453/1680 train_time:39506ms step_avg:87.21ms +step:454/1680 train_time:39593ms step_avg:87.21ms +step:455/1680 train_time:39680ms step_avg:87.21ms +step:456/1680 train_time:39766ms step_avg:87.21ms +step:457/1680 train_time:39854ms step_avg:87.21ms +step:458/1680 train_time:39941ms step_avg:87.21ms +step:459/1680 train_time:40028ms step_avg:87.21ms +step:460/1680 train_time:40115ms step_avg:87.21ms +step:461/1680 train_time:40202ms step_avg:87.21ms +step:462/1680 train_time:40289ms step_avg:87.21ms +step:463/1680 train_time:40377ms step_avg:87.21ms +step:464/1680 train_time:40464ms step_avg:87.21ms +step:465/1680 train_time:40551ms step_avg:87.21ms +step:466/1680 train_time:40638ms step_avg:87.21ms +step:467/1680 train_time:40725ms step_avg:87.20ms +step:468/1680 train_time:40811ms step_avg:87.20ms +step:469/1680 train_time:40899ms step_avg:87.20ms +step:470/1680 train_time:40987ms step_avg:87.21ms +step:471/1680 train_time:41074ms step_avg:87.21ms +step:472/1680 train_time:41161ms step_avg:87.21ms +step:473/1680 train_time:41248ms step_avg:87.21ms +step:474/1680 train_time:41335ms step_avg:87.20ms +step:475/1680 train_time:41423ms step_avg:87.21ms +step:476/1680 train_time:41510ms step_avg:87.20ms +step:477/1680 train_time:41596ms step_avg:87.20ms +step:478/1680 train_time:41684ms step_avg:87.20ms +step:479/1680 train_time:41770ms step_avg:87.20ms +step:480/1680 train_time:41857ms step_avg:87.20ms +step:481/1680 train_time:41944ms step_avg:87.20ms +step:482/1680 train_time:42031ms step_avg:87.20ms +step:483/1680 train_time:42119ms step_avg:87.20ms +step:484/1680 train_time:42206ms step_avg:87.20ms +step:485/1680 train_time:42293ms step_avg:87.20ms +step:486/1680 train_time:42380ms step_avg:87.20ms +step:487/1680 train_time:42467ms step_avg:87.20ms +step:488/1680 train_time:42555ms step_avg:87.20ms +step:489/1680 train_time:42642ms step_avg:87.20ms +step:490/1680 train_time:42729ms step_avg:87.20ms +step:491/1680 train_time:42816ms step_avg:87.20ms +step:492/1680 train_time:42903ms step_avg:87.20ms +step:493/1680 train_time:42990ms step_avg:87.20ms +step:494/1680 train_time:43077ms step_avg:87.20ms +step:495/1680 train_time:43165ms step_avg:87.20ms +step:496/1680 train_time:43252ms step_avg:87.20ms +step:497/1680 train_time:43339ms step_avg:87.20ms +step:498/1680 train_time:43426ms step_avg:87.20ms +step:499/1680 train_time:43514ms step_avg:87.20ms +step:500/1680 train_time:43601ms step_avg:87.20ms +step:500/1680 val_loss:3.7153 train_time:43689ms step_avg:87.38ms +step:501/1680 train_time:43708ms step_avg:87.24ms +step:502/1680 train_time:43777ms step_avg:87.21ms +step:503/1680 train_time:43869ms step_avg:87.22ms +step:504/1680 train_time:43958ms step_avg:87.22ms +step:505/1680 train_time:44046ms step_avg:87.22ms +step:506/1680 train_time:44133ms step_avg:87.22ms +step:507/1680 train_time:44219ms step_avg:87.22ms +step:508/1680 train_time:44306ms step_avg:87.22ms +step:509/1680 train_time:44392ms step_avg:87.21ms +step:510/1680 train_time:44478ms step_avg:87.21ms +step:511/1680 train_time:44564ms step_avg:87.21ms +step:512/1680 train_time:44652ms step_avg:87.21ms +step:513/1680 train_time:44739ms step_avg:87.21ms +step:514/1680 train_time:44828ms step_avg:87.21ms +step:515/1680 train_time:44917ms step_avg:87.22ms +step:516/1680 train_time:45005ms step_avg:87.22ms +step:517/1680 train_time:45092ms step_avg:87.22ms +step:518/1680 train_time:45180ms step_avg:87.22ms +step:519/1680 train_time:45267ms step_avg:87.22ms +step:520/1680 train_time:45353ms step_avg:87.22ms +step:521/1680 train_time:45440ms step_avg:87.22ms +step:522/1680 train_time:45526ms step_avg:87.21ms +step:523/1680 train_time:45612ms step_avg:87.21ms +step:524/1680 train_time:45699ms step_avg:87.21ms +step:525/1680 train_time:45787ms step_avg:87.21ms +step:526/1680 train_time:45876ms step_avg:87.22ms +step:527/1680 train_time:45964ms step_avg:87.22ms +step:528/1680 train_time:46051ms step_avg:87.22ms +step:529/1680 train_time:46139ms step_avg:87.22ms +step:530/1680 train_time:46226ms step_avg:87.22ms +step:531/1680 train_time:46312ms step_avg:87.22ms +step:532/1680 train_time:46399ms step_avg:87.22ms +step:533/1680 train_time:46486ms step_avg:87.22ms +step:534/1680 train_time:46572ms step_avg:87.21ms +step:535/1680 train_time:46659ms step_avg:87.21ms +step:536/1680 train_time:46746ms step_avg:87.21ms +step:537/1680 train_time:46834ms step_avg:87.21ms +step:538/1680 train_time:46921ms step_avg:87.21ms +step:539/1680 train_time:47009ms step_avg:87.21ms +step:540/1680 train_time:47096ms step_avg:87.21ms +step:541/1680 train_time:47183ms step_avg:87.22ms +step:542/1680 train_time:47270ms step_avg:87.21ms +step:543/1680 train_time:47357ms step_avg:87.21ms +step:544/1680 train_time:47444ms step_avg:87.21ms +step:545/1680 train_time:47531ms step_avg:87.21ms +step:546/1680 train_time:47617ms step_avg:87.21ms +step:547/1680 train_time:47705ms step_avg:87.21ms +step:548/1680 train_time:47792ms step_avg:87.21ms +step:549/1680 train_time:47881ms step_avg:87.21ms +step:550/1680 train_time:47970ms step_avg:87.22ms +step:551/1680 train_time:48058ms step_avg:87.22ms +step:552/1680 train_time:48148ms step_avg:87.22ms +step:553/1680 train_time:48237ms step_avg:87.23ms +step:554/1680 train_time:48324ms step_avg:87.23ms +step:555/1680 train_time:48412ms step_avg:87.23ms +step:556/1680 train_time:48501ms step_avg:87.23ms +step:557/1680 train_time:48589ms step_avg:87.23ms +step:558/1680 train_time:48677ms step_avg:87.23ms +step:559/1680 train_time:48766ms step_avg:87.24ms +step:560/1680 train_time:48854ms step_avg:87.24ms +step:561/1680 train_time:48943ms step_avg:87.24ms +step:562/1680 train_time:49031ms step_avg:87.24ms +step:563/1680 train_time:49120ms step_avg:87.25ms +step:564/1680 train_time:49208ms step_avg:87.25ms +step:565/1680 train_time:49296ms step_avg:87.25ms +step:566/1680 train_time:49385ms step_avg:87.25ms +step:567/1680 train_time:49473ms step_avg:87.25ms +step:568/1680 train_time:49562ms step_avg:87.26ms +step:569/1680 train_time:49650ms step_avg:87.26ms +step:570/1680 train_time:49738ms step_avg:87.26ms +step:571/1680 train_time:49827ms step_avg:87.26ms +step:572/1680 train_time:49916ms step_avg:87.27ms +step:573/1680 train_time:50005ms step_avg:87.27ms +step:574/1680 train_time:50093ms step_avg:87.27ms +step:575/1680 train_time:50181ms step_avg:87.27ms +step:576/1680 train_time:50269ms step_avg:87.27ms +step:577/1680 train_time:50358ms step_avg:87.28ms +step:578/1680 train_time:50446ms step_avg:87.28ms +step:579/1680 train_time:50535ms step_avg:87.28ms +step:580/1680 train_time:50623ms step_avg:87.28ms +step:581/1680 train_time:50710ms step_avg:87.28ms +step:582/1680 train_time:50799ms step_avg:87.28ms +step:583/1680 train_time:50888ms step_avg:87.29ms +step:584/1680 train_time:50978ms step_avg:87.29ms +step:585/1680 train_time:51066ms step_avg:87.29ms +step:586/1680 train_time:51154ms step_avg:87.29ms +step:587/1680 train_time:51243ms step_avg:87.30ms +step:588/1680 train_time:51330ms step_avg:87.30ms +step:589/1680 train_time:51419ms step_avg:87.30ms +step:590/1680 train_time:51508ms step_avg:87.30ms +step:591/1680 train_time:51595ms step_avg:87.30ms +step:592/1680 train_time:51683ms step_avg:87.30ms +step:593/1680 train_time:51772ms step_avg:87.30ms +step:594/1680 train_time:51860ms step_avg:87.31ms +step:595/1680 train_time:51949ms step_avg:87.31ms +step:596/1680 train_time:52038ms step_avg:87.31ms +step:597/1680 train_time:52126ms step_avg:87.31ms +step:598/1680 train_time:52215ms step_avg:87.32ms +step:599/1680 train_time:52304ms step_avg:87.32ms +step:600/1680 train_time:52392ms step_avg:87.32ms +step:601/1680 train_time:52481ms step_avg:87.32ms +step:602/1680 train_time:52570ms step_avg:87.33ms +step:603/1680 train_time:52658ms step_avg:87.33ms +step:604/1680 train_time:52747ms step_avg:87.33ms +step:605/1680 train_time:52836ms step_avg:87.33ms +step:606/1680 train_time:52924ms step_avg:87.33ms +step:607/1680 train_time:53012ms step_avg:87.33ms +step:608/1680 train_time:53099ms step_avg:87.33ms +step:609/1680 train_time:53188ms step_avg:87.34ms +step:610/1680 train_time:53276ms step_avg:87.34ms +step:611/1680 train_time:53364ms step_avg:87.34ms +step:612/1680 train_time:53452ms step_avg:87.34ms +step:613/1680 train_time:53541ms step_avg:87.34ms +step:614/1680 train_time:53629ms step_avg:87.34ms +step:615/1680 train_time:53719ms step_avg:87.35ms +step:616/1680 train_time:53807ms step_avg:87.35ms +step:617/1680 train_time:53896ms step_avg:87.35ms +step:618/1680 train_time:53985ms step_avg:87.35ms +step:619/1680 train_time:54073ms step_avg:87.35ms +step:620/1680 train_time:54161ms step_avg:87.36ms +step:621/1680 train_time:54250ms step_avg:87.36ms +step:622/1680 train_time:54338ms step_avg:87.36ms +step:623/1680 train_time:54427ms step_avg:87.36ms +step:624/1680 train_time:54515ms step_avg:87.36ms +step:625/1680 train_time:54603ms step_avg:87.37ms +step:625/1680 val_loss:3.6149 train_time:54693ms step_avg:87.51ms +step:626/1680 train_time:54713ms step_avg:87.40ms +step:627/1680 train_time:54786ms step_avg:87.38ms +step:628/1680 train_time:54875ms step_avg:87.38ms +step:629/1680 train_time:54966ms step_avg:87.39ms +step:630/1680 train_time:55055ms step_avg:87.39ms +step:631/1680 train_time:55143ms step_avg:87.39ms +step:632/1680 train_time:55229ms step_avg:87.39ms +step:633/1680 train_time:55317ms step_avg:87.39ms +step:634/1680 train_time:55404ms step_avg:87.39ms +step:635/1680 train_time:55491ms step_avg:87.39ms +step:636/1680 train_time:55579ms step_avg:87.39ms +step:637/1680 train_time:55672ms step_avg:87.40ms +step:638/1680 train_time:55762ms step_avg:87.40ms +step:639/1680 train_time:55851ms step_avg:87.40ms +step:640/1680 train_time:55942ms step_avg:87.41ms +step:641/1680 train_time:56030ms step_avg:87.41ms +step:642/1680 train_time:56118ms step_avg:87.41ms +step:643/1680 train_time:56206ms step_avg:87.41ms +step:644/1680 train_time:56293ms step_avg:87.41ms +step:645/1680 train_time:56381ms step_avg:87.41ms +step:646/1680 train_time:56468ms step_avg:87.41ms +step:647/1680 train_time:56557ms step_avg:87.41ms +step:648/1680 train_time:56646ms step_avg:87.42ms +step:649/1680 train_time:56735ms step_avg:87.42ms +step:650/1680 train_time:56824ms step_avg:87.42ms +step:651/1680 train_time:56913ms step_avg:87.42ms +step:652/1680 train_time:57003ms step_avg:87.43ms +step:653/1680 train_time:57091ms step_avg:87.43ms +step:654/1680 train_time:57179ms step_avg:87.43ms +step:655/1680 train_time:57267ms step_avg:87.43ms +step:656/1680 train_time:57354ms step_avg:87.43ms +step:657/1680 train_time:57443ms step_avg:87.43ms +step:658/1680 train_time:57531ms step_avg:87.43ms +step:659/1680 train_time:57619ms step_avg:87.43ms +step:660/1680 train_time:57708ms step_avg:87.44ms +step:661/1680 train_time:57797ms step_avg:87.44ms +step:662/1680 train_time:57886ms step_avg:87.44ms +step:663/1680 train_time:57975ms step_avg:87.44ms +step:664/1680 train_time:58064ms step_avg:87.45ms +step:665/1680 train_time:58152ms step_avg:87.45ms +step:666/1680 train_time:58240ms step_avg:87.45ms +step:667/1680 train_time:58327ms step_avg:87.45ms +step:668/1680 train_time:58416ms step_avg:87.45ms +step:669/1680 train_time:58504ms step_avg:87.45ms +step:670/1680 train_time:58593ms step_avg:87.45ms +step:671/1680 train_time:58682ms step_avg:87.45ms +step:672/1680 train_time:58771ms step_avg:87.46ms +step:673/1680 train_time:58859ms step_avg:87.46ms +step:674/1680 train_time:58948ms step_avg:87.46ms +step:675/1680 train_time:59036ms step_avg:87.46ms +step:676/1680 train_time:59125ms step_avg:87.46ms +step:677/1680 train_time:59214ms step_avg:87.47ms +step:678/1680 train_time:59302ms step_avg:87.47ms +step:679/1680 train_time:59390ms step_avg:87.47ms +step:680/1680 train_time:59478ms step_avg:87.47ms +step:681/1680 train_time:59565ms step_avg:87.47ms +step:682/1680 train_time:59653ms step_avg:87.47ms +step:683/1680 train_time:59743ms step_avg:87.47ms +step:684/1680 train_time:59831ms step_avg:87.47ms +step:685/1680 train_time:59921ms step_avg:87.48ms +step:686/1680 train_time:60009ms step_avg:87.48ms +step:687/1680 train_time:60097ms step_avg:87.48ms +step:688/1680 train_time:60185ms step_avg:87.48ms +step:689/1680 train_time:60274ms step_avg:87.48ms +step:690/1680 train_time:60362ms step_avg:87.48ms +step:691/1680 train_time:60450ms step_avg:87.48ms +step:692/1680 train_time:60538ms step_avg:87.48ms +step:693/1680 train_time:60626ms step_avg:87.48ms +step:694/1680 train_time:60715ms step_avg:87.49ms +step:695/1680 train_time:60803ms step_avg:87.49ms +step:696/1680 train_time:60892ms step_avg:87.49ms +step:697/1680 train_time:60980ms step_avg:87.49ms +step:698/1680 train_time:61069ms step_avg:87.49ms +step:699/1680 train_time:61157ms step_avg:87.49ms +step:700/1680 train_time:61245ms step_avg:87.49ms +step:701/1680 train_time:61333ms step_avg:87.49ms +step:702/1680 train_time:61422ms step_avg:87.50ms +step:703/1680 train_time:61509ms step_avg:87.50ms +step:704/1680 train_time:61597ms step_avg:87.50ms +step:705/1680 train_time:61685ms step_avg:87.50ms +step:706/1680 train_time:61774ms step_avg:87.50ms +step:707/1680 train_time:61862ms step_avg:87.50ms +step:708/1680 train_time:61950ms step_avg:87.50ms +step:709/1680 train_time:62039ms step_avg:87.50ms +step:710/1680 train_time:62126ms step_avg:87.50ms +step:711/1680 train_time:62215ms step_avg:87.50ms +step:712/1680 train_time:62303ms step_avg:87.50ms +step:713/1680 train_time:62392ms step_avg:87.51ms +step:714/1680 train_time:62480ms step_avg:87.51ms +step:715/1680 train_time:62568ms step_avg:87.51ms +step:716/1680 train_time:62657ms step_avg:87.51ms +step:717/1680 train_time:62745ms step_avg:87.51ms +step:718/1680 train_time:62833ms step_avg:87.51ms +step:719/1680 train_time:62921ms step_avg:87.51ms +step:720/1680 train_time:63010ms step_avg:87.51ms +step:721/1680 train_time:63098ms step_avg:87.51ms +step:722/1680 train_time:63186ms step_avg:87.51ms +step:723/1680 train_time:63274ms step_avg:87.52ms +step:724/1680 train_time:63362ms step_avg:87.52ms +step:725/1680 train_time:63450ms step_avg:87.52ms +step:726/1680 train_time:63538ms step_avg:87.52ms +step:727/1680 train_time:63627ms step_avg:87.52ms +step:728/1680 train_time:63716ms step_avg:87.52ms +step:729/1680 train_time:63804ms step_avg:87.52ms +step:730/1680 train_time:63893ms step_avg:87.52ms +step:731/1680 train_time:63981ms step_avg:87.53ms +step:732/1680 train_time:64069ms step_avg:87.53ms +step:733/1680 train_time:64157ms step_avg:87.53ms +step:734/1680 train_time:64245ms step_avg:87.53ms +step:735/1680 train_time:64334ms step_avg:87.53ms +step:736/1680 train_time:64422ms step_avg:87.53ms +step:737/1680 train_time:64510ms step_avg:87.53ms +step:738/1680 train_time:64598ms step_avg:87.53ms +step:739/1680 train_time:64686ms step_avg:87.53ms +step:740/1680 train_time:64775ms step_avg:87.53ms +step:741/1680 train_time:64864ms step_avg:87.54ms +step:742/1680 train_time:64953ms step_avg:87.54ms +step:743/1680 train_time:65042ms step_avg:87.54ms +step:744/1680 train_time:65130ms step_avg:87.54ms +step:745/1680 train_time:65217ms step_avg:87.54ms +step:746/1680 train_time:65305ms step_avg:87.54ms +step:747/1680 train_time:65394ms step_avg:87.54ms +step:748/1680 train_time:65483ms step_avg:87.54ms +step:749/1680 train_time:65572ms step_avg:87.55ms +step:750/1680 train_time:65660ms step_avg:87.55ms +step:750/1680 val_loss:3.5647 train_time:65750ms step_avg:87.67ms +step:751/1680 train_time:65769ms step_avg:87.58ms +step:752/1680 train_time:65840ms step_avg:87.55ms +step:753/1680 train_time:65931ms step_avg:87.56ms +step:754/1680 train_time:66019ms step_avg:87.56ms +step:755/1680 train_time:66106ms step_avg:87.56ms +step:756/1680 train_time:66193ms step_avg:87.56ms +step:757/1680 train_time:66281ms step_avg:87.56ms +step:758/1680 train_time:66368ms step_avg:87.56ms +step:759/1680 train_time:66455ms step_avg:87.56ms +step:760/1680 train_time:66544ms step_avg:87.56ms +step:761/1680 train_time:66631ms step_avg:87.56ms +step:762/1680 train_time:66720ms step_avg:87.56ms +step:763/1680 train_time:66810ms step_avg:87.56ms +step:764/1680 train_time:66900ms step_avg:87.57ms +step:765/1680 train_time:66989ms step_avg:87.57ms +step:766/1680 train_time:67077ms step_avg:87.57ms +step:767/1680 train_time:67165ms step_avg:87.57ms +step:768/1680 train_time:67253ms step_avg:87.57ms +step:769/1680 train_time:67340ms step_avg:87.57ms +step:770/1680 train_time:67427ms step_avg:87.57ms +step:771/1680 train_time:67515ms step_avg:87.57ms +step:772/1680 train_time:67603ms step_avg:87.57ms +step:773/1680 train_time:67693ms step_avg:87.57ms +step:774/1680 train_time:67783ms step_avg:87.57ms +step:775/1680 train_time:67872ms step_avg:87.58ms +step:776/1680 train_time:67961ms step_avg:87.58ms +step:777/1680 train_time:68050ms step_avg:87.58ms +step:778/1680 train_time:68138ms step_avg:87.58ms +step:779/1680 train_time:68226ms step_avg:87.58ms +step:780/1680 train_time:68314ms step_avg:87.58ms +step:781/1680 train_time:68401ms step_avg:87.58ms +step:782/1680 train_time:68489ms step_avg:87.58ms +step:783/1680 train_time:68577ms step_avg:87.58ms +step:784/1680 train_time:68666ms step_avg:87.58ms +step:785/1680 train_time:68756ms step_avg:87.59ms +step:786/1680 train_time:68845ms step_avg:87.59ms +step:787/1680 train_time:68935ms step_avg:87.59ms +step:788/1680 train_time:69023ms step_avg:87.59ms +step:789/1680 train_time:69112ms step_avg:87.59ms +step:790/1680 train_time:69200ms step_avg:87.60ms +step:791/1680 train_time:69289ms step_avg:87.60ms +step:792/1680 train_time:69376ms step_avg:87.60ms +step:793/1680 train_time:69464ms step_avg:87.60ms +step:794/1680 train_time:69553ms step_avg:87.60ms +step:795/1680 train_time:69641ms step_avg:87.60ms +step:796/1680 train_time:69730ms step_avg:87.60ms +step:797/1680 train_time:69819ms step_avg:87.60ms +step:798/1680 train_time:69908ms step_avg:87.60ms +step:799/1680 train_time:69997ms step_avg:87.61ms +step:800/1680 train_time:70085ms step_avg:87.61ms +step:801/1680 train_time:70173ms step_avg:87.61ms +step:802/1680 train_time:70261ms step_avg:87.61ms +step:803/1680 train_time:70349ms step_avg:87.61ms +step:804/1680 train_time:70437ms step_avg:87.61ms +step:805/1680 train_time:70525ms step_avg:87.61ms +step:806/1680 train_time:70613ms step_avg:87.61ms +step:807/1680 train_time:70702ms step_avg:87.61ms +step:808/1680 train_time:70791ms step_avg:87.61ms +step:809/1680 train_time:70880ms step_avg:87.61ms +step:810/1680 train_time:70969ms step_avg:87.62ms +step:811/1680 train_time:71057ms step_avg:87.62ms +step:812/1680 train_time:71146ms step_avg:87.62ms +step:813/1680 train_time:71234ms step_avg:87.62ms +step:814/1680 train_time:71322ms step_avg:87.62ms +step:815/1680 train_time:71410ms step_avg:87.62ms +step:816/1680 train_time:71498ms step_avg:87.62ms +step:817/1680 train_time:71586ms step_avg:87.62ms +step:818/1680 train_time:71675ms step_avg:87.62ms +step:819/1680 train_time:71763ms step_avg:87.62ms +step:820/1680 train_time:71853ms step_avg:87.63ms +step:821/1680 train_time:71941ms step_avg:87.63ms +step:822/1680 train_time:72030ms step_avg:87.63ms +step:823/1680 train_time:72118ms step_avg:87.63ms +step:824/1680 train_time:72207ms step_avg:87.63ms +step:825/1680 train_time:72296ms step_avg:87.63ms +step:826/1680 train_time:72384ms step_avg:87.63ms +step:827/1680 train_time:72472ms step_avg:87.63ms +step:828/1680 train_time:72560ms step_avg:87.63ms +step:829/1680 train_time:72648ms step_avg:87.63ms +step:830/1680 train_time:72736ms step_avg:87.63ms +step:831/1680 train_time:72824ms step_avg:87.63ms +step:832/1680 train_time:72913ms step_avg:87.64ms +step:833/1680 train_time:73001ms step_avg:87.64ms +step:834/1680 train_time:73090ms step_avg:87.64ms +step:835/1680 train_time:73178ms step_avg:87.64ms +step:836/1680 train_time:73266ms step_avg:87.64ms +step:837/1680 train_time:73356ms step_avg:87.64ms +step:838/1680 train_time:73444ms step_avg:87.64ms +step:839/1680 train_time:73532ms step_avg:87.64ms +step:840/1680 train_time:73621ms step_avg:87.64ms +step:841/1680 train_time:73710ms step_avg:87.65ms +step:842/1680 train_time:73798ms step_avg:87.65ms +step:843/1680 train_time:73887ms step_avg:87.65ms +step:844/1680 train_time:73975ms step_avg:87.65ms +step:845/1680 train_time:74064ms step_avg:87.65ms +step:846/1680 train_time:74153ms step_avg:87.65ms +step:847/1680 train_time:74241ms step_avg:87.65ms +step:848/1680 train_time:74329ms step_avg:87.65ms +step:849/1680 train_time:74417ms step_avg:87.65ms +step:850/1680 train_time:74505ms step_avg:87.65ms +step:851/1680 train_time:74593ms step_avg:87.65ms +step:852/1680 train_time:74681ms step_avg:87.65ms +step:853/1680 train_time:74769ms step_avg:87.65ms +step:854/1680 train_time:74857ms step_avg:87.65ms +step:855/1680 train_time:74945ms step_avg:87.66ms +step:856/1680 train_time:75034ms step_avg:87.66ms +step:857/1680 train_time:75122ms step_avg:87.66ms +step:858/1680 train_time:75211ms step_avg:87.66ms +step:859/1680 train_time:75299ms step_avg:87.66ms +step:860/1680 train_time:75387ms step_avg:87.66ms +step:861/1680 train_time:75476ms step_avg:87.66ms +step:862/1680 train_time:75565ms step_avg:87.66ms +step:863/1680 train_time:75653ms step_avg:87.66ms +step:864/1680 train_time:75741ms step_avg:87.66ms +step:865/1680 train_time:75829ms step_avg:87.66ms +step:866/1680 train_time:75918ms step_avg:87.67ms +step:867/1680 train_time:76007ms step_avg:87.67ms +step:868/1680 train_time:76095ms step_avg:87.67ms +step:869/1680 train_time:76184ms step_avg:87.67ms +step:870/1680 train_time:76273ms step_avg:87.67ms +step:871/1680 train_time:76360ms step_avg:87.67ms +step:872/1680 train_time:76448ms step_avg:87.67ms +step:873/1680 train_time:76537ms step_avg:87.67ms +step:874/1680 train_time:76625ms step_avg:87.67ms +step:875/1680 train_time:76715ms step_avg:87.67ms +step:875/1680 val_loss:3.5181 train_time:76804ms step_avg:87.78ms +step:876/1680 train_time:76823ms step_avg:87.70ms +step:877/1680 train_time:76895ms step_avg:87.68ms +step:878/1680 train_time:76988ms step_avg:87.69ms +step:879/1680 train_time:77079ms step_avg:87.69ms +step:880/1680 train_time:77167ms step_avg:87.69ms +step:881/1680 train_time:77254ms step_avg:87.69ms +step:882/1680 train_time:77342ms step_avg:87.69ms +step:883/1680 train_time:77429ms step_avg:87.69ms +step:884/1680 train_time:77516ms step_avg:87.69ms +step:885/1680 train_time:77604ms step_avg:87.69ms +step:886/1680 train_time:77692ms step_avg:87.69ms +step:887/1680 train_time:77782ms step_avg:87.69ms +step:888/1680 train_time:77872ms step_avg:87.69ms +step:889/1680 train_time:77963ms step_avg:87.70ms +step:890/1680 train_time:78054ms step_avg:87.70ms +step:891/1680 train_time:78143ms step_avg:87.70ms +step:892/1680 train_time:78232ms step_avg:87.70ms +step:893/1680 train_time:78320ms step_avg:87.70ms +step:894/1680 train_time:78407ms step_avg:87.70ms +step:895/1680 train_time:78495ms step_avg:87.70ms +step:896/1680 train_time:78583ms step_avg:87.70ms +step:897/1680 train_time:78670ms step_avg:87.70ms +step:898/1680 train_time:78758ms step_avg:87.70ms +step:899/1680 train_time:78846ms step_avg:87.70ms +step:900/1680 train_time:78936ms step_avg:87.71ms +step:901/1680 train_time:79025ms step_avg:87.71ms +step:902/1680 train_time:79115ms step_avg:87.71ms +step:903/1680 train_time:79204ms step_avg:87.71ms +step:904/1680 train_time:79292ms step_avg:87.71ms +step:905/1680 train_time:79379ms step_avg:87.71ms +step:906/1680 train_time:79467ms step_avg:87.71ms +step:907/1680 train_time:79555ms step_avg:87.71ms +step:908/1680 train_time:79643ms step_avg:87.71ms +step:909/1680 train_time:79730ms step_avg:87.71ms +step:910/1680 train_time:79819ms step_avg:87.71ms +step:911/1680 train_time:79908ms step_avg:87.71ms +step:912/1680 train_time:79996ms step_avg:87.72ms +step:913/1680 train_time:80085ms step_avg:87.72ms +step:914/1680 train_time:80173ms step_avg:87.72ms +step:915/1680 train_time:80262ms step_avg:87.72ms +step:916/1680 train_time:80351ms step_avg:87.72ms +step:917/1680 train_time:80439ms step_avg:87.72ms +step:918/1680 train_time:80526ms step_avg:87.72ms +step:919/1680 train_time:80615ms step_avg:87.72ms +step:920/1680 train_time:80703ms step_avg:87.72ms +step:921/1680 train_time:80791ms step_avg:87.72ms +step:922/1680 train_time:80880ms step_avg:87.72ms +step:923/1680 train_time:80968ms step_avg:87.72ms +step:924/1680 train_time:81056ms step_avg:87.72ms +step:925/1680 train_time:81145ms step_avg:87.72ms +step:926/1680 train_time:81234ms step_avg:87.73ms +step:927/1680 train_time:81323ms step_avg:87.73ms +step:928/1680 train_time:81411ms step_avg:87.73ms +step:929/1680 train_time:81500ms step_avg:87.73ms +step:930/1680 train_time:81587ms step_avg:87.73ms +step:931/1680 train_time:81675ms step_avg:87.73ms +step:932/1680 train_time:81763ms step_avg:87.73ms +step:933/1680 train_time:81853ms step_avg:87.73ms +step:934/1680 train_time:81941ms step_avg:87.73ms +step:935/1680 train_time:82030ms step_avg:87.73ms +step:936/1680 train_time:82119ms step_avg:87.73ms +step:937/1680 train_time:82207ms step_avg:87.73ms +step:938/1680 train_time:82296ms step_avg:87.74ms +step:939/1680 train_time:82384ms step_avg:87.74ms +step:940/1680 train_time:82472ms step_avg:87.74ms +step:941/1680 train_time:82560ms step_avg:87.74ms +step:942/1680 train_time:82648ms step_avg:87.74ms +step:943/1680 train_time:82736ms step_avg:87.74ms +step:944/1680 train_time:82824ms step_avg:87.74ms +step:945/1680 train_time:82913ms step_avg:87.74ms +step:946/1680 train_time:83002ms step_avg:87.74ms +step:947/1680 train_time:83090ms step_avg:87.74ms +step:948/1680 train_time:83179ms step_avg:87.74ms +step:949/1680 train_time:83268ms step_avg:87.74ms +step:950/1680 train_time:83356ms step_avg:87.74ms +step:951/1680 train_time:83444ms step_avg:87.74ms +step:952/1680 train_time:83532ms step_avg:87.74ms +step:953/1680 train_time:83621ms step_avg:87.74ms +step:954/1680 train_time:83709ms step_avg:87.74ms +step:955/1680 train_time:83797ms step_avg:87.75ms +step:956/1680 train_time:83885ms step_avg:87.75ms +step:957/1680 train_time:83974ms step_avg:87.75ms +step:958/1680 train_time:84062ms step_avg:87.75ms +step:959/1680 train_time:84150ms step_avg:87.75ms +step:960/1680 train_time:84240ms step_avg:87.75ms +step:961/1680 train_time:84328ms step_avg:87.75ms +step:962/1680 train_time:84419ms step_avg:87.75ms +step:963/1680 train_time:84507ms step_avg:87.75ms +step:964/1680 train_time:84596ms step_avg:87.75ms +step:965/1680 train_time:84683ms step_avg:87.75ms +step:966/1680 train_time:84772ms step_avg:87.76ms +step:967/1680 train_time:84861ms step_avg:87.76ms +step:968/1680 train_time:84949ms step_avg:87.76ms +step:969/1680 train_time:85037ms step_avg:87.76ms +step:970/1680 train_time:85126ms step_avg:87.76ms +step:971/1680 train_time:85216ms step_avg:87.76ms +step:972/1680 train_time:85304ms step_avg:87.76ms +step:973/1680 train_time:85393ms step_avg:87.76ms +step:974/1680 train_time:85482ms step_avg:87.76ms +step:975/1680 train_time:85570ms step_avg:87.76ms +step:976/1680 train_time:85659ms step_avg:87.77ms +step:977/1680 train_time:85747ms step_avg:87.77ms +step:978/1680 train_time:85835ms step_avg:87.77ms +step:979/1680 train_time:85923ms step_avg:87.77ms +step:980/1680 train_time:86011ms step_avg:87.77ms +step:981/1680 train_time:86100ms step_avg:87.77ms +step:982/1680 train_time:86188ms step_avg:87.77ms +step:983/1680 train_time:86277ms step_avg:87.77ms +step:984/1680 train_time:86366ms step_avg:87.77ms +step:985/1680 train_time:86455ms step_avg:87.77ms +step:986/1680 train_time:86543ms step_avg:87.77ms +step:987/1680 train_time:86633ms step_avg:87.77ms +step:988/1680 train_time:86721ms step_avg:87.77ms +step:989/1680 train_time:86810ms step_avg:87.78ms +step:990/1680 train_time:86898ms step_avg:87.78ms +step:991/1680 train_time:86986ms step_avg:87.78ms +step:992/1680 train_time:87074ms step_avg:87.78ms +step:993/1680 train_time:87162ms step_avg:87.78ms +step:994/1680 train_time:87251ms step_avg:87.78ms +step:995/1680 train_time:87340ms step_avg:87.78ms +step:996/1680 train_time:87428ms step_avg:87.78ms +step:997/1680 train_time:87517ms step_avg:87.78ms +step:998/1680 train_time:87607ms step_avg:87.78ms +step:999/1680 train_time:87695ms step_avg:87.78ms +step:1000/1680 train_time:87784ms step_avg:87.78ms +step:1000/1680 val_loss:3.4674 train_time:87873ms step_avg:87.87ms +step:1001/1680 train_time:87893ms step_avg:87.81ms +step:1002/1680 train_time:87966ms step_avg:87.79ms +step:1003/1680 train_time:88058ms step_avg:87.79ms +step:1004/1680 train_time:88147ms step_avg:87.80ms +step:1005/1680 train_time:88235ms step_avg:87.80ms +step:1006/1680 train_time:88322ms step_avg:87.80ms +step:1007/1680 train_time:88410ms step_avg:87.80ms +step:1008/1680 train_time:88497ms step_avg:87.79ms +step:1009/1680 train_time:88584ms step_avg:87.79ms +step:1010/1680 train_time:88671ms step_avg:87.79ms +step:1011/1680 train_time:88759ms step_avg:87.79ms +step:1012/1680 train_time:88848ms step_avg:87.79ms +step:1013/1680 train_time:88937ms step_avg:87.80ms +step:1014/1680 train_time:89028ms step_avg:87.80ms +step:1015/1680 train_time:89118ms step_avg:87.80ms +step:1016/1680 train_time:89207ms step_avg:87.80ms +step:1017/1680 train_time:89296ms step_avg:87.80ms +step:1018/1680 train_time:89384ms step_avg:87.80ms +step:1019/1680 train_time:89472ms step_avg:87.80ms +step:1020/1680 train_time:89558ms step_avg:87.80ms +step:1021/1680 train_time:89646ms step_avg:87.80ms +step:1022/1680 train_time:89733ms step_avg:87.80ms +step:1023/1680 train_time:89821ms step_avg:87.80ms +step:1024/1680 train_time:89911ms step_avg:87.80ms +step:1025/1680 train_time:90000ms step_avg:87.80ms +step:1026/1680 train_time:90089ms step_avg:87.81ms +step:1027/1680 train_time:90177ms step_avg:87.81ms +step:1028/1680 train_time:90267ms step_avg:87.81ms +step:1029/1680 train_time:90355ms step_avg:87.81ms +step:1030/1680 train_time:90443ms step_avg:87.81ms +step:1031/1680 train_time:90530ms step_avg:87.81ms +step:1032/1680 train_time:90618ms step_avg:87.81ms +step:1033/1680 train_time:90706ms step_avg:87.81ms +step:1034/1680 train_time:90794ms step_avg:87.81ms +step:1035/1680 train_time:90883ms step_avg:87.81ms +step:1036/1680 train_time:90971ms step_avg:87.81ms +step:1037/1680 train_time:91060ms step_avg:87.81ms +step:1038/1680 train_time:91149ms step_avg:87.81ms +step:1039/1680 train_time:91237ms step_avg:87.81ms +step:1040/1680 train_time:91327ms step_avg:87.81ms +step:1041/1680 train_time:91415ms step_avg:87.81ms +step:1042/1680 train_time:91503ms step_avg:87.81ms +step:1043/1680 train_time:91591ms step_avg:87.81ms +step:1044/1680 train_time:91679ms step_avg:87.81ms +step:1045/1680 train_time:91766ms step_avg:87.81ms +step:1046/1680 train_time:91855ms step_avg:87.82ms +step:1047/1680 train_time:91944ms step_avg:87.82ms +step:1048/1680 train_time:92033ms step_avg:87.82ms +step:1049/1680 train_time:92122ms step_avg:87.82ms +step:1050/1680 train_time:92211ms step_avg:87.82ms +step:1051/1680 train_time:92299ms step_avg:87.82ms +step:1052/1680 train_time:92388ms step_avg:87.82ms +step:1053/1680 train_time:92477ms step_avg:87.82ms +step:1054/1680 train_time:92565ms step_avg:87.82ms +step:1055/1680 train_time:92653ms step_avg:87.82ms +step:1056/1680 train_time:92741ms step_avg:87.82ms +step:1057/1680 train_time:92829ms step_avg:87.82ms +step:1058/1680 train_time:92917ms step_avg:87.82ms +step:1059/1680 train_time:93006ms step_avg:87.82ms +step:1060/1680 train_time:93094ms step_avg:87.82ms +step:1061/1680 train_time:93184ms step_avg:87.83ms +step:1062/1680 train_time:93272ms step_avg:87.83ms +step:1063/1680 train_time:93361ms step_avg:87.83ms +step:1064/1680 train_time:93451ms step_avg:87.83ms +step:1065/1680 train_time:93538ms step_avg:87.83ms +step:1066/1680 train_time:93627ms step_avg:87.83ms +step:1067/1680 train_time:93715ms step_avg:87.83ms +step:1068/1680 train_time:93803ms step_avg:87.83ms +step:1069/1680 train_time:93892ms step_avg:87.83ms +step:1070/1680 train_time:93980ms step_avg:87.83ms +step:1071/1680 train_time:94068ms step_avg:87.83ms +step:1072/1680 train_time:94156ms step_avg:87.83ms +step:1073/1680 train_time:94245ms step_avg:87.83ms +step:1074/1680 train_time:94333ms step_avg:87.83ms +step:1075/1680 train_time:94422ms step_avg:87.83ms +step:1076/1680 train_time:94511ms step_avg:87.84ms +step:1077/1680 train_time:94598ms step_avg:87.83ms +step:1078/1680 train_time:94687ms step_avg:87.84ms +step:1079/1680 train_time:94775ms step_avg:87.84ms +step:1080/1680 train_time:94864ms step_avg:87.84ms +step:1081/1680 train_time:94952ms step_avg:87.84ms +step:1082/1680 train_time:95040ms step_avg:87.84ms +step:1083/1680 train_time:95129ms step_avg:87.84ms +step:1084/1680 train_time:95217ms step_avg:87.84ms +step:1085/1680 train_time:95305ms step_avg:87.84ms +step:1086/1680 train_time:95395ms step_avg:87.84ms +step:1087/1680 train_time:95484ms step_avg:87.84ms +step:1088/1680 train_time:95572ms step_avg:87.84ms +step:1089/1680 train_time:95660ms step_avg:87.84ms +step:1090/1680 train_time:95749ms step_avg:87.84ms +step:1091/1680 train_time:95837ms step_avg:87.84ms +step:1092/1680 train_time:95927ms step_avg:87.84ms +step:1093/1680 train_time:96015ms step_avg:87.85ms +step:1094/1680 train_time:96103ms step_avg:87.85ms +step:1095/1680 train_time:96193ms step_avg:87.85ms +step:1096/1680 train_time:96282ms step_avg:87.85ms +step:1097/1680 train_time:96371ms step_avg:87.85ms +step:1098/1680 train_time:96461ms step_avg:87.85ms +step:1099/1680 train_time:96552ms step_avg:87.85ms +step:1100/1680 train_time:96640ms step_avg:87.85ms +step:1101/1680 train_time:96730ms step_avg:87.86ms +step:1102/1680 train_time:96818ms step_avg:87.86ms +step:1103/1680 train_time:96907ms step_avg:87.86ms +step:1104/1680 train_time:96997ms step_avg:87.86ms +step:1105/1680 train_time:97086ms step_avg:87.86ms +step:1106/1680 train_time:97176ms step_avg:87.86ms +step:1107/1680 train_time:97265ms step_avg:87.86ms +step:1108/1680 train_time:97354ms step_avg:87.86ms +step:1109/1680 train_time:97444ms step_avg:87.87ms +step:1110/1680 train_time:97533ms step_avg:87.87ms +step:1111/1680 train_time:97623ms step_avg:87.87ms +step:1112/1680 train_time:97712ms step_avg:87.87ms +step:1113/1680 train_time:97802ms step_avg:87.87ms +step:1114/1680 train_time:97891ms step_avg:87.87ms +step:1115/1680 train_time:97980ms step_avg:87.87ms +step:1116/1680 train_time:98070ms step_avg:87.88ms +step:1117/1680 train_time:98158ms step_avg:87.88ms +step:1118/1680 train_time:98248ms step_avg:87.88ms +step:1119/1680 train_time:98337ms step_avg:87.88ms +step:1120/1680 train_time:98425ms step_avg:87.88ms +step:1121/1680 train_time:98514ms step_avg:87.88ms +step:1122/1680 train_time:98603ms step_avg:87.88ms +step:1123/1680 train_time:98693ms step_avg:87.88ms +step:1124/1680 train_time:98782ms step_avg:87.88ms +step:1125/1680 train_time:98872ms step_avg:87.89ms +step:1125/1680 val_loss:3.4151 train_time:98961ms step_avg:87.97ms +step:1126/1680 train_time:98982ms step_avg:87.91ms +step:1127/1680 train_time:99055ms step_avg:87.89ms +step:1128/1680 train_time:99145ms step_avg:87.89ms +step:1129/1680 train_time:99236ms step_avg:87.90ms +step:1130/1680 train_time:99327ms step_avg:87.90ms +step:1131/1680 train_time:99415ms step_avg:87.90ms +step:1132/1680 train_time:99503ms step_avg:87.90ms +step:1133/1680 train_time:99591ms step_avg:87.90ms +step:1134/1680 train_time:99678ms step_avg:87.90ms +step:1135/1680 train_time:99766ms step_avg:87.90ms +step:1136/1680 train_time:99854ms step_avg:87.90ms +step:1137/1680 train_time:99944ms step_avg:87.90ms +step:1138/1680 train_time:100034ms step_avg:87.90ms +step:1139/1680 train_time:100126ms step_avg:87.91ms +step:1140/1680 train_time:100217ms step_avg:87.91ms +step:1141/1680 train_time:100307ms step_avg:87.91ms +step:1142/1680 train_time:100395ms step_avg:87.91ms +step:1143/1680 train_time:100484ms step_avg:87.91ms +step:1144/1680 train_time:100573ms step_avg:87.91ms +step:1145/1680 train_time:100661ms step_avg:87.91ms +step:1146/1680 train_time:100750ms step_avg:87.91ms +step:1147/1680 train_time:100838ms step_avg:87.91ms +step:1148/1680 train_time:100927ms step_avg:87.92ms +step:1149/1680 train_time:101018ms step_avg:87.92ms +step:1150/1680 train_time:101108ms step_avg:87.92ms +step:1151/1680 train_time:101198ms step_avg:87.92ms +step:1152/1680 train_time:101288ms step_avg:87.92ms +step:1153/1680 train_time:101377ms step_avg:87.92ms +step:1154/1680 train_time:101465ms step_avg:87.92ms +step:1155/1680 train_time:101554ms step_avg:87.93ms +step:1156/1680 train_time:101643ms step_avg:87.93ms +step:1157/1680 train_time:101731ms step_avg:87.93ms +step:1158/1680 train_time:101821ms step_avg:87.93ms +step:1159/1680 train_time:101911ms step_avg:87.93ms +step:1160/1680 train_time:102000ms step_avg:87.93ms +step:1161/1680 train_time:102091ms step_avg:87.93ms +step:1162/1680 train_time:102181ms step_avg:87.94ms +step:1163/1680 train_time:102270ms step_avg:87.94ms +step:1164/1680 train_time:102359ms step_avg:87.94ms +step:1165/1680 train_time:102448ms step_avg:87.94ms +step:1166/1680 train_time:102537ms step_avg:87.94ms +step:1167/1680 train_time:102625ms step_avg:87.94ms +step:1168/1680 train_time:102714ms step_avg:87.94ms +step:1169/1680 train_time:102803ms step_avg:87.94ms +step:1170/1680 train_time:102893ms step_avg:87.94ms +step:1171/1680 train_time:102981ms step_avg:87.94ms +step:1172/1680 train_time:103071ms step_avg:87.94ms +step:1173/1680 train_time:103160ms step_avg:87.95ms +step:1174/1680 train_time:103250ms step_avg:87.95ms +step:1175/1680 train_time:103339ms step_avg:87.95ms +step:1176/1680 train_time:103428ms step_avg:87.95ms +step:1177/1680 train_time:103516ms step_avg:87.95ms +step:1178/1680 train_time:103605ms step_avg:87.95ms +step:1179/1680 train_time:103694ms step_avg:87.95ms +step:1180/1680 train_time:103783ms step_avg:87.95ms +step:1181/1680 train_time:103872ms step_avg:87.95ms +step:1182/1680 train_time:103961ms step_avg:87.95ms +step:1183/1680 train_time:104051ms step_avg:87.96ms +step:1184/1680 train_time:104141ms step_avg:87.96ms +step:1185/1680 train_time:104230ms step_avg:87.96ms +step:1186/1680 train_time:104319ms step_avg:87.96ms +step:1187/1680 train_time:104409ms step_avg:87.96ms +step:1188/1680 train_time:104498ms step_avg:87.96ms +step:1189/1680 train_time:104587ms step_avg:87.96ms +step:1190/1680 train_time:104676ms step_avg:87.96ms +step:1191/1680 train_time:104765ms step_avg:87.96ms +step:1192/1680 train_time:104854ms step_avg:87.97ms +step:1193/1680 train_time:104944ms step_avg:87.97ms +step:1194/1680 train_time:105033ms step_avg:87.97ms +step:1195/1680 train_time:105122ms step_avg:87.97ms +step:1196/1680 train_time:105213ms step_avg:87.97ms +step:1197/1680 train_time:105302ms step_avg:87.97ms +step:1198/1680 train_time:105391ms step_avg:87.97ms +step:1199/1680 train_time:105479ms step_avg:87.97ms +step:1200/1680 train_time:105569ms step_avg:87.97ms +step:1201/1680 train_time:105658ms step_avg:87.97ms +step:1202/1680 train_time:105747ms step_avg:87.98ms +step:1203/1680 train_time:105836ms step_avg:87.98ms +step:1204/1680 train_time:105925ms step_avg:87.98ms +step:1205/1680 train_time:106015ms step_avg:87.98ms +step:1206/1680 train_time:106104ms step_avg:87.98ms +step:1207/1680 train_time:106193ms step_avg:87.98ms +step:1208/1680 train_time:106282ms step_avg:87.98ms +step:1209/1680 train_time:106372ms step_avg:87.98ms +step:1210/1680 train_time:106460ms step_avg:87.98ms +step:1211/1680 train_time:106549ms step_avg:87.98ms +step:1212/1680 train_time:106639ms step_avg:87.99ms +step:1213/1680 train_time:106730ms step_avg:87.99ms +step:1214/1680 train_time:106818ms step_avg:87.99ms +step:1215/1680 train_time:106907ms step_avg:87.99ms +step:1216/1680 train_time:106996ms step_avg:87.99ms +step:1217/1680 train_time:107087ms step_avg:87.99ms +step:1218/1680 train_time:107176ms step_avg:87.99ms +step:1219/1680 train_time:107266ms step_avg:88.00ms +step:1220/1680 train_time:107356ms step_avg:88.00ms +step:1221/1680 train_time:107447ms step_avg:88.00ms +step:1222/1680 train_time:107537ms step_avg:88.00ms +step:1223/1680 train_time:107626ms step_avg:88.00ms +step:1224/1680 train_time:107716ms step_avg:88.00ms +step:1225/1680 train_time:107805ms step_avg:88.00ms +step:1226/1680 train_time:107894ms step_avg:88.01ms +step:1227/1680 train_time:107984ms step_avg:88.01ms +step:1228/1680 train_time:108074ms step_avg:88.01ms +step:1229/1680 train_time:108163ms step_avg:88.01ms +step:1230/1680 train_time:108254ms step_avg:88.01ms +step:1231/1680 train_time:108343ms step_avg:88.01ms +step:1232/1680 train_time:108433ms step_avg:88.01ms +step:1233/1680 train_time:108522ms step_avg:88.01ms +step:1234/1680 train_time:108612ms step_avg:88.02ms +step:1235/1680 train_time:108701ms step_avg:88.02ms +step:1236/1680 train_time:108789ms step_avg:88.02ms +step:1237/1680 train_time:108878ms step_avg:88.02ms +step:1238/1680 train_time:108967ms step_avg:88.02ms +step:1239/1680 train_time:109057ms step_avg:88.02ms +step:1240/1680 train_time:109146ms step_avg:88.02ms +step:1241/1680 train_time:109235ms step_avg:88.02ms +step:1242/1680 train_time:109325ms step_avg:88.02ms +step:1243/1680 train_time:109414ms step_avg:88.02ms +step:1244/1680 train_time:109504ms step_avg:88.03ms +step:1245/1680 train_time:109593ms step_avg:88.03ms +step:1246/1680 train_time:109683ms step_avg:88.03ms +step:1247/1680 train_time:109772ms step_avg:88.03ms +step:1248/1680 train_time:109861ms step_avg:88.03ms +step:1249/1680 train_time:109949ms step_avg:88.03ms +step:1250/1680 train_time:110039ms step_avg:88.03ms +step:1250/1680 val_loss:3.3769 train_time:110130ms step_avg:88.10ms +step:1251/1680 train_time:110148ms step_avg:88.05ms +step:1252/1680 train_time:110221ms step_avg:88.04ms +step:1253/1680 train_time:110312ms step_avg:88.04ms +step:1254/1680 train_time:110402ms step_avg:88.04ms +step:1255/1680 train_time:110492ms step_avg:88.04ms +step:1256/1680 train_time:110580ms step_avg:88.04ms +step:1257/1680 train_time:110668ms step_avg:88.04ms +step:1258/1680 train_time:110757ms step_avg:88.04ms +step:1259/1680 train_time:110845ms step_avg:88.04ms +step:1260/1680 train_time:110934ms step_avg:88.04ms +step:1261/1680 train_time:111022ms step_avg:88.04ms +step:1262/1680 train_time:111113ms step_avg:88.05ms +step:1263/1680 train_time:111203ms step_avg:88.05ms +step:1264/1680 train_time:111294ms step_avg:88.05ms +step:1265/1680 train_time:111385ms step_avg:88.05ms +step:1266/1680 train_time:111474ms step_avg:88.05ms +step:1267/1680 train_time:111563ms step_avg:88.05ms +step:1268/1680 train_time:111651ms step_avg:88.05ms +step:1269/1680 train_time:111740ms step_avg:88.05ms +step:1270/1680 train_time:111829ms step_avg:88.05ms +step:1271/1680 train_time:111918ms step_avg:88.05ms +step:1272/1680 train_time:112006ms step_avg:88.06ms +step:1273/1680 train_time:112095ms step_avg:88.06ms +step:1274/1680 train_time:112184ms step_avg:88.06ms +step:1275/1680 train_time:112274ms step_avg:88.06ms +step:1276/1680 train_time:112365ms step_avg:88.06ms +step:1277/1680 train_time:112455ms step_avg:88.06ms +step:1278/1680 train_time:112544ms step_avg:88.06ms +step:1279/1680 train_time:112632ms step_avg:88.06ms +step:1280/1680 train_time:112721ms step_avg:88.06ms +step:1281/1680 train_time:112811ms step_avg:88.06ms +step:1282/1680 train_time:112900ms step_avg:88.07ms +step:1283/1680 train_time:112989ms step_avg:88.07ms +step:1284/1680 train_time:113078ms step_avg:88.07ms +step:1285/1680 train_time:113168ms step_avg:88.07ms +step:1286/1680 train_time:113256ms step_avg:88.07ms +step:1287/1680 train_time:113346ms step_avg:88.07ms +step:1288/1680 train_time:113435ms step_avg:88.07ms +step:1289/1680 train_time:113525ms step_avg:88.07ms +step:1290/1680 train_time:113614ms step_avg:88.07ms +step:1291/1680 train_time:113703ms step_avg:88.07ms +step:1292/1680 train_time:113792ms step_avg:88.07ms +step:1293/1680 train_time:113881ms step_avg:88.08ms +step:1294/1680 train_time:113971ms step_avg:88.08ms +step:1295/1680 train_time:114060ms step_avg:88.08ms +step:1296/1680 train_time:114151ms step_avg:88.08ms +step:1297/1680 train_time:114241ms step_avg:88.08ms +step:1298/1680 train_time:114330ms step_avg:88.08ms +step:1299/1680 train_time:114419ms step_avg:88.08ms +step:1300/1680 train_time:114508ms step_avg:88.08ms +step:1301/1680 train_time:114597ms step_avg:88.08ms +step:1302/1680 train_time:114687ms step_avg:88.08ms +step:1303/1680 train_time:114776ms step_avg:88.09ms +step:1304/1680 train_time:114866ms step_avg:88.09ms +step:1305/1680 train_time:114955ms step_avg:88.09ms +step:1306/1680 train_time:115045ms step_avg:88.09ms +step:1307/1680 train_time:115134ms step_avg:88.09ms +step:1308/1680 train_time:115224ms step_avg:88.09ms +step:1309/1680 train_time:115313ms step_avg:88.09ms +step:1310/1680 train_time:115403ms step_avg:88.09ms +step:1311/1680 train_time:115493ms step_avg:88.10ms +step:1312/1680 train_time:115581ms step_avg:88.10ms +step:1313/1680 train_time:115671ms step_avg:88.10ms +step:1314/1680 train_time:115760ms step_avg:88.10ms +step:1315/1680 train_time:115849ms step_avg:88.10ms +step:1316/1680 train_time:115939ms step_avg:88.10ms +step:1317/1680 train_time:116027ms step_avg:88.10ms +step:1318/1680 train_time:116117ms step_avg:88.10ms +step:1319/1680 train_time:116206ms step_avg:88.10ms +step:1320/1680 train_time:116295ms step_avg:88.10ms +step:1321/1680 train_time:116385ms step_avg:88.10ms +step:1322/1680 train_time:116474ms step_avg:88.10ms +step:1323/1680 train_time:116564ms step_avg:88.11ms +step:1324/1680 train_time:116652ms step_avg:88.11ms +step:1325/1680 train_time:116742ms step_avg:88.11ms +step:1326/1680 train_time:116832ms step_avg:88.11ms +step:1327/1680 train_time:116920ms step_avg:88.11ms +step:1328/1680 train_time:117009ms step_avg:88.11ms +step:1329/1680 train_time:117097ms step_avg:88.11ms +step:1330/1680 train_time:117187ms step_avg:88.11ms +step:1331/1680 train_time:117276ms step_avg:88.11ms +step:1332/1680 train_time:117366ms step_avg:88.11ms +step:1333/1680 train_time:117456ms step_avg:88.11ms +step:1334/1680 train_time:117545ms step_avg:88.11ms +step:1335/1680 train_time:117634ms step_avg:88.12ms +step:1336/1680 train_time:117724ms step_avg:88.12ms +step:1337/1680 train_time:117813ms step_avg:88.12ms +step:1338/1680 train_time:117903ms step_avg:88.12ms +step:1339/1680 train_time:117992ms step_avg:88.12ms +step:1340/1680 train_time:118081ms step_avg:88.12ms +step:1341/1680 train_time:118170ms step_avg:88.12ms +step:1342/1680 train_time:118259ms step_avg:88.12ms +step:1343/1680 train_time:118348ms step_avg:88.12ms +step:1344/1680 train_time:118437ms step_avg:88.12ms +step:1345/1680 train_time:118527ms step_avg:88.12ms +step:1346/1680 train_time:118615ms step_avg:88.12ms +step:1347/1680 train_time:118704ms step_avg:88.13ms +step:1348/1680 train_time:118794ms step_avg:88.13ms +step:1349/1680 train_time:118884ms step_avg:88.13ms +step:1350/1680 train_time:118972ms step_avg:88.13ms +step:1351/1680 train_time:119063ms step_avg:88.13ms +step:1352/1680 train_time:119153ms step_avg:88.13ms +step:1353/1680 train_time:119243ms step_avg:88.13ms +step:1354/1680 train_time:119332ms step_avg:88.13ms +step:1355/1680 train_time:119421ms step_avg:88.13ms +step:1356/1680 train_time:119510ms step_avg:88.13ms +step:1357/1680 train_time:119599ms step_avg:88.13ms +step:1358/1680 train_time:119689ms step_avg:88.14ms +step:1359/1680 train_time:119778ms step_avg:88.14ms +step:1360/1680 train_time:119866ms step_avg:88.14ms +step:1361/1680 train_time:119955ms step_avg:88.14ms +step:1362/1680 train_time:120044ms step_avg:88.14ms +step:1363/1680 train_time:120133ms step_avg:88.14ms +step:1364/1680 train_time:120223ms step_avg:88.14ms +step:1365/1680 train_time:120312ms step_avg:88.14ms +step:1366/1680 train_time:120401ms step_avg:88.14ms +step:1367/1680 train_time:120491ms step_avg:88.14ms +step:1368/1680 train_time:120580ms step_avg:88.14ms +step:1369/1680 train_time:120669ms step_avg:88.14ms +step:1370/1680 train_time:120758ms step_avg:88.14ms +step:1371/1680 train_time:120847ms step_avg:88.15ms +step:1372/1680 train_time:120937ms step_avg:88.15ms +step:1373/1680 train_time:121025ms step_avg:88.15ms +step:1374/1680 train_time:121115ms step_avg:88.15ms +step:1375/1680 train_time:121204ms step_avg:88.15ms +step:1375/1680 val_loss:3.3422 train_time:121294ms step_avg:88.21ms +step:1376/1680 train_time:121314ms step_avg:88.16ms +step:1377/1680 train_time:121386ms step_avg:88.15ms +step:1378/1680 train_time:121477ms step_avg:88.15ms +step:1379/1680 train_time:121566ms step_avg:88.16ms +step:1380/1680 train_time:121654ms step_avg:88.16ms +step:1381/1680 train_time:121743ms step_avg:88.16ms +step:1382/1680 train_time:121832ms step_avg:88.16ms +step:1383/1680 train_time:121920ms step_avg:88.16ms +step:1384/1680 train_time:122009ms step_avg:88.16ms +step:1385/1680 train_time:122098ms step_avg:88.16ms +step:1386/1680 train_time:122186ms step_avg:88.16ms +step:1387/1680 train_time:122276ms step_avg:88.16ms +step:1388/1680 train_time:122367ms step_avg:88.16ms +step:1389/1680 train_time:122458ms step_avg:88.16ms +step:1390/1680 train_time:122547ms step_avg:88.16ms +step:1391/1680 train_time:122637ms step_avg:88.16ms +step:1392/1680 train_time:122726ms step_avg:88.17ms +step:1393/1680 train_time:122814ms step_avg:88.17ms +step:1394/1680 train_time:122902ms step_avg:88.17ms +step:1395/1680 train_time:122992ms step_avg:88.17ms +step:1396/1680 train_time:123081ms step_avg:88.17ms +step:1397/1680 train_time:123170ms step_avg:88.17ms +step:1398/1680 train_time:123260ms step_avg:88.17ms +step:1399/1680 train_time:123350ms step_avg:88.17ms +step:1400/1680 train_time:123440ms step_avg:88.17ms +step:1401/1680 train_time:123530ms step_avg:88.17ms +step:1402/1680 train_time:123619ms step_avg:88.17ms +step:1403/1680 train_time:123708ms step_avg:88.17ms +step:1404/1680 train_time:123796ms step_avg:88.17ms +step:1405/1680 train_time:123886ms step_avg:88.18ms +step:1406/1680 train_time:123974ms step_avg:88.17ms +step:1407/1680 train_time:124063ms step_avg:88.18ms +step:1408/1680 train_time:124152ms step_avg:88.18ms +step:1409/1680 train_time:124242ms step_avg:88.18ms +step:1410/1680 train_time:124332ms step_avg:88.18ms +step:1411/1680 train_time:124422ms step_avg:88.18ms +step:1412/1680 train_time:124511ms step_avg:88.18ms +step:1413/1680 train_time:124601ms step_avg:88.18ms +step:1414/1680 train_time:124690ms step_avg:88.18ms +step:1415/1680 train_time:124779ms step_avg:88.18ms +step:1416/1680 train_time:124869ms step_avg:88.18ms +step:1417/1680 train_time:124958ms step_avg:88.19ms +step:1418/1680 train_time:125046ms step_avg:88.19ms +step:1419/1680 train_time:125136ms step_avg:88.19ms +step:1420/1680 train_time:125226ms step_avg:88.19ms +step:1421/1680 train_time:125316ms step_avg:88.19ms +step:1422/1680 train_time:125406ms step_avg:88.19ms +step:1423/1680 train_time:125495ms step_avg:88.19ms +step:1424/1680 train_time:125585ms step_avg:88.19ms +step:1425/1680 train_time:125674ms step_avg:88.19ms +step:1426/1680 train_time:125764ms step_avg:88.19ms +step:1427/1680 train_time:125853ms step_avg:88.19ms +step:1428/1680 train_time:125942ms step_avg:88.20ms +step:1429/1680 train_time:126032ms step_avg:88.20ms +step:1430/1680 train_time:126121ms step_avg:88.20ms +step:1431/1680 train_time:126212ms step_avg:88.20ms +step:1432/1680 train_time:126302ms step_avg:88.20ms +step:1433/1680 train_time:126392ms step_avg:88.20ms +step:1434/1680 train_time:126481ms step_avg:88.20ms +step:1435/1680 train_time:126571ms step_avg:88.20ms +step:1436/1680 train_time:126659ms step_avg:88.20ms +step:1437/1680 train_time:126748ms step_avg:88.20ms +step:1438/1680 train_time:126837ms step_avg:88.20ms +step:1439/1680 train_time:126927ms step_avg:88.20ms +step:1440/1680 train_time:127016ms step_avg:88.21ms +step:1441/1680 train_time:127105ms step_avg:88.21ms +step:1442/1680 train_time:127195ms step_avg:88.21ms +step:1443/1680 train_time:127284ms step_avg:88.21ms +step:1444/1680 train_time:127373ms step_avg:88.21ms +step:1445/1680 train_time:127463ms step_avg:88.21ms +step:1446/1680 train_time:127554ms step_avg:88.21ms +step:1447/1680 train_time:127643ms step_avg:88.21ms +step:1448/1680 train_time:127734ms step_avg:88.21ms +step:1449/1680 train_time:127824ms step_avg:88.22ms +step:1450/1680 train_time:127914ms step_avg:88.22ms +step:1451/1680 train_time:128004ms step_avg:88.22ms +step:1452/1680 train_time:128093ms step_avg:88.22ms +step:1453/1680 train_time:128182ms step_avg:88.22ms +step:1454/1680 train_time:128271ms step_avg:88.22ms +step:1455/1680 train_time:128360ms step_avg:88.22ms +step:1456/1680 train_time:128450ms step_avg:88.22ms +step:1457/1680 train_time:128539ms step_avg:88.22ms +step:1458/1680 train_time:128628ms step_avg:88.22ms +step:1459/1680 train_time:128717ms step_avg:88.22ms +step:1460/1680 train_time:128807ms step_avg:88.22ms +step:1461/1680 train_time:128896ms step_avg:88.22ms +step:1462/1680 train_time:128986ms step_avg:88.23ms +step:1463/1680 train_time:129074ms step_avg:88.23ms +step:1464/1680 train_time:129164ms step_avg:88.23ms +step:1465/1680 train_time:129253ms step_avg:88.23ms +step:1466/1680 train_time:129342ms step_avg:88.23ms +step:1467/1680 train_time:129433ms step_avg:88.23ms +step:1468/1680 train_time:129522ms step_avg:88.23ms +step:1469/1680 train_time:129611ms step_avg:88.23ms +step:1470/1680 train_time:129700ms step_avg:88.23ms +step:1471/1680 train_time:129790ms step_avg:88.23ms +step:1472/1680 train_time:129879ms step_avg:88.23ms +step:1473/1680 train_time:129968ms step_avg:88.23ms +step:1474/1680 train_time:130057ms step_avg:88.23ms +step:1475/1680 train_time:130147ms step_avg:88.24ms +step:1476/1680 train_time:130236ms step_avg:88.24ms +step:1477/1680 train_time:130326ms step_avg:88.24ms +step:1478/1680 train_time:130415ms step_avg:88.24ms +step:1479/1680 train_time:130505ms step_avg:88.24ms +step:1480/1680 train_time:130595ms step_avg:88.24ms +step:1481/1680 train_time:130684ms step_avg:88.24ms +step:1482/1680 train_time:130773ms step_avg:88.24ms +step:1483/1680 train_time:130863ms step_avg:88.24ms +step:1484/1680 train_time:130953ms step_avg:88.24ms +step:1485/1680 train_time:131043ms step_avg:88.24ms +step:1486/1680 train_time:131132ms step_avg:88.24ms +step:1487/1680 train_time:131221ms step_avg:88.25ms +step:1488/1680 train_time:131310ms step_avg:88.25ms +step:1489/1680 train_time:131399ms step_avg:88.25ms +step:1490/1680 train_time:131489ms step_avg:88.25ms +step:1491/1680 train_time:131578ms step_avg:88.25ms +step:1492/1680 train_time:131668ms step_avg:88.25ms +step:1493/1680 train_time:131757ms step_avg:88.25ms +step:1494/1680 train_time:131846ms step_avg:88.25ms +step:1495/1680 train_time:131936ms step_avg:88.25ms +step:1496/1680 train_time:132025ms step_avg:88.25ms +step:1497/1680 train_time:132114ms step_avg:88.25ms +step:1498/1680 train_time:132202ms step_avg:88.25ms +step:1499/1680 train_time:132294ms step_avg:88.25ms +step:1500/1680 train_time:132384ms step_avg:88.26ms +step:1500/1680 val_loss:3.3130 train_time:132475ms step_avg:88.32ms +step:1501/1680 train_time:132493ms step_avg:88.27ms +step:1502/1680 train_time:132566ms step_avg:88.26ms +step:1503/1680 train_time:132658ms step_avg:88.26ms +step:1504/1680 train_time:132748ms step_avg:88.26ms +step:1505/1680 train_time:132836ms step_avg:88.26ms +step:1506/1680 train_time:132924ms step_avg:88.26ms +step:1507/1680 train_time:133012ms step_avg:88.26ms +step:1508/1680 train_time:133100ms step_avg:88.26ms +step:1509/1680 train_time:133188ms step_avg:88.26ms +step:1510/1680 train_time:133277ms step_avg:88.26ms +step:1511/1680 train_time:133366ms step_avg:88.26ms +step:1512/1680 train_time:133458ms step_avg:88.27ms +step:1513/1680 train_time:133550ms step_avg:88.27ms +step:1514/1680 train_time:133641ms step_avg:88.27ms +step:1515/1680 train_time:133732ms step_avg:88.27ms +step:1516/1680 train_time:133821ms step_avg:88.27ms +step:1517/1680 train_time:133909ms step_avg:88.27ms +step:1518/1680 train_time:133997ms step_avg:88.27ms +step:1519/1680 train_time:134086ms step_avg:88.27ms +step:1520/1680 train_time:134174ms step_avg:88.27ms +step:1521/1680 train_time:134262ms step_avg:88.27ms +step:1522/1680 train_time:134352ms step_avg:88.27ms +step:1523/1680 train_time:134442ms step_avg:88.27ms +step:1524/1680 train_time:134532ms step_avg:88.28ms +step:1525/1680 train_time:134623ms step_avg:88.28ms +step:1526/1680 train_time:134712ms step_avg:88.28ms +step:1527/1680 train_time:134802ms step_avg:88.28ms +step:1528/1680 train_time:134890ms step_avg:88.28ms +step:1529/1680 train_time:134979ms step_avg:88.28ms +step:1530/1680 train_time:135067ms step_avg:88.28ms +step:1531/1680 train_time:135156ms step_avg:88.28ms +step:1532/1680 train_time:135244ms step_avg:88.28ms +step:1533/1680 train_time:135333ms step_avg:88.28ms +step:1534/1680 train_time:135422ms step_avg:88.28ms +step:1535/1680 train_time:135511ms step_avg:88.28ms +step:1536/1680 train_time:135601ms step_avg:88.28ms +step:1537/1680 train_time:135692ms step_avg:88.28ms +step:1538/1680 train_time:135782ms step_avg:88.29ms +step:1539/1680 train_time:135871ms step_avg:88.29ms +step:1540/1680 train_time:135961ms step_avg:88.29ms +step:1541/1680 train_time:136050ms step_avg:88.29ms +step:1542/1680 train_time:136139ms step_avg:88.29ms +step:1543/1680 train_time:136228ms step_avg:88.29ms +step:1544/1680 train_time:136317ms step_avg:88.29ms +step:1545/1680 train_time:136406ms step_avg:88.29ms +step:1546/1680 train_time:136495ms step_avg:88.29ms +step:1547/1680 train_time:136584ms step_avg:88.29ms +step:1548/1680 train_time:136674ms step_avg:88.29ms +step:1549/1680 train_time:136764ms step_avg:88.29ms +step:1550/1680 train_time:136853ms step_avg:88.29ms +step:1551/1680 train_time:136942ms step_avg:88.29ms +step:1552/1680 train_time:137031ms step_avg:88.29ms +step:1553/1680 train_time:137120ms step_avg:88.29ms +step:1554/1680 train_time:137209ms step_avg:88.29ms +step:1555/1680 train_time:137298ms step_avg:88.29ms +step:1556/1680 train_time:137387ms step_avg:88.29ms +step:1557/1680 train_time:137476ms step_avg:88.30ms +step:1558/1680 train_time:137565ms step_avg:88.30ms +step:1559/1680 train_time:137654ms step_avg:88.30ms +step:1560/1680 train_time:137745ms step_avg:88.30ms +step:1561/1680 train_time:137834ms step_avg:88.30ms +step:1562/1680 train_time:137924ms step_avg:88.30ms +step:1563/1680 train_time:138013ms step_avg:88.30ms +step:1564/1680 train_time:138102ms step_avg:88.30ms +step:1565/1680 train_time:138191ms step_avg:88.30ms +step:1566/1680 train_time:138281ms step_avg:88.30ms +step:1567/1680 train_time:138371ms step_avg:88.30ms +step:1568/1680 train_time:138460ms step_avg:88.30ms +step:1569/1680 train_time:138549ms step_avg:88.30ms +step:1570/1680 train_time:138638ms step_avg:88.30ms +step:1571/1680 train_time:138729ms step_avg:88.31ms +step:1572/1680 train_time:138818ms step_avg:88.31ms +step:1573/1680 train_time:138909ms step_avg:88.31ms +step:1574/1680 train_time:138998ms step_avg:88.31ms +step:1575/1680 train_time:139087ms step_avg:88.31ms +step:1576/1680 train_time:139175ms step_avg:88.31ms +step:1577/1680 train_time:139265ms step_avg:88.31ms +step:1578/1680 train_time:139354ms step_avg:88.31ms +step:1579/1680 train_time:139444ms step_avg:88.31ms +step:1580/1680 train_time:139533ms step_avg:88.31ms +step:1581/1680 train_time:139622ms step_avg:88.31ms +step:1582/1680 train_time:139712ms step_avg:88.31ms +step:1583/1680 train_time:139802ms step_avg:88.31ms +step:1584/1680 train_time:139891ms step_avg:88.32ms +step:1585/1680 train_time:139980ms step_avg:88.32ms +step:1586/1680 train_time:140069ms step_avg:88.32ms +step:1587/1680 train_time:140158ms step_avg:88.32ms +step:1588/1680 train_time:140248ms step_avg:88.32ms +step:1589/1680 train_time:140339ms step_avg:88.32ms +step:1590/1680 train_time:140428ms step_avg:88.32ms +step:1591/1680 train_time:140516ms step_avg:88.32ms +step:1592/1680 train_time:140606ms step_avg:88.32ms +step:1593/1680 train_time:140695ms step_avg:88.32ms +step:1594/1680 train_time:140784ms step_avg:88.32ms +step:1595/1680 train_time:140873ms step_avg:88.32ms +step:1596/1680 train_time:140962ms step_avg:88.32ms +step:1597/1680 train_time:141052ms step_avg:88.32ms +step:1598/1680 train_time:141141ms step_avg:88.32ms +step:1599/1680 train_time:141231ms step_avg:88.32ms +step:1600/1680 train_time:141321ms step_avg:88.33ms +step:1601/1680 train_time:141411ms step_avg:88.33ms +step:1602/1680 train_time:141501ms step_avg:88.33ms +step:1603/1680 train_time:141590ms step_avg:88.33ms +step:1604/1680 train_time:141680ms step_avg:88.33ms +step:1605/1680 train_time:141769ms step_avg:88.33ms +step:1606/1680 train_time:141858ms step_avg:88.33ms +step:1607/1680 train_time:141947ms step_avg:88.33ms +step:1608/1680 train_time:142037ms step_avg:88.33ms +step:1609/1680 train_time:142127ms step_avg:88.33ms +step:1610/1680 train_time:142217ms step_avg:88.33ms +step:1611/1680 train_time:142306ms step_avg:88.33ms +step:1612/1680 train_time:142395ms step_avg:88.33ms +step:1613/1680 train_time:142485ms step_avg:88.34ms +step:1614/1680 train_time:142575ms step_avg:88.34ms +step:1615/1680 train_time:142663ms step_avg:88.34ms +step:1616/1680 train_time:142752ms step_avg:88.34ms +step:1617/1680 train_time:142843ms step_avg:88.34ms +step:1618/1680 train_time:142932ms step_avg:88.34ms +step:1619/1680 train_time:143021ms step_avg:88.34ms +step:1620/1680 train_time:143109ms step_avg:88.34ms +step:1621/1680 train_time:143200ms step_avg:88.34ms +step:1622/1680 train_time:143289ms step_avg:88.34ms +step:1623/1680 train_time:143380ms step_avg:88.34ms +step:1624/1680 train_time:143470ms step_avg:88.34ms +step:1625/1680 train_time:143560ms step_avg:88.34ms +step:1625/1680 val_loss:3.2892 train_time:143650ms step_avg:88.40ms +step:1626/1680 train_time:143668ms step_avg:88.36ms +step:1627/1680 train_time:143743ms step_avg:88.35ms +step:1628/1680 train_time:143837ms step_avg:88.35ms +step:1629/1680 train_time:143928ms step_avg:88.35ms +step:1630/1680 train_time:144016ms step_avg:88.35ms +step:1631/1680 train_time:144105ms step_avg:88.35ms +step:1632/1680 train_time:144193ms step_avg:88.35ms +step:1633/1680 train_time:144281ms step_avg:88.35ms +step:1634/1680 train_time:144370ms step_avg:88.35ms +step:1635/1680 train_time:144459ms step_avg:88.35ms +step:1636/1680 train_time:144547ms step_avg:88.35ms +step:1637/1680 train_time:144637ms step_avg:88.35ms +step:1638/1680 train_time:144729ms step_avg:88.36ms +step:1639/1680 train_time:144820ms step_avg:88.36ms +step:1640/1680 train_time:144910ms step_avg:88.36ms +step:1641/1680 train_time:144999ms step_avg:88.36ms +step:1642/1680 train_time:145088ms step_avg:88.36ms +step:1643/1680 train_time:145177ms step_avg:88.36ms +step:1644/1680 train_time:145265ms step_avg:88.36ms +step:1645/1680 train_time:145354ms step_avg:88.36ms +step:1646/1680 train_time:145442ms step_avg:88.36ms +step:1647/1680 train_time:145531ms step_avg:88.36ms +step:1648/1680 train_time:145621ms step_avg:88.36ms +step:1649/1680 train_time:145711ms step_avg:88.36ms +step:1650/1680 train_time:145802ms step_avg:88.37ms +step:1651/1680 train_time:145893ms step_avg:88.37ms +step:1652/1680 train_time:145982ms step_avg:88.37ms +step:1653/1680 train_time:146071ms step_avg:88.37ms +step:1654/1680 train_time:146160ms step_avg:88.37ms +step:1655/1680 train_time:146249ms step_avg:88.37ms +step:1656/1680 train_time:146337ms step_avg:88.37ms +step:1657/1680 train_time:146426ms step_avg:88.37ms +step:1658/1680 train_time:146516ms step_avg:88.37ms +step:1659/1680 train_time:146606ms step_avg:88.37ms +step:1660/1680 train_time:146694ms step_avg:88.37ms +step:1661/1680 train_time:146784ms step_avg:88.37ms +step:1662/1680 train_time:146874ms step_avg:88.37ms +step:1663/1680 train_time:146966ms step_avg:88.37ms +step:1664/1680 train_time:147055ms step_avg:88.37ms +step:1665/1680 train_time:147143ms step_avg:88.37ms +step:1666/1680 train_time:147232ms step_avg:88.37ms +step:1667/1680 train_time:147321ms step_avg:88.37ms +step:1668/1680 train_time:147409ms step_avg:88.37ms +step:1669/1680 train_time:147499ms step_avg:88.38ms +step:1670/1680 train_time:147588ms step_avg:88.38ms +step:1671/1680 train_time:147678ms step_avg:88.38ms +step:1672/1680 train_time:147768ms step_avg:88.38ms +step:1673/1680 train_time:147858ms step_avg:88.38ms +step:1674/1680 train_time:147948ms step_avg:88.38ms +step:1675/1680 train_time:148037ms step_avg:88.38ms +step:1676/1680 train_time:148127ms step_avg:88.38ms +step:1677/1680 train_time:148215ms step_avg:88.38ms +step:1678/1680 train_time:148305ms step_avg:88.38ms +step:1679/1680 train_time:148393ms step_avg:88.38ms +step:1680/1680 train_time:148482ms step_avg:88.38ms +step:1680/1680 val_loss:3.2785 train_time:148573ms step_avg:88.44ms +peak memory allocated: 30760 MiB reserved: 46354 MiB diff --git a/records/092725_BF16CE/82d579bc-45e2-4600-8436-7d425016e9b3.txt b/records/092725_BF16CE/82d579bc-45e2-4600-8436-7d425016e9b3.txt new file mode 100644 index 000000000..d452a6530 --- /dev/null +++ b/records/092725_BF16CE/82d579bc-45e2-4600-8436-7d425016e9b3.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:04:08 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 25C P0 119W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 24C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 25C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 26C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 26C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 150803 C /usr/bin/python 0MiB | +| 0 N/A N/A 150804 C /usr/bin/python 0MiB | +| 0 N/A N/A 150805 C /usr/bin/python 0MiB | +| 0 N/A N/A 150806 C /usr/bin/python 0MiB | +| 0 N/A N/A 150807 C /usr/bin/python 0MiB | +| 0 N/A N/A 150808 C /usr/bin/python 0MiB | +| 0 N/A N/A 150809 C /usr/bin/python 0MiB | +| 0 N/A N/A 150810 C /usr/bin/python 0MiB | +| 1 N/A N/A 150804 C /usr/bin/python 0MiB | +| 2 N/A N/A 150805 C /usr/bin/python 0MiB | +| 3 N/A N/A 150806 C /usr/bin/python 0MiB | +| 4 N/A N/A 150807 C /usr/bin/python 0MiB | +| 5 N/A N/A 150808 C /usr/bin/python 0MiB | +| 6 N/A N/A 150809 C /usr/bin/python 0MiB | +| 7 N/A N/A 150810 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:146ms step_avg:146.04ms +step:2/1680 train_time:166ms step_avg:82.79ms +step:3/1680 train_time:230ms step_avg:76.51ms +step:4/1680 train_time:315ms step_avg:78.72ms +step:5/1680 train_time:401ms step_avg:80.17ms +step:6/1680 train_time:487ms step_avg:81.17ms +step:7/1680 train_time:573ms step_avg:81.85ms +step:8/1680 train_time:659ms step_avg:82.38ms +step:9/1680 train_time:745ms step_avg:82.79ms +step:10/1680 train_time:831ms step_avg:83.14ms +step:11/1680 train_time:918ms step_avg:83.42ms +step:12/1680 train_time:1004ms step_avg:83.71ms +step:13/1680 train_time:1094ms step_avg:84.16ms +step:14/1680 train_time:1184ms step_avg:84.55ms +step:15/1680 train_time:1274ms step_avg:84.92ms +step:16/1680 train_time:1361ms step_avg:85.07ms +step:17/1680 train_time:1448ms step_avg:85.17ms +step:18/1680 train_time:1535ms step_avg:85.25ms +step:19/1680 train_time:1621ms step_avg:85.32ms +step:20/1680 train_time:1708ms step_avg:85.39ms +step:21/1680 train_time:1795ms step_avg:85.46ms +step:22/1680 train_time:1881ms step_avg:85.49ms +step:23/1680 train_time:1967ms step_avg:85.54ms +step:24/1680 train_time:2055ms step_avg:85.62ms +step:25/1680 train_time:2144ms step_avg:85.74ms +step:26/1680 train_time:2232ms step_avg:85.85ms +step:27/1680 train_time:2321ms step_avg:85.95ms +step:28/1680 train_time:2409ms step_avg:86.02ms +step:29/1680 train_time:2496ms step_avg:86.06ms +step:30/1680 train_time:2583ms step_avg:86.11ms +step:31/1680 train_time:2670ms step_avg:86.14ms +step:32/1680 train_time:2757ms step_avg:86.16ms +step:33/1680 train_time:2844ms step_avg:86.19ms +step:34/1680 train_time:2931ms step_avg:86.21ms +step:35/1680 train_time:3018ms step_avg:86.24ms +step:36/1680 train_time:3107ms step_avg:86.29ms +step:37/1680 train_time:3195ms step_avg:86.34ms +step:38/1680 train_time:3283ms step_avg:86.39ms +step:39/1680 train_time:3371ms step_avg:86.44ms +step:40/1680 train_time:3458ms step_avg:86.46ms +step:41/1680 train_time:3545ms step_avg:86.47ms +step:42/1680 train_time:3633ms step_avg:86.50ms +step:43/1680 train_time:3720ms step_avg:86.50ms +step:44/1680 train_time:3806ms step_avg:86.50ms +step:45/1680 train_time:3893ms step_avg:86.51ms +step:46/1680 train_time:3980ms step_avg:86.52ms +step:47/1680 train_time:4067ms step_avg:86.54ms +step:48/1680 train_time:4155ms step_avg:86.56ms +step:49/1680 train_time:4243ms step_avg:86.58ms +step:50/1680 train_time:4330ms step_avg:86.61ms +step:51/1680 train_time:4417ms step_avg:86.61ms +step:52/1680 train_time:4504ms step_avg:86.62ms +step:53/1680 train_time:4592ms step_avg:86.64ms +step:54/1680 train_time:4680ms step_avg:86.66ms +step:55/1680 train_time:4766ms step_avg:86.66ms +step:56/1680 train_time:4853ms step_avg:86.66ms +step:57/1680 train_time:4940ms step_avg:86.66ms +step:58/1680 train_time:5027ms step_avg:86.67ms +step:59/1680 train_time:5116ms step_avg:86.70ms +step:60/1680 train_time:5203ms step_avg:86.71ms +step:61/1680 train_time:5290ms step_avg:86.72ms +step:62/1680 train_time:5378ms step_avg:86.74ms +step:63/1680 train_time:5465ms step_avg:86.75ms +step:64/1680 train_time:5552ms step_avg:86.75ms +step:65/1680 train_time:5639ms step_avg:86.76ms +step:66/1680 train_time:5726ms step_avg:86.76ms +step:67/1680 train_time:5813ms step_avg:86.77ms +step:68/1680 train_time:5900ms step_avg:86.77ms +step:69/1680 train_time:5987ms step_avg:86.76ms +step:70/1680 train_time:6074ms step_avg:86.77ms +step:71/1680 train_time:6162ms step_avg:86.79ms +step:72/1680 train_time:6250ms step_avg:86.81ms +step:73/1680 train_time:6337ms step_avg:86.81ms +step:74/1680 train_time:6424ms step_avg:86.81ms +step:75/1680 train_time:6511ms step_avg:86.82ms +step:76/1680 train_time:6598ms step_avg:86.82ms +step:77/1680 train_time:6685ms step_avg:86.82ms +step:78/1680 train_time:6772ms step_avg:86.82ms +step:79/1680 train_time:6859ms step_avg:86.83ms +step:80/1680 train_time:6946ms step_avg:86.83ms +step:81/1680 train_time:7035ms step_avg:86.85ms +step:82/1680 train_time:7121ms step_avg:86.85ms +step:83/1680 train_time:7209ms step_avg:86.85ms +step:84/1680 train_time:7297ms step_avg:86.86ms +step:85/1680 train_time:7384ms step_avg:86.86ms +step:86/1680 train_time:7470ms step_avg:86.86ms +step:87/1680 train_time:7558ms step_avg:86.87ms +step:88/1680 train_time:7645ms step_avg:86.87ms +step:89/1680 train_time:7732ms step_avg:86.87ms +step:90/1680 train_time:7819ms step_avg:86.88ms +step:91/1680 train_time:7906ms step_avg:86.88ms +step:92/1680 train_time:7993ms step_avg:86.88ms +step:93/1680 train_time:8080ms step_avg:86.88ms +step:94/1680 train_time:8166ms step_avg:86.88ms +step:95/1680 train_time:8254ms step_avg:86.88ms +step:96/1680 train_time:8340ms step_avg:86.87ms +step:97/1680 train_time:8426ms step_avg:86.87ms +step:98/1680 train_time:8514ms step_avg:86.88ms +step:99/1680 train_time:8603ms step_avg:86.90ms +step:100/1680 train_time:8690ms step_avg:86.90ms +step:101/1680 train_time:8777ms step_avg:86.90ms +step:102/1680 train_time:8863ms step_avg:86.90ms +step:103/1680 train_time:8950ms step_avg:86.90ms +step:104/1680 train_time:9037ms step_avg:86.90ms +step:105/1680 train_time:9124ms step_avg:86.90ms +step:106/1680 train_time:9212ms step_avg:86.91ms +step:107/1680 train_time:9299ms step_avg:86.90ms +step:108/1680 train_time:9386ms step_avg:86.91ms +step:109/1680 train_time:9473ms step_avg:86.91ms +step:110/1680 train_time:9560ms step_avg:86.91ms +step:111/1680 train_time:9647ms step_avg:86.91ms +step:112/1680 train_time:9734ms step_avg:86.91ms +step:113/1680 train_time:9820ms step_avg:86.91ms +step:114/1680 train_time:9907ms step_avg:86.91ms +step:115/1680 train_time:9994ms step_avg:86.91ms +step:116/1680 train_time:10081ms step_avg:86.91ms +step:117/1680 train_time:10169ms step_avg:86.91ms +step:118/1680 train_time:10256ms step_avg:86.91ms +step:119/1680 train_time:10343ms step_avg:86.92ms +step:120/1680 train_time:10431ms step_avg:86.92ms +step:121/1680 train_time:10518ms step_avg:86.93ms +step:122/1680 train_time:10606ms step_avg:86.93ms +step:123/1680 train_time:10693ms step_avg:86.93ms +step:124/1680 train_time:10780ms step_avg:86.94ms +step:125/1680 train_time:10867ms step_avg:86.94ms +step:125/1680 val_loss:4.3408 train_time:10955ms step_avg:87.64ms +step:126/1680 train_time:10976ms step_avg:87.11ms +step:127/1680 train_time:11043ms step_avg:86.95ms +step:128/1680 train_time:11142ms step_avg:87.04ms +step:129/1680 train_time:11235ms step_avg:87.09ms +step:130/1680 train_time:11322ms step_avg:87.09ms +step:131/1680 train_time:11408ms step_avg:87.08ms +step:132/1680 train_time:11494ms step_avg:87.07ms +step:133/1680 train_time:11580ms step_avg:87.06ms +step:134/1680 train_time:11665ms step_avg:87.06ms +step:135/1680 train_time:11751ms step_avg:87.04ms +step:136/1680 train_time:11837ms step_avg:87.04ms +step:137/1680 train_time:11923ms step_avg:87.03ms +step:138/1680 train_time:12010ms step_avg:87.03ms +step:139/1680 train_time:12100ms step_avg:87.05ms +step:140/1680 train_time:12191ms step_avg:87.08ms +step:141/1680 train_time:12279ms step_avg:87.09ms +step:142/1680 train_time:12367ms step_avg:87.09ms +step:143/1680 train_time:12453ms step_avg:87.08ms +step:144/1680 train_time:12539ms step_avg:87.08ms +step:145/1680 train_time:12626ms step_avg:87.08ms +step:146/1680 train_time:12712ms step_avg:87.07ms +step:147/1680 train_time:12798ms step_avg:87.06ms +step:148/1680 train_time:12884ms step_avg:87.05ms +step:149/1680 train_time:12971ms step_avg:87.05ms +step:150/1680 train_time:13060ms step_avg:87.06ms +step:151/1680 train_time:13148ms step_avg:87.07ms +step:152/1680 train_time:13236ms step_avg:87.08ms +step:153/1680 train_time:13324ms step_avg:87.08ms +step:154/1680 train_time:13412ms step_avg:87.09ms +step:155/1680 train_time:13499ms step_avg:87.09ms +step:156/1680 train_time:13585ms step_avg:87.08ms +step:157/1680 train_time:13671ms step_avg:87.08ms +step:158/1680 train_time:13757ms step_avg:87.07ms +step:159/1680 train_time:13844ms step_avg:87.07ms +step:160/1680 train_time:13931ms step_avg:87.07ms +step:161/1680 train_time:14019ms step_avg:87.07ms +step:162/1680 train_time:14107ms step_avg:87.08ms +step:163/1680 train_time:14195ms step_avg:87.08ms +step:164/1680 train_time:14282ms step_avg:87.08ms +step:165/1680 train_time:14369ms step_avg:87.09ms +step:166/1680 train_time:14457ms step_avg:87.09ms +step:167/1680 train_time:14545ms step_avg:87.09ms +step:168/1680 train_time:14631ms step_avg:87.09ms +step:169/1680 train_time:14717ms step_avg:87.08ms +step:170/1680 train_time:14803ms step_avg:87.08ms +step:171/1680 train_time:14889ms step_avg:87.07ms +step:172/1680 train_time:14976ms step_avg:87.07ms +step:173/1680 train_time:15063ms step_avg:87.07ms +step:174/1680 train_time:15152ms step_avg:87.08ms +step:175/1680 train_time:15240ms step_avg:87.09ms +step:176/1680 train_time:15328ms step_avg:87.09ms +step:177/1680 train_time:15416ms step_avg:87.10ms +step:178/1680 train_time:15503ms step_avg:87.10ms +step:179/1680 train_time:15590ms step_avg:87.09ms +step:180/1680 train_time:15676ms step_avg:87.09ms +step:181/1680 train_time:15763ms step_avg:87.09ms +step:182/1680 train_time:15849ms step_avg:87.08ms +step:183/1680 train_time:15936ms step_avg:87.08ms +step:184/1680 train_time:16023ms step_avg:87.08ms +step:185/1680 train_time:16110ms step_avg:87.08ms +step:186/1680 train_time:16197ms step_avg:87.08ms +step:187/1680 train_time:16285ms step_avg:87.09ms +step:188/1680 train_time:16373ms step_avg:87.09ms +step:189/1680 train_time:16460ms step_avg:87.09ms +step:190/1680 train_time:16547ms step_avg:87.09ms +step:191/1680 train_time:16634ms step_avg:87.09ms +step:192/1680 train_time:16721ms step_avg:87.09ms +step:193/1680 train_time:16808ms step_avg:87.09ms +step:194/1680 train_time:16894ms step_avg:87.08ms +step:195/1680 train_time:16982ms step_avg:87.09ms +step:196/1680 train_time:17068ms step_avg:87.08ms +step:197/1680 train_time:17155ms step_avg:87.08ms +step:198/1680 train_time:17242ms step_avg:87.08ms +step:199/1680 train_time:17330ms step_avg:87.08ms +step:200/1680 train_time:17416ms step_avg:87.08ms +step:201/1680 train_time:17503ms step_avg:87.08ms +step:202/1680 train_time:17591ms step_avg:87.08ms +step:203/1680 train_time:17678ms step_avg:87.08ms +step:204/1680 train_time:17765ms step_avg:87.08ms +step:205/1680 train_time:17851ms step_avg:87.08ms +step:206/1680 train_time:17938ms step_avg:87.08ms +step:207/1680 train_time:18025ms step_avg:87.08ms +step:208/1680 train_time:18112ms step_avg:87.08ms +step:209/1680 train_time:18200ms step_avg:87.08ms +step:210/1680 train_time:18287ms step_avg:87.08ms +step:211/1680 train_time:18375ms step_avg:87.08ms +step:212/1680 train_time:18462ms step_avg:87.08ms +step:213/1680 train_time:18549ms step_avg:87.08ms +step:214/1680 train_time:18636ms step_avg:87.08ms +step:215/1680 train_time:18723ms step_avg:87.08ms +step:216/1680 train_time:18810ms step_avg:87.08ms +step:217/1680 train_time:18896ms step_avg:87.08ms +step:218/1680 train_time:18983ms step_avg:87.08ms +step:219/1680 train_time:19070ms step_avg:87.08ms +step:220/1680 train_time:19157ms step_avg:87.08ms +step:221/1680 train_time:19244ms step_avg:87.08ms +step:222/1680 train_time:19330ms step_avg:87.07ms +step:223/1680 train_time:19418ms step_avg:87.07ms +step:224/1680 train_time:19505ms step_avg:87.08ms +step:225/1680 train_time:19592ms step_avg:87.08ms +step:226/1680 train_time:19679ms step_avg:87.08ms +step:227/1680 train_time:19766ms step_avg:87.08ms +step:228/1680 train_time:19853ms step_avg:87.08ms +step:229/1680 train_time:19940ms step_avg:87.07ms +step:230/1680 train_time:20027ms step_avg:87.07ms +step:231/1680 train_time:20114ms step_avg:87.07ms +step:232/1680 train_time:20201ms step_avg:87.07ms +step:233/1680 train_time:20288ms step_avg:87.07ms +step:234/1680 train_time:20374ms step_avg:87.07ms +step:235/1680 train_time:20462ms step_avg:87.07ms +step:236/1680 train_time:20549ms step_avg:87.07ms +step:237/1680 train_time:20636ms step_avg:87.07ms +step:238/1680 train_time:20723ms step_avg:87.07ms +step:239/1680 train_time:20810ms step_avg:87.07ms +step:240/1680 train_time:20897ms step_avg:87.07ms +step:241/1680 train_time:20983ms step_avg:87.07ms +step:242/1680 train_time:21070ms step_avg:87.07ms +step:243/1680 train_time:21157ms step_avg:87.07ms +step:244/1680 train_time:21245ms step_avg:87.07ms +step:245/1680 train_time:21332ms step_avg:87.07ms +step:246/1680 train_time:21419ms step_avg:87.07ms +step:247/1680 train_time:21506ms step_avg:87.07ms +step:248/1680 train_time:21592ms step_avg:87.07ms +step:249/1680 train_time:21680ms step_avg:87.07ms +step:250/1680 train_time:21767ms step_avg:87.07ms +step:250/1680 val_loss:3.9769 train_time:21855ms step_avg:87.42ms +step:251/1680 train_time:21874ms step_avg:87.15ms +step:252/1680 train_time:21944ms step_avg:87.08ms +step:253/1680 train_time:22036ms step_avg:87.10ms +step:254/1680 train_time:22124ms step_avg:87.10ms +step:255/1680 train_time:22211ms step_avg:87.10ms +step:256/1680 train_time:22297ms step_avg:87.10ms +step:257/1680 train_time:22383ms step_avg:87.09ms +step:258/1680 train_time:22469ms step_avg:87.09ms +step:259/1680 train_time:22555ms step_avg:87.09ms +step:260/1680 train_time:22642ms step_avg:87.08ms +step:261/1680 train_time:22728ms step_avg:87.08ms +step:262/1680 train_time:22816ms step_avg:87.09ms +step:263/1680 train_time:22904ms step_avg:87.09ms +step:264/1680 train_time:22994ms step_avg:87.10ms +step:265/1680 train_time:23083ms step_avg:87.11ms +step:266/1680 train_time:23171ms step_avg:87.11ms +step:267/1680 train_time:23257ms step_avg:87.10ms +step:268/1680 train_time:23344ms step_avg:87.10ms +step:269/1680 train_time:23430ms step_avg:87.10ms +step:270/1680 train_time:23516ms step_avg:87.10ms +step:271/1680 train_time:23602ms step_avg:87.09ms +step:272/1680 train_time:23689ms step_avg:87.09ms +step:273/1680 train_time:23776ms step_avg:87.09ms +step:274/1680 train_time:23863ms step_avg:87.09ms +step:275/1680 train_time:23951ms step_avg:87.10ms +step:276/1680 train_time:24039ms step_avg:87.10ms +step:277/1680 train_time:24126ms step_avg:87.10ms +step:278/1680 train_time:24214ms step_avg:87.10ms +step:279/1680 train_time:24300ms step_avg:87.10ms +step:280/1680 train_time:24387ms step_avg:87.10ms +step:281/1680 train_time:24474ms step_avg:87.09ms +step:282/1680 train_time:24560ms step_avg:87.09ms +step:283/1680 train_time:24647ms step_avg:87.09ms +step:284/1680 train_time:24734ms step_avg:87.09ms +step:285/1680 train_time:24821ms step_avg:87.09ms +step:286/1680 train_time:24909ms step_avg:87.09ms +step:287/1680 train_time:24997ms step_avg:87.10ms +step:288/1680 train_time:25084ms step_avg:87.10ms +step:289/1680 train_time:25171ms step_avg:87.10ms +step:290/1680 train_time:25258ms step_avg:87.09ms +step:291/1680 train_time:25345ms step_avg:87.10ms +step:292/1680 train_time:25432ms step_avg:87.10ms +step:293/1680 train_time:25518ms step_avg:87.09ms +step:294/1680 train_time:25605ms step_avg:87.09ms +step:295/1680 train_time:25691ms step_avg:87.09ms +step:296/1680 train_time:25779ms step_avg:87.09ms +step:297/1680 train_time:25866ms step_avg:87.09ms +step:298/1680 train_time:25954ms step_avg:87.09ms +step:299/1680 train_time:26041ms step_avg:87.09ms +step:300/1680 train_time:26129ms step_avg:87.10ms +step:301/1680 train_time:26216ms step_avg:87.10ms +step:302/1680 train_time:26302ms step_avg:87.09ms +step:303/1680 train_time:26389ms step_avg:87.09ms +step:304/1680 train_time:26476ms step_avg:87.09ms +step:305/1680 train_time:26562ms step_avg:87.09ms +step:306/1680 train_time:26649ms step_avg:87.09ms +step:307/1680 train_time:26736ms step_avg:87.09ms +step:308/1680 train_time:26823ms step_avg:87.09ms +step:309/1680 train_time:26910ms step_avg:87.09ms +step:310/1680 train_time:26998ms step_avg:87.09ms +step:311/1680 train_time:27085ms step_avg:87.09ms +step:312/1680 train_time:27172ms step_avg:87.09ms +step:313/1680 train_time:27260ms step_avg:87.09ms +step:314/1680 train_time:27347ms step_avg:87.09ms +step:315/1680 train_time:27434ms step_avg:87.09ms +step:316/1680 train_time:27521ms step_avg:87.09ms +step:317/1680 train_time:27608ms step_avg:87.09ms +step:318/1680 train_time:27694ms step_avg:87.09ms +step:319/1680 train_time:27782ms step_avg:87.09ms +step:320/1680 train_time:27869ms step_avg:87.09ms +step:321/1680 train_time:27956ms step_avg:87.09ms +step:322/1680 train_time:28043ms step_avg:87.09ms +step:323/1680 train_time:28130ms step_avg:87.09ms +step:324/1680 train_time:28218ms step_avg:87.09ms +step:325/1680 train_time:28305ms step_avg:87.09ms +step:326/1680 train_time:28392ms step_avg:87.09ms +step:327/1680 train_time:28479ms step_avg:87.09ms +step:328/1680 train_time:28566ms step_avg:87.09ms +step:329/1680 train_time:28653ms step_avg:87.09ms +step:330/1680 train_time:28739ms step_avg:87.09ms +step:331/1680 train_time:28827ms step_avg:87.09ms +step:332/1680 train_time:28914ms step_avg:87.09ms +step:333/1680 train_time:29001ms step_avg:87.09ms +step:334/1680 train_time:29087ms step_avg:87.09ms +step:335/1680 train_time:29175ms step_avg:87.09ms +step:336/1680 train_time:29262ms step_avg:87.09ms +step:337/1680 train_time:29350ms step_avg:87.09ms +step:338/1680 train_time:29437ms step_avg:87.09ms +step:339/1680 train_time:29523ms step_avg:87.09ms +step:340/1680 train_time:29609ms step_avg:87.09ms +step:341/1680 train_time:29696ms step_avg:87.09ms +step:342/1680 train_time:29784ms step_avg:87.09ms +step:343/1680 train_time:29871ms step_avg:87.09ms +step:344/1680 train_time:29959ms step_avg:87.09ms +step:345/1680 train_time:30047ms step_avg:87.09ms +step:346/1680 train_time:30133ms step_avg:87.09ms +step:347/1680 train_time:30221ms step_avg:87.09ms +step:348/1680 train_time:30308ms step_avg:87.09ms +step:349/1680 train_time:30396ms step_avg:87.09ms +step:350/1680 train_time:30483ms step_avg:87.09ms +step:351/1680 train_time:30570ms step_avg:87.09ms +step:352/1680 train_time:30656ms step_avg:87.09ms +step:353/1680 train_time:30743ms step_avg:87.09ms +step:354/1680 train_time:30831ms step_avg:87.09ms +step:355/1680 train_time:30918ms step_avg:87.09ms +step:356/1680 train_time:31005ms step_avg:87.09ms +step:357/1680 train_time:31093ms step_avg:87.09ms +step:358/1680 train_time:31180ms step_avg:87.10ms +step:359/1680 train_time:31267ms step_avg:87.10ms +step:360/1680 train_time:31354ms step_avg:87.09ms +step:361/1680 train_time:31441ms step_avg:87.09ms +step:362/1680 train_time:31528ms step_avg:87.09ms +step:363/1680 train_time:31615ms step_avg:87.09ms +step:364/1680 train_time:31702ms step_avg:87.09ms +step:365/1680 train_time:31789ms step_avg:87.09ms +step:366/1680 train_time:31876ms step_avg:87.09ms +step:367/1680 train_time:31962ms step_avg:87.09ms +step:368/1680 train_time:32050ms step_avg:87.09ms +step:369/1680 train_time:32137ms step_avg:87.09ms +step:370/1680 train_time:32224ms step_avg:87.09ms +step:371/1680 train_time:32311ms step_avg:87.09ms +step:372/1680 train_time:32398ms step_avg:87.09ms +step:373/1680 train_time:32485ms step_avg:87.09ms +step:374/1680 train_time:32573ms step_avg:87.09ms +step:375/1680 train_time:32660ms step_avg:87.09ms +step:375/1680 val_loss:3.8239 train_time:32748ms step_avg:87.33ms +step:376/1680 train_time:32768ms step_avg:87.15ms +step:377/1680 train_time:32837ms step_avg:87.10ms +step:378/1680 train_time:32927ms step_avg:87.11ms +step:379/1680 train_time:33017ms step_avg:87.12ms +step:380/1680 train_time:33103ms step_avg:87.11ms +step:381/1680 train_time:33190ms step_avg:87.11ms +step:382/1680 train_time:33276ms step_avg:87.11ms +step:383/1680 train_time:33362ms step_avg:87.11ms +step:384/1680 train_time:33449ms step_avg:87.11ms +step:385/1680 train_time:33535ms step_avg:87.10ms +step:386/1680 train_time:33621ms step_avg:87.10ms +step:387/1680 train_time:33708ms step_avg:87.10ms +step:388/1680 train_time:33796ms step_avg:87.10ms +step:389/1680 train_time:33885ms step_avg:87.11ms +step:390/1680 train_time:33974ms step_avg:87.11ms +step:391/1680 train_time:34061ms step_avg:87.11ms +step:392/1680 train_time:34149ms step_avg:87.11ms +step:393/1680 train_time:34235ms step_avg:87.11ms +step:394/1680 train_time:34322ms step_avg:87.11ms +step:395/1680 train_time:34408ms step_avg:87.11ms +step:396/1680 train_time:34494ms step_avg:87.11ms +step:397/1680 train_time:34581ms step_avg:87.10ms +step:398/1680 train_time:34667ms step_avg:87.10ms +step:399/1680 train_time:34755ms step_avg:87.10ms +step:400/1680 train_time:34842ms step_avg:87.11ms +step:401/1680 train_time:34930ms step_avg:87.11ms +step:402/1680 train_time:35018ms step_avg:87.11ms +step:403/1680 train_time:35105ms step_avg:87.11ms +step:404/1680 train_time:35192ms step_avg:87.11ms +step:405/1680 train_time:35279ms step_avg:87.11ms +step:406/1680 train_time:35365ms step_avg:87.11ms +step:407/1680 train_time:35451ms step_avg:87.10ms +step:408/1680 train_time:35538ms step_avg:87.10ms +step:409/1680 train_time:35625ms step_avg:87.10ms +step:410/1680 train_time:35712ms step_avg:87.10ms +step:411/1680 train_time:35799ms step_avg:87.10ms +step:412/1680 train_time:35886ms step_avg:87.10ms +step:413/1680 train_time:35974ms step_avg:87.11ms +step:414/1680 train_time:36062ms step_avg:87.11ms +step:415/1680 train_time:36149ms step_avg:87.11ms +step:416/1680 train_time:36236ms step_avg:87.11ms +step:417/1680 train_time:36323ms step_avg:87.10ms +step:418/1680 train_time:36410ms step_avg:87.10ms +step:419/1680 train_time:36496ms step_avg:87.10ms +step:420/1680 train_time:36583ms step_avg:87.10ms +step:421/1680 train_time:36670ms step_avg:87.10ms +step:422/1680 train_time:36758ms step_avg:87.10ms +step:423/1680 train_time:36845ms step_avg:87.10ms +step:424/1680 train_time:36933ms step_avg:87.11ms +step:425/1680 train_time:37019ms step_avg:87.10ms +step:426/1680 train_time:37107ms step_avg:87.10ms +step:427/1680 train_time:37194ms step_avg:87.11ms +step:428/1680 train_time:37281ms step_avg:87.11ms +step:429/1680 train_time:37368ms step_avg:87.11ms +step:430/1680 train_time:37455ms step_avg:87.10ms +step:431/1680 train_time:37542ms step_avg:87.10ms +step:432/1680 train_time:37629ms step_avg:87.10ms +step:433/1680 train_time:37715ms step_avg:87.10ms +step:434/1680 train_time:37802ms step_avg:87.10ms +step:435/1680 train_time:37889ms step_avg:87.10ms +step:436/1680 train_time:37977ms step_avg:87.10ms +step:437/1680 train_time:38064ms step_avg:87.10ms +step:438/1680 train_time:38151ms step_avg:87.10ms +step:439/1680 train_time:38238ms step_avg:87.10ms +step:440/1680 train_time:38325ms step_avg:87.10ms +step:441/1680 train_time:38412ms step_avg:87.10ms +step:442/1680 train_time:38499ms step_avg:87.10ms +step:443/1680 train_time:38586ms step_avg:87.10ms +step:444/1680 train_time:38672ms step_avg:87.10ms +step:445/1680 train_time:38760ms step_avg:87.10ms +step:446/1680 train_time:38847ms step_avg:87.10ms +step:447/1680 train_time:38934ms step_avg:87.10ms +step:448/1680 train_time:39021ms step_avg:87.10ms +step:449/1680 train_time:39109ms step_avg:87.10ms +step:450/1680 train_time:39196ms step_avg:87.10ms +step:451/1680 train_time:39283ms step_avg:87.10ms +step:452/1680 train_time:39370ms step_avg:87.10ms +step:453/1680 train_time:39457ms step_avg:87.10ms +step:454/1680 train_time:39545ms step_avg:87.10ms +step:455/1680 train_time:39631ms step_avg:87.10ms +step:456/1680 train_time:39718ms step_avg:87.10ms +step:457/1680 train_time:39805ms step_avg:87.10ms +step:458/1680 train_time:39892ms step_avg:87.10ms +step:459/1680 train_time:39979ms step_avg:87.10ms +step:460/1680 train_time:40066ms step_avg:87.10ms +step:461/1680 train_time:40153ms step_avg:87.10ms +step:462/1680 train_time:40240ms step_avg:87.10ms +step:463/1680 train_time:40327ms step_avg:87.10ms +step:464/1680 train_time:40415ms step_avg:87.10ms +step:465/1680 train_time:40502ms step_avg:87.10ms +step:466/1680 train_time:40589ms step_avg:87.10ms +step:467/1680 train_time:40676ms step_avg:87.10ms +step:468/1680 train_time:40763ms step_avg:87.10ms +step:469/1680 train_time:40850ms step_avg:87.10ms +step:470/1680 train_time:40937ms step_avg:87.10ms +step:471/1680 train_time:41023ms step_avg:87.10ms +step:472/1680 train_time:41110ms step_avg:87.10ms +step:473/1680 train_time:41198ms step_avg:87.10ms +step:474/1680 train_time:41285ms step_avg:87.10ms +step:475/1680 train_time:41371ms step_avg:87.10ms +step:476/1680 train_time:41459ms step_avg:87.10ms +step:477/1680 train_time:41546ms step_avg:87.10ms +step:478/1680 train_time:41633ms step_avg:87.10ms +step:479/1680 train_time:41720ms step_avg:87.10ms +step:480/1680 train_time:41808ms step_avg:87.10ms +step:481/1680 train_time:41895ms step_avg:87.10ms +step:482/1680 train_time:41983ms step_avg:87.10ms +step:483/1680 train_time:42069ms step_avg:87.10ms +step:484/1680 train_time:42157ms step_avg:87.10ms +step:485/1680 train_time:42244ms step_avg:87.10ms +step:486/1680 train_time:42331ms step_avg:87.10ms +step:487/1680 train_time:42418ms step_avg:87.10ms +step:488/1680 train_time:42505ms step_avg:87.10ms +step:489/1680 train_time:42592ms step_avg:87.10ms +step:490/1680 train_time:42679ms step_avg:87.10ms +step:491/1680 train_time:42766ms step_avg:87.10ms +step:492/1680 train_time:42852ms step_avg:87.10ms +step:493/1680 train_time:42939ms step_avg:87.10ms +step:494/1680 train_time:43026ms step_avg:87.10ms +step:495/1680 train_time:43113ms step_avg:87.10ms +step:496/1680 train_time:43200ms step_avg:87.10ms +step:497/1680 train_time:43287ms step_avg:87.10ms +step:498/1680 train_time:43374ms step_avg:87.10ms +step:499/1680 train_time:43461ms step_avg:87.10ms +step:500/1680 train_time:43549ms step_avg:87.10ms +step:500/1680 val_loss:3.7206 train_time:43637ms step_avg:87.27ms +step:501/1680 train_time:43658ms step_avg:87.14ms +step:502/1680 train_time:43728ms step_avg:87.11ms +step:503/1680 train_time:43819ms step_avg:87.12ms +step:504/1680 train_time:43907ms step_avg:87.12ms +step:505/1680 train_time:43994ms step_avg:87.12ms +step:506/1680 train_time:44080ms step_avg:87.12ms +step:507/1680 train_time:44166ms step_avg:87.11ms +step:508/1680 train_time:44252ms step_avg:87.11ms +step:509/1680 train_time:44338ms step_avg:87.11ms +step:510/1680 train_time:44424ms step_avg:87.11ms +step:511/1680 train_time:44510ms step_avg:87.10ms +step:512/1680 train_time:44598ms step_avg:87.11ms +step:513/1680 train_time:44686ms step_avg:87.11ms +step:514/1680 train_time:44775ms step_avg:87.11ms +step:515/1680 train_time:44863ms step_avg:87.11ms +step:516/1680 train_time:44951ms step_avg:87.11ms +step:517/1680 train_time:45038ms step_avg:87.12ms +step:518/1680 train_time:45125ms step_avg:87.11ms +step:519/1680 train_time:45212ms step_avg:87.11ms +step:520/1680 train_time:45298ms step_avg:87.11ms +step:521/1680 train_time:45385ms step_avg:87.11ms +step:522/1680 train_time:45471ms step_avg:87.11ms +step:523/1680 train_time:45559ms step_avg:87.11ms +step:524/1680 train_time:45646ms step_avg:87.11ms +step:525/1680 train_time:45734ms step_avg:87.11ms +step:526/1680 train_time:45822ms step_avg:87.11ms +step:527/1680 train_time:45910ms step_avg:87.12ms +step:528/1680 train_time:45997ms step_avg:87.12ms +step:529/1680 train_time:46084ms step_avg:87.12ms +step:530/1680 train_time:46170ms step_avg:87.11ms +step:531/1680 train_time:46257ms step_avg:87.11ms +step:532/1680 train_time:46343ms step_avg:87.11ms +step:533/1680 train_time:46430ms step_avg:87.11ms +step:534/1680 train_time:46516ms step_avg:87.11ms +step:535/1680 train_time:46604ms step_avg:87.11ms +step:536/1680 train_time:46691ms step_avg:87.11ms +step:537/1680 train_time:46779ms step_avg:87.11ms +step:538/1680 train_time:46867ms step_avg:87.11ms +step:539/1680 train_time:46954ms step_avg:87.11ms +step:540/1680 train_time:47042ms step_avg:87.11ms +step:541/1680 train_time:47129ms step_avg:87.12ms +step:542/1680 train_time:47216ms step_avg:87.11ms +step:543/1680 train_time:47303ms step_avg:87.11ms +step:544/1680 train_time:47389ms step_avg:87.11ms +step:545/1680 train_time:47476ms step_avg:87.11ms +step:546/1680 train_time:47563ms step_avg:87.11ms +step:547/1680 train_time:47650ms step_avg:87.11ms +step:548/1680 train_time:47737ms step_avg:87.11ms +step:549/1680 train_time:47827ms step_avg:87.12ms +step:550/1680 train_time:47916ms step_avg:87.12ms +step:551/1680 train_time:48003ms step_avg:87.12ms +step:552/1680 train_time:48091ms step_avg:87.12ms +step:553/1680 train_time:48180ms step_avg:87.12ms +step:554/1680 train_time:48268ms step_avg:87.13ms +step:555/1680 train_time:48356ms step_avg:87.13ms +step:556/1680 train_time:48443ms step_avg:87.13ms +step:557/1680 train_time:48531ms step_avg:87.13ms +step:558/1680 train_time:48619ms step_avg:87.13ms +step:559/1680 train_time:48708ms step_avg:87.13ms +step:560/1680 train_time:48796ms step_avg:87.14ms +step:561/1680 train_time:48885ms step_avg:87.14ms +step:562/1680 train_time:48973ms step_avg:87.14ms +step:563/1680 train_time:49062ms step_avg:87.14ms +step:564/1680 train_time:49150ms step_avg:87.15ms +step:565/1680 train_time:49239ms step_avg:87.15ms +step:566/1680 train_time:49327ms step_avg:87.15ms +step:567/1680 train_time:49414ms step_avg:87.15ms +step:568/1680 train_time:49503ms step_avg:87.15ms +step:569/1680 train_time:49591ms step_avg:87.15ms +step:570/1680 train_time:49679ms step_avg:87.16ms +step:571/1680 train_time:49768ms step_avg:87.16ms +step:572/1680 train_time:49856ms step_avg:87.16ms +step:573/1680 train_time:49945ms step_avg:87.16ms +step:574/1680 train_time:50033ms step_avg:87.17ms +step:575/1680 train_time:50122ms step_avg:87.17ms +step:576/1680 train_time:50211ms step_avg:87.17ms +step:577/1680 train_time:50299ms step_avg:87.17ms +step:578/1680 train_time:50387ms step_avg:87.17ms +step:579/1680 train_time:50476ms step_avg:87.18ms +step:580/1680 train_time:50563ms step_avg:87.18ms +step:581/1680 train_time:50651ms step_avg:87.18ms +step:582/1680 train_time:50740ms step_avg:87.18ms +step:583/1680 train_time:50828ms step_avg:87.18ms +step:584/1680 train_time:50916ms step_avg:87.18ms +step:585/1680 train_time:51005ms step_avg:87.19ms +step:586/1680 train_time:51093ms step_avg:87.19ms +step:587/1680 train_time:51182ms step_avg:87.19ms +step:588/1680 train_time:51269ms step_avg:87.19ms +step:589/1680 train_time:51357ms step_avg:87.19ms +step:590/1680 train_time:51445ms step_avg:87.20ms +step:591/1680 train_time:51534ms step_avg:87.20ms +step:592/1680 train_time:51622ms step_avg:87.20ms +step:593/1680 train_time:51711ms step_avg:87.20ms +step:594/1680 train_time:51798ms step_avg:87.20ms +step:595/1680 train_time:51886ms step_avg:87.20ms +step:596/1680 train_time:51974ms step_avg:87.20ms +step:597/1680 train_time:52062ms step_avg:87.21ms +step:598/1680 train_time:52152ms step_avg:87.21ms +step:599/1680 train_time:52240ms step_avg:87.21ms +step:600/1680 train_time:52329ms step_avg:87.21ms +step:601/1680 train_time:52417ms step_avg:87.22ms +step:602/1680 train_time:52505ms step_avg:87.22ms +step:603/1680 train_time:52593ms step_avg:87.22ms +step:604/1680 train_time:52681ms step_avg:87.22ms +step:605/1680 train_time:52769ms step_avg:87.22ms +step:606/1680 train_time:52857ms step_avg:87.22ms +step:607/1680 train_time:52946ms step_avg:87.23ms +step:608/1680 train_time:53034ms step_avg:87.23ms +step:609/1680 train_time:53123ms step_avg:87.23ms +step:610/1680 train_time:53211ms step_avg:87.23ms +step:611/1680 train_time:53300ms step_avg:87.23ms +step:612/1680 train_time:53388ms step_avg:87.24ms +step:613/1680 train_time:53477ms step_avg:87.24ms +step:614/1680 train_time:53565ms step_avg:87.24ms +step:615/1680 train_time:53653ms step_avg:87.24ms +step:616/1680 train_time:53741ms step_avg:87.24ms +step:617/1680 train_time:53829ms step_avg:87.24ms +step:618/1680 train_time:53918ms step_avg:87.25ms +step:619/1680 train_time:54006ms step_avg:87.25ms +step:620/1680 train_time:54094ms step_avg:87.25ms +step:621/1680 train_time:54182ms step_avg:87.25ms +step:622/1680 train_time:54270ms step_avg:87.25ms +step:623/1680 train_time:54359ms step_avg:87.25ms +step:624/1680 train_time:54447ms step_avg:87.25ms +step:625/1680 train_time:54535ms step_avg:87.26ms +step:625/1680 val_loss:3.6201 train_time:54625ms step_avg:87.40ms +step:626/1680 train_time:54645ms step_avg:87.29ms +step:627/1680 train_time:54714ms step_avg:87.26ms +step:628/1680 train_time:54804ms step_avg:87.27ms +step:629/1680 train_time:54895ms step_avg:87.27ms +step:630/1680 train_time:54983ms step_avg:87.27ms +step:631/1680 train_time:55070ms step_avg:87.27ms +step:632/1680 train_time:55156ms step_avg:87.27ms +step:633/1680 train_time:55243ms step_avg:87.27ms +step:634/1680 train_time:55330ms step_avg:87.27ms +step:635/1680 train_time:55417ms step_avg:87.27ms +step:636/1680 train_time:55505ms step_avg:87.27ms +step:637/1680 train_time:55594ms step_avg:87.28ms +step:638/1680 train_time:55685ms step_avg:87.28ms +step:639/1680 train_time:55775ms step_avg:87.28ms +step:640/1680 train_time:55864ms step_avg:87.29ms +step:641/1680 train_time:55953ms step_avg:87.29ms +step:642/1680 train_time:56040ms step_avg:87.29ms +step:643/1680 train_time:56127ms step_avg:87.29ms +step:644/1680 train_time:56214ms step_avg:87.29ms +step:645/1680 train_time:56302ms step_avg:87.29ms +step:646/1680 train_time:56389ms step_avg:87.29ms +step:647/1680 train_time:56477ms step_avg:87.29ms +step:648/1680 train_time:56565ms step_avg:87.29ms +step:649/1680 train_time:56653ms step_avg:87.29ms +step:650/1680 train_time:56742ms step_avg:87.30ms +step:651/1680 train_time:56831ms step_avg:87.30ms +step:652/1680 train_time:56919ms step_avg:87.30ms +step:653/1680 train_time:57008ms step_avg:87.30ms +step:654/1680 train_time:57096ms step_avg:87.30ms +step:655/1680 train_time:57183ms step_avg:87.30ms +step:656/1680 train_time:57271ms step_avg:87.30ms +step:657/1680 train_time:57358ms step_avg:87.30ms +step:658/1680 train_time:57446ms step_avg:87.30ms +step:659/1680 train_time:57533ms step_avg:87.30ms +step:660/1680 train_time:57622ms step_avg:87.31ms +step:661/1680 train_time:57710ms step_avg:87.31ms +step:662/1680 train_time:57799ms step_avg:87.31ms +step:663/1680 train_time:57888ms step_avg:87.31ms +step:664/1680 train_time:57976ms step_avg:87.31ms +step:665/1680 train_time:58064ms step_avg:87.31ms +step:666/1680 train_time:58152ms step_avg:87.31ms +step:667/1680 train_time:58239ms step_avg:87.32ms +step:668/1680 train_time:58328ms step_avg:87.32ms +step:669/1680 train_time:58416ms step_avg:87.32ms +step:670/1680 train_time:58504ms step_avg:87.32ms +step:671/1680 train_time:58592ms step_avg:87.32ms +step:672/1680 train_time:58680ms step_avg:87.32ms +step:673/1680 train_time:58768ms step_avg:87.32ms +step:674/1680 train_time:58857ms step_avg:87.32ms +step:675/1680 train_time:58945ms step_avg:87.33ms +step:676/1680 train_time:59033ms step_avg:87.33ms +step:677/1680 train_time:59121ms step_avg:87.33ms +step:678/1680 train_time:59209ms step_avg:87.33ms +step:679/1680 train_time:59298ms step_avg:87.33ms +step:680/1680 train_time:59386ms step_avg:87.33ms +step:681/1680 train_time:59474ms step_avg:87.33ms +step:682/1680 train_time:59562ms step_avg:87.33ms +step:683/1680 train_time:59651ms step_avg:87.34ms +step:684/1680 train_time:59739ms step_avg:87.34ms +step:685/1680 train_time:59828ms step_avg:87.34ms +step:686/1680 train_time:59916ms step_avg:87.34ms +step:687/1680 train_time:60005ms step_avg:87.34ms +step:688/1680 train_time:60092ms step_avg:87.34ms +step:689/1680 train_time:60180ms step_avg:87.34ms +step:690/1680 train_time:60269ms step_avg:87.35ms +step:691/1680 train_time:60357ms step_avg:87.35ms +step:692/1680 train_time:60445ms step_avg:87.35ms +step:693/1680 train_time:60533ms step_avg:87.35ms +step:694/1680 train_time:60621ms step_avg:87.35ms +step:695/1680 train_time:60709ms step_avg:87.35ms +step:696/1680 train_time:60798ms step_avg:87.35ms +step:697/1680 train_time:60886ms step_avg:87.35ms +step:698/1680 train_time:60975ms step_avg:87.36ms +step:699/1680 train_time:61064ms step_avg:87.36ms +step:700/1680 train_time:61152ms step_avg:87.36ms +step:701/1680 train_time:61240ms step_avg:87.36ms +step:702/1680 train_time:61329ms step_avg:87.36ms +step:703/1680 train_time:61417ms step_avg:87.36ms +step:704/1680 train_time:61505ms step_avg:87.36ms +step:705/1680 train_time:61593ms step_avg:87.37ms +step:706/1680 train_time:61680ms step_avg:87.37ms +step:707/1680 train_time:61769ms step_avg:87.37ms +step:708/1680 train_time:61858ms step_avg:87.37ms +step:709/1680 train_time:61946ms step_avg:87.37ms +step:710/1680 train_time:62034ms step_avg:87.37ms +step:711/1680 train_time:62122ms step_avg:87.37ms +step:712/1680 train_time:62210ms step_avg:87.37ms +step:713/1680 train_time:62298ms step_avg:87.38ms +step:714/1680 train_time:62387ms step_avg:87.38ms +step:715/1680 train_time:62475ms step_avg:87.38ms +step:716/1680 train_time:62564ms step_avg:87.38ms +step:717/1680 train_time:62652ms step_avg:87.38ms +step:718/1680 train_time:62740ms step_avg:87.38ms +step:719/1680 train_time:62828ms step_avg:87.38ms +step:720/1680 train_time:62916ms step_avg:87.38ms +step:721/1680 train_time:63005ms step_avg:87.38ms +step:722/1680 train_time:63093ms step_avg:87.39ms +step:723/1680 train_time:63181ms step_avg:87.39ms +step:724/1680 train_time:63269ms step_avg:87.39ms +step:725/1680 train_time:63357ms step_avg:87.39ms +step:726/1680 train_time:63444ms step_avg:87.39ms +step:727/1680 train_time:63533ms step_avg:87.39ms +step:728/1680 train_time:63622ms step_avg:87.39ms +step:729/1680 train_time:63710ms step_avg:87.39ms +step:730/1680 train_time:63798ms step_avg:87.39ms +step:731/1680 train_time:63886ms step_avg:87.40ms +step:732/1680 train_time:63974ms step_avg:87.40ms +step:733/1680 train_time:64063ms step_avg:87.40ms +step:734/1680 train_time:64150ms step_avg:87.40ms +step:735/1680 train_time:64239ms step_avg:87.40ms +step:736/1680 train_time:64327ms step_avg:87.40ms +step:737/1680 train_time:64415ms step_avg:87.40ms +step:738/1680 train_time:64504ms step_avg:87.40ms +step:739/1680 train_time:64592ms step_avg:87.41ms +step:740/1680 train_time:64680ms step_avg:87.41ms +step:741/1680 train_time:64768ms step_avg:87.41ms +step:742/1680 train_time:64856ms step_avg:87.41ms +step:743/1680 train_time:64945ms step_avg:87.41ms +step:744/1680 train_time:65033ms step_avg:87.41ms +step:745/1680 train_time:65121ms step_avg:87.41ms +step:746/1680 train_time:65209ms step_avg:87.41ms +step:747/1680 train_time:65297ms step_avg:87.41ms +step:748/1680 train_time:65385ms step_avg:87.41ms +step:749/1680 train_time:65473ms step_avg:87.41ms +step:750/1680 train_time:65561ms step_avg:87.41ms +step:750/1680 val_loss:3.5669 train_time:65652ms step_avg:87.54ms +step:751/1680 train_time:65670ms step_avg:87.44ms +step:752/1680 train_time:65742ms step_avg:87.42ms +step:753/1680 train_time:65834ms step_avg:87.43ms +step:754/1680 train_time:65924ms step_avg:87.43ms +step:755/1680 train_time:66012ms step_avg:87.43ms +step:756/1680 train_time:66099ms step_avg:87.43ms +step:757/1680 train_time:66186ms step_avg:87.43ms +step:758/1680 train_time:66274ms step_avg:87.43ms +step:759/1680 train_time:66361ms step_avg:87.43ms +step:760/1680 train_time:66450ms step_avg:87.43ms +step:761/1680 train_time:66537ms step_avg:87.43ms +step:762/1680 train_time:66626ms step_avg:87.44ms +step:763/1680 train_time:66716ms step_avg:87.44ms +step:764/1680 train_time:66806ms step_avg:87.44ms +step:765/1680 train_time:66895ms step_avg:87.44ms +step:766/1680 train_time:66983ms step_avg:87.45ms +step:767/1680 train_time:67071ms step_avg:87.45ms +step:768/1680 train_time:67159ms step_avg:87.45ms +step:769/1680 train_time:67246ms step_avg:87.45ms +step:770/1680 train_time:67333ms step_avg:87.45ms +step:771/1680 train_time:67421ms step_avg:87.45ms +step:772/1680 train_time:67509ms step_avg:87.45ms +step:773/1680 train_time:67597ms step_avg:87.45ms +step:774/1680 train_time:67685ms step_avg:87.45ms +step:775/1680 train_time:67774ms step_avg:87.45ms +step:776/1680 train_time:67863ms step_avg:87.45ms +step:777/1680 train_time:67951ms step_avg:87.45ms +step:778/1680 train_time:68039ms step_avg:87.45ms +step:779/1680 train_time:68127ms step_avg:87.46ms +step:780/1680 train_time:68215ms step_avg:87.45ms +step:781/1680 train_time:68302ms step_avg:87.46ms +step:782/1680 train_time:68390ms step_avg:87.46ms +step:783/1680 train_time:68479ms step_avg:87.46ms +step:784/1680 train_time:68567ms step_avg:87.46ms +step:785/1680 train_time:68657ms step_avg:87.46ms +step:786/1680 train_time:68745ms step_avg:87.46ms +step:787/1680 train_time:68834ms step_avg:87.46ms +step:788/1680 train_time:68922ms step_avg:87.46ms +step:789/1680 train_time:69010ms step_avg:87.47ms +step:790/1680 train_time:69098ms step_avg:87.47ms +step:791/1680 train_time:69186ms step_avg:87.47ms +step:792/1680 train_time:69275ms step_avg:87.47ms +step:793/1680 train_time:69362ms step_avg:87.47ms +step:794/1680 train_time:69450ms step_avg:87.47ms +step:795/1680 train_time:69538ms step_avg:87.47ms +step:796/1680 train_time:69626ms step_avg:87.47ms +step:797/1680 train_time:69715ms step_avg:87.47ms +step:798/1680 train_time:69803ms step_avg:87.47ms +step:799/1680 train_time:69892ms step_avg:87.47ms +step:800/1680 train_time:69979ms step_avg:87.47ms +step:801/1680 train_time:70067ms step_avg:87.47ms +step:802/1680 train_time:70156ms step_avg:87.48ms +step:803/1680 train_time:70244ms step_avg:87.48ms +step:804/1680 train_time:70331ms step_avg:87.48ms +step:805/1680 train_time:70419ms step_avg:87.48ms +step:806/1680 train_time:70507ms step_avg:87.48ms +step:807/1680 train_time:70595ms step_avg:87.48ms +step:808/1680 train_time:70683ms step_avg:87.48ms +step:809/1680 train_time:70772ms step_avg:87.48ms +step:810/1680 train_time:70862ms step_avg:87.48ms +step:811/1680 train_time:70950ms step_avg:87.48ms +step:812/1680 train_time:71038ms step_avg:87.49ms +step:813/1680 train_time:71127ms step_avg:87.49ms +step:814/1680 train_time:71215ms step_avg:87.49ms +step:815/1680 train_time:71303ms step_avg:87.49ms +step:816/1680 train_time:71391ms step_avg:87.49ms +step:817/1680 train_time:71478ms step_avg:87.49ms +step:818/1680 train_time:71567ms step_avg:87.49ms +step:819/1680 train_time:71655ms step_avg:87.49ms +step:820/1680 train_time:71743ms step_avg:87.49ms +step:821/1680 train_time:71832ms step_avg:87.49ms +step:822/1680 train_time:71920ms step_avg:87.49ms +step:823/1680 train_time:72009ms step_avg:87.50ms +step:824/1680 train_time:72097ms step_avg:87.50ms +step:825/1680 train_time:72184ms step_avg:87.50ms +step:826/1680 train_time:72273ms step_avg:87.50ms +step:827/1680 train_time:72361ms step_avg:87.50ms +step:828/1680 train_time:72449ms step_avg:87.50ms +step:829/1680 train_time:72537ms step_avg:87.50ms +step:830/1680 train_time:72625ms step_avg:87.50ms +step:831/1680 train_time:72713ms step_avg:87.50ms +step:832/1680 train_time:72801ms step_avg:87.50ms +step:833/1680 train_time:72889ms step_avg:87.50ms +step:834/1680 train_time:72978ms step_avg:87.50ms +step:835/1680 train_time:73066ms step_avg:87.50ms +step:836/1680 train_time:73154ms step_avg:87.50ms +step:837/1680 train_time:73243ms step_avg:87.51ms +step:838/1680 train_time:73331ms step_avg:87.51ms +step:839/1680 train_time:73419ms step_avg:87.51ms +step:840/1680 train_time:73507ms step_avg:87.51ms +step:841/1680 train_time:73595ms step_avg:87.51ms +step:842/1680 train_time:73683ms step_avg:87.51ms +step:843/1680 train_time:73771ms step_avg:87.51ms +step:844/1680 train_time:73859ms step_avg:87.51ms +step:845/1680 train_time:73947ms step_avg:87.51ms +step:846/1680 train_time:74036ms step_avg:87.51ms +step:847/1680 train_time:74124ms step_avg:87.51ms +step:848/1680 train_time:74213ms step_avg:87.52ms +step:849/1680 train_time:74300ms step_avg:87.52ms +step:850/1680 train_time:74388ms step_avg:87.52ms +step:851/1680 train_time:74477ms step_avg:87.52ms +step:852/1680 train_time:74565ms step_avg:87.52ms +step:853/1680 train_time:74654ms step_avg:87.52ms +step:854/1680 train_time:74741ms step_avg:87.52ms +step:855/1680 train_time:74830ms step_avg:87.52ms +step:856/1680 train_time:74918ms step_avg:87.52ms +step:857/1680 train_time:75006ms step_avg:87.52ms +step:858/1680 train_time:75095ms step_avg:87.52ms +step:859/1680 train_time:75183ms step_avg:87.52ms +step:860/1680 train_time:75271ms step_avg:87.52ms +step:861/1680 train_time:75359ms step_avg:87.53ms +step:862/1680 train_time:75448ms step_avg:87.53ms +step:863/1680 train_time:75536ms step_avg:87.53ms +step:864/1680 train_time:75624ms step_avg:87.53ms +step:865/1680 train_time:75712ms step_avg:87.53ms +step:866/1680 train_time:75801ms step_avg:87.53ms +step:867/1680 train_time:75889ms step_avg:87.53ms +step:868/1680 train_time:75977ms step_avg:87.53ms +step:869/1680 train_time:76065ms step_avg:87.53ms +step:870/1680 train_time:76153ms step_avg:87.53ms +step:871/1680 train_time:76242ms step_avg:87.53ms +step:872/1680 train_time:76331ms step_avg:87.54ms +step:873/1680 train_time:76419ms step_avg:87.54ms +step:874/1680 train_time:76506ms step_avg:87.54ms +step:875/1680 train_time:76594ms step_avg:87.54ms +step:875/1680 val_loss:3.5207 train_time:76683ms step_avg:87.64ms +step:876/1680 train_time:76703ms step_avg:87.56ms +step:877/1680 train_time:76774ms step_avg:87.54ms +step:878/1680 train_time:76867ms step_avg:87.55ms +step:879/1680 train_time:76955ms step_avg:87.55ms +step:880/1680 train_time:77042ms step_avg:87.55ms +step:881/1680 train_time:77129ms step_avg:87.55ms +step:882/1680 train_time:77216ms step_avg:87.55ms +step:883/1680 train_time:77303ms step_avg:87.55ms +step:884/1680 train_time:77390ms step_avg:87.55ms +step:885/1680 train_time:77479ms step_avg:87.55ms +step:886/1680 train_time:77567ms step_avg:87.55ms +step:887/1680 train_time:77657ms step_avg:87.55ms +step:888/1680 train_time:77746ms step_avg:87.55ms +step:889/1680 train_time:77836ms step_avg:87.55ms +step:890/1680 train_time:77924ms step_avg:87.56ms +step:891/1680 train_time:78012ms step_avg:87.56ms +step:892/1680 train_time:78100ms step_avg:87.56ms +step:893/1680 train_time:78188ms step_avg:87.56ms +step:894/1680 train_time:78276ms step_avg:87.56ms +step:895/1680 train_time:78363ms step_avg:87.56ms +step:896/1680 train_time:78451ms step_avg:87.56ms +step:897/1680 train_time:78538ms step_avg:87.56ms +step:898/1680 train_time:78627ms step_avg:87.56ms +step:899/1680 train_time:78716ms step_avg:87.56ms +step:900/1680 train_time:78805ms step_avg:87.56ms +step:901/1680 train_time:78893ms step_avg:87.56ms +step:902/1680 train_time:78982ms step_avg:87.56ms +step:903/1680 train_time:79070ms step_avg:87.56ms +step:904/1680 train_time:79158ms step_avg:87.56ms +step:905/1680 train_time:79246ms step_avg:87.56ms +step:906/1680 train_time:79334ms step_avg:87.56ms +step:907/1680 train_time:79421ms step_avg:87.56ms +step:908/1680 train_time:79510ms step_avg:87.57ms +step:909/1680 train_time:79598ms step_avg:87.57ms +step:910/1680 train_time:79688ms step_avg:87.57ms +step:911/1680 train_time:79776ms step_avg:87.57ms +step:912/1680 train_time:79864ms step_avg:87.57ms +step:913/1680 train_time:79953ms step_avg:87.57ms +step:914/1680 train_time:80042ms step_avg:87.57ms +step:915/1680 train_time:80130ms step_avg:87.57ms +step:916/1680 train_time:80218ms step_avg:87.57ms +step:917/1680 train_time:80306ms step_avg:87.58ms +step:918/1680 train_time:80394ms step_avg:87.58ms +step:919/1680 train_time:80483ms step_avg:87.58ms +step:920/1680 train_time:80572ms step_avg:87.58ms +step:921/1680 train_time:80660ms step_avg:87.58ms +step:922/1680 train_time:80749ms step_avg:87.58ms +step:923/1680 train_time:80837ms step_avg:87.58ms +step:924/1680 train_time:80925ms step_avg:87.58ms +step:925/1680 train_time:81013ms step_avg:87.58ms +step:926/1680 train_time:81102ms step_avg:87.58ms +step:927/1680 train_time:81190ms step_avg:87.58ms +step:928/1680 train_time:81278ms step_avg:87.58ms +step:929/1680 train_time:81366ms step_avg:87.58ms +step:930/1680 train_time:81454ms step_avg:87.59ms +step:931/1680 train_time:81543ms step_avg:87.59ms +step:932/1680 train_time:81632ms step_avg:87.59ms +step:933/1680 train_time:81721ms step_avg:87.59ms +step:934/1680 train_time:81809ms step_avg:87.59ms +step:935/1680 train_time:81898ms step_avg:87.59ms +step:936/1680 train_time:81987ms step_avg:87.59ms +step:937/1680 train_time:82075ms step_avg:87.59ms +step:938/1680 train_time:82163ms step_avg:87.59ms +step:939/1680 train_time:82252ms step_avg:87.59ms +step:940/1680 train_time:82341ms step_avg:87.60ms +step:941/1680 train_time:82428ms step_avg:87.60ms +step:942/1680 train_time:82517ms step_avg:87.60ms +step:943/1680 train_time:82605ms step_avg:87.60ms +step:944/1680 train_time:82693ms step_avg:87.60ms +step:945/1680 train_time:82782ms step_avg:87.60ms +step:946/1680 train_time:82870ms step_avg:87.60ms +step:947/1680 train_time:82959ms step_avg:87.60ms +step:948/1680 train_time:83047ms step_avg:87.60ms +step:949/1680 train_time:83136ms step_avg:87.60ms +step:950/1680 train_time:83224ms step_avg:87.60ms +step:951/1680 train_time:83312ms step_avg:87.60ms +step:952/1680 train_time:83401ms step_avg:87.61ms +step:953/1680 train_time:83488ms step_avg:87.61ms +step:954/1680 train_time:83577ms step_avg:87.61ms +step:955/1680 train_time:83665ms step_avg:87.61ms +step:956/1680 train_time:83753ms step_avg:87.61ms +step:957/1680 train_time:83842ms step_avg:87.61ms +step:958/1680 train_time:83930ms step_avg:87.61ms +step:959/1680 train_time:84019ms step_avg:87.61ms +step:960/1680 train_time:84108ms step_avg:87.61ms +step:961/1680 train_time:84195ms step_avg:87.61ms +step:962/1680 train_time:84284ms step_avg:87.61ms +step:963/1680 train_time:84372ms step_avg:87.61ms +step:964/1680 train_time:84460ms step_avg:87.61ms +step:965/1680 train_time:84548ms step_avg:87.61ms +step:966/1680 train_time:84636ms step_avg:87.61ms +step:967/1680 train_time:84725ms step_avg:87.62ms +step:968/1680 train_time:84813ms step_avg:87.62ms +step:969/1680 train_time:84902ms step_avg:87.62ms +step:970/1680 train_time:84990ms step_avg:87.62ms +step:971/1680 train_time:85078ms step_avg:87.62ms +step:972/1680 train_time:85166ms step_avg:87.62ms +step:973/1680 train_time:85255ms step_avg:87.62ms +step:974/1680 train_time:85342ms step_avg:87.62ms +step:975/1680 train_time:85430ms step_avg:87.62ms +step:976/1680 train_time:85518ms step_avg:87.62ms +step:977/1680 train_time:85606ms step_avg:87.62ms +step:978/1680 train_time:85694ms step_avg:87.62ms +step:979/1680 train_time:85783ms step_avg:87.62ms +step:980/1680 train_time:85872ms step_avg:87.62ms +step:981/1680 train_time:85960ms step_avg:87.63ms +step:982/1680 train_time:86048ms step_avg:87.63ms +step:983/1680 train_time:86137ms step_avg:87.63ms +step:984/1680 train_time:86224ms step_avg:87.63ms +step:985/1680 train_time:86313ms step_avg:87.63ms +step:986/1680 train_time:86401ms step_avg:87.63ms +step:987/1680 train_time:86489ms step_avg:87.63ms +step:988/1680 train_time:86578ms step_avg:87.63ms +step:989/1680 train_time:86666ms step_avg:87.63ms +step:990/1680 train_time:86754ms step_avg:87.63ms +step:991/1680 train_time:86843ms step_avg:87.63ms +step:992/1680 train_time:86932ms step_avg:87.63ms +step:993/1680 train_time:87020ms step_avg:87.63ms +step:994/1680 train_time:87108ms step_avg:87.63ms +step:995/1680 train_time:87197ms step_avg:87.63ms +step:996/1680 train_time:87285ms step_avg:87.64ms +step:997/1680 train_time:87373ms step_avg:87.64ms +step:998/1680 train_time:87461ms step_avg:87.64ms +step:999/1680 train_time:87550ms step_avg:87.64ms +step:1000/1680 train_time:87638ms step_avg:87.64ms +step:1000/1680 val_loss:3.4703 train_time:87728ms step_avg:87.73ms +step:1001/1680 train_time:87747ms step_avg:87.66ms +step:1002/1680 train_time:87817ms step_avg:87.64ms +step:1003/1680 train_time:87910ms step_avg:87.65ms +step:1004/1680 train_time:87999ms step_avg:87.65ms +step:1005/1680 train_time:88088ms step_avg:87.65ms +step:1006/1680 train_time:88176ms step_avg:87.65ms +step:1007/1680 train_time:88263ms step_avg:87.65ms +step:1008/1680 train_time:88351ms step_avg:87.65ms +step:1009/1680 train_time:88438ms step_avg:87.65ms +step:1010/1680 train_time:88527ms step_avg:87.65ms +step:1011/1680 train_time:88614ms step_avg:87.65ms +step:1012/1680 train_time:88704ms step_avg:87.65ms +step:1013/1680 train_time:88794ms step_avg:87.65ms +step:1014/1680 train_time:88884ms step_avg:87.66ms +step:1015/1680 train_time:88973ms step_avg:87.66ms +step:1016/1680 train_time:89062ms step_avg:87.66ms +step:1017/1680 train_time:89150ms step_avg:87.66ms +step:1018/1680 train_time:89239ms step_avg:87.66ms +step:1019/1680 train_time:89327ms step_avg:87.66ms +step:1020/1680 train_time:89414ms step_avg:87.66ms +step:1021/1680 train_time:89502ms step_avg:87.66ms +step:1022/1680 train_time:89590ms step_avg:87.66ms +step:1023/1680 train_time:89677ms step_avg:87.66ms +step:1024/1680 train_time:89767ms step_avg:87.66ms +step:1025/1680 train_time:89856ms step_avg:87.66ms +step:1026/1680 train_time:89945ms step_avg:87.67ms +step:1027/1680 train_time:90034ms step_avg:87.67ms +step:1028/1680 train_time:90122ms step_avg:87.67ms +step:1029/1680 train_time:90210ms step_avg:87.67ms +step:1030/1680 train_time:90298ms step_avg:87.67ms +step:1031/1680 train_time:90385ms step_avg:87.67ms +step:1032/1680 train_time:90473ms step_avg:87.67ms +step:1033/1680 train_time:90561ms step_avg:87.67ms +step:1034/1680 train_time:90649ms step_avg:87.67ms +step:1035/1680 train_time:90738ms step_avg:87.67ms +step:1036/1680 train_time:90827ms step_avg:87.67ms +step:1037/1680 train_time:90915ms step_avg:87.67ms +step:1038/1680 train_time:91004ms step_avg:87.67ms +step:1039/1680 train_time:91092ms step_avg:87.67ms +step:1040/1680 train_time:91180ms step_avg:87.67ms +step:1041/1680 train_time:91268ms step_avg:87.67ms +step:1042/1680 train_time:91356ms step_avg:87.67ms +step:1043/1680 train_time:91444ms step_avg:87.67ms +step:1044/1680 train_time:91532ms step_avg:87.67ms +step:1045/1680 train_time:91621ms step_avg:87.68ms +step:1046/1680 train_time:91709ms step_avg:87.68ms +step:1047/1680 train_time:91797ms step_avg:87.68ms +step:1048/1680 train_time:91886ms step_avg:87.68ms +step:1049/1680 train_time:91974ms step_avg:87.68ms +step:1050/1680 train_time:92063ms step_avg:87.68ms +step:1051/1680 train_time:92152ms step_avg:87.68ms +step:1052/1680 train_time:92239ms step_avg:87.68ms +step:1053/1680 train_time:92327ms step_avg:87.68ms +step:1054/1680 train_time:92416ms step_avg:87.68ms +step:1055/1680 train_time:92505ms step_avg:87.68ms +step:1056/1680 train_time:92593ms step_avg:87.68ms +step:1057/1680 train_time:92681ms step_avg:87.68ms +step:1058/1680 train_time:92770ms step_avg:87.68ms +step:1059/1680 train_time:92858ms step_avg:87.68ms +step:1060/1680 train_time:92948ms step_avg:87.69ms +step:1061/1680 train_time:93035ms step_avg:87.69ms +step:1062/1680 train_time:93124ms step_avg:87.69ms +step:1063/1680 train_time:93212ms step_avg:87.69ms +step:1064/1680 train_time:93300ms step_avg:87.69ms +step:1065/1680 train_time:93389ms step_avg:87.69ms +step:1066/1680 train_time:93477ms step_avg:87.69ms +step:1067/1680 train_time:93565ms step_avg:87.69ms +step:1068/1680 train_time:93653ms step_avg:87.69ms +step:1069/1680 train_time:93741ms step_avg:87.69ms +step:1070/1680 train_time:93830ms step_avg:87.69ms +step:1071/1680 train_time:93919ms step_avg:87.69ms +step:1072/1680 train_time:94007ms step_avg:87.69ms +step:1073/1680 train_time:94095ms step_avg:87.69ms +step:1074/1680 train_time:94184ms step_avg:87.69ms +step:1075/1680 train_time:94272ms step_avg:87.69ms +step:1076/1680 train_time:94360ms step_avg:87.69ms +step:1077/1680 train_time:94448ms step_avg:87.70ms +step:1078/1680 train_time:94536ms step_avg:87.70ms +step:1079/1680 train_time:94624ms step_avg:87.70ms +step:1080/1680 train_time:94714ms step_avg:87.70ms +step:1081/1680 train_time:94802ms step_avg:87.70ms +step:1082/1680 train_time:94890ms step_avg:87.70ms +step:1083/1680 train_time:94978ms step_avg:87.70ms +step:1084/1680 train_time:95067ms step_avg:87.70ms +step:1085/1680 train_time:95155ms step_avg:87.70ms +step:1086/1680 train_time:95243ms step_avg:87.70ms +step:1087/1680 train_time:95331ms step_avg:87.70ms +step:1088/1680 train_time:95419ms step_avg:87.70ms +step:1089/1680 train_time:95507ms step_avg:87.70ms +step:1090/1680 train_time:95595ms step_avg:87.70ms +step:1091/1680 train_time:95683ms step_avg:87.70ms +step:1092/1680 train_time:95772ms step_avg:87.70ms +step:1093/1680 train_time:95860ms step_avg:87.70ms +step:1094/1680 train_time:95949ms step_avg:87.70ms +step:1095/1680 train_time:96037ms step_avg:87.70ms +step:1096/1680 train_time:96126ms step_avg:87.71ms +step:1097/1680 train_time:96214ms step_avg:87.71ms +step:1098/1680 train_time:96303ms step_avg:87.71ms +step:1099/1680 train_time:96392ms step_avg:87.71ms +step:1100/1680 train_time:96481ms step_avg:87.71ms +step:1101/1680 train_time:96569ms step_avg:87.71ms +step:1102/1680 train_time:96659ms step_avg:87.71ms +step:1103/1680 train_time:96748ms step_avg:87.71ms +step:1104/1680 train_time:96837ms step_avg:87.71ms +step:1105/1680 train_time:96927ms step_avg:87.72ms +step:1106/1680 train_time:97017ms step_avg:87.72ms +step:1107/1680 train_time:97106ms step_avg:87.72ms +step:1108/1680 train_time:97195ms step_avg:87.72ms +step:1109/1680 train_time:97284ms step_avg:87.72ms +step:1110/1680 train_time:97373ms step_avg:87.72ms +step:1111/1680 train_time:97461ms step_avg:87.72ms +step:1112/1680 train_time:97550ms step_avg:87.72ms +step:1113/1680 train_time:97639ms step_avg:87.73ms +step:1114/1680 train_time:97728ms step_avg:87.73ms +step:1115/1680 train_time:97817ms step_avg:87.73ms +step:1116/1680 train_time:97907ms step_avg:87.73ms +step:1117/1680 train_time:97996ms step_avg:87.73ms +step:1118/1680 train_time:98086ms step_avg:87.73ms +step:1119/1680 train_time:98174ms step_avg:87.73ms +step:1120/1680 train_time:98263ms step_avg:87.74ms +step:1121/1680 train_time:98352ms step_avg:87.74ms +step:1122/1680 train_time:98441ms step_avg:87.74ms +step:1123/1680 train_time:98530ms step_avg:87.74ms +step:1124/1680 train_time:98618ms step_avg:87.74ms +step:1125/1680 train_time:98707ms step_avg:87.74ms +step:1125/1680 val_loss:3.4176 train_time:98798ms step_avg:87.82ms +step:1126/1680 train_time:98817ms step_avg:87.76ms +step:1127/1680 train_time:98888ms step_avg:87.74ms +step:1128/1680 train_time:98981ms step_avg:87.75ms +step:1129/1680 train_time:99073ms step_avg:87.75ms +step:1130/1680 train_time:99161ms step_avg:87.75ms +step:1131/1680 train_time:99249ms step_avg:87.75ms +step:1132/1680 train_time:99336ms step_avg:87.75ms +step:1133/1680 train_time:99424ms step_avg:87.75ms +step:1134/1680 train_time:99512ms step_avg:87.75ms +step:1135/1680 train_time:99600ms step_avg:87.75ms +step:1136/1680 train_time:99689ms step_avg:87.75ms +step:1137/1680 train_time:99779ms step_avg:87.76ms +step:1138/1680 train_time:99869ms step_avg:87.76ms +step:1139/1680 train_time:99961ms step_avg:87.76ms +step:1140/1680 train_time:100051ms step_avg:87.76ms +step:1141/1680 train_time:100140ms step_avg:87.77ms +step:1142/1680 train_time:100229ms step_avg:87.77ms +step:1143/1680 train_time:100318ms step_avg:87.77ms +step:1144/1680 train_time:100406ms step_avg:87.77ms +step:1145/1680 train_time:100494ms step_avg:87.77ms +step:1146/1680 train_time:100583ms step_avg:87.77ms +step:1147/1680 train_time:100671ms step_avg:87.77ms +step:1148/1680 train_time:100759ms step_avg:87.77ms +step:1149/1680 train_time:100849ms step_avg:87.77ms +step:1150/1680 train_time:100939ms step_avg:87.77ms +step:1151/1680 train_time:101029ms step_avg:87.78ms +step:1152/1680 train_time:101119ms step_avg:87.78ms +step:1153/1680 train_time:101209ms step_avg:87.78ms +step:1154/1680 train_time:101297ms step_avg:87.78ms +step:1155/1680 train_time:101385ms step_avg:87.78ms +step:1156/1680 train_time:101474ms step_avg:87.78ms +step:1157/1680 train_time:101562ms step_avg:87.78ms +step:1158/1680 train_time:101651ms step_avg:87.78ms +step:1159/1680 train_time:101739ms step_avg:87.78ms +step:1160/1680 train_time:101828ms step_avg:87.78ms +step:1161/1680 train_time:101919ms step_avg:87.79ms +step:1162/1680 train_time:102008ms step_avg:87.79ms +step:1163/1680 train_time:102098ms step_avg:87.79ms +step:1164/1680 train_time:102187ms step_avg:87.79ms +step:1165/1680 train_time:102276ms step_avg:87.79ms +step:1166/1680 train_time:102365ms step_avg:87.79ms +step:1167/1680 train_time:102453ms step_avg:87.79ms +step:1168/1680 train_time:102542ms step_avg:87.79ms +step:1169/1680 train_time:102631ms step_avg:87.79ms +step:1170/1680 train_time:102720ms step_avg:87.80ms +step:1171/1680 train_time:102810ms step_avg:87.80ms +step:1172/1680 train_time:102899ms step_avg:87.80ms +step:1173/1680 train_time:102989ms step_avg:87.80ms +step:1174/1680 train_time:103079ms step_avg:87.80ms +step:1175/1680 train_time:103168ms step_avg:87.80ms +step:1176/1680 train_time:103257ms step_avg:87.80ms +step:1177/1680 train_time:103346ms step_avg:87.80ms +step:1178/1680 train_time:103435ms step_avg:87.81ms +step:1179/1680 train_time:103524ms step_avg:87.81ms +step:1180/1680 train_time:103612ms step_avg:87.81ms +step:1181/1680 train_time:103701ms step_avg:87.81ms +step:1182/1680 train_time:103790ms step_avg:87.81ms +step:1183/1680 train_time:103880ms step_avg:87.81ms +step:1184/1680 train_time:103968ms step_avg:87.81ms +step:1185/1680 train_time:104058ms step_avg:87.81ms +step:1186/1680 train_time:104147ms step_avg:87.81ms +step:1187/1680 train_time:104236ms step_avg:87.81ms +step:1188/1680 train_time:104324ms step_avg:87.82ms +step:1189/1680 train_time:104413ms step_avg:87.82ms +step:1190/1680 train_time:104502ms step_avg:87.82ms +step:1191/1680 train_time:104592ms step_avg:87.82ms +step:1192/1680 train_time:104680ms step_avg:87.82ms +step:1193/1680 train_time:104769ms step_avg:87.82ms +step:1194/1680 train_time:104858ms step_avg:87.82ms +step:1195/1680 train_time:104947ms step_avg:87.82ms +step:1196/1680 train_time:105036ms step_avg:87.82ms +step:1197/1680 train_time:105125ms step_avg:87.82ms +step:1198/1680 train_time:105214ms step_avg:87.82ms +step:1199/1680 train_time:105302ms step_avg:87.83ms +step:1200/1680 train_time:105391ms step_avg:87.83ms +step:1201/1680 train_time:105481ms step_avg:87.83ms +step:1202/1680 train_time:105570ms step_avg:87.83ms +step:1203/1680 train_time:105658ms step_avg:87.83ms +step:1204/1680 train_time:105747ms step_avg:87.83ms +step:1205/1680 train_time:105835ms step_avg:87.83ms +step:1206/1680 train_time:105924ms step_avg:87.83ms +step:1207/1680 train_time:106013ms step_avg:87.83ms +step:1208/1680 train_time:106102ms step_avg:87.83ms +step:1209/1680 train_time:106191ms step_avg:87.83ms +step:1210/1680 train_time:106279ms step_avg:87.83ms +step:1211/1680 train_time:106368ms step_avg:87.84ms +step:1212/1680 train_time:106458ms step_avg:87.84ms +step:1213/1680 train_time:106546ms step_avg:87.84ms +step:1214/1680 train_time:106635ms step_avg:87.84ms +step:1215/1680 train_time:106724ms step_avg:87.84ms +step:1216/1680 train_time:106813ms step_avg:87.84ms +step:1217/1680 train_time:106901ms step_avg:87.84ms +step:1218/1680 train_time:106990ms step_avg:87.84ms +step:1219/1680 train_time:107080ms step_avg:87.84ms +step:1220/1680 train_time:107169ms step_avg:87.84ms +step:1221/1680 train_time:107258ms step_avg:87.84ms +step:1222/1680 train_time:107348ms step_avg:87.85ms +step:1223/1680 train_time:107436ms step_avg:87.85ms +step:1224/1680 train_time:107525ms step_avg:87.85ms +step:1225/1680 train_time:107614ms step_avg:87.85ms +step:1226/1680 train_time:107703ms step_avg:87.85ms +step:1227/1680 train_time:107792ms step_avg:87.85ms +step:1228/1680 train_time:107882ms step_avg:87.85ms +step:1229/1680 train_time:107970ms step_avg:87.85ms +step:1230/1680 train_time:108059ms step_avg:87.85ms +step:1231/1680 train_time:108148ms step_avg:87.85ms +step:1232/1680 train_time:108236ms step_avg:87.85ms +step:1233/1680 train_time:108325ms step_avg:87.85ms +step:1234/1680 train_time:108415ms step_avg:87.86ms +step:1235/1680 train_time:108504ms step_avg:87.86ms +step:1236/1680 train_time:108593ms step_avg:87.86ms +step:1237/1680 train_time:108682ms step_avg:87.86ms +step:1238/1680 train_time:108770ms step_avg:87.86ms +step:1239/1680 train_time:108859ms step_avg:87.86ms +step:1240/1680 train_time:108948ms step_avg:87.86ms +step:1241/1680 train_time:109037ms step_avg:87.86ms +step:1242/1680 train_time:109126ms step_avg:87.86ms +step:1243/1680 train_time:109215ms step_avg:87.86ms +step:1244/1680 train_time:109303ms step_avg:87.86ms +step:1245/1680 train_time:109393ms step_avg:87.87ms +step:1246/1680 train_time:109483ms step_avg:87.87ms +step:1247/1680 train_time:109572ms step_avg:87.87ms +step:1248/1680 train_time:109661ms step_avg:87.87ms +step:1249/1680 train_time:109749ms step_avg:87.87ms +step:1250/1680 train_time:109838ms step_avg:87.87ms +step:1250/1680 val_loss:3.3791 train_time:109929ms step_avg:87.94ms +step:1251/1680 train_time:109948ms step_avg:87.89ms +step:1252/1680 train_time:110019ms step_avg:87.87ms +step:1253/1680 train_time:110111ms step_avg:87.88ms +step:1254/1680 train_time:110200ms step_avg:87.88ms +step:1255/1680 train_time:110288ms step_avg:87.88ms +step:1256/1680 train_time:110377ms step_avg:87.88ms +step:1257/1680 train_time:110466ms step_avg:87.88ms +step:1258/1680 train_time:110554ms step_avg:87.88ms +step:1259/1680 train_time:110643ms step_avg:87.88ms +step:1260/1680 train_time:110732ms step_avg:87.88ms +step:1261/1680 train_time:110820ms step_avg:87.88ms +step:1262/1680 train_time:110911ms step_avg:87.89ms +step:1263/1680 train_time:111003ms step_avg:87.89ms +step:1264/1680 train_time:111094ms step_avg:87.89ms +step:1265/1680 train_time:111183ms step_avg:87.89ms +step:1266/1680 train_time:111272ms step_avg:87.89ms +step:1267/1680 train_time:111360ms step_avg:87.89ms +step:1268/1680 train_time:111449ms step_avg:87.89ms +step:1269/1680 train_time:111538ms step_avg:87.89ms +step:1270/1680 train_time:111627ms step_avg:87.89ms +step:1271/1680 train_time:111716ms step_avg:87.90ms +step:1272/1680 train_time:111804ms step_avg:87.90ms +step:1273/1680 train_time:111894ms step_avg:87.90ms +step:1274/1680 train_time:111984ms step_avg:87.90ms +step:1275/1680 train_time:112074ms step_avg:87.90ms +step:1276/1680 train_time:112165ms step_avg:87.90ms +step:1277/1680 train_time:112254ms step_avg:87.90ms +step:1278/1680 train_time:112343ms step_avg:87.91ms +step:1279/1680 train_time:112432ms step_avg:87.91ms +step:1280/1680 train_time:112521ms step_avg:87.91ms +step:1281/1680 train_time:112609ms step_avg:87.91ms +step:1282/1680 train_time:112697ms step_avg:87.91ms +step:1283/1680 train_time:112787ms step_avg:87.91ms +step:1284/1680 train_time:112876ms step_avg:87.91ms +step:1285/1680 train_time:112966ms step_avg:87.91ms +step:1286/1680 train_time:113056ms step_avg:87.91ms +step:1287/1680 train_time:113146ms step_avg:87.91ms +step:1288/1680 train_time:113236ms step_avg:87.92ms +step:1289/1680 train_time:113325ms step_avg:87.92ms +step:1290/1680 train_time:113414ms step_avg:87.92ms +step:1291/1680 train_time:113503ms step_avg:87.92ms +step:1292/1680 train_time:113592ms step_avg:87.92ms +step:1293/1680 train_time:113681ms step_avg:87.92ms +step:1294/1680 train_time:113770ms step_avg:87.92ms +step:1295/1680 train_time:113859ms step_avg:87.92ms +step:1296/1680 train_time:113948ms step_avg:87.92ms +step:1297/1680 train_time:114037ms step_avg:87.92ms +step:1298/1680 train_time:114127ms step_avg:87.93ms +step:1299/1680 train_time:114216ms step_avg:87.93ms +step:1300/1680 train_time:114306ms step_avg:87.93ms +step:1301/1680 train_time:114394ms step_avg:87.93ms +step:1302/1680 train_time:114484ms step_avg:87.93ms +step:1303/1680 train_time:114573ms step_avg:87.93ms +step:1304/1680 train_time:114662ms step_avg:87.93ms +step:1305/1680 train_time:114751ms step_avg:87.93ms +step:1306/1680 train_time:114842ms step_avg:87.93ms +step:1307/1680 train_time:114931ms step_avg:87.93ms +step:1308/1680 train_time:115020ms step_avg:87.94ms +step:1309/1680 train_time:115109ms step_avg:87.94ms +step:1310/1680 train_time:115199ms step_avg:87.94ms +step:1311/1680 train_time:115289ms step_avg:87.94ms +step:1312/1680 train_time:115377ms step_avg:87.94ms +step:1313/1680 train_time:115466ms step_avg:87.94ms +step:1314/1680 train_time:115556ms step_avg:87.94ms +step:1315/1680 train_time:115645ms step_avg:87.94ms +step:1316/1680 train_time:115734ms step_avg:87.94ms +step:1317/1680 train_time:115823ms step_avg:87.94ms +step:1318/1680 train_time:115912ms step_avg:87.95ms +step:1319/1680 train_time:116002ms step_avg:87.95ms +step:1320/1680 train_time:116091ms step_avg:87.95ms +step:1321/1680 train_time:116181ms step_avg:87.95ms +step:1322/1680 train_time:116269ms step_avg:87.95ms +step:1323/1680 train_time:116359ms step_avg:87.95ms +step:1324/1680 train_time:116448ms step_avg:87.95ms +step:1325/1680 train_time:116537ms step_avg:87.95ms +step:1326/1680 train_time:116625ms step_avg:87.95ms +step:1327/1680 train_time:116714ms step_avg:87.95ms +step:1328/1680 train_time:116803ms step_avg:87.95ms +step:1329/1680 train_time:116893ms step_avg:87.96ms +step:1330/1680 train_time:116982ms step_avg:87.96ms +step:1331/1680 train_time:117071ms step_avg:87.96ms +step:1332/1680 train_time:117161ms step_avg:87.96ms +step:1333/1680 train_time:117250ms step_avg:87.96ms +step:1334/1680 train_time:117339ms step_avg:87.96ms +step:1335/1680 train_time:117428ms step_avg:87.96ms +step:1336/1680 train_time:117517ms step_avg:87.96ms +step:1337/1680 train_time:117606ms step_avg:87.96ms +step:1338/1680 train_time:117694ms step_avg:87.96ms +step:1339/1680 train_time:117783ms step_avg:87.96ms +step:1340/1680 train_time:117872ms step_avg:87.96ms +step:1341/1680 train_time:117961ms step_avg:87.96ms +step:1342/1680 train_time:118051ms step_avg:87.97ms +step:1343/1680 train_time:118139ms step_avg:87.97ms +step:1344/1680 train_time:118228ms step_avg:87.97ms +step:1345/1680 train_time:118317ms step_avg:87.97ms +step:1346/1680 train_time:118407ms step_avg:87.97ms +step:1347/1680 train_time:118495ms step_avg:87.97ms +step:1348/1680 train_time:118584ms step_avg:87.97ms +step:1349/1680 train_time:118673ms step_avg:87.97ms +step:1350/1680 train_time:118762ms step_avg:87.97ms +step:1351/1680 train_time:118851ms step_avg:87.97ms +step:1352/1680 train_time:118939ms step_avg:87.97ms +step:1353/1680 train_time:119028ms step_avg:87.97ms +step:1354/1680 train_time:119118ms step_avg:87.98ms +step:1355/1680 train_time:119208ms step_avg:87.98ms +step:1356/1680 train_time:119297ms step_avg:87.98ms +step:1357/1680 train_time:119386ms step_avg:87.98ms +step:1358/1680 train_time:119475ms step_avg:87.98ms +step:1359/1680 train_time:119563ms step_avg:87.98ms +step:1360/1680 train_time:119652ms step_avg:87.98ms +step:1361/1680 train_time:119741ms step_avg:87.98ms +step:1362/1680 train_time:119830ms step_avg:87.98ms +step:1363/1680 train_time:119918ms step_avg:87.98ms +step:1364/1680 train_time:120008ms step_avg:87.98ms +step:1365/1680 train_time:120097ms step_avg:87.98ms +step:1366/1680 train_time:120186ms step_avg:87.98ms +step:1367/1680 train_time:120276ms step_avg:87.99ms +step:1368/1680 train_time:120365ms step_avg:87.99ms +step:1369/1680 train_time:120455ms step_avg:87.99ms +step:1370/1680 train_time:120543ms step_avg:87.99ms +step:1371/1680 train_time:120633ms step_avg:87.99ms +step:1372/1680 train_time:120721ms step_avg:87.99ms +step:1373/1680 train_time:120810ms step_avg:87.99ms +step:1374/1680 train_time:120899ms step_avg:87.99ms +step:1375/1680 train_time:120988ms step_avg:87.99ms +step:1375/1680 val_loss:3.3448 train_time:121080ms step_avg:88.06ms +step:1376/1680 train_time:121098ms step_avg:88.01ms +step:1377/1680 train_time:121174ms step_avg:88.00ms +step:1378/1680 train_time:121268ms step_avg:88.00ms +step:1379/1680 train_time:121357ms step_avg:88.00ms +step:1380/1680 train_time:121446ms step_avg:88.00ms +step:1381/1680 train_time:121534ms step_avg:88.00ms +step:1382/1680 train_time:121622ms step_avg:88.00ms +step:1383/1680 train_time:121710ms step_avg:88.00ms +step:1384/1680 train_time:121799ms step_avg:88.00ms +step:1385/1680 train_time:121886ms step_avg:88.00ms +step:1386/1680 train_time:121974ms step_avg:88.00ms +step:1387/1680 train_time:122063ms step_avg:88.01ms +step:1388/1680 train_time:122154ms step_avg:88.01ms +step:1389/1680 train_time:122246ms step_avg:88.01ms +step:1390/1680 train_time:122336ms step_avg:88.01ms +step:1391/1680 train_time:122427ms step_avg:88.01ms +step:1392/1680 train_time:122516ms step_avg:88.01ms +step:1393/1680 train_time:122604ms step_avg:88.01ms +step:1394/1680 train_time:122692ms step_avg:88.01ms +step:1395/1680 train_time:122780ms step_avg:88.01ms +step:1396/1680 train_time:122868ms step_avg:88.01ms +step:1397/1680 train_time:122957ms step_avg:88.01ms +step:1398/1680 train_time:123046ms step_avg:88.02ms +step:1399/1680 train_time:123136ms step_avg:88.02ms +step:1400/1680 train_time:123225ms step_avg:88.02ms +step:1401/1680 train_time:123316ms step_avg:88.02ms +step:1402/1680 train_time:123407ms step_avg:88.02ms +step:1403/1680 train_time:123496ms step_avg:88.02ms +step:1404/1680 train_time:123585ms step_avg:88.02ms +step:1405/1680 train_time:123673ms step_avg:88.02ms +step:1406/1680 train_time:123761ms step_avg:88.02ms +step:1407/1680 train_time:123849ms step_avg:88.02ms +step:1408/1680 train_time:123938ms step_avg:88.02ms +step:1409/1680 train_time:124027ms step_avg:88.02ms +step:1410/1680 train_time:124116ms step_avg:88.03ms +step:1411/1680 train_time:124205ms step_avg:88.03ms +step:1412/1680 train_time:124295ms step_avg:88.03ms +step:1413/1680 train_time:124385ms step_avg:88.03ms +step:1414/1680 train_time:124475ms step_avg:88.03ms +step:1415/1680 train_time:124564ms step_avg:88.03ms +step:1416/1680 train_time:124652ms step_avg:88.03ms +step:1417/1680 train_time:124740ms step_avg:88.03ms +step:1418/1680 train_time:124829ms step_avg:88.03ms +step:1419/1680 train_time:124918ms step_avg:88.03ms +step:1420/1680 train_time:125006ms step_avg:88.03ms +step:1421/1680 train_time:125096ms step_avg:88.03ms +step:1422/1680 train_time:125185ms step_avg:88.03ms +step:1423/1680 train_time:125275ms step_avg:88.04ms +step:1424/1680 train_time:125364ms step_avg:88.04ms +step:1425/1680 train_time:125453ms step_avg:88.04ms +step:1426/1680 train_time:125541ms step_avg:88.04ms +step:1427/1680 train_time:125630ms step_avg:88.04ms +step:1428/1680 train_time:125719ms step_avg:88.04ms +step:1429/1680 train_time:125808ms step_avg:88.04ms +step:1430/1680 train_time:125897ms step_avg:88.04ms +step:1431/1680 train_time:125985ms step_avg:88.04ms +step:1432/1680 train_time:126074ms step_avg:88.04ms +step:1433/1680 train_time:126163ms step_avg:88.04ms +step:1434/1680 train_time:126253ms step_avg:88.04ms +step:1435/1680 train_time:126342ms step_avg:88.04ms +step:1436/1680 train_time:126431ms step_avg:88.04ms +step:1437/1680 train_time:126520ms step_avg:88.04ms +step:1438/1680 train_time:126609ms step_avg:88.04ms +step:1439/1680 train_time:126698ms step_avg:88.05ms +step:1440/1680 train_time:126787ms step_avg:88.05ms +step:1441/1680 train_time:126876ms step_avg:88.05ms +step:1442/1680 train_time:126965ms step_avg:88.05ms +step:1443/1680 train_time:127054ms step_avg:88.05ms +step:1444/1680 train_time:127143ms step_avg:88.05ms +step:1445/1680 train_time:127232ms step_avg:88.05ms +step:1446/1680 train_time:127321ms step_avg:88.05ms +step:1447/1680 train_time:127411ms step_avg:88.05ms +step:1448/1680 train_time:127500ms step_avg:88.05ms +step:1449/1680 train_time:127589ms step_avg:88.05ms +step:1450/1680 train_time:127678ms step_avg:88.05ms +step:1451/1680 train_time:127767ms step_avg:88.05ms +step:1452/1680 train_time:127856ms step_avg:88.06ms +step:1453/1680 train_time:127945ms step_avg:88.06ms +step:1454/1680 train_time:128034ms step_avg:88.06ms +step:1455/1680 train_time:128123ms step_avg:88.06ms +step:1456/1680 train_time:128213ms step_avg:88.06ms +step:1457/1680 train_time:128302ms step_avg:88.06ms +step:1458/1680 train_time:128392ms step_avg:88.06ms +step:1459/1680 train_time:128480ms step_avg:88.06ms +step:1460/1680 train_time:128569ms step_avg:88.06ms +step:1461/1680 train_time:128658ms step_avg:88.06ms +step:1462/1680 train_time:128747ms step_avg:88.06ms +step:1463/1680 train_time:128836ms step_avg:88.06ms +step:1464/1680 train_time:128926ms step_avg:88.06ms +step:1465/1680 train_time:129015ms step_avg:88.06ms +step:1466/1680 train_time:129103ms step_avg:88.06ms +step:1467/1680 train_time:129193ms step_avg:88.07ms +step:1468/1680 train_time:129282ms step_avg:88.07ms +step:1469/1680 train_time:129372ms step_avg:88.07ms +step:1470/1680 train_time:129460ms step_avg:88.07ms +step:1471/1680 train_time:129549ms step_avg:88.07ms +step:1472/1680 train_time:129638ms step_avg:88.07ms +step:1473/1680 train_time:129727ms step_avg:88.07ms +step:1474/1680 train_time:129815ms step_avg:88.07ms +step:1475/1680 train_time:129906ms step_avg:88.07ms +step:1476/1680 train_time:129994ms step_avg:88.07ms +step:1477/1680 train_time:130082ms step_avg:88.07ms +step:1478/1680 train_time:130172ms step_avg:88.07ms +step:1479/1680 train_time:130262ms step_avg:88.07ms +step:1480/1680 train_time:130353ms step_avg:88.08ms +step:1481/1680 train_time:130442ms step_avg:88.08ms +step:1482/1680 train_time:130532ms step_avg:88.08ms +step:1483/1680 train_time:130620ms step_avg:88.08ms +step:1484/1680 train_time:130709ms step_avg:88.08ms +step:1485/1680 train_time:130798ms step_avg:88.08ms +step:1486/1680 train_time:130888ms step_avg:88.08ms +step:1487/1680 train_time:130976ms step_avg:88.08ms +step:1488/1680 train_time:131065ms step_avg:88.08ms +step:1489/1680 train_time:131155ms step_avg:88.08ms +step:1490/1680 train_time:131243ms step_avg:88.08ms +step:1491/1680 train_time:131334ms step_avg:88.08ms +step:1492/1680 train_time:131422ms step_avg:88.08ms +step:1493/1680 train_time:131512ms step_avg:88.09ms +step:1494/1680 train_time:131601ms step_avg:88.09ms +step:1495/1680 train_time:131691ms step_avg:88.09ms +step:1496/1680 train_time:131779ms step_avg:88.09ms +step:1497/1680 train_time:131867ms step_avg:88.09ms +step:1498/1680 train_time:131957ms step_avg:88.09ms +step:1499/1680 train_time:132046ms step_avg:88.09ms +step:1500/1680 train_time:132135ms step_avg:88.09ms +step:1500/1680 val_loss:3.3152 train_time:132225ms step_avg:88.15ms +step:1501/1680 train_time:132244ms step_avg:88.10ms +step:1502/1680 train_time:132317ms step_avg:88.09ms +step:1503/1680 train_time:132410ms step_avg:88.10ms +step:1504/1680 train_time:132499ms step_avg:88.10ms +step:1505/1680 train_time:132587ms step_avg:88.10ms +step:1506/1680 train_time:132675ms step_avg:88.10ms +step:1507/1680 train_time:132763ms step_avg:88.10ms +step:1508/1680 train_time:132852ms step_avg:88.10ms +step:1509/1680 train_time:132939ms step_avg:88.10ms +step:1510/1680 train_time:133028ms step_avg:88.10ms +step:1511/1680 train_time:133117ms step_avg:88.10ms +step:1512/1680 train_time:133208ms step_avg:88.10ms +step:1513/1680 train_time:133299ms step_avg:88.10ms +step:1514/1680 train_time:133391ms step_avg:88.10ms +step:1515/1680 train_time:133481ms step_avg:88.11ms +step:1516/1680 train_time:133570ms step_avg:88.11ms +step:1517/1680 train_time:133659ms step_avg:88.11ms +step:1518/1680 train_time:133748ms step_avg:88.11ms +step:1519/1680 train_time:133837ms step_avg:88.11ms +step:1520/1680 train_time:133925ms step_avg:88.11ms +step:1521/1680 train_time:134013ms step_avg:88.11ms +step:1522/1680 train_time:134102ms step_avg:88.11ms +step:1523/1680 train_time:134191ms step_avg:88.11ms +step:1524/1680 train_time:134282ms step_avg:88.11ms +step:1525/1680 train_time:134373ms step_avg:88.11ms +step:1526/1680 train_time:134464ms step_avg:88.12ms +step:1527/1680 train_time:134553ms step_avg:88.12ms +step:1528/1680 train_time:134641ms step_avg:88.12ms +step:1529/1680 train_time:134729ms step_avg:88.12ms +step:1530/1680 train_time:134818ms step_avg:88.12ms +step:1531/1680 train_time:134907ms step_avg:88.12ms +step:1532/1680 train_time:134996ms step_avg:88.12ms +step:1533/1680 train_time:135084ms step_avg:88.12ms +step:1534/1680 train_time:135173ms step_avg:88.12ms +step:1535/1680 train_time:135263ms step_avg:88.12ms +step:1536/1680 train_time:135353ms step_avg:88.12ms +step:1537/1680 train_time:135443ms step_avg:88.12ms +step:1538/1680 train_time:135532ms step_avg:88.12ms +step:1539/1680 train_time:135622ms step_avg:88.12ms +step:1540/1680 train_time:135710ms step_avg:88.12ms +step:1541/1680 train_time:135799ms step_avg:88.12ms +step:1542/1680 train_time:135888ms step_avg:88.12ms +step:1543/1680 train_time:135976ms step_avg:88.12ms +step:1544/1680 train_time:136064ms step_avg:88.12ms +step:1545/1680 train_time:136153ms step_avg:88.13ms +step:1546/1680 train_time:136243ms step_avg:88.13ms +step:1547/1680 train_time:136332ms step_avg:88.13ms +step:1548/1680 train_time:136422ms step_avg:88.13ms +step:1549/1680 train_time:136511ms step_avg:88.13ms +step:1550/1680 train_time:136600ms step_avg:88.13ms +step:1551/1680 train_time:136688ms step_avg:88.13ms +step:1552/1680 train_time:136777ms step_avg:88.13ms +step:1553/1680 train_time:136866ms step_avg:88.13ms +step:1554/1680 train_time:136955ms step_avg:88.13ms +step:1555/1680 train_time:137044ms step_avg:88.13ms +step:1556/1680 train_time:137134ms step_avg:88.13ms +step:1557/1680 train_time:137224ms step_avg:88.13ms +step:1558/1680 train_time:137313ms step_avg:88.13ms +step:1559/1680 train_time:137402ms step_avg:88.13ms +step:1560/1680 train_time:137491ms step_avg:88.14ms +step:1561/1680 train_time:137582ms step_avg:88.14ms +step:1562/1680 train_time:137672ms step_avg:88.14ms +step:1563/1680 train_time:137760ms step_avg:88.14ms +step:1564/1680 train_time:137849ms step_avg:88.14ms +step:1565/1680 train_time:137937ms step_avg:88.14ms +step:1566/1680 train_time:138026ms step_avg:88.14ms +step:1567/1680 train_time:138116ms step_avg:88.14ms +step:1568/1680 train_time:138205ms step_avg:88.14ms +step:1569/1680 train_time:138294ms step_avg:88.14ms +step:1570/1680 train_time:138384ms step_avg:88.14ms +step:1571/1680 train_time:138473ms step_avg:88.14ms +step:1572/1680 train_time:138562ms step_avg:88.14ms +step:1573/1680 train_time:138651ms step_avg:88.14ms +step:1574/1680 train_time:138741ms step_avg:88.15ms +step:1575/1680 train_time:138830ms step_avg:88.15ms +step:1576/1680 train_time:138919ms step_avg:88.15ms +step:1577/1680 train_time:139008ms step_avg:88.15ms +step:1578/1680 train_time:139098ms step_avg:88.15ms +step:1579/1680 train_time:139187ms step_avg:88.15ms +step:1580/1680 train_time:139276ms step_avg:88.15ms +step:1581/1680 train_time:139364ms step_avg:88.15ms +step:1582/1680 train_time:139453ms step_avg:88.15ms +step:1583/1680 train_time:139542ms step_avg:88.15ms +step:1584/1680 train_time:139631ms step_avg:88.15ms +step:1585/1680 train_time:139720ms step_avg:88.15ms +step:1586/1680 train_time:139810ms step_avg:88.15ms +step:1587/1680 train_time:139899ms step_avg:88.15ms +step:1588/1680 train_time:139988ms step_avg:88.15ms +step:1589/1680 train_time:140078ms step_avg:88.15ms +step:1590/1680 train_time:140166ms step_avg:88.15ms +step:1591/1680 train_time:140255ms step_avg:88.16ms +step:1592/1680 train_time:140344ms step_avg:88.16ms +step:1593/1680 train_time:140434ms step_avg:88.16ms +step:1594/1680 train_time:140523ms step_avg:88.16ms +step:1595/1680 train_time:140612ms step_avg:88.16ms +step:1596/1680 train_time:140702ms step_avg:88.16ms +step:1597/1680 train_time:140791ms step_avg:88.16ms +step:1598/1680 train_time:140880ms step_avg:88.16ms +step:1599/1680 train_time:140970ms step_avg:88.16ms +step:1600/1680 train_time:141059ms step_avg:88.16ms +step:1601/1680 train_time:141148ms step_avg:88.16ms +step:1602/1680 train_time:141237ms step_avg:88.16ms +step:1603/1680 train_time:141326ms step_avg:88.16ms +step:1604/1680 train_time:141415ms step_avg:88.16ms +step:1605/1680 train_time:141504ms step_avg:88.16ms +step:1606/1680 train_time:141593ms step_avg:88.16ms +step:1607/1680 train_time:141682ms step_avg:88.17ms +step:1608/1680 train_time:141771ms step_avg:88.17ms +step:1609/1680 train_time:141860ms step_avg:88.17ms +step:1610/1680 train_time:141950ms step_avg:88.17ms +step:1611/1680 train_time:142039ms step_avg:88.17ms +step:1612/1680 train_time:142128ms step_avg:88.17ms +step:1613/1680 train_time:142218ms step_avg:88.17ms +step:1614/1680 train_time:142307ms step_avg:88.17ms +step:1615/1680 train_time:142396ms step_avg:88.17ms +step:1616/1680 train_time:142484ms step_avg:88.17ms +step:1617/1680 train_time:142573ms step_avg:88.17ms +step:1618/1680 train_time:142662ms step_avg:88.17ms +step:1619/1680 train_time:142752ms step_avg:88.17ms +step:1620/1680 train_time:142841ms step_avg:88.17ms +step:1621/1680 train_time:142930ms step_avg:88.17ms +step:1622/1680 train_time:143020ms step_avg:88.18ms +step:1623/1680 train_time:143110ms step_avg:88.18ms +step:1624/1680 train_time:143199ms step_avg:88.18ms +step:1625/1680 train_time:143288ms step_avg:88.18ms +step:1625/1680 val_loss:3.2912 train_time:143378ms step_avg:88.23ms +step:1626/1680 train_time:143397ms step_avg:88.19ms +step:1627/1680 train_time:143472ms step_avg:88.18ms +step:1628/1680 train_time:143565ms step_avg:88.19ms +step:1629/1680 train_time:143656ms step_avg:88.19ms +step:1630/1680 train_time:143744ms step_avg:88.19ms +step:1631/1680 train_time:143831ms step_avg:88.19ms +step:1632/1680 train_time:143919ms step_avg:88.19ms +step:1633/1680 train_time:144007ms step_avg:88.19ms +step:1634/1680 train_time:144095ms step_avg:88.19ms +step:1635/1680 train_time:144183ms step_avg:88.19ms +step:1636/1680 train_time:144271ms step_avg:88.19ms +step:1637/1680 train_time:144363ms step_avg:88.19ms +step:1638/1680 train_time:144455ms step_avg:88.19ms +step:1639/1680 train_time:144546ms step_avg:88.19ms +step:1640/1680 train_time:144637ms step_avg:88.19ms +step:1641/1680 train_time:144726ms step_avg:88.19ms +step:1642/1680 train_time:144815ms step_avg:88.19ms +step:1643/1680 train_time:144903ms step_avg:88.19ms +step:1644/1680 train_time:144992ms step_avg:88.19ms +step:1645/1680 train_time:145080ms step_avg:88.19ms +step:1646/1680 train_time:145167ms step_avg:88.19ms +step:1647/1680 train_time:145256ms step_avg:88.19ms +step:1648/1680 train_time:145345ms step_avg:88.19ms +step:1649/1680 train_time:145435ms step_avg:88.20ms +step:1650/1680 train_time:145526ms step_avg:88.20ms +step:1651/1680 train_time:145615ms step_avg:88.20ms +step:1652/1680 train_time:145705ms step_avg:88.20ms +step:1653/1680 train_time:145793ms step_avg:88.20ms +step:1654/1680 train_time:145882ms step_avg:88.20ms +step:1655/1680 train_time:145971ms step_avg:88.20ms +step:1656/1680 train_time:146059ms step_avg:88.20ms +step:1657/1680 train_time:146147ms step_avg:88.20ms +step:1658/1680 train_time:146236ms step_avg:88.20ms +step:1659/1680 train_time:146326ms step_avg:88.20ms +step:1660/1680 train_time:146416ms step_avg:88.20ms +step:1661/1680 train_time:146506ms step_avg:88.20ms +step:1662/1680 train_time:146596ms step_avg:88.20ms +step:1663/1680 train_time:146686ms step_avg:88.21ms +step:1664/1680 train_time:146774ms step_avg:88.21ms +step:1665/1680 train_time:146863ms step_avg:88.21ms +step:1666/1680 train_time:146952ms step_avg:88.21ms +step:1667/1680 train_time:147040ms step_avg:88.21ms +step:1668/1680 train_time:147129ms step_avg:88.21ms +step:1669/1680 train_time:147218ms step_avg:88.21ms +step:1670/1680 train_time:147307ms step_avg:88.21ms +step:1671/1680 train_time:147396ms step_avg:88.21ms +step:1672/1680 train_time:147485ms step_avg:88.21ms +step:1673/1680 train_time:147574ms step_avg:88.21ms +step:1674/1680 train_time:147664ms step_avg:88.21ms +step:1675/1680 train_time:147753ms step_avg:88.21ms +step:1676/1680 train_time:147843ms step_avg:88.21ms +step:1677/1680 train_time:147932ms step_avg:88.21ms +step:1678/1680 train_time:148020ms step_avg:88.21ms +step:1679/1680 train_time:148109ms step_avg:88.21ms +step:1680/1680 train_time:148198ms step_avg:88.21ms +step:1680/1680 val_loss:3.2805 train_time:148289ms step_avg:88.27ms +peak memory allocated: 30760 MiB reserved: 45834 MiB diff --git a/records/092725_BF16CE/b21c8cc7-c09c-401f-b654-23e947ad3e38.txt b/records/092725_BF16CE/b21c8cc7-c09c-401f-b654-23e947ad3e38.txt new file mode 100644 index 000000000..4b14488b6 --- /dev/null +++ b/records/092725_BF16CE/b21c8cc7-c09c-401f-b654-23e947ad3e38.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:03:08 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 167692 C /usr/bin/python 0MiB | +| 0 N/A N/A 167693 C /usr/bin/python 0MiB | +| 0 N/A N/A 167694 C /usr/bin/python 0MiB | +| 0 N/A N/A 167695 C /usr/bin/python 0MiB | +| 0 N/A N/A 167696 C /usr/bin/python 0MiB | +| 0 N/A N/A 167697 C /usr/bin/python 0MiB | +| 0 N/A N/A 167698 C /usr/bin/python 0MiB | +| 0 N/A N/A 167699 C /usr/bin/python 0MiB | +| 1 N/A N/A 167693 C /usr/bin/python 0MiB | +| 2 N/A N/A 167694 C /usr/bin/python 0MiB | +| 3 N/A N/A 167695 C /usr/bin/python 0MiB | +| 4 N/A N/A 167696 C /usr/bin/python 0MiB | +| 5 N/A N/A 167697 C /usr/bin/python 0MiB | +| 6 N/A N/A 167698 C /usr/bin/python 0MiB | +| 7 N/A N/A 167699 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:140ms step_avg:140.18ms +step:2/1680 train_time:159ms step_avg:79.71ms +step:3/1680 train_time:224ms step_avg:74.72ms +step:4/1680 train_time:309ms step_avg:77.29ms +step:5/1680 train_time:396ms step_avg:79.12ms +step:6/1680 train_time:481ms step_avg:80.22ms +step:7/1680 train_time:567ms step_avg:81.04ms +step:8/1680 train_time:653ms step_avg:81.64ms +step:9/1680 train_time:739ms step_avg:82.13ms +step:10/1680 train_time:825ms step_avg:82.52ms +step:11/1680 train_time:912ms step_avg:82.87ms +step:12/1680 train_time:999ms step_avg:83.28ms +step:13/1680 train_time:1091ms step_avg:83.89ms +step:14/1680 train_time:1179ms step_avg:84.22ms +step:15/1680 train_time:1267ms step_avg:84.44ms +step:16/1680 train_time:1354ms step_avg:84.62ms +step:17/1680 train_time:1441ms step_avg:84.74ms +step:18/1680 train_time:1527ms step_avg:84.83ms +step:19/1680 train_time:1614ms step_avg:84.93ms +step:20/1680 train_time:1700ms step_avg:85.01ms +step:21/1680 train_time:1787ms step_avg:85.08ms +step:22/1680 train_time:1873ms step_avg:85.14ms +step:23/1680 train_time:1960ms step_avg:85.21ms +step:24/1680 train_time:2049ms step_avg:85.36ms +step:25/1680 train_time:2137ms step_avg:85.47ms +step:26/1680 train_time:2226ms step_avg:85.60ms +step:27/1680 train_time:2314ms step_avg:85.69ms +step:28/1680 train_time:2401ms step_avg:85.75ms +step:29/1680 train_time:2488ms step_avg:85.79ms +step:30/1680 train_time:2575ms step_avg:85.84ms +step:31/1680 train_time:2662ms step_avg:85.89ms +step:32/1680 train_time:2749ms step_avg:85.91ms +step:33/1680 train_time:2836ms step_avg:85.95ms +step:34/1680 train_time:2923ms step_avg:85.97ms +step:35/1680 train_time:3010ms step_avg:86.00ms +step:36/1680 train_time:3098ms step_avg:86.05ms +step:37/1680 train_time:3185ms step_avg:86.09ms +step:38/1680 train_time:3273ms step_avg:86.13ms +step:39/1680 train_time:3360ms step_avg:86.16ms +step:40/1680 train_time:3448ms step_avg:86.19ms +step:41/1680 train_time:3534ms step_avg:86.20ms +step:42/1680 train_time:3621ms step_avg:86.21ms +step:43/1680 train_time:3708ms step_avg:86.23ms +step:44/1680 train_time:3795ms step_avg:86.26ms +step:45/1680 train_time:3882ms step_avg:86.28ms +step:46/1680 train_time:3969ms step_avg:86.29ms +step:47/1680 train_time:4057ms step_avg:86.31ms +step:48/1680 train_time:4144ms step_avg:86.34ms +step:49/1680 train_time:4232ms step_avg:86.37ms +step:50/1680 train_time:4320ms step_avg:86.39ms +step:51/1680 train_time:4407ms step_avg:86.41ms +step:52/1680 train_time:4494ms step_avg:86.43ms +step:53/1680 train_time:4582ms step_avg:86.44ms +step:54/1680 train_time:4668ms step_avg:86.45ms +step:55/1680 train_time:4756ms step_avg:86.47ms +step:56/1680 train_time:4843ms step_avg:86.48ms +step:57/1680 train_time:4929ms step_avg:86.48ms +step:58/1680 train_time:5017ms step_avg:86.49ms +step:59/1680 train_time:5104ms step_avg:86.51ms +step:60/1680 train_time:5192ms step_avg:86.53ms +step:61/1680 train_time:5279ms step_avg:86.54ms +step:62/1680 train_time:5366ms step_avg:86.56ms +step:63/1680 train_time:5454ms step_avg:86.57ms +step:64/1680 train_time:5541ms step_avg:86.57ms +step:65/1680 train_time:5627ms step_avg:86.57ms +step:66/1680 train_time:5715ms step_avg:86.58ms +step:67/1680 train_time:5801ms step_avg:86.59ms +step:68/1680 train_time:5888ms step_avg:86.59ms +step:69/1680 train_time:5975ms step_avg:86.60ms +step:70/1680 train_time:6063ms step_avg:86.61ms +step:71/1680 train_time:6150ms step_avg:86.62ms +step:72/1680 train_time:6237ms step_avg:86.63ms +step:73/1680 train_time:6325ms step_avg:86.65ms +step:74/1680 train_time:6413ms step_avg:86.66ms +step:75/1680 train_time:6500ms step_avg:86.66ms +step:76/1680 train_time:6587ms step_avg:86.67ms +step:77/1680 train_time:6674ms step_avg:86.68ms +step:78/1680 train_time:6761ms step_avg:86.67ms +step:79/1680 train_time:6847ms step_avg:86.68ms +step:80/1680 train_time:6935ms step_avg:86.69ms +step:81/1680 train_time:7022ms step_avg:86.69ms +step:82/1680 train_time:7109ms step_avg:86.70ms +step:83/1680 train_time:7196ms step_avg:86.70ms +step:84/1680 train_time:7284ms step_avg:86.72ms +step:85/1680 train_time:7371ms step_avg:86.72ms +step:86/1680 train_time:7459ms step_avg:86.73ms +step:87/1680 train_time:7546ms step_avg:86.73ms +step:88/1680 train_time:7633ms step_avg:86.74ms +step:89/1680 train_time:7720ms step_avg:86.74ms +step:90/1680 train_time:7807ms step_avg:86.74ms +step:91/1680 train_time:7894ms step_avg:86.75ms +step:92/1680 train_time:7981ms step_avg:86.75ms +step:93/1680 train_time:8068ms step_avg:86.76ms +step:94/1680 train_time:8155ms step_avg:86.76ms +step:95/1680 train_time:8242ms step_avg:86.76ms +step:96/1680 train_time:8330ms step_avg:86.77ms +step:97/1680 train_time:8417ms step_avg:86.77ms +step:98/1680 train_time:8504ms step_avg:86.78ms +step:99/1680 train_time:8592ms step_avg:86.79ms +step:100/1680 train_time:8679ms step_avg:86.79ms +step:101/1680 train_time:8766ms step_avg:86.79ms +step:102/1680 train_time:8853ms step_avg:86.79ms +step:103/1680 train_time:8940ms step_avg:86.79ms +step:104/1680 train_time:9027ms step_avg:86.80ms +step:105/1680 train_time:9115ms step_avg:86.81ms +step:106/1680 train_time:9202ms step_avg:86.81ms +step:107/1680 train_time:9289ms step_avg:86.81ms +step:108/1680 train_time:9376ms step_avg:86.82ms +step:109/1680 train_time:9464ms step_avg:86.82ms +step:110/1680 train_time:9550ms step_avg:86.82ms +step:111/1680 train_time:9638ms step_avg:86.83ms +step:112/1680 train_time:9725ms step_avg:86.83ms +step:113/1680 train_time:9812ms step_avg:86.83ms +step:114/1680 train_time:9899ms step_avg:86.83ms +step:115/1680 train_time:9985ms step_avg:86.83ms +step:116/1680 train_time:10073ms step_avg:86.83ms +step:117/1680 train_time:10160ms step_avg:86.83ms +step:118/1680 train_time:10246ms step_avg:86.83ms +step:119/1680 train_time:10334ms step_avg:86.84ms +step:120/1680 train_time:10421ms step_avg:86.84ms +step:121/1680 train_time:10507ms step_avg:86.84ms +step:122/1680 train_time:10595ms step_avg:86.84ms +step:123/1680 train_time:10682ms step_avg:86.85ms +step:124/1680 train_time:10769ms step_avg:86.85ms +step:125/1680 train_time:10856ms step_avg:86.85ms +step:125/1680 val_loss:4.2950 train_time:10944ms step_avg:87.55ms +step:126/1680 train_time:10965ms step_avg:87.02ms +step:127/1680 train_time:11036ms step_avg:86.89ms +step:128/1680 train_time:11130ms step_avg:86.95ms +step:129/1680 train_time:11222ms step_avg:86.99ms +step:130/1680 train_time:11309ms step_avg:86.99ms +step:131/1680 train_time:11396ms step_avg:86.99ms +step:132/1680 train_time:11482ms step_avg:86.99ms +step:133/1680 train_time:11568ms step_avg:86.98ms +step:134/1680 train_time:11654ms step_avg:86.97ms +step:135/1680 train_time:11740ms step_avg:86.96ms +step:136/1680 train_time:11826ms step_avg:86.96ms +step:137/1680 train_time:11912ms step_avg:86.95ms +step:138/1680 train_time:11999ms step_avg:86.95ms +step:139/1680 train_time:12089ms step_avg:86.97ms +step:140/1680 train_time:12178ms step_avg:86.99ms +step:141/1680 train_time:12266ms step_avg:86.99ms +step:142/1680 train_time:12354ms step_avg:87.00ms +step:143/1680 train_time:12441ms step_avg:87.00ms +step:144/1680 train_time:12527ms step_avg:86.99ms +step:145/1680 train_time:12613ms step_avg:86.99ms +step:146/1680 train_time:12700ms step_avg:86.99ms +step:147/1680 train_time:12786ms step_avg:86.98ms +step:148/1680 train_time:12873ms step_avg:86.98ms +step:149/1680 train_time:12960ms step_avg:86.98ms +step:150/1680 train_time:13048ms step_avg:86.99ms +step:151/1680 train_time:13136ms step_avg:87.00ms +step:152/1680 train_time:13225ms step_avg:87.01ms +step:153/1680 train_time:13312ms step_avg:87.01ms +step:154/1680 train_time:13400ms step_avg:87.01ms +step:155/1680 train_time:13487ms step_avg:87.01ms +step:156/1680 train_time:13574ms step_avg:87.01ms +step:157/1680 train_time:13660ms step_avg:87.01ms +step:158/1680 train_time:13747ms step_avg:87.00ms +step:159/1680 train_time:13834ms step_avg:87.01ms +step:160/1680 train_time:13921ms step_avg:87.00ms +step:161/1680 train_time:14008ms step_avg:87.01ms +step:162/1680 train_time:14096ms step_avg:87.01ms +step:163/1680 train_time:14184ms step_avg:87.02ms +step:164/1680 train_time:14271ms step_avg:87.02ms +step:165/1680 train_time:14359ms step_avg:87.02ms +step:166/1680 train_time:14446ms step_avg:87.02ms +step:167/1680 train_time:14533ms step_avg:87.02ms +step:168/1680 train_time:14620ms step_avg:87.03ms +step:169/1680 train_time:14707ms step_avg:87.02ms +step:170/1680 train_time:14793ms step_avg:87.02ms +step:171/1680 train_time:14879ms step_avg:87.01ms +step:172/1680 train_time:14967ms step_avg:87.02ms +step:173/1680 train_time:15054ms step_avg:87.02ms +step:174/1680 train_time:15141ms step_avg:87.02ms +step:175/1680 train_time:15228ms step_avg:87.02ms +step:176/1680 train_time:15316ms step_avg:87.02ms +step:177/1680 train_time:15403ms step_avg:87.02ms +step:178/1680 train_time:15489ms step_avg:87.02ms +step:179/1680 train_time:15577ms step_avg:87.02ms +step:180/1680 train_time:15663ms step_avg:87.02ms +step:181/1680 train_time:15750ms step_avg:87.01ms +step:182/1680 train_time:15836ms step_avg:87.01ms +step:183/1680 train_time:15922ms step_avg:87.01ms +step:184/1680 train_time:16009ms step_avg:87.01ms +step:185/1680 train_time:16096ms step_avg:87.01ms +step:186/1680 train_time:16183ms step_avg:87.01ms +step:187/1680 train_time:16271ms step_avg:87.01ms +step:188/1680 train_time:16358ms step_avg:87.01ms +step:189/1680 train_time:16445ms step_avg:87.01ms +step:190/1680 train_time:16532ms step_avg:87.01ms +step:191/1680 train_time:16619ms step_avg:87.01ms +step:192/1680 train_time:16705ms step_avg:87.00ms +step:193/1680 train_time:16792ms step_avg:87.00ms +step:194/1680 train_time:16878ms step_avg:87.00ms +step:195/1680 train_time:16965ms step_avg:87.00ms +step:196/1680 train_time:17053ms step_avg:87.00ms +step:197/1680 train_time:17139ms step_avg:87.00ms +step:198/1680 train_time:17227ms step_avg:87.00ms +step:199/1680 train_time:17314ms step_avg:87.01ms +step:200/1680 train_time:17401ms step_avg:87.01ms +step:201/1680 train_time:17488ms step_avg:87.01ms +step:202/1680 train_time:17575ms step_avg:87.01ms +step:203/1680 train_time:17662ms step_avg:87.00ms +step:204/1680 train_time:17750ms step_avg:87.01ms +step:205/1680 train_time:17837ms step_avg:87.01ms +step:206/1680 train_time:17923ms step_avg:87.00ms +step:207/1680 train_time:18010ms step_avg:87.00ms +step:208/1680 train_time:18097ms step_avg:87.00ms +step:209/1680 train_time:18184ms step_avg:87.00ms +step:210/1680 train_time:18271ms step_avg:87.01ms +step:211/1680 train_time:18358ms step_avg:87.00ms +step:212/1680 train_time:18445ms step_avg:87.00ms +step:213/1680 train_time:18532ms step_avg:87.00ms +step:214/1680 train_time:18619ms step_avg:87.00ms +step:215/1680 train_time:18706ms step_avg:87.00ms +step:216/1680 train_time:18793ms step_avg:87.00ms +step:217/1680 train_time:18880ms step_avg:87.00ms +step:218/1680 train_time:18967ms step_avg:87.00ms +step:219/1680 train_time:19054ms step_avg:87.01ms +step:220/1680 train_time:19141ms step_avg:87.00ms +step:221/1680 train_time:19228ms step_avg:87.00ms +step:222/1680 train_time:19315ms step_avg:87.00ms +step:223/1680 train_time:19401ms step_avg:87.00ms +step:224/1680 train_time:19488ms step_avg:87.00ms +step:225/1680 train_time:19576ms step_avg:87.00ms +step:226/1680 train_time:19663ms step_avg:87.00ms +step:227/1680 train_time:19750ms step_avg:87.00ms +step:228/1680 train_time:19837ms step_avg:87.00ms +step:229/1680 train_time:19923ms step_avg:87.00ms +step:230/1680 train_time:20010ms step_avg:87.00ms +step:231/1680 train_time:20096ms step_avg:87.00ms +step:232/1680 train_time:20184ms step_avg:87.00ms +step:233/1680 train_time:20271ms step_avg:87.00ms +step:234/1680 train_time:20357ms step_avg:87.00ms +step:235/1680 train_time:20445ms step_avg:87.00ms +step:236/1680 train_time:20532ms step_avg:87.00ms +step:237/1680 train_time:20619ms step_avg:87.00ms +step:238/1680 train_time:20706ms step_avg:87.00ms +step:239/1680 train_time:20793ms step_avg:87.00ms +step:240/1680 train_time:20881ms step_avg:87.00ms +step:241/1680 train_time:20968ms step_avg:87.00ms +step:242/1680 train_time:21054ms step_avg:87.00ms +step:243/1680 train_time:21142ms step_avg:87.00ms +step:244/1680 train_time:21229ms step_avg:87.00ms +step:245/1680 train_time:21315ms step_avg:87.00ms +step:246/1680 train_time:21402ms step_avg:87.00ms +step:247/1680 train_time:21490ms step_avg:87.00ms +step:248/1680 train_time:21577ms step_avg:87.00ms +step:249/1680 train_time:21664ms step_avg:87.00ms +step:250/1680 train_time:21751ms step_avg:87.00ms +step:250/1680 val_loss:3.9698 train_time:21839ms step_avg:87.36ms +step:251/1680 train_time:21859ms step_avg:87.09ms +step:252/1680 train_time:21930ms step_avg:87.02ms +step:253/1680 train_time:22020ms step_avg:87.04ms +step:254/1680 train_time:22108ms step_avg:87.04ms +step:255/1680 train_time:22195ms step_avg:87.04ms +step:256/1680 train_time:22281ms step_avg:87.04ms +step:257/1680 train_time:22367ms step_avg:87.03ms +step:258/1680 train_time:22453ms step_avg:87.03ms +step:259/1680 train_time:22539ms step_avg:87.02ms +step:260/1680 train_time:22625ms step_avg:87.02ms +step:261/1680 train_time:22711ms step_avg:87.02ms +step:262/1680 train_time:22799ms step_avg:87.02ms +step:263/1680 train_time:22886ms step_avg:87.02ms +step:264/1680 train_time:22975ms step_avg:87.02ms +step:265/1680 train_time:23062ms step_avg:87.03ms +step:266/1680 train_time:23150ms step_avg:87.03ms +step:267/1680 train_time:23237ms step_avg:87.03ms +step:268/1680 train_time:23324ms step_avg:87.03ms +step:269/1680 train_time:23410ms step_avg:87.03ms +step:270/1680 train_time:23496ms step_avg:87.02ms +step:271/1680 train_time:23583ms step_avg:87.02ms +step:272/1680 train_time:23669ms step_avg:87.02ms +step:273/1680 train_time:23756ms step_avg:87.02ms +step:274/1680 train_time:23843ms step_avg:87.02ms +step:275/1680 train_time:23931ms step_avg:87.02ms +step:276/1680 train_time:24019ms step_avg:87.03ms +step:277/1680 train_time:24107ms step_avg:87.03ms +step:278/1680 train_time:24193ms step_avg:87.03ms +step:279/1680 train_time:24280ms step_avg:87.03ms +step:280/1680 train_time:24367ms step_avg:87.02ms +step:281/1680 train_time:24454ms step_avg:87.02ms +step:282/1680 train_time:24540ms step_avg:87.02ms +step:283/1680 train_time:24627ms step_avg:87.02ms +step:284/1680 train_time:24713ms step_avg:87.02ms +step:285/1680 train_time:24801ms step_avg:87.02ms +step:286/1680 train_time:24888ms step_avg:87.02ms +step:287/1680 train_time:24975ms step_avg:87.02ms +step:288/1680 train_time:25062ms step_avg:87.02ms +step:289/1680 train_time:25149ms step_avg:87.02ms +step:290/1680 train_time:25236ms step_avg:87.02ms +step:291/1680 train_time:25323ms step_avg:87.02ms +step:292/1680 train_time:25410ms step_avg:87.02ms +step:293/1680 train_time:25497ms step_avg:87.02ms +step:294/1680 train_time:25583ms step_avg:87.02ms +step:295/1680 train_time:25670ms step_avg:87.02ms +step:296/1680 train_time:25757ms step_avg:87.02ms +step:297/1680 train_time:25845ms step_avg:87.02ms +step:298/1680 train_time:25932ms step_avg:87.02ms +step:299/1680 train_time:26019ms step_avg:87.02ms +step:300/1680 train_time:26106ms step_avg:87.02ms +step:301/1680 train_time:26194ms step_avg:87.02ms +step:302/1680 train_time:26281ms step_avg:87.02ms +step:303/1680 train_time:26368ms step_avg:87.02ms +step:304/1680 train_time:26455ms step_avg:87.02ms +step:305/1680 train_time:26541ms step_avg:87.02ms +step:306/1680 train_time:26628ms step_avg:87.02ms +step:307/1680 train_time:26715ms step_avg:87.02ms +step:308/1680 train_time:26802ms step_avg:87.02ms +step:309/1680 train_time:26889ms step_avg:87.02ms +step:310/1680 train_time:26976ms step_avg:87.02ms +step:311/1680 train_time:27063ms step_avg:87.02ms +step:312/1680 train_time:27151ms step_avg:87.02ms +step:313/1680 train_time:27238ms step_avg:87.02ms +step:314/1680 train_time:27325ms step_avg:87.02ms +step:315/1680 train_time:27412ms step_avg:87.02ms +step:316/1680 train_time:27499ms step_avg:87.02ms +step:317/1680 train_time:27586ms step_avg:87.02ms +step:318/1680 train_time:27673ms step_avg:87.02ms +step:319/1680 train_time:27760ms step_avg:87.02ms +step:320/1680 train_time:27847ms step_avg:87.02ms +step:321/1680 train_time:27934ms step_avg:87.02ms +step:322/1680 train_time:28021ms step_avg:87.02ms +step:323/1680 train_time:28108ms step_avg:87.02ms +step:324/1680 train_time:28196ms step_avg:87.02ms +step:325/1680 train_time:28283ms step_avg:87.02ms +step:326/1680 train_time:28370ms step_avg:87.02ms +step:327/1680 train_time:28457ms step_avg:87.02ms +step:328/1680 train_time:28544ms step_avg:87.03ms +step:329/1680 train_time:28631ms step_avg:87.02ms +step:330/1680 train_time:28718ms step_avg:87.02ms +step:331/1680 train_time:28805ms step_avg:87.03ms +step:332/1680 train_time:28892ms step_avg:87.02ms +step:333/1680 train_time:28980ms step_avg:87.03ms +step:334/1680 train_time:29066ms step_avg:87.03ms +step:335/1680 train_time:29153ms step_avg:87.02ms +step:336/1680 train_time:29240ms step_avg:87.02ms +step:337/1680 train_time:29328ms step_avg:87.03ms +step:338/1680 train_time:29415ms step_avg:87.03ms +step:339/1680 train_time:29502ms step_avg:87.03ms +step:340/1680 train_time:29589ms step_avg:87.03ms +step:341/1680 train_time:29677ms step_avg:87.03ms +step:342/1680 train_time:29764ms step_avg:87.03ms +step:343/1680 train_time:29851ms step_avg:87.03ms +step:344/1680 train_time:29938ms step_avg:87.03ms +step:345/1680 train_time:30025ms step_avg:87.03ms +step:346/1680 train_time:30113ms step_avg:87.03ms +step:347/1680 train_time:30199ms step_avg:87.03ms +step:348/1680 train_time:30286ms step_avg:87.03ms +step:349/1680 train_time:30373ms step_avg:87.03ms +step:350/1680 train_time:30460ms step_avg:87.03ms +step:351/1680 train_time:30548ms step_avg:87.03ms +step:352/1680 train_time:30635ms step_avg:87.03ms +step:353/1680 train_time:30722ms step_avg:87.03ms +step:354/1680 train_time:30809ms step_avg:87.03ms +step:355/1680 train_time:30895ms step_avg:87.03ms +step:356/1680 train_time:30983ms step_avg:87.03ms +step:357/1680 train_time:31070ms step_avg:87.03ms +step:358/1680 train_time:31157ms step_avg:87.03ms +step:359/1680 train_time:31244ms step_avg:87.03ms +step:360/1680 train_time:31331ms step_avg:87.03ms +step:361/1680 train_time:31418ms step_avg:87.03ms +step:362/1680 train_time:31505ms step_avg:87.03ms +step:363/1680 train_time:31592ms step_avg:87.03ms +step:364/1680 train_time:31678ms step_avg:87.03ms +step:365/1680 train_time:31766ms step_avg:87.03ms +step:366/1680 train_time:31852ms step_avg:87.03ms +step:367/1680 train_time:31940ms step_avg:87.03ms +step:368/1680 train_time:32027ms step_avg:87.03ms +step:369/1680 train_time:32114ms step_avg:87.03ms +step:370/1680 train_time:32201ms step_avg:87.03ms +step:371/1680 train_time:32288ms step_avg:87.03ms +step:372/1680 train_time:32375ms step_avg:87.03ms +step:373/1680 train_time:32463ms step_avg:87.03ms +step:374/1680 train_time:32550ms step_avg:87.03ms +step:375/1680 train_time:32637ms step_avg:87.03ms +step:375/1680 val_loss:3.8132 train_time:32725ms step_avg:87.27ms +step:376/1680 train_time:32745ms step_avg:87.09ms +step:377/1680 train_time:32814ms step_avg:87.04ms +step:378/1680 train_time:32906ms step_avg:87.05ms +step:379/1680 train_time:32994ms step_avg:87.06ms +step:380/1680 train_time:33081ms step_avg:87.06ms +step:381/1680 train_time:33168ms step_avg:87.06ms +step:382/1680 train_time:33254ms step_avg:87.05ms +step:383/1680 train_time:33340ms step_avg:87.05ms +step:384/1680 train_time:33427ms step_avg:87.05ms +step:385/1680 train_time:33513ms step_avg:87.05ms +step:386/1680 train_time:33600ms step_avg:87.05ms +step:387/1680 train_time:33687ms step_avg:87.05ms +step:388/1680 train_time:33776ms step_avg:87.05ms +step:389/1680 train_time:33864ms step_avg:87.05ms +step:390/1680 train_time:33953ms step_avg:87.06ms +step:391/1680 train_time:34040ms step_avg:87.06ms +step:392/1680 train_time:34128ms step_avg:87.06ms +step:393/1680 train_time:34214ms step_avg:87.06ms +step:394/1680 train_time:34302ms step_avg:87.06ms +step:395/1680 train_time:34388ms step_avg:87.06ms +step:396/1680 train_time:34475ms step_avg:87.06ms +step:397/1680 train_time:34560ms step_avg:87.05ms +step:398/1680 train_time:34648ms step_avg:87.05ms +step:399/1680 train_time:34735ms step_avg:87.06ms +step:400/1680 train_time:34823ms step_avg:87.06ms +step:401/1680 train_time:34911ms step_avg:87.06ms +step:402/1680 train_time:34999ms step_avg:87.06ms +step:403/1680 train_time:35086ms step_avg:87.06ms +step:404/1680 train_time:35173ms step_avg:87.06ms +step:405/1680 train_time:35260ms step_avg:87.06ms +step:406/1680 train_time:35347ms step_avg:87.06ms +step:407/1680 train_time:35433ms step_avg:87.06ms +step:408/1680 train_time:35520ms step_avg:87.06ms +step:409/1680 train_time:35607ms step_avg:87.06ms +step:410/1680 train_time:35694ms step_avg:87.06ms +step:411/1680 train_time:35781ms step_avg:87.06ms +step:412/1680 train_time:35868ms step_avg:87.06ms +step:413/1680 train_time:35956ms step_avg:87.06ms +step:414/1680 train_time:36045ms step_avg:87.07ms +step:415/1680 train_time:36132ms step_avg:87.06ms +step:416/1680 train_time:36219ms step_avg:87.06ms +step:417/1680 train_time:36306ms step_avg:87.06ms +step:418/1680 train_time:36393ms step_avg:87.06ms +step:419/1680 train_time:36480ms step_avg:87.06ms +step:420/1680 train_time:36566ms step_avg:87.06ms +step:421/1680 train_time:36653ms step_avg:87.06ms +step:422/1680 train_time:36740ms step_avg:87.06ms +step:423/1680 train_time:36828ms step_avg:87.06ms +step:424/1680 train_time:36916ms step_avg:87.06ms +step:425/1680 train_time:37004ms step_avg:87.07ms +step:426/1680 train_time:37091ms step_avg:87.07ms +step:427/1680 train_time:37179ms step_avg:87.07ms +step:428/1680 train_time:37266ms step_avg:87.07ms +step:429/1680 train_time:37353ms step_avg:87.07ms +step:430/1680 train_time:37440ms step_avg:87.07ms +step:431/1680 train_time:37526ms step_avg:87.07ms +step:432/1680 train_time:37614ms step_avg:87.07ms +step:433/1680 train_time:37700ms step_avg:87.07ms +step:434/1680 train_time:37786ms step_avg:87.07ms +step:435/1680 train_time:37873ms step_avg:87.06ms +step:436/1680 train_time:37961ms step_avg:87.07ms +step:437/1680 train_time:38048ms step_avg:87.07ms +step:438/1680 train_time:38136ms step_avg:87.07ms +step:439/1680 train_time:38223ms step_avg:87.07ms +step:440/1680 train_time:38310ms step_avg:87.07ms +step:441/1680 train_time:38397ms step_avg:87.07ms +step:442/1680 train_time:38484ms step_avg:87.07ms +step:443/1680 train_time:38571ms step_avg:87.07ms +step:444/1680 train_time:38657ms step_avg:87.07ms +step:445/1680 train_time:38744ms step_avg:87.07ms +step:446/1680 train_time:38831ms step_avg:87.07ms +step:447/1680 train_time:38918ms step_avg:87.06ms +step:448/1680 train_time:39005ms step_avg:87.06ms +step:449/1680 train_time:39093ms step_avg:87.07ms +step:450/1680 train_time:39181ms step_avg:87.07ms +step:451/1680 train_time:39268ms step_avg:87.07ms +step:452/1680 train_time:39356ms step_avg:87.07ms +step:453/1680 train_time:39442ms step_avg:87.07ms +step:454/1680 train_time:39529ms step_avg:87.07ms +step:455/1680 train_time:39616ms step_avg:87.07ms +step:456/1680 train_time:39703ms step_avg:87.07ms +step:457/1680 train_time:39790ms step_avg:87.07ms +step:458/1680 train_time:39877ms step_avg:87.07ms +step:459/1680 train_time:39964ms step_avg:87.07ms +step:460/1680 train_time:40051ms step_avg:87.07ms +step:461/1680 train_time:40139ms step_avg:87.07ms +step:462/1680 train_time:40226ms step_avg:87.07ms +step:463/1680 train_time:40313ms step_avg:87.07ms +step:464/1680 train_time:40401ms step_avg:87.07ms +step:465/1680 train_time:40487ms step_avg:87.07ms +step:466/1680 train_time:40574ms step_avg:87.07ms +step:467/1680 train_time:40661ms step_avg:87.07ms +step:468/1680 train_time:40748ms step_avg:87.07ms +step:469/1680 train_time:40835ms step_avg:87.07ms +step:470/1680 train_time:40922ms step_avg:87.07ms +step:471/1680 train_time:41010ms step_avg:87.07ms +step:472/1680 train_time:41097ms step_avg:87.07ms +step:473/1680 train_time:41184ms step_avg:87.07ms +step:474/1680 train_time:41272ms step_avg:87.07ms +step:475/1680 train_time:41358ms step_avg:87.07ms +step:476/1680 train_time:41445ms step_avg:87.07ms +step:477/1680 train_time:41532ms step_avg:87.07ms +step:478/1680 train_time:41620ms step_avg:87.07ms +step:479/1680 train_time:41706ms step_avg:87.07ms +step:480/1680 train_time:41793ms step_avg:87.07ms +step:481/1680 train_time:41879ms step_avg:87.07ms +step:482/1680 train_time:41966ms step_avg:87.07ms +step:483/1680 train_time:42053ms step_avg:87.07ms +step:484/1680 train_time:42141ms step_avg:87.07ms +step:485/1680 train_time:42229ms step_avg:87.07ms +step:486/1680 train_time:42316ms step_avg:87.07ms +step:487/1680 train_time:42403ms step_avg:87.07ms +step:488/1680 train_time:42490ms step_avg:87.07ms +step:489/1680 train_time:42576ms step_avg:87.07ms +step:490/1680 train_time:42662ms step_avg:87.07ms +step:491/1680 train_time:42749ms step_avg:87.07ms +step:492/1680 train_time:42837ms step_avg:87.07ms +step:493/1680 train_time:42924ms step_avg:87.07ms +step:494/1680 train_time:43012ms step_avg:87.07ms +step:495/1680 train_time:43099ms step_avg:87.07ms +step:496/1680 train_time:43187ms step_avg:87.07ms +step:497/1680 train_time:43273ms step_avg:87.07ms +step:498/1680 train_time:43360ms step_avg:87.07ms +step:499/1680 train_time:43447ms step_avg:87.07ms +step:500/1680 train_time:43534ms step_avg:87.07ms +step:500/1680 val_loss:3.7162 train_time:43623ms step_avg:87.25ms +step:501/1680 train_time:43642ms step_avg:87.11ms +step:502/1680 train_time:43712ms step_avg:87.08ms +step:503/1680 train_time:43801ms step_avg:87.08ms +step:504/1680 train_time:43888ms step_avg:87.08ms +step:505/1680 train_time:43975ms step_avg:87.08ms +step:506/1680 train_time:44061ms step_avg:87.08ms +step:507/1680 train_time:44148ms step_avg:87.08ms +step:508/1680 train_time:44234ms step_avg:87.08ms +step:509/1680 train_time:44320ms step_avg:87.07ms +step:510/1680 train_time:44407ms step_avg:87.07ms +step:511/1680 train_time:44493ms step_avg:87.07ms +step:512/1680 train_time:44581ms step_avg:87.07ms +step:513/1680 train_time:44669ms step_avg:87.07ms +step:514/1680 train_time:44758ms step_avg:87.08ms +step:515/1680 train_time:44846ms step_avg:87.08ms +step:516/1680 train_time:44933ms step_avg:87.08ms +step:517/1680 train_time:45020ms step_avg:87.08ms +step:518/1680 train_time:45106ms step_avg:87.08ms +step:519/1680 train_time:45193ms step_avg:87.08ms +step:520/1680 train_time:45280ms step_avg:87.08ms +step:521/1680 train_time:45366ms step_avg:87.07ms +step:522/1680 train_time:45452ms step_avg:87.07ms +step:523/1680 train_time:45539ms step_avg:87.07ms +step:524/1680 train_time:45626ms step_avg:87.07ms +step:525/1680 train_time:45713ms step_avg:87.07ms +step:526/1680 train_time:45801ms step_avg:87.07ms +step:527/1680 train_time:45888ms step_avg:87.07ms +step:528/1680 train_time:45975ms step_avg:87.07ms +step:529/1680 train_time:46063ms step_avg:87.07ms +step:530/1680 train_time:46149ms step_avg:87.07ms +step:531/1680 train_time:46236ms step_avg:87.07ms +step:532/1680 train_time:46322ms step_avg:87.07ms +step:533/1680 train_time:46409ms step_avg:87.07ms +step:534/1680 train_time:46496ms step_avg:87.07ms +step:535/1680 train_time:46583ms step_avg:87.07ms +step:536/1680 train_time:46671ms step_avg:87.07ms +step:537/1680 train_time:46758ms step_avg:87.07ms +step:538/1680 train_time:46846ms step_avg:87.07ms +step:539/1680 train_time:46933ms step_avg:87.07ms +step:540/1680 train_time:47020ms step_avg:87.07ms +step:541/1680 train_time:47107ms step_avg:87.07ms +step:542/1680 train_time:47194ms step_avg:87.07ms +step:543/1680 train_time:47281ms step_avg:87.07ms +step:544/1680 train_time:47368ms step_avg:87.07ms +step:545/1680 train_time:47455ms step_avg:87.07ms +step:546/1680 train_time:47542ms step_avg:87.07ms +step:547/1680 train_time:47630ms step_avg:87.07ms +step:548/1680 train_time:47717ms step_avg:87.08ms +step:549/1680 train_time:47806ms step_avg:87.08ms +step:550/1680 train_time:47894ms step_avg:87.08ms +step:551/1680 train_time:47983ms step_avg:87.08ms +step:552/1680 train_time:48071ms step_avg:87.08ms +step:553/1680 train_time:48158ms step_avg:87.09ms +step:554/1680 train_time:48247ms step_avg:87.09ms +step:555/1680 train_time:48335ms step_avg:87.09ms +step:556/1680 train_time:48423ms step_avg:87.09ms +step:557/1680 train_time:48511ms step_avg:87.09ms +step:558/1680 train_time:48599ms step_avg:87.10ms +step:559/1680 train_time:48688ms step_avg:87.10ms +step:560/1680 train_time:48777ms step_avg:87.10ms +step:561/1680 train_time:48866ms step_avg:87.10ms +step:562/1680 train_time:48954ms step_avg:87.11ms +step:563/1680 train_time:49041ms step_avg:87.11ms +step:564/1680 train_time:49129ms step_avg:87.11ms +step:565/1680 train_time:49217ms step_avg:87.11ms +step:566/1680 train_time:49305ms step_avg:87.11ms +step:567/1680 train_time:49393ms step_avg:87.11ms +step:568/1680 train_time:49480ms step_avg:87.11ms +step:569/1680 train_time:49568ms step_avg:87.11ms +step:570/1680 train_time:49657ms step_avg:87.12ms +step:571/1680 train_time:49746ms step_avg:87.12ms +step:572/1680 train_time:49835ms step_avg:87.12ms +step:573/1680 train_time:49923ms step_avg:87.13ms +step:574/1680 train_time:50011ms step_avg:87.13ms +step:575/1680 train_time:50099ms step_avg:87.13ms +step:576/1680 train_time:50187ms step_avg:87.13ms +step:577/1680 train_time:50276ms step_avg:87.13ms +step:578/1680 train_time:50364ms step_avg:87.13ms +step:579/1680 train_time:50452ms step_avg:87.14ms +step:580/1680 train_time:50540ms step_avg:87.14ms +step:581/1680 train_time:50628ms step_avg:87.14ms +step:582/1680 train_time:50717ms step_avg:87.14ms +step:583/1680 train_time:50806ms step_avg:87.15ms +step:584/1680 train_time:50896ms step_avg:87.15ms +step:585/1680 train_time:50983ms step_avg:87.15ms +step:586/1680 train_time:51071ms step_avg:87.15ms +step:587/1680 train_time:51159ms step_avg:87.15ms +step:588/1680 train_time:51247ms step_avg:87.15ms +step:589/1680 train_time:51336ms step_avg:87.16ms +step:590/1680 train_time:51423ms step_avg:87.16ms +step:591/1680 train_time:51511ms step_avg:87.16ms +step:592/1680 train_time:51599ms step_avg:87.16ms +step:593/1680 train_time:51687ms step_avg:87.16ms +step:594/1680 train_time:51776ms step_avg:87.16ms +step:595/1680 train_time:51864ms step_avg:87.17ms +step:596/1680 train_time:51953ms step_avg:87.17ms +step:597/1680 train_time:52042ms step_avg:87.17ms +step:598/1680 train_time:52129ms step_avg:87.17ms +step:599/1680 train_time:52218ms step_avg:87.17ms +step:600/1680 train_time:52306ms step_avg:87.18ms +step:601/1680 train_time:52394ms step_avg:87.18ms +step:602/1680 train_time:52482ms step_avg:87.18ms +step:603/1680 train_time:52571ms step_avg:87.18ms +step:604/1680 train_time:52659ms step_avg:87.18ms +step:605/1680 train_time:52748ms step_avg:87.19ms +step:606/1680 train_time:52836ms step_avg:87.19ms +step:607/1680 train_time:52925ms step_avg:87.19ms +step:608/1680 train_time:53013ms step_avg:87.19ms +step:609/1680 train_time:53101ms step_avg:87.19ms +step:610/1680 train_time:53188ms step_avg:87.19ms +step:611/1680 train_time:53277ms step_avg:87.20ms +step:612/1680 train_time:53366ms step_avg:87.20ms +step:613/1680 train_time:53454ms step_avg:87.20ms +step:614/1680 train_time:53542ms step_avg:87.20ms +step:615/1680 train_time:53631ms step_avg:87.21ms +step:616/1680 train_time:53719ms step_avg:87.21ms +step:617/1680 train_time:53807ms step_avg:87.21ms +step:618/1680 train_time:53895ms step_avg:87.21ms +step:619/1680 train_time:53983ms step_avg:87.21ms +step:620/1680 train_time:54071ms step_avg:87.21ms +step:621/1680 train_time:54159ms step_avg:87.21ms +step:622/1680 train_time:54247ms step_avg:87.21ms +step:623/1680 train_time:54336ms step_avg:87.22ms +step:624/1680 train_time:54424ms step_avg:87.22ms +step:625/1680 train_time:54512ms step_avg:87.22ms +step:625/1680 val_loss:3.6162 train_time:54602ms step_avg:87.36ms +step:626/1680 train_time:54621ms step_avg:87.25ms +step:627/1680 train_time:54693ms step_avg:87.23ms +step:628/1680 train_time:54780ms step_avg:87.23ms +step:629/1680 train_time:54870ms step_avg:87.23ms +step:630/1680 train_time:54960ms step_avg:87.24ms +step:631/1680 train_time:55048ms step_avg:87.24ms +step:632/1680 train_time:55136ms step_avg:87.24ms +step:633/1680 train_time:55223ms step_avg:87.24ms +step:634/1680 train_time:55310ms step_avg:87.24ms +step:635/1680 train_time:55397ms step_avg:87.24ms +step:636/1680 train_time:55485ms step_avg:87.24ms +step:637/1680 train_time:55579ms step_avg:87.25ms +step:638/1680 train_time:55671ms step_avg:87.26ms +step:639/1680 train_time:55760ms step_avg:87.26ms +step:640/1680 train_time:55850ms step_avg:87.27ms +step:641/1680 train_time:55937ms step_avg:87.27ms +step:642/1680 train_time:56025ms step_avg:87.27ms +step:643/1680 train_time:56112ms step_avg:87.27ms +step:644/1680 train_time:56199ms step_avg:87.27ms +step:645/1680 train_time:56287ms step_avg:87.27ms +step:646/1680 train_time:56374ms step_avg:87.27ms +step:647/1680 train_time:56462ms step_avg:87.27ms +step:648/1680 train_time:56552ms step_avg:87.27ms +step:649/1680 train_time:56641ms step_avg:87.27ms +step:650/1680 train_time:56730ms step_avg:87.28ms +step:651/1680 train_time:56818ms step_avg:87.28ms +step:652/1680 train_time:56906ms step_avg:87.28ms +step:653/1680 train_time:56994ms step_avg:87.28ms +step:654/1680 train_time:57082ms step_avg:87.28ms +step:655/1680 train_time:57169ms step_avg:87.28ms +step:656/1680 train_time:57257ms step_avg:87.28ms +step:657/1680 train_time:57344ms step_avg:87.28ms +step:658/1680 train_time:57432ms step_avg:87.28ms +step:659/1680 train_time:57521ms step_avg:87.28ms +step:660/1680 train_time:57609ms step_avg:87.29ms +step:661/1680 train_time:57698ms step_avg:87.29ms +step:662/1680 train_time:57788ms step_avg:87.29ms +step:663/1680 train_time:57876ms step_avg:87.29ms +step:664/1680 train_time:57964ms step_avg:87.30ms +step:665/1680 train_time:58052ms step_avg:87.30ms +step:666/1680 train_time:58140ms step_avg:87.30ms +step:667/1680 train_time:58229ms step_avg:87.30ms +step:668/1680 train_time:58316ms step_avg:87.30ms +step:669/1680 train_time:58404ms step_avg:87.30ms +step:670/1680 train_time:58492ms step_avg:87.30ms +step:671/1680 train_time:58580ms step_avg:87.30ms +step:672/1680 train_time:58669ms step_avg:87.31ms +step:673/1680 train_time:58758ms step_avg:87.31ms +step:674/1680 train_time:58846ms step_avg:87.31ms +step:675/1680 train_time:58934ms step_avg:87.31ms +step:676/1680 train_time:59022ms step_avg:87.31ms +step:677/1680 train_time:59111ms step_avg:87.31ms +step:678/1680 train_time:59199ms step_avg:87.31ms +step:679/1680 train_time:59286ms step_avg:87.31ms +step:680/1680 train_time:59375ms step_avg:87.32ms +step:681/1680 train_time:59462ms step_avg:87.32ms +step:682/1680 train_time:59551ms step_avg:87.32ms +step:683/1680 train_time:59640ms step_avg:87.32ms +step:684/1680 train_time:59729ms step_avg:87.32ms +step:685/1680 train_time:59817ms step_avg:87.32ms +step:686/1680 train_time:59905ms step_avg:87.33ms +step:687/1680 train_time:59993ms step_avg:87.33ms +step:688/1680 train_time:60081ms step_avg:87.33ms +step:689/1680 train_time:60170ms step_avg:87.33ms +step:690/1680 train_time:60258ms step_avg:87.33ms +step:691/1680 train_time:60346ms step_avg:87.33ms +step:692/1680 train_time:60434ms step_avg:87.33ms +step:693/1680 train_time:60522ms step_avg:87.33ms +step:694/1680 train_time:60610ms step_avg:87.33ms +step:695/1680 train_time:60699ms step_avg:87.34ms +step:696/1680 train_time:60788ms step_avg:87.34ms +step:697/1680 train_time:60876ms step_avg:87.34ms +step:698/1680 train_time:60965ms step_avg:87.34ms +step:699/1680 train_time:61052ms step_avg:87.34ms +step:700/1680 train_time:61140ms step_avg:87.34ms +step:701/1680 train_time:61228ms step_avg:87.34ms +step:702/1680 train_time:61316ms step_avg:87.34ms +step:703/1680 train_time:61403ms step_avg:87.34ms +step:704/1680 train_time:61492ms step_avg:87.35ms +step:705/1680 train_time:61581ms step_avg:87.35ms +step:706/1680 train_time:61669ms step_avg:87.35ms +step:707/1680 train_time:61758ms step_avg:87.35ms +step:708/1680 train_time:61846ms step_avg:87.35ms +step:709/1680 train_time:61935ms step_avg:87.36ms +step:710/1680 train_time:62022ms step_avg:87.36ms +step:711/1680 train_time:62111ms step_avg:87.36ms +step:712/1680 train_time:62198ms step_avg:87.36ms +step:713/1680 train_time:62286ms step_avg:87.36ms +step:714/1680 train_time:62374ms step_avg:87.36ms +step:715/1680 train_time:62462ms step_avg:87.36ms +step:716/1680 train_time:62551ms step_avg:87.36ms +step:717/1680 train_time:62639ms step_avg:87.36ms +step:718/1680 train_time:62728ms step_avg:87.37ms +step:719/1680 train_time:62816ms step_avg:87.37ms +step:720/1680 train_time:62904ms step_avg:87.37ms +step:721/1680 train_time:62993ms step_avg:87.37ms +step:722/1680 train_time:63081ms step_avg:87.37ms +step:723/1680 train_time:63169ms step_avg:87.37ms +step:724/1680 train_time:63257ms step_avg:87.37ms +step:725/1680 train_time:63345ms step_avg:87.37ms +step:726/1680 train_time:63433ms step_avg:87.37ms +step:727/1680 train_time:63522ms step_avg:87.37ms +step:728/1680 train_time:63609ms step_avg:87.38ms +step:729/1680 train_time:63698ms step_avg:87.38ms +step:730/1680 train_time:63787ms step_avg:87.38ms +step:731/1680 train_time:63875ms step_avg:87.38ms +step:732/1680 train_time:63964ms step_avg:87.38ms +step:733/1680 train_time:64052ms step_avg:87.38ms +step:734/1680 train_time:64140ms step_avg:87.38ms +step:735/1680 train_time:64228ms step_avg:87.39ms +step:736/1680 train_time:64316ms step_avg:87.39ms +step:737/1680 train_time:64404ms step_avg:87.39ms +step:738/1680 train_time:64492ms step_avg:87.39ms +step:739/1680 train_time:64579ms step_avg:87.39ms +step:740/1680 train_time:64668ms step_avg:87.39ms +step:741/1680 train_time:64756ms step_avg:87.39ms +step:742/1680 train_time:64845ms step_avg:87.39ms +step:743/1680 train_time:64934ms step_avg:87.39ms +step:744/1680 train_time:65021ms step_avg:87.39ms +step:745/1680 train_time:65109ms step_avg:87.39ms +step:746/1680 train_time:65198ms step_avg:87.40ms +step:747/1680 train_time:65286ms step_avg:87.40ms +step:748/1680 train_time:65375ms step_avg:87.40ms +step:749/1680 train_time:65463ms step_avg:87.40ms +step:750/1680 train_time:65551ms step_avg:87.40ms +step:750/1680 val_loss:3.5639 train_time:65641ms step_avg:87.52ms +step:751/1680 train_time:65659ms step_avg:87.43ms +step:752/1680 train_time:65731ms step_avg:87.41ms +step:753/1680 train_time:65824ms step_avg:87.42ms +step:754/1680 train_time:65912ms step_avg:87.42ms +step:755/1680 train_time:66000ms step_avg:87.42ms +step:756/1680 train_time:66087ms step_avg:87.42ms +step:757/1680 train_time:66174ms step_avg:87.42ms +step:758/1680 train_time:66262ms step_avg:87.42ms +step:759/1680 train_time:66349ms step_avg:87.42ms +step:760/1680 train_time:66438ms step_avg:87.42ms +step:761/1680 train_time:66526ms step_avg:87.42ms +step:762/1680 train_time:66614ms step_avg:87.42ms +step:763/1680 train_time:66705ms step_avg:87.42ms +step:764/1680 train_time:66794ms step_avg:87.43ms +step:765/1680 train_time:66883ms step_avg:87.43ms +step:766/1680 train_time:66972ms step_avg:87.43ms +step:767/1680 train_time:67060ms step_avg:87.43ms +step:768/1680 train_time:67147ms step_avg:87.43ms +step:769/1680 train_time:67235ms step_avg:87.43ms +step:770/1680 train_time:67322ms step_avg:87.43ms +step:771/1680 train_time:67409ms step_avg:87.43ms +step:772/1680 train_time:67497ms step_avg:87.43ms +step:773/1680 train_time:67586ms step_avg:87.43ms +step:774/1680 train_time:67675ms step_avg:87.44ms +step:775/1680 train_time:67764ms step_avg:87.44ms +step:776/1680 train_time:67854ms step_avg:87.44ms +step:777/1680 train_time:67943ms step_avg:87.44ms +step:778/1680 train_time:68031ms step_avg:87.44ms +step:779/1680 train_time:68119ms step_avg:87.44ms +step:780/1680 train_time:68206ms step_avg:87.44ms +step:781/1680 train_time:68294ms step_avg:87.44ms +step:782/1680 train_time:68381ms step_avg:87.44ms +step:783/1680 train_time:68469ms step_avg:87.44ms +step:784/1680 train_time:68556ms step_avg:87.44ms +step:785/1680 train_time:68645ms step_avg:87.45ms +step:786/1680 train_time:68734ms step_avg:87.45ms +step:787/1680 train_time:68823ms step_avg:87.45ms +step:788/1680 train_time:68911ms step_avg:87.45ms +step:789/1680 train_time:69000ms step_avg:87.45ms +step:790/1680 train_time:69088ms step_avg:87.45ms +step:791/1680 train_time:69176ms step_avg:87.45ms +step:792/1680 train_time:69264ms step_avg:87.45ms +step:793/1680 train_time:69352ms step_avg:87.45ms +step:794/1680 train_time:69440ms step_avg:87.46ms +step:795/1680 train_time:69527ms step_avg:87.46ms +step:796/1680 train_time:69615ms step_avg:87.46ms +step:797/1680 train_time:69705ms step_avg:87.46ms +step:798/1680 train_time:69794ms step_avg:87.46ms +step:799/1680 train_time:69882ms step_avg:87.46ms +step:800/1680 train_time:69970ms step_avg:87.46ms +step:801/1680 train_time:70058ms step_avg:87.46ms +step:802/1680 train_time:70146ms step_avg:87.46ms +step:803/1680 train_time:70234ms step_avg:87.46ms +step:804/1680 train_time:70322ms step_avg:87.47ms +step:805/1680 train_time:70410ms step_avg:87.47ms +step:806/1680 train_time:70498ms step_avg:87.47ms +step:807/1680 train_time:70586ms step_avg:87.47ms +step:808/1680 train_time:70674ms step_avg:87.47ms +step:809/1680 train_time:70763ms step_avg:87.47ms +step:810/1680 train_time:70852ms step_avg:87.47ms +step:811/1680 train_time:70940ms step_avg:87.47ms +step:812/1680 train_time:71029ms step_avg:87.47ms +step:813/1680 train_time:71118ms step_avg:87.48ms +step:814/1680 train_time:71206ms step_avg:87.48ms +step:815/1680 train_time:71294ms step_avg:87.48ms +step:816/1680 train_time:71382ms step_avg:87.48ms +step:817/1680 train_time:71471ms step_avg:87.48ms +step:818/1680 train_time:71558ms step_avg:87.48ms +step:819/1680 train_time:71647ms step_avg:87.48ms +step:820/1680 train_time:71735ms step_avg:87.48ms +step:821/1680 train_time:71824ms step_avg:87.48ms +step:822/1680 train_time:71912ms step_avg:87.48ms +step:823/1680 train_time:72001ms step_avg:87.49ms +step:824/1680 train_time:72089ms step_avg:87.49ms +step:825/1680 train_time:72177ms step_avg:87.49ms +step:826/1680 train_time:72264ms step_avg:87.49ms +step:827/1680 train_time:72352ms step_avg:87.49ms +step:828/1680 train_time:72440ms step_avg:87.49ms +step:829/1680 train_time:72528ms step_avg:87.49ms +step:830/1680 train_time:72616ms step_avg:87.49ms +step:831/1680 train_time:72705ms step_avg:87.49ms +step:832/1680 train_time:72793ms step_avg:87.49ms +step:833/1680 train_time:72881ms step_avg:87.49ms +step:834/1680 train_time:72970ms step_avg:87.49ms +step:835/1680 train_time:73058ms step_avg:87.49ms +step:836/1680 train_time:73146ms step_avg:87.50ms +step:837/1680 train_time:73234ms step_avg:87.50ms +step:838/1680 train_time:73322ms step_avg:87.50ms +step:839/1680 train_time:73409ms step_avg:87.50ms +step:840/1680 train_time:73498ms step_avg:87.50ms +step:841/1680 train_time:73585ms step_avg:87.50ms +step:842/1680 train_time:73673ms step_avg:87.50ms +step:843/1680 train_time:73762ms step_avg:87.50ms +step:844/1680 train_time:73850ms step_avg:87.50ms +step:845/1680 train_time:73938ms step_avg:87.50ms +step:846/1680 train_time:74026ms step_avg:87.50ms +step:847/1680 train_time:74115ms step_avg:87.50ms +step:848/1680 train_time:74204ms step_avg:87.50ms +step:849/1680 train_time:74291ms step_avg:87.50ms +step:850/1680 train_time:74379ms step_avg:87.50ms +step:851/1680 train_time:74467ms step_avg:87.50ms +step:852/1680 train_time:74554ms step_avg:87.51ms +step:853/1680 train_time:74643ms step_avg:87.51ms +step:854/1680 train_time:74732ms step_avg:87.51ms +step:855/1680 train_time:74821ms step_avg:87.51ms +step:856/1680 train_time:74909ms step_avg:87.51ms +step:857/1680 train_time:74998ms step_avg:87.51ms +step:858/1680 train_time:75086ms step_avg:87.51ms +step:859/1680 train_time:75175ms step_avg:87.51ms +step:860/1680 train_time:75263ms step_avg:87.52ms +step:861/1680 train_time:75351ms step_avg:87.52ms +step:862/1680 train_time:75438ms step_avg:87.52ms +step:863/1680 train_time:75526ms step_avg:87.52ms +step:864/1680 train_time:75614ms step_avg:87.52ms +step:865/1680 train_time:75703ms step_avg:87.52ms +step:866/1680 train_time:75791ms step_avg:87.52ms +step:867/1680 train_time:75880ms step_avg:87.52ms +step:868/1680 train_time:75968ms step_avg:87.52ms +step:869/1680 train_time:76056ms step_avg:87.52ms +step:870/1680 train_time:76144ms step_avg:87.52ms +step:871/1680 train_time:76232ms step_avg:87.52ms +step:872/1680 train_time:76320ms step_avg:87.52ms +step:873/1680 train_time:76407ms step_avg:87.52ms +step:874/1680 train_time:76495ms step_avg:87.52ms +step:875/1680 train_time:76583ms step_avg:87.52ms +step:875/1680 val_loss:3.5171 train_time:76673ms step_avg:87.63ms +step:876/1680 train_time:76692ms step_avg:87.55ms +step:877/1680 train_time:76765ms step_avg:87.53ms +step:878/1680 train_time:76857ms step_avg:87.54ms +step:879/1680 train_time:76946ms step_avg:87.54ms +step:880/1680 train_time:77033ms step_avg:87.54ms +step:881/1680 train_time:77120ms step_avg:87.54ms +step:882/1680 train_time:77207ms step_avg:87.54ms +step:883/1680 train_time:77294ms step_avg:87.54ms +step:884/1680 train_time:77380ms step_avg:87.53ms +step:885/1680 train_time:77469ms step_avg:87.54ms +step:886/1680 train_time:77556ms step_avg:87.54ms +step:887/1680 train_time:77646ms step_avg:87.54ms +step:888/1680 train_time:77736ms step_avg:87.54ms +step:889/1680 train_time:77827ms step_avg:87.54ms +step:890/1680 train_time:77916ms step_avg:87.55ms +step:891/1680 train_time:78005ms step_avg:87.55ms +step:892/1680 train_time:78094ms step_avg:87.55ms +step:893/1680 train_time:78181ms step_avg:87.55ms +step:894/1680 train_time:78269ms step_avg:87.55ms +step:895/1680 train_time:78356ms step_avg:87.55ms +step:896/1680 train_time:78443ms step_avg:87.55ms +step:897/1680 train_time:78532ms step_avg:87.55ms +step:898/1680 train_time:78620ms step_avg:87.55ms +step:899/1680 train_time:78708ms step_avg:87.55ms +step:900/1680 train_time:78798ms step_avg:87.55ms +step:901/1680 train_time:78888ms step_avg:87.56ms +step:902/1680 train_time:78977ms step_avg:87.56ms +step:903/1680 train_time:79066ms step_avg:87.56ms +step:904/1680 train_time:79154ms step_avg:87.56ms +step:905/1680 train_time:79242ms step_avg:87.56ms +step:906/1680 train_time:79329ms step_avg:87.56ms +step:907/1680 train_time:79416ms step_avg:87.56ms +step:908/1680 train_time:79504ms step_avg:87.56ms +step:909/1680 train_time:79592ms step_avg:87.56ms +step:910/1680 train_time:79681ms step_avg:87.56ms +step:911/1680 train_time:79770ms step_avg:87.56ms +step:912/1680 train_time:79859ms step_avg:87.57ms +step:913/1680 train_time:79948ms step_avg:87.57ms +step:914/1680 train_time:80037ms step_avg:87.57ms +step:915/1680 train_time:80125ms step_avg:87.57ms +step:916/1680 train_time:80213ms step_avg:87.57ms +step:917/1680 train_time:80301ms step_avg:87.57ms +step:918/1680 train_time:80389ms step_avg:87.57ms +step:919/1680 train_time:80477ms step_avg:87.57ms +step:920/1680 train_time:80565ms step_avg:87.57ms +step:921/1680 train_time:80654ms step_avg:87.57ms +step:922/1680 train_time:80742ms step_avg:87.57ms +step:923/1680 train_time:80830ms step_avg:87.57ms +step:924/1680 train_time:80919ms step_avg:87.57ms +step:925/1680 train_time:81007ms step_avg:87.58ms +step:926/1680 train_time:81096ms step_avg:87.58ms +step:927/1680 train_time:81184ms step_avg:87.58ms +step:928/1680 train_time:81273ms step_avg:87.58ms +step:929/1680 train_time:81360ms step_avg:87.58ms +step:930/1680 train_time:81449ms step_avg:87.58ms +step:931/1680 train_time:81537ms step_avg:87.58ms +step:932/1680 train_time:81625ms step_avg:87.58ms +step:933/1680 train_time:81713ms step_avg:87.58ms +step:934/1680 train_time:81801ms step_avg:87.58ms +step:935/1680 train_time:81890ms step_avg:87.58ms +step:936/1680 train_time:81979ms step_avg:87.58ms +step:937/1680 train_time:82068ms step_avg:87.59ms +step:938/1680 train_time:82157ms step_avg:87.59ms +step:939/1680 train_time:82245ms step_avg:87.59ms +step:940/1680 train_time:82333ms step_avg:87.59ms +step:941/1680 train_time:82420ms step_avg:87.59ms +step:942/1680 train_time:82508ms step_avg:87.59ms +step:943/1680 train_time:82596ms step_avg:87.59ms +step:944/1680 train_time:82684ms step_avg:87.59ms +step:945/1680 train_time:82773ms step_avg:87.59ms +step:946/1680 train_time:82861ms step_avg:87.59ms +step:947/1680 train_time:82950ms step_avg:87.59ms +step:948/1680 train_time:83038ms step_avg:87.59ms +step:949/1680 train_time:83127ms step_avg:87.59ms +step:950/1680 train_time:83215ms step_avg:87.59ms +step:951/1680 train_time:83303ms step_avg:87.59ms +step:952/1680 train_time:83391ms step_avg:87.60ms +step:953/1680 train_time:83479ms step_avg:87.60ms +step:954/1680 train_time:83567ms step_avg:87.60ms +step:955/1680 train_time:83656ms step_avg:87.60ms +step:956/1680 train_time:83745ms step_avg:87.60ms +step:957/1680 train_time:83833ms step_avg:87.60ms +step:958/1680 train_time:83921ms step_avg:87.60ms +step:959/1680 train_time:84010ms step_avg:87.60ms +step:960/1680 train_time:84099ms step_avg:87.60ms +step:961/1680 train_time:84189ms step_avg:87.61ms +step:962/1680 train_time:84277ms step_avg:87.61ms +step:963/1680 train_time:84365ms step_avg:87.61ms +step:964/1680 train_time:84454ms step_avg:87.61ms +step:965/1680 train_time:84541ms step_avg:87.61ms +step:966/1680 train_time:84630ms step_avg:87.61ms +step:967/1680 train_time:84717ms step_avg:87.61ms +step:968/1680 train_time:84806ms step_avg:87.61ms +step:969/1680 train_time:84894ms step_avg:87.61ms +step:970/1680 train_time:84982ms step_avg:87.61ms +step:971/1680 train_time:85071ms step_avg:87.61ms +step:972/1680 train_time:85160ms step_avg:87.61ms +step:973/1680 train_time:85248ms step_avg:87.61ms +step:974/1680 train_time:85336ms step_avg:87.61ms +step:975/1680 train_time:85424ms step_avg:87.61ms +step:976/1680 train_time:85512ms step_avg:87.61ms +step:977/1680 train_time:85600ms step_avg:87.61ms +step:978/1680 train_time:85688ms step_avg:87.62ms +step:979/1680 train_time:85777ms step_avg:87.62ms +step:980/1680 train_time:85866ms step_avg:87.62ms +step:981/1680 train_time:85954ms step_avg:87.62ms +step:982/1680 train_time:86042ms step_avg:87.62ms +step:983/1680 train_time:86130ms step_avg:87.62ms +step:984/1680 train_time:86218ms step_avg:87.62ms +step:985/1680 train_time:86306ms step_avg:87.62ms +step:986/1680 train_time:86395ms step_avg:87.62ms +step:987/1680 train_time:86483ms step_avg:87.62ms +step:988/1680 train_time:86571ms step_avg:87.62ms +step:989/1680 train_time:86659ms step_avg:87.62ms +step:990/1680 train_time:86747ms step_avg:87.62ms +step:991/1680 train_time:86835ms step_avg:87.62ms +step:992/1680 train_time:86923ms step_avg:87.62ms +step:993/1680 train_time:87011ms step_avg:87.62ms +step:994/1680 train_time:87099ms step_avg:87.62ms +step:995/1680 train_time:87187ms step_avg:87.63ms +step:996/1680 train_time:87275ms step_avg:87.63ms +step:997/1680 train_time:87364ms step_avg:87.63ms +step:998/1680 train_time:87452ms step_avg:87.63ms +step:999/1680 train_time:87540ms step_avg:87.63ms +step:1000/1680 train_time:87629ms step_avg:87.63ms +step:1000/1680 val_loss:3.4680 train_time:87718ms step_avg:87.72ms +step:1001/1680 train_time:87738ms step_avg:87.65ms +step:1002/1680 train_time:87813ms step_avg:87.64ms +step:1003/1680 train_time:87903ms step_avg:87.64ms +step:1004/1680 train_time:87992ms step_avg:87.64ms +step:1005/1680 train_time:88079ms step_avg:87.64ms +step:1006/1680 train_time:88166ms step_avg:87.64ms +step:1007/1680 train_time:88253ms step_avg:87.64ms +step:1008/1680 train_time:88341ms step_avg:87.64ms +step:1009/1680 train_time:88429ms step_avg:87.64ms +step:1010/1680 train_time:88516ms step_avg:87.64ms +step:1011/1680 train_time:88604ms step_avg:87.64ms +step:1012/1680 train_time:88693ms step_avg:87.64ms +step:1013/1680 train_time:88783ms step_avg:87.64ms +step:1014/1680 train_time:88872ms step_avg:87.65ms +step:1015/1680 train_time:88961ms step_avg:87.65ms +step:1016/1680 train_time:89049ms step_avg:87.65ms +step:1017/1680 train_time:89137ms step_avg:87.65ms +step:1018/1680 train_time:89225ms step_avg:87.65ms +step:1019/1680 train_time:89312ms step_avg:87.65ms +step:1020/1680 train_time:89399ms step_avg:87.65ms +step:1021/1680 train_time:89487ms step_avg:87.65ms +step:1022/1680 train_time:89575ms step_avg:87.65ms +step:1023/1680 train_time:89663ms step_avg:87.65ms +step:1024/1680 train_time:89752ms step_avg:87.65ms +step:1025/1680 train_time:89841ms step_avg:87.65ms +step:1026/1680 train_time:89929ms step_avg:87.65ms +step:1027/1680 train_time:90019ms step_avg:87.65ms +step:1028/1680 train_time:90107ms step_avg:87.65ms +step:1029/1680 train_time:90194ms step_avg:87.65ms +step:1030/1680 train_time:90283ms step_avg:87.65ms +step:1031/1680 train_time:90370ms step_avg:87.65ms +step:1032/1680 train_time:90457ms step_avg:87.65ms +step:1033/1680 train_time:90546ms step_avg:87.65ms +step:1034/1680 train_time:90634ms step_avg:87.65ms +step:1035/1680 train_time:90723ms step_avg:87.65ms +step:1036/1680 train_time:90811ms step_avg:87.66ms +step:1037/1680 train_time:90901ms step_avg:87.66ms +step:1038/1680 train_time:90989ms step_avg:87.66ms +step:1039/1680 train_time:91078ms step_avg:87.66ms +step:1040/1680 train_time:91166ms step_avg:87.66ms +step:1041/1680 train_time:91254ms step_avg:87.66ms +step:1042/1680 train_time:91342ms step_avg:87.66ms +step:1043/1680 train_time:91430ms step_avg:87.66ms +step:1044/1680 train_time:91518ms step_avg:87.66ms +step:1045/1680 train_time:91607ms step_avg:87.66ms +step:1046/1680 train_time:91696ms step_avg:87.66ms +step:1047/1680 train_time:91785ms step_avg:87.66ms +step:1048/1680 train_time:91873ms step_avg:87.67ms +step:1049/1680 train_time:91962ms step_avg:87.67ms +step:1050/1680 train_time:92050ms step_avg:87.67ms +step:1051/1680 train_time:92138ms step_avg:87.67ms +step:1052/1680 train_time:92227ms step_avg:87.67ms +step:1053/1680 train_time:92315ms step_avg:87.67ms +step:1054/1680 train_time:92403ms step_avg:87.67ms +step:1055/1680 train_time:92491ms step_avg:87.67ms +step:1056/1680 train_time:92579ms step_avg:87.67ms +step:1057/1680 train_time:92668ms step_avg:87.67ms +step:1058/1680 train_time:92756ms step_avg:87.67ms +step:1059/1680 train_time:92846ms step_avg:87.67ms +step:1060/1680 train_time:92935ms step_avg:87.67ms +step:1061/1680 train_time:93023ms step_avg:87.68ms +step:1062/1680 train_time:93111ms step_avg:87.67ms +step:1063/1680 train_time:93200ms step_avg:87.68ms +step:1064/1680 train_time:93288ms step_avg:87.68ms +step:1065/1680 train_time:93377ms step_avg:87.68ms +step:1066/1680 train_time:93465ms step_avg:87.68ms +step:1067/1680 train_time:93553ms step_avg:87.68ms +step:1068/1680 train_time:93641ms step_avg:87.68ms +step:1069/1680 train_time:93730ms step_avg:87.68ms +step:1070/1680 train_time:93819ms step_avg:87.68ms +step:1071/1680 train_time:93908ms step_avg:87.68ms +step:1072/1680 train_time:93997ms step_avg:87.68ms +step:1073/1680 train_time:94086ms step_avg:87.68ms +step:1074/1680 train_time:94173ms step_avg:87.68ms +step:1075/1680 train_time:94262ms step_avg:87.69ms +step:1076/1680 train_time:94350ms step_avg:87.69ms +step:1077/1680 train_time:94438ms step_avg:87.69ms +step:1078/1680 train_time:94527ms step_avg:87.69ms +step:1079/1680 train_time:94615ms step_avg:87.69ms +step:1080/1680 train_time:94704ms step_avg:87.69ms +step:1081/1680 train_time:94792ms step_avg:87.69ms +step:1082/1680 train_time:94882ms step_avg:87.69ms +step:1083/1680 train_time:94970ms step_avg:87.69ms +step:1084/1680 train_time:95058ms step_avg:87.69ms +step:1085/1680 train_time:95147ms step_avg:87.69ms +step:1086/1680 train_time:95235ms step_avg:87.69ms +step:1087/1680 train_time:95323ms step_avg:87.69ms +step:1088/1680 train_time:95411ms step_avg:87.69ms +step:1089/1680 train_time:95500ms step_avg:87.70ms +step:1090/1680 train_time:95588ms step_avg:87.70ms +step:1091/1680 train_time:95676ms step_avg:87.70ms +step:1092/1680 train_time:95764ms step_avg:87.70ms +step:1093/1680 train_time:95852ms step_avg:87.70ms +step:1094/1680 train_time:95941ms step_avg:87.70ms +step:1095/1680 train_time:96030ms step_avg:87.70ms +step:1096/1680 train_time:96118ms step_avg:87.70ms +step:1097/1680 train_time:96208ms step_avg:87.70ms +step:1098/1680 train_time:96297ms step_avg:87.70ms +step:1099/1680 train_time:96386ms step_avg:87.70ms +step:1100/1680 train_time:96474ms step_avg:87.70ms +step:1101/1680 train_time:96563ms step_avg:87.70ms +step:1102/1680 train_time:96651ms step_avg:87.71ms +step:1103/1680 train_time:96741ms step_avg:87.71ms +step:1104/1680 train_time:96830ms step_avg:87.71ms +step:1105/1680 train_time:96919ms step_avg:87.71ms +step:1106/1680 train_time:97008ms step_avg:87.71ms +step:1107/1680 train_time:97097ms step_avg:87.71ms +step:1108/1680 train_time:97187ms step_avg:87.71ms +step:1109/1680 train_time:97275ms step_avg:87.71ms +step:1110/1680 train_time:97364ms step_avg:87.72ms +step:1111/1680 train_time:97452ms step_avg:87.72ms +step:1112/1680 train_time:97542ms step_avg:87.72ms +step:1113/1680 train_time:97631ms step_avg:87.72ms +step:1114/1680 train_time:97719ms step_avg:87.72ms +step:1115/1680 train_time:97809ms step_avg:87.72ms +step:1116/1680 train_time:97898ms step_avg:87.72ms +step:1117/1680 train_time:97988ms step_avg:87.72ms +step:1118/1680 train_time:98078ms step_avg:87.73ms +step:1119/1680 train_time:98167ms step_avg:87.73ms +step:1120/1680 train_time:98256ms step_avg:87.73ms +step:1121/1680 train_time:98346ms step_avg:87.73ms +step:1122/1680 train_time:98435ms step_avg:87.73ms +step:1123/1680 train_time:98525ms step_avg:87.73ms +step:1124/1680 train_time:98613ms step_avg:87.73ms +step:1125/1680 train_time:98702ms step_avg:87.73ms +step:1125/1680 val_loss:3.4147 train_time:98792ms step_avg:87.82ms +step:1126/1680 train_time:98812ms step_avg:87.75ms +step:1127/1680 train_time:98882ms step_avg:87.74ms +step:1128/1680 train_time:98971ms step_avg:87.74ms +step:1129/1680 train_time:99063ms step_avg:87.74ms +step:1130/1680 train_time:99151ms step_avg:87.74ms +step:1131/1680 train_time:99238ms step_avg:87.74ms +step:1132/1680 train_time:99326ms step_avg:87.74ms +step:1133/1680 train_time:99414ms step_avg:87.74ms +step:1134/1680 train_time:99502ms step_avg:87.74ms +step:1135/1680 train_time:99589ms step_avg:87.74ms +step:1136/1680 train_time:99679ms step_avg:87.75ms +step:1137/1680 train_time:99772ms step_avg:87.75ms +step:1138/1680 train_time:99863ms step_avg:87.75ms +step:1139/1680 train_time:99952ms step_avg:87.75ms +step:1140/1680 train_time:100043ms step_avg:87.76ms +step:1141/1680 train_time:100132ms step_avg:87.76ms +step:1142/1680 train_time:100220ms step_avg:87.76ms +step:1143/1680 train_time:100308ms step_avg:87.76ms +step:1144/1680 train_time:100397ms step_avg:87.76ms +step:1145/1680 train_time:100484ms step_avg:87.76ms +step:1146/1680 train_time:100573ms step_avg:87.76ms +step:1147/1680 train_time:100662ms step_avg:87.76ms +step:1148/1680 train_time:100752ms step_avg:87.76ms +step:1149/1680 train_time:100841ms step_avg:87.76ms +step:1150/1680 train_time:100931ms step_avg:87.77ms +step:1151/1680 train_time:101021ms step_avg:87.77ms +step:1152/1680 train_time:101109ms step_avg:87.77ms +step:1153/1680 train_time:101198ms step_avg:87.77ms +step:1154/1680 train_time:101286ms step_avg:87.77ms +step:1155/1680 train_time:101375ms step_avg:87.77ms +step:1156/1680 train_time:101463ms step_avg:87.77ms +step:1157/1680 train_time:101551ms step_avg:87.77ms +step:1158/1680 train_time:101640ms step_avg:87.77ms +step:1159/1680 train_time:101730ms step_avg:87.77ms +step:1160/1680 train_time:101820ms step_avg:87.78ms +step:1161/1680 train_time:101909ms step_avg:87.78ms +step:1162/1680 train_time:101998ms step_avg:87.78ms +step:1163/1680 train_time:102088ms step_avg:87.78ms +step:1164/1680 train_time:102176ms step_avg:87.78ms +step:1165/1680 train_time:102265ms step_avg:87.78ms +step:1166/1680 train_time:102354ms step_avg:87.78ms +step:1167/1680 train_time:102442ms step_avg:87.78ms +step:1168/1680 train_time:102531ms step_avg:87.78ms +step:1169/1680 train_time:102620ms step_avg:87.78ms +step:1170/1680 train_time:102709ms step_avg:87.79ms +step:1171/1680 train_time:102798ms step_avg:87.79ms +step:1172/1680 train_time:102887ms step_avg:87.79ms +step:1173/1680 train_time:102978ms step_avg:87.79ms +step:1174/1680 train_time:103067ms step_avg:87.79ms +step:1175/1680 train_time:103156ms step_avg:87.79ms +step:1176/1680 train_time:103245ms step_avg:87.79ms +step:1177/1680 train_time:103334ms step_avg:87.79ms +step:1178/1680 train_time:103423ms step_avg:87.80ms +step:1179/1680 train_time:103511ms step_avg:87.80ms +step:1180/1680 train_time:103600ms step_avg:87.80ms +step:1181/1680 train_time:103689ms step_avg:87.80ms +step:1182/1680 train_time:103778ms step_avg:87.80ms +step:1183/1680 train_time:103867ms step_avg:87.80ms +step:1184/1680 train_time:103956ms step_avg:87.80ms +step:1185/1680 train_time:104045ms step_avg:87.80ms +step:1186/1680 train_time:104134ms step_avg:87.80ms +step:1187/1680 train_time:104223ms step_avg:87.80ms +step:1188/1680 train_time:104312ms step_avg:87.80ms +step:1189/1680 train_time:104400ms step_avg:87.81ms +step:1190/1680 train_time:104489ms step_avg:87.81ms +step:1191/1680 train_time:104578ms step_avg:87.81ms +step:1192/1680 train_time:104666ms step_avg:87.81ms +step:1193/1680 train_time:104755ms step_avg:87.81ms +step:1194/1680 train_time:104844ms step_avg:87.81ms +step:1195/1680 train_time:104933ms step_avg:87.81ms +step:1196/1680 train_time:105023ms step_avg:87.81ms +step:1197/1680 train_time:105111ms step_avg:87.81ms +step:1198/1680 train_time:105200ms step_avg:87.81ms +step:1199/1680 train_time:105288ms step_avg:87.81ms +step:1200/1680 train_time:105377ms step_avg:87.81ms +step:1201/1680 train_time:105466ms step_avg:87.81ms +step:1202/1680 train_time:105555ms step_avg:87.82ms +step:1203/1680 train_time:105643ms step_avg:87.82ms +step:1204/1680 train_time:105732ms step_avg:87.82ms +step:1205/1680 train_time:105821ms step_avg:87.82ms +step:1206/1680 train_time:105910ms step_avg:87.82ms +step:1207/1680 train_time:105999ms step_avg:87.82ms +step:1208/1680 train_time:106088ms step_avg:87.82ms +step:1209/1680 train_time:106177ms step_avg:87.82ms +step:1210/1680 train_time:106266ms step_avg:87.82ms +step:1211/1680 train_time:106355ms step_avg:87.82ms +step:1212/1680 train_time:106444ms step_avg:87.83ms +step:1213/1680 train_time:106533ms step_avg:87.83ms +step:1214/1680 train_time:106623ms step_avg:87.83ms +step:1215/1680 train_time:106712ms step_avg:87.83ms +step:1216/1680 train_time:106801ms step_avg:87.83ms +step:1217/1680 train_time:106890ms step_avg:87.83ms +step:1218/1680 train_time:106979ms step_avg:87.83ms +step:1219/1680 train_time:107069ms step_avg:87.83ms +step:1220/1680 train_time:107158ms step_avg:87.83ms +step:1221/1680 train_time:107247ms step_avg:87.84ms +step:1222/1680 train_time:107336ms step_avg:87.84ms +step:1223/1680 train_time:107424ms step_avg:87.84ms +step:1224/1680 train_time:107512ms step_avg:87.84ms +step:1225/1680 train_time:107602ms step_avg:87.84ms +step:1226/1680 train_time:107691ms step_avg:87.84ms +step:1227/1680 train_time:107780ms step_avg:87.84ms +step:1228/1680 train_time:107869ms step_avg:87.84ms +step:1229/1680 train_time:107957ms step_avg:87.84ms +step:1230/1680 train_time:108047ms step_avg:87.84ms +step:1231/1680 train_time:108136ms step_avg:87.84ms +step:1232/1680 train_time:108225ms step_avg:87.84ms +step:1233/1680 train_time:108314ms step_avg:87.85ms +step:1234/1680 train_time:108403ms step_avg:87.85ms +step:1235/1680 train_time:108491ms step_avg:87.85ms +step:1236/1680 train_time:108580ms step_avg:87.85ms +step:1237/1680 train_time:108669ms step_avg:87.85ms +step:1238/1680 train_time:108758ms step_avg:87.85ms +step:1239/1680 train_time:108846ms step_avg:87.85ms +step:1240/1680 train_time:108935ms step_avg:87.85ms +step:1241/1680 train_time:109024ms step_avg:87.85ms +step:1242/1680 train_time:109113ms step_avg:87.85ms +step:1243/1680 train_time:109202ms step_avg:87.85ms +step:1244/1680 train_time:109291ms step_avg:87.85ms +step:1245/1680 train_time:109380ms step_avg:87.86ms +step:1246/1680 train_time:109469ms step_avg:87.86ms +step:1247/1680 train_time:109557ms step_avg:87.86ms +step:1248/1680 train_time:109648ms step_avg:87.86ms +step:1249/1680 train_time:109737ms step_avg:87.86ms +step:1250/1680 train_time:109826ms step_avg:87.86ms +step:1250/1680 val_loss:3.3760 train_time:109916ms step_avg:87.93ms +step:1251/1680 train_time:109934ms step_avg:87.88ms +step:1252/1680 train_time:110009ms step_avg:87.87ms +step:1253/1680 train_time:110101ms step_avg:87.87ms +step:1254/1680 train_time:110190ms step_avg:87.87ms +step:1255/1680 train_time:110278ms step_avg:87.87ms +step:1256/1680 train_time:110366ms step_avg:87.87ms +step:1257/1680 train_time:110453ms step_avg:87.87ms +step:1258/1680 train_time:110542ms step_avg:87.87ms +step:1259/1680 train_time:110630ms step_avg:87.87ms +step:1260/1680 train_time:110719ms step_avg:87.87ms +step:1261/1680 train_time:110808ms step_avg:87.87ms +step:1262/1680 train_time:110899ms step_avg:87.88ms +step:1263/1680 train_time:110990ms step_avg:87.88ms +step:1264/1680 train_time:111080ms step_avg:87.88ms +step:1265/1680 train_time:111170ms step_avg:87.88ms +step:1266/1680 train_time:111259ms step_avg:87.88ms +step:1267/1680 train_time:111348ms step_avg:87.88ms +step:1268/1680 train_time:111436ms step_avg:87.88ms +step:1269/1680 train_time:111524ms step_avg:87.88ms +step:1270/1680 train_time:111612ms step_avg:87.88ms +step:1271/1680 train_time:111700ms step_avg:87.88ms +step:1272/1680 train_time:111789ms step_avg:87.88ms +step:1273/1680 train_time:111879ms step_avg:87.89ms +step:1274/1680 train_time:111970ms step_avg:87.89ms +step:1275/1680 train_time:112059ms step_avg:87.89ms +step:1276/1680 train_time:112150ms step_avg:87.89ms +step:1277/1680 train_time:112239ms step_avg:87.89ms +step:1278/1680 train_time:112327ms step_avg:87.89ms +step:1279/1680 train_time:112416ms step_avg:87.89ms +step:1280/1680 train_time:112505ms step_avg:87.89ms +step:1281/1680 train_time:112593ms step_avg:87.89ms +step:1282/1680 train_time:112682ms step_avg:87.90ms +step:1283/1680 train_time:112771ms step_avg:87.90ms +step:1284/1680 train_time:112861ms step_avg:87.90ms +step:1285/1680 train_time:112950ms step_avg:87.90ms +step:1286/1680 train_time:113040ms step_avg:87.90ms +step:1287/1680 train_time:113130ms step_avg:87.90ms +step:1288/1680 train_time:113219ms step_avg:87.90ms +step:1289/1680 train_time:113309ms step_avg:87.90ms +step:1290/1680 train_time:113398ms step_avg:87.91ms +step:1291/1680 train_time:113487ms step_avg:87.91ms +step:1292/1680 train_time:113575ms step_avg:87.91ms +step:1293/1680 train_time:113664ms step_avg:87.91ms +step:1294/1680 train_time:113752ms step_avg:87.91ms +step:1295/1680 train_time:113842ms step_avg:87.91ms +step:1296/1680 train_time:113930ms step_avg:87.91ms +step:1297/1680 train_time:114020ms step_avg:87.91ms +step:1298/1680 train_time:114109ms step_avg:87.91ms +step:1299/1680 train_time:114198ms step_avg:87.91ms +step:1300/1680 train_time:114287ms step_avg:87.91ms +step:1301/1680 train_time:114376ms step_avg:87.91ms +step:1302/1680 train_time:114465ms step_avg:87.91ms +step:1303/1680 train_time:114553ms step_avg:87.91ms +step:1304/1680 train_time:114641ms step_avg:87.91ms +step:1305/1680 train_time:114730ms step_avg:87.92ms +step:1306/1680 train_time:114818ms step_avg:87.92ms +step:1307/1680 train_time:114907ms step_avg:87.92ms +step:1308/1680 train_time:114996ms step_avg:87.92ms +step:1309/1680 train_time:115085ms step_avg:87.92ms +step:1310/1680 train_time:115175ms step_avg:87.92ms +step:1311/1680 train_time:115264ms step_avg:87.92ms +step:1312/1680 train_time:115352ms step_avg:87.92ms +step:1313/1680 train_time:115441ms step_avg:87.92ms +step:1314/1680 train_time:115530ms step_avg:87.92ms +step:1315/1680 train_time:115619ms step_avg:87.92ms +step:1316/1680 train_time:115708ms step_avg:87.92ms +step:1317/1680 train_time:115796ms step_avg:87.92ms +step:1318/1680 train_time:115885ms step_avg:87.93ms +step:1319/1680 train_time:115974ms step_avg:87.93ms +step:1320/1680 train_time:116063ms step_avg:87.93ms +step:1321/1680 train_time:116151ms step_avg:87.93ms +step:1322/1680 train_time:116241ms step_avg:87.93ms +step:1323/1680 train_time:116330ms step_avg:87.93ms +step:1324/1680 train_time:116419ms step_avg:87.93ms +step:1325/1680 train_time:116509ms step_avg:87.93ms +step:1326/1680 train_time:116598ms step_avg:87.93ms +step:1327/1680 train_time:116686ms step_avg:87.93ms +step:1328/1680 train_time:116775ms step_avg:87.93ms +step:1329/1680 train_time:116865ms step_avg:87.93ms +step:1330/1680 train_time:116953ms step_avg:87.93ms +step:1331/1680 train_time:117042ms step_avg:87.94ms +step:1332/1680 train_time:117132ms step_avg:87.94ms +step:1333/1680 train_time:117220ms step_avg:87.94ms +step:1334/1680 train_time:117310ms step_avg:87.94ms +step:1335/1680 train_time:117399ms step_avg:87.94ms +step:1336/1680 train_time:117489ms step_avg:87.94ms +step:1337/1680 train_time:117577ms step_avg:87.94ms +step:1338/1680 train_time:117666ms step_avg:87.94ms +step:1339/1680 train_time:117755ms step_avg:87.94ms +step:1340/1680 train_time:117844ms step_avg:87.94ms +step:1341/1680 train_time:117933ms step_avg:87.94ms +step:1342/1680 train_time:118022ms step_avg:87.94ms +step:1343/1680 train_time:118111ms step_avg:87.95ms +step:1344/1680 train_time:118200ms step_avg:87.95ms +step:1345/1680 train_time:118289ms step_avg:87.95ms +step:1346/1680 train_time:118377ms step_avg:87.95ms +step:1347/1680 train_time:118467ms step_avg:87.95ms +step:1348/1680 train_time:118556ms step_avg:87.95ms +step:1349/1680 train_time:118645ms step_avg:87.95ms +step:1350/1680 train_time:118734ms step_avg:87.95ms +step:1351/1680 train_time:118822ms step_avg:87.95ms +step:1352/1680 train_time:118912ms step_avg:87.95ms +step:1353/1680 train_time:119001ms step_avg:87.95ms +step:1354/1680 train_time:119091ms step_avg:87.95ms +step:1355/1680 train_time:119179ms step_avg:87.96ms +step:1356/1680 train_time:119268ms step_avg:87.96ms +step:1357/1680 train_time:119357ms step_avg:87.96ms +step:1358/1680 train_time:119447ms step_avg:87.96ms +step:1359/1680 train_time:119536ms step_avg:87.96ms +step:1360/1680 train_time:119624ms step_avg:87.96ms +step:1361/1680 train_time:119713ms step_avg:87.96ms +step:1362/1680 train_time:119801ms step_avg:87.96ms +step:1363/1680 train_time:119891ms step_avg:87.96ms +step:1364/1680 train_time:119979ms step_avg:87.96ms +step:1365/1680 train_time:120068ms step_avg:87.96ms +step:1366/1680 train_time:120157ms step_avg:87.96ms +step:1367/1680 train_time:120246ms step_avg:87.96ms +step:1368/1680 train_time:120334ms step_avg:87.96ms +step:1369/1680 train_time:120423ms step_avg:87.96ms +step:1370/1680 train_time:120512ms step_avg:87.97ms +step:1371/1680 train_time:120602ms step_avg:87.97ms +step:1372/1680 train_time:120691ms step_avg:87.97ms +step:1373/1680 train_time:120781ms step_avg:87.97ms +step:1374/1680 train_time:120870ms step_avg:87.97ms +step:1375/1680 train_time:120958ms step_avg:87.97ms +step:1375/1680 val_loss:3.3416 train_time:121049ms step_avg:88.04ms +step:1376/1680 train_time:121069ms step_avg:87.99ms +step:1377/1680 train_time:121143ms step_avg:87.98ms +step:1378/1680 train_time:121234ms step_avg:87.98ms +step:1379/1680 train_time:121323ms step_avg:87.98ms +step:1380/1680 train_time:121411ms step_avg:87.98ms +step:1381/1680 train_time:121499ms step_avg:87.98ms +step:1382/1680 train_time:121587ms step_avg:87.98ms +step:1383/1680 train_time:121674ms step_avg:87.98ms +step:1384/1680 train_time:121762ms step_avg:87.98ms +step:1385/1680 train_time:121850ms step_avg:87.98ms +step:1386/1680 train_time:121939ms step_avg:87.98ms +step:1387/1680 train_time:122029ms step_avg:87.98ms +step:1388/1680 train_time:122121ms step_avg:87.98ms +step:1389/1680 train_time:122211ms step_avg:87.98ms +step:1390/1680 train_time:122301ms step_avg:87.99ms +step:1391/1680 train_time:122390ms step_avg:87.99ms +step:1392/1680 train_time:122478ms step_avg:87.99ms +step:1393/1680 train_time:122566ms step_avg:87.99ms +step:1394/1680 train_time:122654ms step_avg:87.99ms +step:1395/1680 train_time:122742ms step_avg:87.99ms +step:1396/1680 train_time:122830ms step_avg:87.99ms +step:1397/1680 train_time:122919ms step_avg:87.99ms +step:1398/1680 train_time:123008ms step_avg:87.99ms +step:1399/1680 train_time:123098ms step_avg:87.99ms +step:1400/1680 train_time:123187ms step_avg:87.99ms +step:1401/1680 train_time:123277ms step_avg:87.99ms +step:1402/1680 train_time:123367ms step_avg:87.99ms +step:1403/1680 train_time:123457ms step_avg:87.99ms +step:1404/1680 train_time:123545ms step_avg:87.99ms +step:1405/1680 train_time:123633ms step_avg:87.99ms +step:1406/1680 train_time:123720ms step_avg:87.99ms +step:1407/1680 train_time:123809ms step_avg:87.99ms +step:1408/1680 train_time:123898ms step_avg:88.00ms +step:1409/1680 train_time:123987ms step_avg:88.00ms +step:1410/1680 train_time:124077ms step_avg:88.00ms +step:1411/1680 train_time:124167ms step_avg:88.00ms +step:1412/1680 train_time:124256ms step_avg:88.00ms +step:1413/1680 train_time:124346ms step_avg:88.00ms +step:1414/1680 train_time:124435ms step_avg:88.00ms +step:1415/1680 train_time:124525ms step_avg:88.00ms +step:1416/1680 train_time:124614ms step_avg:88.00ms +step:1417/1680 train_time:124701ms step_avg:88.00ms +step:1418/1680 train_time:124789ms step_avg:88.00ms +step:1419/1680 train_time:124878ms step_avg:88.00ms +step:1420/1680 train_time:124968ms step_avg:88.01ms +step:1421/1680 train_time:125058ms step_avg:88.01ms +step:1422/1680 train_time:125148ms step_avg:88.01ms +step:1423/1680 train_time:125237ms step_avg:88.01ms +step:1424/1680 train_time:125327ms step_avg:88.01ms +step:1425/1680 train_time:125417ms step_avg:88.01ms +step:1426/1680 train_time:125506ms step_avg:88.01ms +step:1427/1680 train_time:125595ms step_avg:88.01ms +step:1428/1680 train_time:125684ms step_avg:88.01ms +step:1429/1680 train_time:125772ms step_avg:88.01ms +step:1430/1680 train_time:125861ms step_avg:88.01ms +step:1431/1680 train_time:125950ms step_avg:88.02ms +step:1432/1680 train_time:126039ms step_avg:88.02ms +step:1433/1680 train_time:126128ms step_avg:88.02ms +step:1434/1680 train_time:126217ms step_avg:88.02ms +step:1435/1680 train_time:126306ms step_avg:88.02ms +step:1436/1680 train_time:126396ms step_avg:88.02ms +step:1437/1680 train_time:126484ms step_avg:88.02ms +step:1438/1680 train_time:126572ms step_avg:88.02ms +step:1439/1680 train_time:126660ms step_avg:88.02ms +step:1440/1680 train_time:126749ms step_avg:88.02ms +step:1441/1680 train_time:126837ms step_avg:88.02ms +step:1442/1680 train_time:126927ms step_avg:88.02ms +step:1443/1680 train_time:127017ms step_avg:88.02ms +step:1444/1680 train_time:127106ms step_avg:88.02ms +step:1445/1680 train_time:127195ms step_avg:88.02ms +step:1446/1680 train_time:127284ms step_avg:88.02ms +step:1447/1680 train_time:127373ms step_avg:88.03ms +step:1448/1680 train_time:127462ms step_avg:88.03ms +step:1449/1680 train_time:127552ms step_avg:88.03ms +step:1450/1680 train_time:127640ms step_avg:88.03ms +step:1451/1680 train_time:127729ms step_avg:88.03ms +step:1452/1680 train_time:127817ms step_avg:88.03ms +step:1453/1680 train_time:127906ms step_avg:88.03ms +step:1454/1680 train_time:127995ms step_avg:88.03ms +step:1455/1680 train_time:128085ms step_avg:88.03ms +step:1456/1680 train_time:128174ms step_avg:88.03ms +step:1457/1680 train_time:128263ms step_avg:88.03ms +step:1458/1680 train_time:128352ms step_avg:88.03ms +step:1459/1680 train_time:128441ms step_avg:88.03ms +step:1460/1680 train_time:128531ms step_avg:88.03ms +step:1461/1680 train_time:128620ms step_avg:88.04ms +step:1462/1680 train_time:128709ms step_avg:88.04ms +step:1463/1680 train_time:128797ms step_avg:88.04ms +step:1464/1680 train_time:128886ms step_avg:88.04ms +step:1465/1680 train_time:128975ms step_avg:88.04ms +step:1466/1680 train_time:129064ms step_avg:88.04ms +step:1467/1680 train_time:129153ms step_avg:88.04ms +step:1468/1680 train_time:129242ms step_avg:88.04ms +step:1469/1680 train_time:129331ms step_avg:88.04ms +step:1470/1680 train_time:129419ms step_avg:88.04ms +step:1471/1680 train_time:129508ms step_avg:88.04ms +step:1472/1680 train_time:129597ms step_avg:88.04ms +step:1473/1680 train_time:129686ms step_avg:88.04ms +step:1474/1680 train_time:129774ms step_avg:88.04ms +step:1475/1680 train_time:129863ms step_avg:88.04ms +step:1476/1680 train_time:129953ms step_avg:88.04ms +step:1477/1680 train_time:130041ms step_avg:88.04ms +step:1478/1680 train_time:130130ms step_avg:88.04ms +step:1479/1680 train_time:130219ms step_avg:88.04ms +step:1480/1680 train_time:130308ms step_avg:88.05ms +step:1481/1680 train_time:130397ms step_avg:88.05ms +step:1482/1680 train_time:130486ms step_avg:88.05ms +step:1483/1680 train_time:130576ms step_avg:88.05ms +step:1484/1680 train_time:130665ms step_avg:88.05ms +step:1485/1680 train_time:130755ms step_avg:88.05ms +step:1486/1680 train_time:130843ms step_avg:88.05ms +step:1487/1680 train_time:130932ms step_avg:88.05ms +step:1488/1680 train_time:131021ms step_avg:88.05ms +step:1489/1680 train_time:131109ms step_avg:88.05ms +step:1490/1680 train_time:131199ms step_avg:88.05ms +step:1491/1680 train_time:131287ms step_avg:88.05ms +step:1492/1680 train_time:131376ms step_avg:88.05ms +step:1493/1680 train_time:131464ms step_avg:88.05ms +step:1494/1680 train_time:131554ms step_avg:88.05ms +step:1495/1680 train_time:131643ms step_avg:88.06ms +step:1496/1680 train_time:131732ms step_avg:88.06ms +step:1497/1680 train_time:131821ms step_avg:88.06ms +step:1498/1680 train_time:131911ms step_avg:88.06ms +step:1499/1680 train_time:132000ms step_avg:88.06ms +step:1500/1680 train_time:132089ms step_avg:88.06ms +step:1500/1680 val_loss:3.3120 train_time:132179ms step_avg:88.12ms +step:1501/1680 train_time:132198ms step_avg:88.07ms +step:1502/1680 train_time:132273ms step_avg:88.06ms +step:1503/1680 train_time:132364ms step_avg:88.07ms +step:1504/1680 train_time:132453ms step_avg:88.07ms +step:1505/1680 train_time:132541ms step_avg:88.07ms +step:1506/1680 train_time:132629ms step_avg:88.07ms +step:1507/1680 train_time:132717ms step_avg:88.07ms +step:1508/1680 train_time:132805ms step_avg:88.07ms +step:1509/1680 train_time:132893ms step_avg:88.07ms +step:1510/1680 train_time:132981ms step_avg:88.07ms +step:1511/1680 train_time:133070ms step_avg:88.07ms +step:1512/1680 train_time:133161ms step_avg:88.07ms +step:1513/1680 train_time:133251ms step_avg:88.07ms +step:1514/1680 train_time:133342ms step_avg:88.07ms +step:1515/1680 train_time:133431ms step_avg:88.07ms +step:1516/1680 train_time:133520ms step_avg:88.07ms +step:1517/1680 train_time:133609ms step_avg:88.07ms +step:1518/1680 train_time:133697ms step_avg:88.07ms +step:1519/1680 train_time:133785ms step_avg:88.07ms +step:1520/1680 train_time:133874ms step_avg:88.08ms +step:1521/1680 train_time:133962ms step_avg:88.07ms +step:1522/1680 train_time:134051ms step_avg:88.08ms +step:1523/1680 train_time:134140ms step_avg:88.08ms +step:1524/1680 train_time:134230ms step_avg:88.08ms +step:1525/1680 train_time:134320ms step_avg:88.08ms +step:1526/1680 train_time:134409ms step_avg:88.08ms +step:1527/1680 train_time:134499ms step_avg:88.08ms +step:1528/1680 train_time:134587ms step_avg:88.08ms +step:1529/1680 train_time:134676ms step_avg:88.08ms +step:1530/1680 train_time:134765ms step_avg:88.08ms +step:1531/1680 train_time:134853ms step_avg:88.08ms +step:1532/1680 train_time:134941ms step_avg:88.08ms +step:1533/1680 train_time:135030ms step_avg:88.08ms +step:1534/1680 train_time:135119ms step_avg:88.08ms +step:1535/1680 train_time:135209ms step_avg:88.08ms +step:1536/1680 train_time:135299ms step_avg:88.09ms +step:1537/1680 train_time:135389ms step_avg:88.09ms +step:1538/1680 train_time:135479ms step_avg:88.09ms +step:1539/1680 train_time:135568ms step_avg:88.09ms +step:1540/1680 train_time:135657ms step_avg:88.09ms +step:1541/1680 train_time:135745ms step_avg:88.09ms +step:1542/1680 train_time:135834ms step_avg:88.09ms +step:1543/1680 train_time:135922ms step_avg:88.09ms +step:1544/1680 train_time:136012ms step_avg:88.09ms +step:1545/1680 train_time:136101ms step_avg:88.09ms +step:1546/1680 train_time:136190ms step_avg:88.09ms +step:1547/1680 train_time:136280ms step_avg:88.09ms +step:1548/1680 train_time:136370ms step_avg:88.09ms +step:1549/1680 train_time:136459ms step_avg:88.09ms +step:1550/1680 train_time:136549ms step_avg:88.10ms +step:1551/1680 train_time:136637ms step_avg:88.10ms +step:1552/1680 train_time:136726ms step_avg:88.10ms +step:1553/1680 train_time:136815ms step_avg:88.10ms +step:1554/1680 train_time:136903ms step_avg:88.10ms +step:1555/1680 train_time:136993ms step_avg:88.10ms +step:1556/1680 train_time:137081ms step_avg:88.10ms +step:1557/1680 train_time:137171ms step_avg:88.10ms +step:1558/1680 train_time:137260ms step_avg:88.10ms +step:1559/1680 train_time:137350ms step_avg:88.10ms +step:1560/1680 train_time:137439ms step_avg:88.10ms +step:1561/1680 train_time:137528ms step_avg:88.10ms +step:1562/1680 train_time:137617ms step_avg:88.10ms +step:1563/1680 train_time:137706ms step_avg:88.10ms +step:1564/1680 train_time:137795ms step_avg:88.10ms +step:1565/1680 train_time:137883ms step_avg:88.10ms +step:1566/1680 train_time:137972ms step_avg:88.10ms +step:1567/1680 train_time:138061ms step_avg:88.11ms +step:1568/1680 train_time:138150ms step_avg:88.11ms +step:1569/1680 train_time:138238ms step_avg:88.11ms +step:1570/1680 train_time:138327ms step_avg:88.11ms +step:1571/1680 train_time:138416ms step_avg:88.11ms +step:1572/1680 train_time:138505ms step_avg:88.11ms +step:1573/1680 train_time:138595ms step_avg:88.11ms +step:1574/1680 train_time:138684ms step_avg:88.11ms +step:1575/1680 train_time:138773ms step_avg:88.11ms +step:1576/1680 train_time:138861ms step_avg:88.11ms +step:1577/1680 train_time:138949ms step_avg:88.11ms +step:1578/1680 train_time:139038ms step_avg:88.11ms +step:1579/1680 train_time:139128ms step_avg:88.11ms +step:1580/1680 train_time:139217ms step_avg:88.11ms +step:1581/1680 train_time:139305ms step_avg:88.11ms +step:1582/1680 train_time:139394ms step_avg:88.11ms +step:1583/1680 train_time:139483ms step_avg:88.11ms +step:1584/1680 train_time:139573ms step_avg:88.11ms +step:1585/1680 train_time:139662ms step_avg:88.11ms +step:1586/1680 train_time:139751ms step_avg:88.12ms +step:1587/1680 train_time:139839ms step_avg:88.12ms +step:1588/1680 train_time:139928ms step_avg:88.12ms +step:1589/1680 train_time:140017ms step_avg:88.12ms +step:1590/1680 train_time:140106ms step_avg:88.12ms +step:1591/1680 train_time:140196ms step_avg:88.12ms +step:1592/1680 train_time:140284ms step_avg:88.12ms +step:1593/1680 train_time:140373ms step_avg:88.12ms +step:1594/1680 train_time:140462ms step_avg:88.12ms +step:1595/1680 train_time:140552ms step_avg:88.12ms +step:1596/1680 train_time:140641ms step_avg:88.12ms +step:1597/1680 train_time:140731ms step_avg:88.12ms +step:1598/1680 train_time:140819ms step_avg:88.12ms +step:1599/1680 train_time:140907ms step_avg:88.12ms +step:1600/1680 train_time:140996ms step_avg:88.12ms +step:1601/1680 train_time:141085ms step_avg:88.12ms +step:1602/1680 train_time:141174ms step_avg:88.12ms +step:1603/1680 train_time:141263ms step_avg:88.12ms +step:1604/1680 train_time:141352ms step_avg:88.12ms +step:1605/1680 train_time:141440ms step_avg:88.12ms +step:1606/1680 train_time:141529ms step_avg:88.13ms +step:1607/1680 train_time:141618ms step_avg:88.13ms +step:1608/1680 train_time:141707ms step_avg:88.13ms +step:1609/1680 train_time:141796ms step_avg:88.13ms +step:1610/1680 train_time:141885ms step_avg:88.13ms +step:1611/1680 train_time:141974ms step_avg:88.13ms +step:1612/1680 train_time:142064ms step_avg:88.13ms +step:1613/1680 train_time:142153ms step_avg:88.13ms +step:1614/1680 train_time:142242ms step_avg:88.13ms +step:1615/1680 train_time:142331ms step_avg:88.13ms +step:1616/1680 train_time:142419ms step_avg:88.13ms +step:1617/1680 train_time:142508ms step_avg:88.13ms +step:1618/1680 train_time:142597ms step_avg:88.13ms +step:1619/1680 train_time:142687ms step_avg:88.13ms +step:1620/1680 train_time:142776ms step_avg:88.13ms +step:1621/1680 train_time:142865ms step_avg:88.13ms +step:1622/1680 train_time:142955ms step_avg:88.13ms +step:1623/1680 train_time:143043ms step_avg:88.14ms +step:1624/1680 train_time:143132ms step_avg:88.14ms +step:1625/1680 train_time:143220ms step_avg:88.14ms +step:1625/1680 val_loss:3.2883 train_time:143310ms step_avg:88.19ms +step:1626/1680 train_time:143330ms step_avg:88.15ms +step:1627/1680 train_time:143402ms step_avg:88.14ms +step:1628/1680 train_time:143494ms step_avg:88.14ms +step:1629/1680 train_time:143585ms step_avg:88.14ms +step:1630/1680 train_time:143673ms step_avg:88.14ms +step:1631/1680 train_time:143761ms step_avg:88.14ms +step:1632/1680 train_time:143848ms step_avg:88.14ms +step:1633/1680 train_time:143936ms step_avg:88.14ms +step:1634/1680 train_time:144024ms step_avg:88.14ms +step:1635/1680 train_time:144113ms step_avg:88.14ms +step:1636/1680 train_time:144203ms step_avg:88.14ms +step:1637/1680 train_time:144292ms step_avg:88.14ms +step:1638/1680 train_time:144383ms step_avg:88.15ms +step:1639/1680 train_time:144473ms step_avg:88.15ms +step:1640/1680 train_time:144563ms step_avg:88.15ms +step:1641/1680 train_time:144652ms step_avg:88.15ms +step:1642/1680 train_time:144741ms step_avg:88.15ms +step:1643/1680 train_time:144829ms step_avg:88.15ms +step:1644/1680 train_time:144917ms step_avg:88.15ms +step:1645/1680 train_time:145006ms step_avg:88.15ms +step:1646/1680 train_time:145095ms step_avg:88.15ms +step:1647/1680 train_time:145184ms step_avg:88.15ms +step:1648/1680 train_time:145274ms step_avg:88.15ms +step:1649/1680 train_time:145364ms step_avg:88.15ms +step:1650/1680 train_time:145454ms step_avg:88.15ms +step:1651/1680 train_time:145545ms step_avg:88.16ms +step:1652/1680 train_time:145634ms step_avg:88.16ms +step:1653/1680 train_time:145723ms step_avg:88.16ms +step:1654/1680 train_time:145812ms step_avg:88.16ms +step:1655/1680 train_time:145900ms step_avg:88.16ms +step:1656/1680 train_time:145988ms step_avg:88.16ms +step:1657/1680 train_time:146077ms step_avg:88.16ms +step:1658/1680 train_time:146166ms step_avg:88.16ms +step:1659/1680 train_time:146255ms step_avg:88.16ms +step:1660/1680 train_time:146344ms step_avg:88.16ms +step:1661/1680 train_time:146434ms step_avg:88.16ms +step:1662/1680 train_time:146524ms step_avg:88.16ms +step:1663/1680 train_time:146614ms step_avg:88.16ms +step:1664/1680 train_time:146704ms step_avg:88.16ms +step:1665/1680 train_time:146792ms step_avg:88.16ms +step:1666/1680 train_time:146881ms step_avg:88.16ms +step:1667/1680 train_time:146969ms step_avg:88.16ms +step:1668/1680 train_time:147057ms step_avg:88.16ms +step:1669/1680 train_time:147147ms step_avg:88.16ms +step:1670/1680 train_time:147235ms step_avg:88.16ms +step:1671/1680 train_time:147325ms step_avg:88.17ms +step:1672/1680 train_time:147415ms step_avg:88.17ms +step:1673/1680 train_time:147506ms step_avg:88.17ms +step:1674/1680 train_time:147596ms step_avg:88.17ms +step:1675/1680 train_time:147686ms step_avg:88.17ms +step:1676/1680 train_time:147774ms step_avg:88.17ms +step:1677/1680 train_time:147863ms step_avg:88.17ms +step:1678/1680 train_time:147952ms step_avg:88.17ms +step:1679/1680 train_time:148041ms step_avg:88.17ms +step:1680/1680 train_time:148130ms step_avg:88.17ms +step:1680/1680 val_loss:3.2774 train_time:148220ms step_avg:88.23ms +peak memory allocated: 30760 MiB reserved: 46094 MiB diff --git a/records/092725_BF16CE/b54670db-06ce-4aa6-b50f-869bfe329c8b.txt b/records/092725_BF16CE/b54670db-06ce-4aa6-b50f-869bfe329c8b.txt new file mode 100644 index 000000000..07d17e364 --- /dev/null +++ b/records/092725_BF16CE/b54670db-06ce-4aa6-b50f-869bfe329c8b.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:29:38 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 124W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 25C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 123W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 28C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 158005 C /usr/bin/python 0MiB | +| 0 N/A N/A 158006 C /usr/bin/python 0MiB | +| 0 N/A N/A 158007 C /usr/bin/python 0MiB | +| 0 N/A N/A 158008 C /usr/bin/python 0MiB | +| 0 N/A N/A 158009 C /usr/bin/python 0MiB | +| 0 N/A N/A 158010 C /usr/bin/python 0MiB | +| 0 N/A N/A 158011 C /usr/bin/python 0MiB | +| 0 N/A N/A 158012 C /usr/bin/python 0MiB | +| 1 N/A N/A 158006 C /usr/bin/python 0MiB | +| 2 N/A N/A 158007 C /usr/bin/python 0MiB | +| 3 N/A N/A 158008 C /usr/bin/python 0MiB | +| 4 N/A N/A 158009 C /usr/bin/python 0MiB | +| 5 N/A N/A 158010 C /usr/bin/python 0MiB | +| 6 N/A N/A 158011 C /usr/bin/python 0MiB | +| 7 N/A N/A 158012 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:147ms step_avg:147.10ms +step:2/1680 train_time:167ms step_avg:83.35ms +step:3/1680 train_time:231ms step_avg:76.88ms +step:4/1680 train_time:316ms step_avg:79.00ms +step:5/1680 train_time:402ms step_avg:80.35ms +step:6/1680 train_time:489ms step_avg:81.50ms +step:7/1680 train_time:575ms step_avg:82.17ms +step:8/1680 train_time:661ms step_avg:82.67ms +step:9/1680 train_time:747ms step_avg:83.03ms +step:10/1680 train_time:834ms step_avg:83.39ms +step:11/1680 train_time:920ms step_avg:83.63ms +step:12/1680 train_time:1009ms step_avg:84.07ms +step:13/1680 train_time:1101ms step_avg:84.67ms +step:14/1680 train_time:1190ms step_avg:85.03ms +step:15/1680 train_time:1278ms step_avg:85.22ms +step:16/1680 train_time:1365ms step_avg:85.34ms +step:17/1680 train_time:1453ms step_avg:85.46ms +step:18/1680 train_time:1539ms step_avg:85.52ms +step:19/1680 train_time:1626ms step_avg:85.57ms +step:20/1680 train_time:1712ms step_avg:85.62ms +step:21/1680 train_time:1799ms step_avg:85.65ms +step:22/1680 train_time:1885ms step_avg:85.69ms +step:23/1680 train_time:1972ms step_avg:85.75ms +step:24/1680 train_time:2061ms step_avg:85.88ms +step:25/1680 train_time:2149ms step_avg:85.97ms +step:26/1680 train_time:2238ms step_avg:86.06ms +step:27/1680 train_time:2326ms step_avg:86.14ms +step:28/1680 train_time:2413ms step_avg:86.18ms +step:29/1680 train_time:2500ms step_avg:86.20ms +step:30/1680 train_time:2587ms step_avg:86.23ms +step:31/1680 train_time:2674ms step_avg:86.25ms +step:32/1680 train_time:2760ms step_avg:86.26ms +step:33/1680 train_time:2847ms step_avg:86.28ms +step:34/1680 train_time:2935ms step_avg:86.31ms +step:35/1680 train_time:3023ms step_avg:86.39ms +step:36/1680 train_time:3111ms step_avg:86.42ms +step:37/1680 train_time:3199ms step_avg:86.46ms +step:38/1680 train_time:3287ms step_avg:86.50ms +step:39/1680 train_time:3375ms step_avg:86.53ms +step:40/1680 train_time:3462ms step_avg:86.55ms +step:41/1680 train_time:3549ms step_avg:86.57ms +step:42/1680 train_time:3636ms step_avg:86.58ms +step:43/1680 train_time:3724ms step_avg:86.60ms +step:44/1680 train_time:3810ms step_avg:86.59ms +step:45/1680 train_time:3897ms step_avg:86.60ms +step:46/1680 train_time:3985ms step_avg:86.63ms +step:47/1680 train_time:4073ms step_avg:86.67ms +step:48/1680 train_time:4162ms step_avg:86.71ms +step:49/1680 train_time:4250ms step_avg:86.73ms +step:50/1680 train_time:4337ms step_avg:86.74ms +step:51/1680 train_time:4425ms step_avg:86.76ms +step:52/1680 train_time:4512ms step_avg:86.78ms +step:53/1680 train_time:4600ms step_avg:86.78ms +step:54/1680 train_time:4687ms step_avg:86.80ms +step:55/1680 train_time:4774ms step_avg:86.80ms +step:56/1680 train_time:4862ms step_avg:86.82ms +step:57/1680 train_time:4949ms step_avg:86.83ms +step:58/1680 train_time:5036ms step_avg:86.84ms +step:59/1680 train_time:5124ms step_avg:86.85ms +step:60/1680 train_time:5211ms step_avg:86.86ms +step:61/1680 train_time:5299ms step_avg:86.87ms +step:62/1680 train_time:5387ms step_avg:86.88ms +step:63/1680 train_time:5474ms step_avg:86.89ms +step:64/1680 train_time:5561ms step_avg:86.90ms +step:65/1680 train_time:5648ms step_avg:86.90ms +step:66/1680 train_time:5737ms step_avg:86.92ms +step:67/1680 train_time:5824ms step_avg:86.92ms +step:68/1680 train_time:5911ms step_avg:86.93ms +step:69/1680 train_time:5998ms step_avg:86.93ms +step:70/1680 train_time:6086ms step_avg:86.95ms +step:71/1680 train_time:6174ms step_avg:86.96ms +step:72/1680 train_time:6262ms step_avg:86.97ms +step:73/1680 train_time:6349ms step_avg:86.97ms +step:74/1680 train_time:6437ms step_avg:86.98ms +step:75/1680 train_time:6523ms step_avg:86.98ms +step:76/1680 train_time:6611ms step_avg:86.99ms +step:77/1680 train_time:6698ms step_avg:86.99ms +step:78/1680 train_time:6785ms step_avg:86.99ms +step:79/1680 train_time:6873ms step_avg:87.00ms +step:80/1680 train_time:6960ms step_avg:87.00ms +step:81/1680 train_time:7047ms step_avg:87.00ms +step:82/1680 train_time:7135ms step_avg:87.01ms +step:83/1680 train_time:7222ms step_avg:87.02ms +step:84/1680 train_time:7309ms step_avg:87.02ms +step:85/1680 train_time:7397ms step_avg:87.02ms +step:86/1680 train_time:7484ms step_avg:87.02ms +step:87/1680 train_time:7571ms step_avg:87.02ms +step:88/1680 train_time:7658ms step_avg:87.02ms +step:89/1680 train_time:7744ms step_avg:87.02ms +step:90/1680 train_time:7832ms step_avg:87.02ms +step:91/1680 train_time:7919ms step_avg:87.02ms +step:92/1680 train_time:8005ms step_avg:87.02ms +step:93/1680 train_time:8093ms step_avg:87.03ms +step:94/1680 train_time:8181ms step_avg:87.03ms +step:95/1680 train_time:8268ms step_avg:87.03ms +step:96/1680 train_time:8355ms step_avg:87.04ms +step:97/1680 train_time:8442ms step_avg:87.03ms +step:98/1680 train_time:8530ms step_avg:87.04ms +step:99/1680 train_time:8617ms step_avg:87.04ms +step:100/1680 train_time:8704ms step_avg:87.04ms +step:101/1680 train_time:8792ms step_avg:87.05ms +step:102/1680 train_time:8879ms step_avg:87.04ms +step:103/1680 train_time:8966ms step_avg:87.05ms +step:104/1680 train_time:9054ms step_avg:87.05ms +step:105/1680 train_time:9141ms step_avg:87.05ms +step:106/1680 train_time:9228ms step_avg:87.06ms +step:107/1680 train_time:9315ms step_avg:87.06ms +step:108/1680 train_time:9402ms step_avg:87.06ms +step:109/1680 train_time:9489ms step_avg:87.06ms +step:110/1680 train_time:9577ms step_avg:87.06ms +step:111/1680 train_time:9664ms step_avg:87.06ms +step:112/1680 train_time:9751ms step_avg:87.06ms +step:113/1680 train_time:9838ms step_avg:87.06ms +step:114/1680 train_time:9925ms step_avg:87.07ms +step:115/1680 train_time:10013ms step_avg:87.07ms +step:116/1680 train_time:10100ms step_avg:87.07ms +step:117/1680 train_time:10187ms step_avg:87.07ms +step:118/1680 train_time:10275ms step_avg:87.08ms +step:119/1680 train_time:10362ms step_avg:87.08ms +step:120/1680 train_time:10449ms step_avg:87.08ms +step:121/1680 train_time:10537ms step_avg:87.08ms +step:122/1680 train_time:10625ms step_avg:87.09ms +step:123/1680 train_time:10711ms step_avg:87.08ms +step:124/1680 train_time:10798ms step_avg:87.08ms +step:125/1680 train_time:10885ms step_avg:87.08ms +step:125/1680 val_loss:4.3150 train_time:10973ms step_avg:87.78ms +step:126/1680 train_time:10992ms step_avg:87.24ms +step:127/1680 train_time:11064ms step_avg:87.12ms +step:128/1680 train_time:11159ms step_avg:87.18ms +step:129/1680 train_time:11252ms step_avg:87.22ms +step:130/1680 train_time:11340ms step_avg:87.23ms +step:131/1680 train_time:11427ms step_avg:87.23ms +step:132/1680 train_time:11514ms step_avg:87.22ms +step:133/1680 train_time:11599ms step_avg:87.21ms +step:134/1680 train_time:11685ms step_avg:87.20ms +step:135/1680 train_time:11771ms step_avg:87.19ms +step:136/1680 train_time:11857ms step_avg:87.19ms +step:137/1680 train_time:11944ms step_avg:87.18ms +step:138/1680 train_time:12032ms step_avg:87.19ms +step:139/1680 train_time:12120ms step_avg:87.20ms +step:140/1680 train_time:12209ms step_avg:87.21ms +step:141/1680 train_time:12298ms step_avg:87.22ms +step:142/1680 train_time:12386ms step_avg:87.22ms +step:143/1680 train_time:12473ms step_avg:87.22ms +step:144/1680 train_time:12560ms step_avg:87.22ms +step:145/1680 train_time:12646ms step_avg:87.22ms +step:146/1680 train_time:12732ms step_avg:87.21ms +step:147/1680 train_time:12818ms step_avg:87.20ms +step:148/1680 train_time:12904ms step_avg:87.19ms +step:149/1680 train_time:12991ms step_avg:87.19ms +step:150/1680 train_time:13079ms step_avg:87.19ms +step:151/1680 train_time:13167ms step_avg:87.20ms +step:152/1680 train_time:13255ms step_avg:87.20ms +step:153/1680 train_time:13342ms step_avg:87.20ms +step:154/1680 train_time:13429ms step_avg:87.20ms +step:155/1680 train_time:13516ms step_avg:87.20ms +step:156/1680 train_time:13603ms step_avg:87.20ms +step:157/1680 train_time:13690ms step_avg:87.19ms +step:158/1680 train_time:13776ms step_avg:87.19ms +step:159/1680 train_time:13862ms step_avg:87.18ms +step:160/1680 train_time:13949ms step_avg:87.18ms +step:161/1680 train_time:14036ms step_avg:87.18ms +step:162/1680 train_time:14124ms step_avg:87.18ms +step:163/1680 train_time:14212ms step_avg:87.19ms +step:164/1680 train_time:14300ms step_avg:87.19ms +step:165/1680 train_time:14387ms step_avg:87.20ms +step:166/1680 train_time:14475ms step_avg:87.20ms +step:167/1680 train_time:14562ms step_avg:87.20ms +step:168/1680 train_time:14649ms step_avg:87.20ms +step:169/1680 train_time:14736ms step_avg:87.20ms +step:170/1680 train_time:14822ms step_avg:87.19ms +step:171/1680 train_time:14909ms step_avg:87.19ms +step:172/1680 train_time:14996ms step_avg:87.19ms +step:173/1680 train_time:15083ms step_avg:87.18ms +step:174/1680 train_time:15171ms step_avg:87.19ms +step:175/1680 train_time:15259ms step_avg:87.19ms +step:176/1680 train_time:15346ms step_avg:87.19ms +step:177/1680 train_time:15434ms step_avg:87.20ms +step:178/1680 train_time:15521ms step_avg:87.19ms +step:179/1680 train_time:15608ms step_avg:87.19ms +step:180/1680 train_time:15695ms step_avg:87.19ms +step:181/1680 train_time:15781ms step_avg:87.19ms +step:182/1680 train_time:15868ms step_avg:87.19ms +step:183/1680 train_time:15955ms step_avg:87.19ms +step:184/1680 train_time:16042ms step_avg:87.18ms +step:185/1680 train_time:16130ms step_avg:87.19ms +step:186/1680 train_time:16217ms step_avg:87.19ms +step:187/1680 train_time:16304ms step_avg:87.19ms +step:188/1680 train_time:16391ms step_avg:87.19ms +step:189/1680 train_time:16479ms step_avg:87.19ms +step:190/1680 train_time:16566ms step_avg:87.19ms +step:191/1680 train_time:16654ms step_avg:87.19ms +step:192/1680 train_time:16740ms step_avg:87.19ms +step:193/1680 train_time:16827ms step_avg:87.19ms +step:194/1680 train_time:16913ms step_avg:87.18ms +step:195/1680 train_time:17000ms step_avg:87.18ms +step:196/1680 train_time:17087ms step_avg:87.18ms +step:197/1680 train_time:17174ms step_avg:87.18ms +step:198/1680 train_time:17261ms step_avg:87.18ms +step:199/1680 train_time:17349ms step_avg:87.18ms +step:200/1680 train_time:17436ms step_avg:87.18ms +step:201/1680 train_time:17523ms step_avg:87.18ms +step:202/1680 train_time:17610ms step_avg:87.18ms +step:203/1680 train_time:17698ms step_avg:87.18ms +step:204/1680 train_time:17785ms step_avg:87.18ms +step:205/1680 train_time:17871ms step_avg:87.18ms +step:206/1680 train_time:17958ms step_avg:87.18ms +step:207/1680 train_time:18045ms step_avg:87.17ms +step:208/1680 train_time:18132ms step_avg:87.17ms +step:209/1680 train_time:18219ms step_avg:87.17ms +step:210/1680 train_time:18307ms step_avg:87.18ms +step:211/1680 train_time:18394ms step_avg:87.18ms +step:212/1680 train_time:18482ms step_avg:87.18ms +step:213/1680 train_time:18569ms step_avg:87.18ms +step:214/1680 train_time:18656ms step_avg:87.18ms +step:215/1680 train_time:18743ms step_avg:87.18ms +step:216/1680 train_time:18830ms step_avg:87.18ms +step:217/1680 train_time:18917ms step_avg:87.18ms +step:218/1680 train_time:19004ms step_avg:87.17ms +step:219/1680 train_time:19091ms step_avg:87.17ms +step:220/1680 train_time:19178ms step_avg:87.17ms +step:221/1680 train_time:19265ms step_avg:87.17ms +step:222/1680 train_time:19353ms step_avg:87.17ms +step:223/1680 train_time:19440ms step_avg:87.18ms +step:224/1680 train_time:19527ms step_avg:87.17ms +step:225/1680 train_time:19615ms step_avg:87.18ms +step:226/1680 train_time:19701ms step_avg:87.17ms +step:227/1680 train_time:19788ms step_avg:87.17ms +step:228/1680 train_time:19874ms step_avg:87.17ms +step:229/1680 train_time:19961ms step_avg:87.17ms +step:230/1680 train_time:20049ms step_avg:87.17ms +step:231/1680 train_time:20136ms step_avg:87.17ms +step:232/1680 train_time:20223ms step_avg:87.17ms +step:233/1680 train_time:20311ms step_avg:87.17ms +step:234/1680 train_time:20398ms step_avg:87.17ms +step:235/1680 train_time:20486ms step_avg:87.17ms +step:236/1680 train_time:20573ms step_avg:87.17ms +step:237/1680 train_time:20660ms step_avg:87.17ms +step:238/1680 train_time:20747ms step_avg:87.17ms +step:239/1680 train_time:20834ms step_avg:87.17ms +step:240/1680 train_time:20921ms step_avg:87.17ms +step:241/1680 train_time:21008ms step_avg:87.17ms +step:242/1680 train_time:21095ms step_avg:87.17ms +step:243/1680 train_time:21182ms step_avg:87.17ms +step:244/1680 train_time:21269ms step_avg:87.17ms +step:245/1680 train_time:21356ms step_avg:87.17ms +step:246/1680 train_time:21444ms step_avg:87.17ms +step:247/1680 train_time:21531ms step_avg:87.17ms +step:248/1680 train_time:21618ms step_avg:87.17ms +step:249/1680 train_time:21705ms step_avg:87.17ms +step:250/1680 train_time:21792ms step_avg:87.17ms +step:250/1680 val_loss:3.9741 train_time:21881ms step_avg:87.52ms +step:251/1680 train_time:21899ms step_avg:87.25ms +step:252/1680 train_time:21969ms step_avg:87.18ms +step:253/1680 train_time:22061ms step_avg:87.20ms +step:254/1680 train_time:22150ms step_avg:87.21ms +step:255/1680 train_time:22238ms step_avg:87.21ms +step:256/1680 train_time:22325ms step_avg:87.21ms +step:257/1680 train_time:22411ms step_avg:87.20ms +step:258/1680 train_time:22497ms step_avg:87.20ms +step:259/1680 train_time:22583ms step_avg:87.19ms +step:260/1680 train_time:22670ms step_avg:87.19ms +step:261/1680 train_time:22756ms step_avg:87.19ms +step:262/1680 train_time:22843ms step_avg:87.19ms +step:263/1680 train_time:22931ms step_avg:87.19ms +step:264/1680 train_time:23020ms step_avg:87.20ms +step:265/1680 train_time:23108ms step_avg:87.20ms +step:266/1680 train_time:23196ms step_avg:87.20ms +step:267/1680 train_time:23282ms step_avg:87.20ms +step:268/1680 train_time:23369ms step_avg:87.20ms +step:269/1680 train_time:23456ms step_avg:87.20ms +step:270/1680 train_time:23543ms step_avg:87.19ms +step:271/1680 train_time:23629ms step_avg:87.19ms +step:272/1680 train_time:23717ms step_avg:87.19ms +step:273/1680 train_time:23803ms step_avg:87.19ms +step:274/1680 train_time:23891ms step_avg:87.19ms +step:275/1680 train_time:23979ms step_avg:87.20ms +step:276/1680 train_time:24066ms step_avg:87.20ms +step:277/1680 train_time:24153ms step_avg:87.20ms +step:278/1680 train_time:24242ms step_avg:87.20ms +step:279/1680 train_time:24329ms step_avg:87.20ms +step:280/1680 train_time:24416ms step_avg:87.20ms +step:281/1680 train_time:24503ms step_avg:87.20ms +step:282/1680 train_time:24590ms step_avg:87.20ms +step:283/1680 train_time:24677ms step_avg:87.20ms +step:284/1680 train_time:24763ms step_avg:87.19ms +step:285/1680 train_time:24851ms step_avg:87.20ms +step:286/1680 train_time:24938ms step_avg:87.19ms +step:287/1680 train_time:25025ms step_avg:87.20ms +step:288/1680 train_time:25113ms step_avg:87.20ms +step:289/1680 train_time:25200ms step_avg:87.20ms +step:290/1680 train_time:25288ms step_avg:87.20ms +step:291/1680 train_time:25375ms step_avg:87.20ms +step:292/1680 train_time:25463ms step_avg:87.20ms +step:293/1680 train_time:25549ms step_avg:87.20ms +step:294/1680 train_time:25636ms step_avg:87.20ms +step:295/1680 train_time:25722ms step_avg:87.19ms +step:296/1680 train_time:25809ms step_avg:87.19ms +step:297/1680 train_time:25896ms step_avg:87.19ms +step:298/1680 train_time:25984ms step_avg:87.19ms +step:299/1680 train_time:26071ms step_avg:87.19ms +step:300/1680 train_time:26158ms step_avg:87.19ms +step:301/1680 train_time:26247ms step_avg:87.20ms +step:302/1680 train_time:26334ms step_avg:87.20ms +step:303/1680 train_time:26421ms step_avg:87.20ms +step:304/1680 train_time:26508ms step_avg:87.20ms +step:305/1680 train_time:26595ms step_avg:87.20ms +step:306/1680 train_time:26682ms step_avg:87.20ms +step:307/1680 train_time:26769ms step_avg:87.20ms +step:308/1680 train_time:26856ms step_avg:87.20ms +step:309/1680 train_time:26943ms step_avg:87.20ms +step:310/1680 train_time:27030ms step_avg:87.19ms +step:311/1680 train_time:27118ms step_avg:87.20ms +step:312/1680 train_time:27205ms step_avg:87.20ms +step:313/1680 train_time:27292ms step_avg:87.20ms +step:314/1680 train_time:27379ms step_avg:87.20ms +step:315/1680 train_time:27467ms step_avg:87.20ms +step:316/1680 train_time:27554ms step_avg:87.20ms +step:317/1680 train_time:27640ms step_avg:87.19ms +step:318/1680 train_time:27727ms step_avg:87.19ms +step:319/1680 train_time:27814ms step_avg:87.19ms +step:320/1680 train_time:27901ms step_avg:87.19ms +step:321/1680 train_time:27989ms step_avg:87.19ms +step:322/1680 train_time:28076ms step_avg:87.19ms +step:323/1680 train_time:28163ms step_avg:87.19ms +step:324/1680 train_time:28250ms step_avg:87.19ms +step:325/1680 train_time:28338ms step_avg:87.19ms +step:326/1680 train_time:28425ms step_avg:87.19ms +step:327/1680 train_time:28512ms step_avg:87.19ms +step:328/1680 train_time:28599ms step_avg:87.19ms +step:329/1680 train_time:28685ms step_avg:87.19ms +step:330/1680 train_time:28772ms step_avg:87.19ms +step:331/1680 train_time:28859ms step_avg:87.19ms +step:332/1680 train_time:28946ms step_avg:87.19ms +step:333/1680 train_time:29034ms step_avg:87.19ms +step:334/1680 train_time:29121ms step_avg:87.19ms +step:335/1680 train_time:29208ms step_avg:87.19ms +step:336/1680 train_time:29296ms step_avg:87.19ms +step:337/1680 train_time:29383ms step_avg:87.19ms +step:338/1680 train_time:29470ms step_avg:87.19ms +step:339/1680 train_time:29557ms step_avg:87.19ms +step:340/1680 train_time:29644ms step_avg:87.19ms +step:341/1680 train_time:29731ms step_avg:87.19ms +step:342/1680 train_time:29818ms step_avg:87.19ms +step:343/1680 train_time:29905ms step_avg:87.19ms +step:344/1680 train_time:29992ms step_avg:87.19ms +step:345/1680 train_time:30080ms step_avg:87.19ms +step:346/1680 train_time:30167ms step_avg:87.19ms +step:347/1680 train_time:30254ms step_avg:87.19ms +step:348/1680 train_time:30341ms step_avg:87.19ms +step:349/1680 train_time:30428ms step_avg:87.19ms +step:350/1680 train_time:30516ms step_avg:87.19ms +step:351/1680 train_time:30603ms step_avg:87.19ms +step:352/1680 train_time:30690ms step_avg:87.19ms +step:353/1680 train_time:30777ms step_avg:87.19ms +step:354/1680 train_time:30865ms step_avg:87.19ms +step:355/1680 train_time:30952ms step_avg:87.19ms +step:356/1680 train_time:31039ms step_avg:87.19ms +step:357/1680 train_time:31127ms step_avg:87.19ms +step:358/1680 train_time:31214ms step_avg:87.19ms +step:359/1680 train_time:31301ms step_avg:87.19ms +step:360/1680 train_time:31388ms step_avg:87.19ms +step:361/1680 train_time:31475ms step_avg:87.19ms +step:362/1680 train_time:31562ms step_avg:87.19ms +step:363/1680 train_time:31649ms step_avg:87.19ms +step:364/1680 train_time:31736ms step_avg:87.19ms +step:365/1680 train_time:31823ms step_avg:87.19ms +step:366/1680 train_time:31911ms step_avg:87.19ms +step:367/1680 train_time:31998ms step_avg:87.19ms +step:368/1680 train_time:32086ms step_avg:87.19ms +step:369/1680 train_time:32173ms step_avg:87.19ms +step:370/1680 train_time:32260ms step_avg:87.19ms +step:371/1680 train_time:32347ms step_avg:87.19ms +step:372/1680 train_time:32434ms step_avg:87.19ms +step:373/1680 train_time:32521ms step_avg:87.19ms +step:374/1680 train_time:32608ms step_avg:87.19ms +step:375/1680 train_time:32695ms step_avg:87.19ms +step:375/1680 val_loss:3.8226 train_time:32783ms step_avg:87.42ms +step:376/1680 train_time:32802ms step_avg:87.24ms +step:377/1680 train_time:32876ms step_avg:87.20ms +step:378/1680 train_time:32967ms step_avg:87.21ms +step:379/1680 train_time:33055ms step_avg:87.22ms +step:380/1680 train_time:33142ms step_avg:87.22ms +step:381/1680 train_time:33228ms step_avg:87.21ms +step:382/1680 train_time:33315ms step_avg:87.21ms +step:383/1680 train_time:33401ms step_avg:87.21ms +step:384/1680 train_time:33487ms step_avg:87.21ms +step:385/1680 train_time:33574ms step_avg:87.20ms +step:386/1680 train_time:33660ms step_avg:87.20ms +step:387/1680 train_time:33747ms step_avg:87.20ms +step:388/1680 train_time:33836ms step_avg:87.21ms +step:389/1680 train_time:33925ms step_avg:87.21ms +step:390/1680 train_time:34014ms step_avg:87.21ms +step:391/1680 train_time:34102ms step_avg:87.22ms +step:392/1680 train_time:34188ms step_avg:87.22ms +step:393/1680 train_time:34275ms step_avg:87.21ms +step:394/1680 train_time:34362ms step_avg:87.21ms +step:395/1680 train_time:34448ms step_avg:87.21ms +step:396/1680 train_time:34534ms step_avg:87.21ms +step:397/1680 train_time:34621ms step_avg:87.21ms +step:398/1680 train_time:34708ms step_avg:87.20ms +step:399/1680 train_time:34795ms step_avg:87.21ms +step:400/1680 train_time:34882ms step_avg:87.21ms +step:401/1680 train_time:34970ms step_avg:87.21ms +step:402/1680 train_time:35057ms step_avg:87.21ms +step:403/1680 train_time:35145ms step_avg:87.21ms +step:404/1680 train_time:35232ms step_avg:87.21ms +step:405/1680 train_time:35319ms step_avg:87.21ms +step:406/1680 train_time:35405ms step_avg:87.21ms +step:407/1680 train_time:35492ms step_avg:87.20ms +step:408/1680 train_time:35579ms step_avg:87.20ms +step:409/1680 train_time:35665ms step_avg:87.20ms +step:410/1680 train_time:35753ms step_avg:87.20ms +step:411/1680 train_time:35840ms step_avg:87.20ms +step:412/1680 train_time:35927ms step_avg:87.20ms +step:413/1680 train_time:36015ms step_avg:87.20ms +step:414/1680 train_time:36102ms step_avg:87.20ms +step:415/1680 train_time:36189ms step_avg:87.20ms +step:416/1680 train_time:36276ms step_avg:87.20ms +step:417/1680 train_time:36363ms step_avg:87.20ms +step:418/1680 train_time:36450ms step_avg:87.20ms +step:419/1680 train_time:36538ms step_avg:87.20ms +step:420/1680 train_time:36624ms step_avg:87.20ms +step:421/1680 train_time:36711ms step_avg:87.20ms +step:422/1680 train_time:36798ms step_avg:87.20ms +step:423/1680 train_time:36885ms step_avg:87.20ms +step:424/1680 train_time:36973ms step_avg:87.20ms +step:425/1680 train_time:37061ms step_avg:87.20ms +step:426/1680 train_time:37148ms step_avg:87.20ms +step:427/1680 train_time:37234ms step_avg:87.20ms +step:428/1680 train_time:37321ms step_avg:87.20ms +step:429/1680 train_time:37408ms step_avg:87.20ms +step:430/1680 train_time:37495ms step_avg:87.20ms +step:431/1680 train_time:37582ms step_avg:87.20ms +step:432/1680 train_time:37669ms step_avg:87.20ms +step:433/1680 train_time:37756ms step_avg:87.20ms +step:434/1680 train_time:37842ms step_avg:87.19ms +step:435/1680 train_time:37930ms step_avg:87.19ms +step:436/1680 train_time:38017ms step_avg:87.20ms +step:437/1680 train_time:38105ms step_avg:87.20ms +step:438/1680 train_time:38192ms step_avg:87.20ms +step:439/1680 train_time:38280ms step_avg:87.20ms +step:440/1680 train_time:38368ms step_avg:87.20ms +step:441/1680 train_time:38454ms step_avg:87.20ms +step:442/1680 train_time:38541ms step_avg:87.20ms +step:443/1680 train_time:38627ms step_avg:87.20ms +step:444/1680 train_time:38715ms step_avg:87.20ms +step:445/1680 train_time:38802ms step_avg:87.20ms +step:446/1680 train_time:38889ms step_avg:87.19ms +step:447/1680 train_time:38977ms step_avg:87.20ms +step:448/1680 train_time:39065ms step_avg:87.20ms +step:449/1680 train_time:39152ms step_avg:87.20ms +step:450/1680 train_time:39239ms step_avg:87.20ms +step:451/1680 train_time:39326ms step_avg:87.20ms +step:452/1680 train_time:39414ms step_avg:87.20ms +step:453/1680 train_time:39500ms step_avg:87.20ms +step:454/1680 train_time:39587ms step_avg:87.20ms +step:455/1680 train_time:39674ms step_avg:87.20ms +step:456/1680 train_time:39761ms step_avg:87.19ms +step:457/1680 train_time:39848ms step_avg:87.19ms +step:458/1680 train_time:39935ms step_avg:87.19ms +step:459/1680 train_time:40022ms step_avg:87.19ms +step:460/1680 train_time:40110ms step_avg:87.20ms +step:461/1680 train_time:40198ms step_avg:87.20ms +step:462/1680 train_time:40285ms step_avg:87.20ms +step:463/1680 train_time:40372ms step_avg:87.20ms +step:464/1680 train_time:40459ms step_avg:87.20ms +step:465/1680 train_time:40546ms step_avg:87.19ms +step:466/1680 train_time:40633ms step_avg:87.19ms +step:467/1680 train_time:40720ms step_avg:87.19ms +step:468/1680 train_time:40808ms step_avg:87.20ms +step:469/1680 train_time:40895ms step_avg:87.20ms +step:470/1680 train_time:40982ms step_avg:87.20ms +step:471/1680 train_time:41069ms step_avg:87.20ms +step:472/1680 train_time:41156ms step_avg:87.20ms +step:473/1680 train_time:41244ms step_avg:87.20ms +step:474/1680 train_time:41331ms step_avg:87.20ms +step:475/1680 train_time:41417ms step_avg:87.19ms +step:476/1680 train_time:41504ms step_avg:87.19ms +step:477/1680 train_time:41591ms step_avg:87.19ms +step:478/1680 train_time:41679ms step_avg:87.19ms +step:479/1680 train_time:41767ms step_avg:87.20ms +step:480/1680 train_time:41854ms step_avg:87.19ms +step:481/1680 train_time:41940ms step_avg:87.19ms +step:482/1680 train_time:42028ms step_avg:87.19ms +step:483/1680 train_time:42115ms step_avg:87.19ms +step:484/1680 train_time:42202ms step_avg:87.19ms +step:485/1680 train_time:42289ms step_avg:87.19ms +step:486/1680 train_time:42377ms step_avg:87.19ms +step:487/1680 train_time:42463ms step_avg:87.19ms +step:488/1680 train_time:42550ms step_avg:87.19ms +step:489/1680 train_time:42637ms step_avg:87.19ms +step:490/1680 train_time:42725ms step_avg:87.19ms +step:491/1680 train_time:42812ms step_avg:87.19ms +step:492/1680 train_time:42900ms step_avg:87.20ms +step:493/1680 train_time:42987ms step_avg:87.19ms +step:494/1680 train_time:43074ms step_avg:87.19ms +step:495/1680 train_time:43161ms step_avg:87.19ms +step:496/1680 train_time:43248ms step_avg:87.19ms +step:497/1680 train_time:43335ms step_avg:87.19ms +step:498/1680 train_time:43422ms step_avg:87.19ms +step:499/1680 train_time:43509ms step_avg:87.19ms +step:500/1680 train_time:43596ms step_avg:87.19ms +step:500/1680 val_loss:3.7220 train_time:43685ms step_avg:87.37ms +step:501/1680 train_time:43704ms step_avg:87.23ms +step:502/1680 train_time:43773ms step_avg:87.20ms +step:503/1680 train_time:43864ms step_avg:87.21ms +step:504/1680 train_time:43956ms step_avg:87.21ms +step:505/1680 train_time:44042ms step_avg:87.21ms +step:506/1680 train_time:44129ms step_avg:87.21ms +step:507/1680 train_time:44215ms step_avg:87.21ms +step:508/1680 train_time:44301ms step_avg:87.21ms +step:509/1680 train_time:44387ms step_avg:87.20ms +step:510/1680 train_time:44474ms step_avg:87.20ms +step:511/1680 train_time:44560ms step_avg:87.20ms +step:512/1680 train_time:44646ms step_avg:87.20ms +step:513/1680 train_time:44735ms step_avg:87.20ms +step:514/1680 train_time:44825ms step_avg:87.21ms +step:515/1680 train_time:44914ms step_avg:87.21ms +step:516/1680 train_time:45002ms step_avg:87.21ms +step:517/1680 train_time:45089ms step_avg:87.21ms +step:518/1680 train_time:45175ms step_avg:87.21ms +step:519/1680 train_time:45262ms step_avg:87.21ms +step:520/1680 train_time:45348ms step_avg:87.21ms +step:521/1680 train_time:45434ms step_avg:87.21ms +step:522/1680 train_time:45520ms step_avg:87.20ms +step:523/1680 train_time:45607ms step_avg:87.20ms +step:524/1680 train_time:45694ms step_avg:87.20ms +step:525/1680 train_time:45782ms step_avg:87.20ms +step:526/1680 train_time:45870ms step_avg:87.21ms +step:527/1680 train_time:45959ms step_avg:87.21ms +step:528/1680 train_time:46047ms step_avg:87.21ms +step:529/1680 train_time:46134ms step_avg:87.21ms +step:530/1680 train_time:46221ms step_avg:87.21ms +step:531/1680 train_time:46308ms step_avg:87.21ms +step:532/1680 train_time:46395ms step_avg:87.21ms +step:533/1680 train_time:46481ms step_avg:87.21ms +step:534/1680 train_time:46568ms step_avg:87.21ms +step:535/1680 train_time:46655ms step_avg:87.20ms +step:536/1680 train_time:46742ms step_avg:87.20ms +step:537/1680 train_time:46829ms step_avg:87.21ms +step:538/1680 train_time:46917ms step_avg:87.21ms +step:539/1680 train_time:47004ms step_avg:87.21ms +step:540/1680 train_time:47092ms step_avg:87.21ms +step:541/1680 train_time:47179ms step_avg:87.21ms +step:542/1680 train_time:47266ms step_avg:87.21ms +step:543/1680 train_time:47352ms step_avg:87.20ms +step:544/1680 train_time:47439ms step_avg:87.20ms +step:545/1680 train_time:47526ms step_avg:87.20ms +step:546/1680 train_time:47613ms step_avg:87.20ms +step:547/1680 train_time:47700ms step_avg:87.20ms +step:548/1680 train_time:47787ms step_avg:87.20ms +step:549/1680 train_time:47876ms step_avg:87.21ms +step:550/1680 train_time:47965ms step_avg:87.21ms +step:551/1680 train_time:48053ms step_avg:87.21ms +step:552/1680 train_time:48141ms step_avg:87.21ms +step:553/1680 train_time:48230ms step_avg:87.22ms +step:554/1680 train_time:48318ms step_avg:87.22ms +step:555/1680 train_time:48406ms step_avg:87.22ms +step:556/1680 train_time:48493ms step_avg:87.22ms +step:557/1680 train_time:48582ms step_avg:87.22ms +step:558/1680 train_time:48669ms step_avg:87.22ms +step:559/1680 train_time:48758ms step_avg:87.22ms +step:560/1680 train_time:48847ms step_avg:87.23ms +step:561/1680 train_time:48935ms step_avg:87.23ms +step:562/1680 train_time:49023ms step_avg:87.23ms +step:563/1680 train_time:49112ms step_avg:87.23ms +step:564/1680 train_time:49201ms step_avg:87.24ms +step:565/1680 train_time:49289ms step_avg:87.24ms +step:566/1680 train_time:49377ms step_avg:87.24ms +step:567/1680 train_time:49466ms step_avg:87.24ms +step:568/1680 train_time:49554ms step_avg:87.24ms +step:569/1680 train_time:49643ms step_avg:87.25ms +step:570/1680 train_time:49731ms step_avg:87.25ms +step:571/1680 train_time:49820ms step_avg:87.25ms +step:572/1680 train_time:49908ms step_avg:87.25ms +step:573/1680 train_time:49996ms step_avg:87.25ms +step:574/1680 train_time:50085ms step_avg:87.26ms +step:575/1680 train_time:50173ms step_avg:87.26ms +step:576/1680 train_time:50262ms step_avg:87.26ms +step:577/1680 train_time:50350ms step_avg:87.26ms +step:578/1680 train_time:50438ms step_avg:87.26ms +step:579/1680 train_time:50526ms step_avg:87.26ms +step:580/1680 train_time:50614ms step_avg:87.27ms +step:581/1680 train_time:50703ms step_avg:87.27ms +step:582/1680 train_time:50791ms step_avg:87.27ms +step:583/1680 train_time:50879ms step_avg:87.27ms +step:584/1680 train_time:50968ms step_avg:87.27ms +step:585/1680 train_time:51057ms step_avg:87.28ms +step:586/1680 train_time:51145ms step_avg:87.28ms +step:587/1680 train_time:51234ms step_avg:87.28ms +step:588/1680 train_time:51322ms step_avg:87.28ms +step:589/1680 train_time:51410ms step_avg:87.28ms +step:590/1680 train_time:51497ms step_avg:87.28ms +step:591/1680 train_time:51586ms step_avg:87.29ms +step:592/1680 train_time:51674ms step_avg:87.29ms +step:593/1680 train_time:51762ms step_avg:87.29ms +step:594/1680 train_time:51850ms step_avg:87.29ms +step:595/1680 train_time:51938ms step_avg:87.29ms +step:596/1680 train_time:52027ms step_avg:87.29ms +step:597/1680 train_time:52116ms step_avg:87.30ms +step:598/1680 train_time:52204ms step_avg:87.30ms +step:599/1680 train_time:52292ms step_avg:87.30ms +step:600/1680 train_time:52381ms step_avg:87.30ms +step:601/1680 train_time:52469ms step_avg:87.30ms +step:602/1680 train_time:52556ms step_avg:87.30ms +step:603/1680 train_time:52644ms step_avg:87.30ms +step:604/1680 train_time:52732ms step_avg:87.30ms +step:605/1680 train_time:52821ms step_avg:87.31ms +step:606/1680 train_time:52909ms step_avg:87.31ms +step:607/1680 train_time:52998ms step_avg:87.31ms +step:608/1680 train_time:53086ms step_avg:87.31ms +step:609/1680 train_time:53174ms step_avg:87.31ms +step:610/1680 train_time:53262ms step_avg:87.32ms +step:611/1680 train_time:53351ms step_avg:87.32ms +step:612/1680 train_time:53440ms step_avg:87.32ms +step:613/1680 train_time:53528ms step_avg:87.32ms +step:614/1680 train_time:53616ms step_avg:87.32ms +step:615/1680 train_time:53704ms step_avg:87.32ms +step:616/1680 train_time:53792ms step_avg:87.32ms +step:617/1680 train_time:53880ms step_avg:87.33ms +step:618/1680 train_time:53968ms step_avg:87.33ms +step:619/1680 train_time:54057ms step_avg:87.33ms +step:620/1680 train_time:54145ms step_avg:87.33ms +step:621/1680 train_time:54233ms step_avg:87.33ms +step:622/1680 train_time:54322ms step_avg:87.33ms +step:623/1680 train_time:54410ms step_avg:87.34ms +step:624/1680 train_time:54498ms step_avg:87.34ms +step:625/1680 train_time:54586ms step_avg:87.34ms +step:625/1680 val_loss:3.6192 train_time:54676ms step_avg:87.48ms +step:626/1680 train_time:54696ms step_avg:87.37ms +step:627/1680 train_time:54767ms step_avg:87.35ms +step:628/1680 train_time:54857ms step_avg:87.35ms +step:629/1680 train_time:54949ms step_avg:87.36ms +step:630/1680 train_time:55037ms step_avg:87.36ms +step:631/1680 train_time:55124ms step_avg:87.36ms +step:632/1680 train_time:55211ms step_avg:87.36ms +step:633/1680 train_time:55298ms step_avg:87.36ms +step:634/1680 train_time:55385ms step_avg:87.36ms +step:635/1680 train_time:55473ms step_avg:87.36ms +step:636/1680 train_time:55561ms step_avg:87.36ms +step:637/1680 train_time:55651ms step_avg:87.36ms +step:638/1680 train_time:55740ms step_avg:87.37ms +step:639/1680 train_time:55831ms step_avg:87.37ms +step:640/1680 train_time:55921ms step_avg:87.38ms +step:641/1680 train_time:56009ms step_avg:87.38ms +step:642/1680 train_time:56096ms step_avg:87.38ms +step:643/1680 train_time:56184ms step_avg:87.38ms +step:644/1680 train_time:56272ms step_avg:87.38ms +step:645/1680 train_time:56359ms step_avg:87.38ms +step:646/1680 train_time:56447ms step_avg:87.38ms +step:647/1680 train_time:56535ms step_avg:87.38ms +step:648/1680 train_time:56624ms step_avg:87.38ms +step:649/1680 train_time:56712ms step_avg:87.38ms +step:650/1680 train_time:56801ms step_avg:87.39ms +step:651/1680 train_time:56890ms step_avg:87.39ms +step:652/1680 train_time:56979ms step_avg:87.39ms +step:653/1680 train_time:57067ms step_avg:87.39ms +step:654/1680 train_time:57156ms step_avg:87.39ms +step:655/1680 train_time:57244ms step_avg:87.40ms +step:656/1680 train_time:57332ms step_avg:87.40ms +step:657/1680 train_time:57419ms step_avg:87.40ms +step:658/1680 train_time:57507ms step_avg:87.40ms +step:659/1680 train_time:57596ms step_avg:87.40ms +step:660/1680 train_time:57685ms step_avg:87.40ms +step:661/1680 train_time:57775ms step_avg:87.41ms +step:662/1680 train_time:57864ms step_avg:87.41ms +step:663/1680 train_time:57953ms step_avg:87.41ms +step:664/1680 train_time:58041ms step_avg:87.41ms +step:665/1680 train_time:58130ms step_avg:87.41ms +step:666/1680 train_time:58218ms step_avg:87.41ms +step:667/1680 train_time:58306ms step_avg:87.41ms +step:668/1680 train_time:58394ms step_avg:87.42ms +step:669/1680 train_time:58481ms step_avg:87.42ms +step:670/1680 train_time:58569ms step_avg:87.42ms +step:671/1680 train_time:58658ms step_avg:87.42ms +step:672/1680 train_time:58746ms step_avg:87.42ms +step:673/1680 train_time:58836ms step_avg:87.42ms +step:674/1680 train_time:58925ms step_avg:87.43ms +step:675/1680 train_time:59014ms step_avg:87.43ms +step:676/1680 train_time:59102ms step_avg:87.43ms +step:677/1680 train_time:59189ms step_avg:87.43ms +step:678/1680 train_time:59277ms step_avg:87.43ms +step:679/1680 train_time:59366ms step_avg:87.43ms +step:680/1680 train_time:59454ms step_avg:87.43ms +step:681/1680 train_time:59541ms step_avg:87.43ms +step:682/1680 train_time:59630ms step_avg:87.43ms +step:683/1680 train_time:59718ms step_avg:87.44ms +step:684/1680 train_time:59807ms step_avg:87.44ms +step:685/1680 train_time:59896ms step_avg:87.44ms +step:686/1680 train_time:59985ms step_avg:87.44ms +step:687/1680 train_time:60073ms step_avg:87.44ms +step:688/1680 train_time:60161ms step_avg:87.44ms +step:689/1680 train_time:60250ms step_avg:87.45ms +step:690/1680 train_time:60338ms step_avg:87.45ms +step:691/1680 train_time:60426ms step_avg:87.45ms +step:692/1680 train_time:60514ms step_avg:87.45ms +step:693/1680 train_time:60602ms step_avg:87.45ms +step:694/1680 train_time:60691ms step_avg:87.45ms +step:695/1680 train_time:60779ms step_avg:87.45ms +step:696/1680 train_time:60868ms step_avg:87.45ms +step:697/1680 train_time:60957ms step_avg:87.46ms +step:698/1680 train_time:61045ms step_avg:87.46ms +step:699/1680 train_time:61133ms step_avg:87.46ms +step:700/1680 train_time:61222ms step_avg:87.46ms +step:701/1680 train_time:61310ms step_avg:87.46ms +step:702/1680 train_time:61398ms step_avg:87.46ms +step:703/1680 train_time:61486ms step_avg:87.46ms +step:704/1680 train_time:61575ms step_avg:87.46ms +step:705/1680 train_time:61663ms step_avg:87.47ms +step:706/1680 train_time:61752ms step_avg:87.47ms +step:707/1680 train_time:61840ms step_avg:87.47ms +step:708/1680 train_time:61929ms step_avg:87.47ms +step:709/1680 train_time:62018ms step_avg:87.47ms +step:710/1680 train_time:62106ms step_avg:87.47ms +step:711/1680 train_time:62194ms step_avg:87.47ms +step:712/1680 train_time:62283ms step_avg:87.48ms +step:713/1680 train_time:62371ms step_avg:87.48ms +step:714/1680 train_time:62459ms step_avg:87.48ms +step:715/1680 train_time:62547ms step_avg:87.48ms +step:716/1680 train_time:62635ms step_avg:87.48ms +step:717/1680 train_time:62724ms step_avg:87.48ms +step:718/1680 train_time:62813ms step_avg:87.48ms +step:719/1680 train_time:62900ms step_avg:87.48ms +step:720/1680 train_time:62988ms step_avg:87.48ms +step:721/1680 train_time:63076ms step_avg:87.48ms +step:722/1680 train_time:63165ms step_avg:87.49ms +step:723/1680 train_time:63254ms step_avg:87.49ms +step:724/1680 train_time:63341ms step_avg:87.49ms +step:725/1680 train_time:63430ms step_avg:87.49ms +step:726/1680 train_time:63518ms step_avg:87.49ms +step:727/1680 train_time:63606ms step_avg:87.49ms +step:728/1680 train_time:63695ms step_avg:87.49ms +step:729/1680 train_time:63784ms step_avg:87.50ms +step:730/1680 train_time:63873ms step_avg:87.50ms +step:731/1680 train_time:63960ms step_avg:87.50ms +step:732/1680 train_time:64049ms step_avg:87.50ms +step:733/1680 train_time:64137ms step_avg:87.50ms +step:734/1680 train_time:64225ms step_avg:87.50ms +step:735/1680 train_time:64314ms step_avg:87.50ms +step:736/1680 train_time:64402ms step_avg:87.50ms +step:737/1680 train_time:64489ms step_avg:87.50ms +step:738/1680 train_time:64578ms step_avg:87.50ms +step:739/1680 train_time:64666ms step_avg:87.51ms +step:740/1680 train_time:64755ms step_avg:87.51ms +step:741/1680 train_time:64844ms step_avg:87.51ms +step:742/1680 train_time:64932ms step_avg:87.51ms +step:743/1680 train_time:65020ms step_avg:87.51ms +step:744/1680 train_time:65109ms step_avg:87.51ms +step:745/1680 train_time:65197ms step_avg:87.51ms +step:746/1680 train_time:65286ms step_avg:87.51ms +step:747/1680 train_time:65374ms step_avg:87.51ms +step:748/1680 train_time:65462ms step_avg:87.52ms +step:749/1680 train_time:65549ms step_avg:87.52ms +step:750/1680 train_time:65637ms step_avg:87.52ms +step:750/1680 val_loss:3.5679 train_time:65727ms step_avg:87.64ms +step:751/1680 train_time:65745ms step_avg:87.54ms +step:752/1680 train_time:65818ms step_avg:87.52ms +step:753/1680 train_time:65914ms step_avg:87.53ms +step:754/1680 train_time:66003ms step_avg:87.54ms +step:755/1680 train_time:66090ms step_avg:87.54ms +step:756/1680 train_time:66178ms step_avg:87.54ms +step:757/1680 train_time:66265ms step_avg:87.54ms +step:758/1680 train_time:66352ms step_avg:87.54ms +step:759/1680 train_time:66439ms step_avg:87.53ms +step:760/1680 train_time:66526ms step_avg:87.53ms +step:761/1680 train_time:66613ms step_avg:87.53ms +step:762/1680 train_time:66703ms step_avg:87.54ms +step:763/1680 train_time:66792ms step_avg:87.54ms +step:764/1680 train_time:66884ms step_avg:87.54ms +step:765/1680 train_time:66974ms step_avg:87.55ms +step:766/1680 train_time:67063ms step_avg:87.55ms +step:767/1680 train_time:67151ms step_avg:87.55ms +step:768/1680 train_time:67239ms step_avg:87.55ms +step:769/1680 train_time:67327ms step_avg:87.55ms +step:770/1680 train_time:67414ms step_avg:87.55ms +step:771/1680 train_time:67502ms step_avg:87.55ms +step:772/1680 train_time:67589ms step_avg:87.55ms +step:773/1680 train_time:67677ms step_avg:87.55ms +step:774/1680 train_time:67766ms step_avg:87.55ms +step:775/1680 train_time:67856ms step_avg:87.56ms +step:776/1680 train_time:67947ms step_avg:87.56ms +step:777/1680 train_time:68035ms step_avg:87.56ms +step:778/1680 train_time:68123ms step_avg:87.56ms +step:779/1680 train_time:68211ms step_avg:87.56ms +step:780/1680 train_time:68299ms step_avg:87.56ms +step:781/1680 train_time:68387ms step_avg:87.56ms +step:782/1680 train_time:68475ms step_avg:87.56ms +step:783/1680 train_time:68562ms step_avg:87.56ms +step:784/1680 train_time:68651ms step_avg:87.56ms +step:785/1680 train_time:68739ms step_avg:87.57ms +step:786/1680 train_time:68827ms step_avg:87.57ms +step:787/1680 train_time:68917ms step_avg:87.57ms +step:788/1680 train_time:69007ms step_avg:87.57ms +step:789/1680 train_time:69096ms step_avg:87.57ms +step:790/1680 train_time:69184ms step_avg:87.57ms +step:791/1680 train_time:69272ms step_avg:87.57ms +step:792/1680 train_time:69359ms step_avg:87.58ms +step:793/1680 train_time:69447ms step_avg:87.58ms +step:794/1680 train_time:69536ms step_avg:87.58ms +step:795/1680 train_time:69625ms step_avg:87.58ms +step:796/1680 train_time:69713ms step_avg:87.58ms +step:797/1680 train_time:69803ms step_avg:87.58ms +step:798/1680 train_time:69891ms step_avg:87.58ms +step:799/1680 train_time:69979ms step_avg:87.58ms +step:800/1680 train_time:70067ms step_avg:87.58ms +step:801/1680 train_time:70155ms step_avg:87.58ms +step:802/1680 train_time:70243ms step_avg:87.59ms +step:803/1680 train_time:70331ms step_avg:87.59ms +step:804/1680 train_time:70418ms step_avg:87.59ms +step:805/1680 train_time:70507ms step_avg:87.59ms +step:806/1680 train_time:70595ms step_avg:87.59ms +step:807/1680 train_time:70683ms step_avg:87.59ms +step:808/1680 train_time:70771ms step_avg:87.59ms +step:809/1680 train_time:70859ms step_avg:87.59ms +step:810/1680 train_time:70947ms step_avg:87.59ms +step:811/1680 train_time:71036ms step_avg:87.59ms +step:812/1680 train_time:71123ms step_avg:87.59ms +step:813/1680 train_time:71212ms step_avg:87.59ms +step:814/1680 train_time:71300ms step_avg:87.59ms +step:815/1680 train_time:71388ms step_avg:87.59ms +step:816/1680 train_time:71476ms step_avg:87.59ms +step:817/1680 train_time:71564ms step_avg:87.59ms +step:818/1680 train_time:71651ms step_avg:87.59ms +step:819/1680 train_time:71741ms step_avg:87.60ms +step:820/1680 train_time:71828ms step_avg:87.60ms +step:821/1680 train_time:71917ms step_avg:87.60ms +step:822/1680 train_time:72006ms step_avg:87.60ms +step:823/1680 train_time:72094ms step_avg:87.60ms +step:824/1680 train_time:72182ms step_avg:87.60ms +step:825/1680 train_time:72270ms step_avg:87.60ms +step:826/1680 train_time:72358ms step_avg:87.60ms +step:827/1680 train_time:72446ms step_avg:87.60ms +step:828/1680 train_time:72534ms step_avg:87.60ms +step:829/1680 train_time:72622ms step_avg:87.60ms +step:830/1680 train_time:72710ms step_avg:87.60ms +step:831/1680 train_time:72798ms step_avg:87.60ms +step:832/1680 train_time:72886ms step_avg:87.60ms +step:833/1680 train_time:72975ms step_avg:87.60ms +step:834/1680 train_time:73063ms step_avg:87.60ms +step:835/1680 train_time:73151ms step_avg:87.61ms +step:836/1680 train_time:73239ms step_avg:87.61ms +step:837/1680 train_time:73327ms step_avg:87.61ms +step:838/1680 train_time:73416ms step_avg:87.61ms +step:839/1680 train_time:73504ms step_avg:87.61ms +step:840/1680 train_time:73592ms step_avg:87.61ms +step:841/1680 train_time:73680ms step_avg:87.61ms +step:842/1680 train_time:73768ms step_avg:87.61ms +step:843/1680 train_time:73857ms step_avg:87.61ms +step:844/1680 train_time:73945ms step_avg:87.61ms +step:845/1680 train_time:74033ms step_avg:87.61ms +step:846/1680 train_time:74121ms step_avg:87.61ms +step:847/1680 train_time:74209ms step_avg:87.61ms +step:848/1680 train_time:74298ms step_avg:87.62ms +step:849/1680 train_time:74386ms step_avg:87.62ms +step:850/1680 train_time:74474ms step_avg:87.62ms +step:851/1680 train_time:74563ms step_avg:87.62ms +step:852/1680 train_time:74652ms step_avg:87.62ms +step:853/1680 train_time:74740ms step_avg:87.62ms +step:854/1680 train_time:74829ms step_avg:87.62ms +step:855/1680 train_time:74917ms step_avg:87.62ms +step:856/1680 train_time:75006ms step_avg:87.62ms +step:857/1680 train_time:75095ms step_avg:87.62ms +step:858/1680 train_time:75183ms step_avg:87.63ms +step:859/1680 train_time:75271ms step_avg:87.63ms +step:860/1680 train_time:75359ms step_avg:87.63ms +step:861/1680 train_time:75447ms step_avg:87.63ms +step:862/1680 train_time:75535ms step_avg:87.63ms +step:863/1680 train_time:75623ms step_avg:87.63ms +step:864/1680 train_time:75711ms step_avg:87.63ms +step:865/1680 train_time:75799ms step_avg:87.63ms +step:866/1680 train_time:75887ms step_avg:87.63ms +step:867/1680 train_time:75976ms step_avg:87.63ms +step:868/1680 train_time:76065ms step_avg:87.63ms +step:869/1680 train_time:76153ms step_avg:87.63ms +step:870/1680 train_time:76242ms step_avg:87.63ms +step:871/1680 train_time:76330ms step_avg:87.63ms +step:872/1680 train_time:76418ms step_avg:87.64ms +step:873/1680 train_time:76506ms step_avg:87.64ms +step:874/1680 train_time:76594ms step_avg:87.64ms +step:875/1680 train_time:76682ms step_avg:87.64ms +step:875/1680 val_loss:3.5211 train_time:76771ms step_avg:87.74ms +step:876/1680 train_time:76790ms step_avg:87.66ms +step:877/1680 train_time:76863ms step_avg:87.64ms +step:878/1680 train_time:76954ms step_avg:87.65ms +step:879/1680 train_time:77043ms step_avg:87.65ms +step:880/1680 train_time:77131ms step_avg:87.65ms +step:881/1680 train_time:77218ms step_avg:87.65ms +step:882/1680 train_time:77306ms step_avg:87.65ms +step:883/1680 train_time:77392ms step_avg:87.65ms +step:884/1680 train_time:77479ms step_avg:87.65ms +step:885/1680 train_time:77567ms step_avg:87.65ms +step:886/1680 train_time:77655ms step_avg:87.65ms +step:887/1680 train_time:77745ms step_avg:87.65ms +step:888/1680 train_time:77836ms step_avg:87.65ms +step:889/1680 train_time:77927ms step_avg:87.66ms +step:890/1680 train_time:78016ms step_avg:87.66ms +step:891/1680 train_time:78105ms step_avg:87.66ms +step:892/1680 train_time:78193ms step_avg:87.66ms +step:893/1680 train_time:78281ms step_avg:87.66ms +step:894/1680 train_time:78368ms step_avg:87.66ms +step:895/1680 train_time:78455ms step_avg:87.66ms +step:896/1680 train_time:78542ms step_avg:87.66ms +step:897/1680 train_time:78630ms step_avg:87.66ms +step:898/1680 train_time:78718ms step_avg:87.66ms +step:899/1680 train_time:78808ms step_avg:87.66ms +step:900/1680 train_time:78898ms step_avg:87.66ms +step:901/1680 train_time:78987ms step_avg:87.67ms +step:902/1680 train_time:79076ms step_avg:87.67ms +step:903/1680 train_time:79165ms step_avg:87.67ms +step:904/1680 train_time:79253ms step_avg:87.67ms +step:905/1680 train_time:79341ms step_avg:87.67ms +step:906/1680 train_time:79429ms step_avg:87.67ms +step:907/1680 train_time:79516ms step_avg:87.67ms +step:908/1680 train_time:79603ms step_avg:87.67ms +step:909/1680 train_time:79692ms step_avg:87.67ms +step:910/1680 train_time:79780ms step_avg:87.67ms +step:911/1680 train_time:79870ms step_avg:87.67ms +step:912/1680 train_time:79959ms step_avg:87.67ms +step:913/1680 train_time:80048ms step_avg:87.68ms +step:914/1680 train_time:80136ms step_avg:87.68ms +step:915/1680 train_time:80225ms step_avg:87.68ms +step:916/1680 train_time:80313ms step_avg:87.68ms +step:917/1680 train_time:80401ms step_avg:87.68ms +step:918/1680 train_time:80488ms step_avg:87.68ms +step:919/1680 train_time:80576ms step_avg:87.68ms +step:920/1680 train_time:80665ms step_avg:87.68ms +step:921/1680 train_time:80753ms step_avg:87.68ms +step:922/1680 train_time:80841ms step_avg:87.68ms +step:923/1680 train_time:80930ms step_avg:87.68ms +step:924/1680 train_time:81019ms step_avg:87.68ms +step:925/1680 train_time:81108ms step_avg:87.68ms +step:926/1680 train_time:81196ms step_avg:87.68ms +step:927/1680 train_time:81284ms step_avg:87.69ms +step:928/1680 train_time:81372ms step_avg:87.69ms +step:929/1680 train_time:81460ms step_avg:87.69ms +step:930/1680 train_time:81548ms step_avg:87.69ms +step:931/1680 train_time:81636ms step_avg:87.69ms +step:932/1680 train_time:81724ms step_avg:87.69ms +step:933/1680 train_time:81812ms step_avg:87.69ms +step:934/1680 train_time:81901ms step_avg:87.69ms +step:935/1680 train_time:81991ms step_avg:87.69ms +step:936/1680 train_time:82079ms step_avg:87.69ms +step:937/1680 train_time:82168ms step_avg:87.69ms +step:938/1680 train_time:82256ms step_avg:87.69ms +step:939/1680 train_time:82345ms step_avg:87.69ms +step:940/1680 train_time:82433ms step_avg:87.69ms +step:941/1680 train_time:82521ms step_avg:87.69ms +step:942/1680 train_time:82609ms step_avg:87.70ms +step:943/1680 train_time:82697ms step_avg:87.70ms +step:944/1680 train_time:82786ms step_avg:87.70ms +step:945/1680 train_time:82875ms step_avg:87.70ms +step:946/1680 train_time:82963ms step_avg:87.70ms +step:947/1680 train_time:83053ms step_avg:87.70ms +step:948/1680 train_time:83141ms step_avg:87.70ms +step:949/1680 train_time:83229ms step_avg:87.70ms +step:950/1680 train_time:83317ms step_avg:87.70ms +step:951/1680 train_time:83405ms step_avg:87.70ms +step:952/1680 train_time:83494ms step_avg:87.70ms +step:953/1680 train_time:83582ms step_avg:87.70ms +step:954/1680 train_time:83670ms step_avg:87.70ms +step:955/1680 train_time:83758ms step_avg:87.70ms +step:956/1680 train_time:83847ms step_avg:87.71ms +step:957/1680 train_time:83935ms step_avg:87.71ms +step:958/1680 train_time:84023ms step_avg:87.71ms +step:959/1680 train_time:84112ms step_avg:87.71ms +step:960/1680 train_time:84200ms step_avg:87.71ms +step:961/1680 train_time:84288ms step_avg:87.71ms +step:962/1680 train_time:84376ms step_avg:87.71ms +step:963/1680 train_time:84465ms step_avg:87.71ms +step:964/1680 train_time:84553ms step_avg:87.71ms +step:965/1680 train_time:84641ms step_avg:87.71ms +step:966/1680 train_time:84730ms step_avg:87.71ms +step:967/1680 train_time:84818ms step_avg:87.71ms +step:968/1680 train_time:84906ms step_avg:87.71ms +step:969/1680 train_time:84995ms step_avg:87.71ms +step:970/1680 train_time:85084ms step_avg:87.72ms +step:971/1680 train_time:85172ms step_avg:87.72ms +step:972/1680 train_time:85260ms step_avg:87.72ms +step:973/1680 train_time:85349ms step_avg:87.72ms +step:974/1680 train_time:85437ms step_avg:87.72ms +step:975/1680 train_time:85524ms step_avg:87.72ms +step:976/1680 train_time:85613ms step_avg:87.72ms +step:977/1680 train_time:85701ms step_avg:87.72ms +step:978/1680 train_time:85789ms step_avg:87.72ms +step:979/1680 train_time:85877ms step_avg:87.72ms +step:980/1680 train_time:85965ms step_avg:87.72ms +step:981/1680 train_time:86055ms step_avg:87.72ms +step:982/1680 train_time:86144ms step_avg:87.72ms +step:983/1680 train_time:86232ms step_avg:87.72ms +step:984/1680 train_time:86319ms step_avg:87.72ms +step:985/1680 train_time:86408ms step_avg:87.72ms +step:986/1680 train_time:86496ms step_avg:87.72ms +step:987/1680 train_time:86585ms step_avg:87.73ms +step:988/1680 train_time:86673ms step_avg:87.73ms +step:989/1680 train_time:86761ms step_avg:87.73ms +step:990/1680 train_time:86850ms step_avg:87.73ms +step:991/1680 train_time:86938ms step_avg:87.73ms +step:992/1680 train_time:87027ms step_avg:87.73ms +step:993/1680 train_time:87115ms step_avg:87.73ms +step:994/1680 train_time:87204ms step_avg:87.73ms +step:995/1680 train_time:87292ms step_avg:87.73ms +step:996/1680 train_time:87380ms step_avg:87.73ms +step:997/1680 train_time:87469ms step_avg:87.73ms +step:998/1680 train_time:87557ms step_avg:87.73ms +step:999/1680 train_time:87645ms step_avg:87.73ms +step:1000/1680 train_time:87733ms step_avg:87.73ms +step:1000/1680 val_loss:3.4714 train_time:87822ms step_avg:87.82ms +step:1001/1680 train_time:87841ms step_avg:87.75ms +step:1002/1680 train_time:87915ms step_avg:87.74ms +step:1003/1680 train_time:88007ms step_avg:87.74ms +step:1004/1680 train_time:88097ms step_avg:87.75ms +step:1005/1680 train_time:88185ms step_avg:87.75ms +step:1006/1680 train_time:88273ms step_avg:87.75ms +step:1007/1680 train_time:88359ms step_avg:87.75ms +step:1008/1680 train_time:88447ms step_avg:87.74ms +step:1009/1680 train_time:88534ms step_avg:87.74ms +step:1010/1680 train_time:88621ms step_avg:87.74ms +step:1011/1680 train_time:88709ms step_avg:87.74ms +step:1012/1680 train_time:88798ms step_avg:87.75ms +step:1013/1680 train_time:88888ms step_avg:87.75ms +step:1014/1680 train_time:88979ms step_avg:87.75ms +step:1015/1680 train_time:89069ms step_avg:87.75ms +step:1016/1680 train_time:89158ms step_avg:87.75ms +step:1017/1680 train_time:89246ms step_avg:87.75ms +step:1018/1680 train_time:89334ms step_avg:87.75ms +step:1019/1680 train_time:89421ms step_avg:87.75ms +step:1020/1680 train_time:89509ms step_avg:87.75ms +step:1021/1680 train_time:89597ms step_avg:87.75ms +step:1022/1680 train_time:89684ms step_avg:87.75ms +step:1023/1680 train_time:89773ms step_avg:87.75ms +step:1024/1680 train_time:89861ms step_avg:87.76ms +step:1025/1680 train_time:89950ms step_avg:87.76ms +step:1026/1680 train_time:90040ms step_avg:87.76ms +step:1027/1680 train_time:90128ms step_avg:87.76ms +step:1028/1680 train_time:90218ms step_avg:87.76ms +step:1029/1680 train_time:90306ms step_avg:87.76ms +step:1030/1680 train_time:90394ms step_avg:87.76ms +step:1031/1680 train_time:90481ms step_avg:87.76ms +step:1032/1680 train_time:90569ms step_avg:87.76ms +step:1033/1680 train_time:90657ms step_avg:87.76ms +step:1034/1680 train_time:90745ms step_avg:87.76ms +step:1035/1680 train_time:90833ms step_avg:87.76ms +step:1036/1680 train_time:90923ms step_avg:87.76ms +step:1037/1680 train_time:91012ms step_avg:87.76ms +step:1038/1680 train_time:91101ms step_avg:87.77ms +step:1039/1680 train_time:91189ms step_avg:87.77ms +step:1040/1680 train_time:91278ms step_avg:87.77ms +step:1041/1680 train_time:91366ms step_avg:87.77ms +step:1042/1680 train_time:91454ms step_avg:87.77ms +step:1043/1680 train_time:91542ms step_avg:87.77ms +step:1044/1680 train_time:91629ms step_avg:87.77ms +step:1045/1680 train_time:91717ms step_avg:87.77ms +step:1046/1680 train_time:91806ms step_avg:87.77ms +step:1047/1680 train_time:91895ms step_avg:87.77ms +step:1048/1680 train_time:91985ms step_avg:87.77ms +step:1049/1680 train_time:92074ms step_avg:87.77ms +step:1050/1680 train_time:92162ms step_avg:87.77ms +step:1051/1680 train_time:92251ms step_avg:87.77ms +step:1052/1680 train_time:92339ms step_avg:87.77ms +step:1053/1680 train_time:92427ms step_avg:87.77ms +step:1054/1680 train_time:92515ms step_avg:87.78ms +step:1055/1680 train_time:92603ms step_avg:87.78ms +step:1056/1680 train_time:92691ms step_avg:87.78ms +step:1057/1680 train_time:92779ms step_avg:87.78ms +step:1058/1680 train_time:92867ms step_avg:87.78ms +step:1059/1680 train_time:92956ms step_avg:87.78ms +step:1060/1680 train_time:93046ms step_avg:87.78ms +step:1061/1680 train_time:93134ms step_avg:87.78ms +step:1062/1680 train_time:93222ms step_avg:87.78ms +step:1063/1680 train_time:93311ms step_avg:87.78ms +step:1064/1680 train_time:93399ms step_avg:87.78ms +step:1065/1680 train_time:93487ms step_avg:87.78ms +step:1066/1680 train_time:93575ms step_avg:87.78ms +step:1067/1680 train_time:93663ms step_avg:87.78ms +step:1068/1680 train_time:93752ms step_avg:87.78ms +step:1069/1680 train_time:93840ms step_avg:87.78ms +step:1070/1680 train_time:93929ms step_avg:87.78ms +step:1071/1680 train_time:94017ms step_avg:87.78ms +step:1072/1680 train_time:94107ms step_avg:87.79ms +step:1073/1680 train_time:94196ms step_avg:87.79ms +step:1074/1680 train_time:94283ms step_avg:87.79ms +step:1075/1680 train_time:94372ms step_avg:87.79ms +step:1076/1680 train_time:94460ms step_avg:87.79ms +step:1077/1680 train_time:94548ms step_avg:87.79ms +step:1078/1680 train_time:94636ms step_avg:87.79ms +step:1079/1680 train_time:94725ms step_avg:87.79ms +step:1080/1680 train_time:94813ms step_avg:87.79ms +step:1081/1680 train_time:94901ms step_avg:87.79ms +step:1082/1680 train_time:94990ms step_avg:87.79ms +step:1083/1680 train_time:95078ms step_avg:87.79ms +step:1084/1680 train_time:95167ms step_avg:87.79ms +step:1085/1680 train_time:95255ms step_avg:87.79ms +step:1086/1680 train_time:95343ms step_avg:87.79ms +step:1087/1680 train_time:95432ms step_avg:87.79ms +step:1088/1680 train_time:95520ms step_avg:87.79ms +step:1089/1680 train_time:95608ms step_avg:87.79ms +step:1090/1680 train_time:95697ms step_avg:87.80ms +step:1091/1680 train_time:95786ms step_avg:87.80ms +step:1092/1680 train_time:95874ms step_avg:87.80ms +step:1093/1680 train_time:95963ms step_avg:87.80ms +step:1094/1680 train_time:96051ms step_avg:87.80ms +step:1095/1680 train_time:96139ms step_avg:87.80ms +step:1096/1680 train_time:96229ms step_avg:87.80ms +step:1097/1680 train_time:96317ms step_avg:87.80ms +step:1098/1680 train_time:96406ms step_avg:87.80ms +step:1099/1680 train_time:96495ms step_avg:87.80ms +step:1100/1680 train_time:96584ms step_avg:87.80ms +step:1101/1680 train_time:96672ms step_avg:87.80ms +step:1102/1680 train_time:96761ms step_avg:87.81ms +step:1103/1680 train_time:96850ms step_avg:87.81ms +step:1104/1680 train_time:96939ms step_avg:87.81ms +step:1105/1680 train_time:97028ms step_avg:87.81ms +step:1106/1680 train_time:97117ms step_avg:87.81ms +step:1107/1680 train_time:97206ms step_avg:87.81ms +step:1108/1680 train_time:97297ms step_avg:87.81ms +step:1109/1680 train_time:97387ms step_avg:87.81ms +step:1110/1680 train_time:97475ms step_avg:87.82ms +step:1111/1680 train_time:97565ms step_avg:87.82ms +step:1112/1680 train_time:97654ms step_avg:87.82ms +step:1113/1680 train_time:97744ms step_avg:87.82ms +step:1114/1680 train_time:97832ms step_avg:87.82ms +step:1115/1680 train_time:97922ms step_avg:87.82ms +step:1116/1680 train_time:98011ms step_avg:87.82ms +step:1117/1680 train_time:98100ms step_avg:87.82ms +step:1118/1680 train_time:98188ms step_avg:87.83ms +step:1119/1680 train_time:98278ms step_avg:87.83ms +step:1120/1680 train_time:98366ms step_avg:87.83ms +step:1121/1680 train_time:98455ms step_avg:87.83ms +step:1122/1680 train_time:98544ms step_avg:87.83ms +step:1123/1680 train_time:98633ms step_avg:87.83ms +step:1124/1680 train_time:98722ms step_avg:87.83ms +step:1125/1680 train_time:98811ms step_avg:87.83ms +step:1125/1680 val_loss:3.4185 train_time:98901ms step_avg:87.91ms +step:1126/1680 train_time:98922ms step_avg:87.85ms +step:1127/1680 train_time:98993ms step_avg:87.84ms +step:1128/1680 train_time:99083ms step_avg:87.84ms +step:1129/1680 train_time:99174ms step_avg:87.84ms +step:1130/1680 train_time:99263ms step_avg:87.84ms +step:1131/1680 train_time:99351ms step_avg:87.84ms +step:1132/1680 train_time:99439ms step_avg:87.84ms +step:1133/1680 train_time:99527ms step_avg:87.84ms +step:1134/1680 train_time:99615ms step_avg:87.84ms +step:1135/1680 train_time:99703ms step_avg:87.84ms +step:1136/1680 train_time:99792ms step_avg:87.85ms +step:1137/1680 train_time:99883ms step_avg:87.85ms +step:1138/1680 train_time:99974ms step_avg:87.85ms +step:1139/1680 train_time:100064ms step_avg:87.85ms +step:1140/1680 train_time:100155ms step_avg:87.86ms +step:1141/1680 train_time:100243ms step_avg:87.86ms +step:1142/1680 train_time:100333ms step_avg:87.86ms +step:1143/1680 train_time:100422ms step_avg:87.86ms +step:1144/1680 train_time:100510ms step_avg:87.86ms +step:1145/1680 train_time:100598ms step_avg:87.86ms +step:1146/1680 train_time:100686ms step_avg:87.86ms +step:1147/1680 train_time:100775ms step_avg:87.86ms +step:1148/1680 train_time:100863ms step_avg:87.86ms +step:1149/1680 train_time:100953ms step_avg:87.86ms +step:1150/1680 train_time:101043ms step_avg:87.86ms +step:1151/1680 train_time:101132ms step_avg:87.86ms +step:1152/1680 train_time:101221ms step_avg:87.87ms +step:1153/1680 train_time:101311ms step_avg:87.87ms +step:1154/1680 train_time:101401ms step_avg:87.87ms +step:1155/1680 train_time:101491ms step_avg:87.87ms +step:1156/1680 train_time:101579ms step_avg:87.87ms +step:1157/1680 train_time:101667ms step_avg:87.87ms +step:1158/1680 train_time:101756ms step_avg:87.87ms +step:1159/1680 train_time:101845ms step_avg:87.87ms +step:1160/1680 train_time:101933ms step_avg:87.87ms +step:1161/1680 train_time:102024ms step_avg:87.88ms +step:1162/1680 train_time:102113ms step_avg:87.88ms +step:1163/1680 train_time:102202ms step_avg:87.88ms +step:1164/1680 train_time:102292ms step_avg:87.88ms +step:1165/1680 train_time:102380ms step_avg:87.88ms +step:1166/1680 train_time:102470ms step_avg:87.88ms +step:1167/1680 train_time:102559ms step_avg:87.88ms +step:1168/1680 train_time:102648ms step_avg:87.88ms +step:1169/1680 train_time:102737ms step_avg:87.88ms +step:1170/1680 train_time:102825ms step_avg:87.88ms +step:1171/1680 train_time:102914ms step_avg:87.89ms +step:1172/1680 train_time:103002ms step_avg:87.89ms +step:1173/1680 train_time:103091ms step_avg:87.89ms +step:1174/1680 train_time:103181ms step_avg:87.89ms +step:1175/1680 train_time:103270ms step_avg:87.89ms +step:1176/1680 train_time:103359ms step_avg:87.89ms +step:1177/1680 train_time:103448ms step_avg:87.89ms +step:1178/1680 train_time:103538ms step_avg:87.89ms +step:1179/1680 train_time:103626ms step_avg:87.89ms +step:1180/1680 train_time:103715ms step_avg:87.89ms +step:1181/1680 train_time:103804ms step_avg:87.89ms +step:1182/1680 train_time:103893ms step_avg:87.90ms +step:1183/1680 train_time:103982ms step_avg:87.90ms +step:1184/1680 train_time:104072ms step_avg:87.90ms +step:1185/1680 train_time:104161ms step_avg:87.90ms +step:1186/1680 train_time:104250ms step_avg:87.90ms +step:1187/1680 train_time:104339ms step_avg:87.90ms +step:1188/1680 train_time:104428ms step_avg:87.90ms +step:1189/1680 train_time:104516ms step_avg:87.90ms +step:1190/1680 train_time:104605ms step_avg:87.90ms +step:1191/1680 train_time:104694ms step_avg:87.90ms +step:1192/1680 train_time:104783ms step_avg:87.91ms +step:1193/1680 train_time:104873ms step_avg:87.91ms +step:1194/1680 train_time:104962ms step_avg:87.91ms +step:1195/1680 train_time:105051ms step_avg:87.91ms +step:1196/1680 train_time:105139ms step_avg:87.91ms +step:1197/1680 train_time:105228ms step_avg:87.91ms +step:1198/1680 train_time:105317ms step_avg:87.91ms +step:1199/1680 train_time:105407ms step_avg:87.91ms +step:1200/1680 train_time:105495ms step_avg:87.91ms +step:1201/1680 train_time:105584ms step_avg:87.91ms +step:1202/1680 train_time:105673ms step_avg:87.91ms +step:1203/1680 train_time:105761ms step_avg:87.91ms +step:1204/1680 train_time:105850ms step_avg:87.92ms +step:1205/1680 train_time:105939ms step_avg:87.92ms +step:1206/1680 train_time:106028ms step_avg:87.92ms +step:1207/1680 train_time:106117ms step_avg:87.92ms +step:1208/1680 train_time:106206ms step_avg:87.92ms +step:1209/1680 train_time:106295ms step_avg:87.92ms +step:1210/1680 train_time:106384ms step_avg:87.92ms +step:1211/1680 train_time:106472ms step_avg:87.92ms +step:1212/1680 train_time:106562ms step_avg:87.92ms +step:1213/1680 train_time:106650ms step_avg:87.92ms +step:1214/1680 train_time:106739ms step_avg:87.92ms +step:1215/1680 train_time:106828ms step_avg:87.92ms +step:1216/1680 train_time:106917ms step_avg:87.93ms +step:1217/1680 train_time:107006ms step_avg:87.93ms +step:1218/1680 train_time:107096ms step_avg:87.93ms +step:1219/1680 train_time:107185ms step_avg:87.93ms +step:1220/1680 train_time:107274ms step_avg:87.93ms +step:1221/1680 train_time:107363ms step_avg:87.93ms +step:1222/1680 train_time:107452ms step_avg:87.93ms +step:1223/1680 train_time:107541ms step_avg:87.93ms +step:1224/1680 train_time:107630ms step_avg:87.93ms +step:1225/1680 train_time:107719ms step_avg:87.93ms +step:1226/1680 train_time:107807ms step_avg:87.93ms +step:1227/1680 train_time:107896ms step_avg:87.93ms +step:1228/1680 train_time:107985ms step_avg:87.94ms +step:1229/1680 train_time:108074ms step_avg:87.94ms +step:1230/1680 train_time:108164ms step_avg:87.94ms +step:1231/1680 train_time:108253ms step_avg:87.94ms +step:1232/1680 train_time:108342ms step_avg:87.94ms +step:1233/1680 train_time:108431ms step_avg:87.94ms +step:1234/1680 train_time:108520ms step_avg:87.94ms +step:1235/1680 train_time:108610ms step_avg:87.94ms +step:1236/1680 train_time:108699ms step_avg:87.94ms +step:1237/1680 train_time:108787ms step_avg:87.94ms +step:1238/1680 train_time:108877ms step_avg:87.95ms +step:1239/1680 train_time:108966ms step_avg:87.95ms +step:1240/1680 train_time:109055ms step_avg:87.95ms +step:1241/1680 train_time:109144ms step_avg:87.95ms +step:1242/1680 train_time:109234ms step_avg:87.95ms +step:1243/1680 train_time:109322ms step_avg:87.95ms +step:1244/1680 train_time:109412ms step_avg:87.95ms +step:1245/1680 train_time:109500ms step_avg:87.95ms +step:1246/1680 train_time:109589ms step_avg:87.95ms +step:1247/1680 train_time:109678ms step_avg:87.95ms +step:1248/1680 train_time:109767ms step_avg:87.95ms +step:1249/1680 train_time:109856ms step_avg:87.96ms +step:1250/1680 train_time:109945ms step_avg:87.96ms +step:1250/1680 val_loss:3.3794 train_time:110036ms step_avg:88.03ms +step:1251/1680 train_time:110054ms step_avg:87.97ms +step:1252/1680 train_time:110128ms step_avg:87.96ms +step:1253/1680 train_time:110219ms step_avg:87.96ms +step:1254/1680 train_time:110308ms step_avg:87.96ms +step:1255/1680 train_time:110397ms step_avg:87.97ms +step:1256/1680 train_time:110485ms step_avg:87.97ms +step:1257/1680 train_time:110573ms step_avg:87.97ms +step:1258/1680 train_time:110661ms step_avg:87.97ms +step:1259/1680 train_time:110750ms step_avg:87.97ms +step:1260/1680 train_time:110838ms step_avg:87.97ms +step:1261/1680 train_time:110927ms step_avg:87.97ms +step:1262/1680 train_time:111017ms step_avg:87.97ms +step:1263/1680 train_time:111109ms step_avg:87.97ms +step:1264/1680 train_time:111198ms step_avg:87.97ms +step:1265/1680 train_time:111288ms step_avg:87.97ms +step:1266/1680 train_time:111377ms step_avg:87.98ms +step:1267/1680 train_time:111466ms step_avg:87.98ms +step:1268/1680 train_time:111554ms step_avg:87.98ms +step:1269/1680 train_time:111642ms step_avg:87.98ms +step:1270/1680 train_time:111730ms step_avg:87.98ms +step:1271/1680 train_time:111819ms step_avg:87.98ms +step:1272/1680 train_time:111908ms step_avg:87.98ms +step:1273/1680 train_time:112000ms step_avg:87.98ms +step:1274/1680 train_time:112089ms step_avg:87.98ms +step:1275/1680 train_time:112179ms step_avg:87.98ms +step:1276/1680 train_time:112269ms step_avg:87.99ms +step:1277/1680 train_time:112359ms step_avg:87.99ms +step:1278/1680 train_time:112448ms step_avg:87.99ms +step:1279/1680 train_time:112536ms step_avg:87.99ms +step:1280/1680 train_time:112625ms step_avg:87.99ms +step:1281/1680 train_time:112713ms step_avg:87.99ms +step:1282/1680 train_time:112802ms step_avg:87.99ms +step:1283/1680 train_time:112891ms step_avg:87.99ms +step:1284/1680 train_time:112981ms step_avg:87.99ms +step:1285/1680 train_time:113072ms step_avg:87.99ms +step:1286/1680 train_time:113161ms step_avg:87.99ms +step:1287/1680 train_time:113251ms step_avg:88.00ms +step:1288/1680 train_time:113340ms step_avg:88.00ms +step:1289/1680 train_time:113429ms step_avg:88.00ms +step:1290/1680 train_time:113519ms step_avg:88.00ms +step:1291/1680 train_time:113608ms step_avg:88.00ms +step:1292/1680 train_time:113696ms step_avg:88.00ms +step:1293/1680 train_time:113785ms step_avg:88.00ms +step:1294/1680 train_time:113874ms step_avg:88.00ms +step:1295/1680 train_time:113963ms step_avg:88.00ms +step:1296/1680 train_time:114053ms step_avg:88.00ms +step:1297/1680 train_time:114143ms step_avg:88.01ms +step:1298/1680 train_time:114233ms step_avg:88.01ms +step:1299/1680 train_time:114322ms step_avg:88.01ms +step:1300/1680 train_time:114412ms step_avg:88.01ms +step:1301/1680 train_time:114501ms step_avg:88.01ms +step:1302/1680 train_time:114589ms step_avg:88.01ms +step:1303/1680 train_time:114678ms step_avg:88.01ms +step:1304/1680 train_time:114767ms step_avg:88.01ms +step:1305/1680 train_time:114855ms step_avg:88.01ms +step:1306/1680 train_time:114944ms step_avg:88.01ms +step:1307/1680 train_time:115033ms step_avg:88.01ms +step:1308/1680 train_time:115122ms step_avg:88.01ms +step:1309/1680 train_time:115211ms step_avg:88.01ms +step:1310/1680 train_time:115301ms step_avg:88.02ms +step:1311/1680 train_time:115390ms step_avg:88.02ms +step:1312/1680 train_time:115479ms step_avg:88.02ms +step:1313/1680 train_time:115568ms step_avg:88.02ms +step:1314/1680 train_time:115657ms step_avg:88.02ms +step:1315/1680 train_time:115746ms step_avg:88.02ms +step:1316/1680 train_time:115835ms step_avg:88.02ms +step:1317/1680 train_time:115924ms step_avg:88.02ms +step:1318/1680 train_time:116013ms step_avg:88.02ms +step:1319/1680 train_time:116102ms step_avg:88.02ms +step:1320/1680 train_time:116192ms step_avg:88.02ms +step:1321/1680 train_time:116281ms step_avg:88.02ms +step:1322/1680 train_time:116370ms step_avg:88.03ms +step:1323/1680 train_time:116459ms step_avg:88.03ms +step:1324/1680 train_time:116548ms step_avg:88.03ms +step:1325/1680 train_time:116637ms step_avg:88.03ms +step:1326/1680 train_time:116725ms step_avg:88.03ms +step:1327/1680 train_time:116815ms step_avg:88.03ms +step:1328/1680 train_time:116904ms step_avg:88.03ms +step:1329/1680 train_time:116993ms step_avg:88.03ms +step:1330/1680 train_time:117082ms step_avg:88.03ms +step:1331/1680 train_time:117172ms step_avg:88.03ms +step:1332/1680 train_time:117262ms step_avg:88.03ms +step:1333/1680 train_time:117351ms step_avg:88.04ms +step:1334/1680 train_time:117440ms step_avg:88.04ms +step:1335/1680 train_time:117529ms step_avg:88.04ms +step:1336/1680 train_time:117619ms step_avg:88.04ms +step:1337/1680 train_time:117709ms step_avg:88.04ms +step:1338/1680 train_time:117799ms step_avg:88.04ms +step:1339/1680 train_time:117888ms step_avg:88.04ms +step:1340/1680 train_time:117977ms step_avg:88.04ms +step:1341/1680 train_time:118065ms step_avg:88.04ms +step:1342/1680 train_time:118154ms step_avg:88.04ms +step:1343/1680 train_time:118243ms step_avg:88.04ms +step:1344/1680 train_time:118333ms step_avg:88.05ms +step:1345/1680 train_time:118423ms step_avg:88.05ms +step:1346/1680 train_time:118513ms step_avg:88.05ms +step:1347/1680 train_time:118602ms step_avg:88.05ms +step:1348/1680 train_time:118692ms step_avg:88.05ms +step:1349/1680 train_time:118782ms step_avg:88.05ms +step:1350/1680 train_time:118871ms step_avg:88.05ms +step:1351/1680 train_time:118961ms step_avg:88.05ms +step:1352/1680 train_time:119051ms step_avg:88.06ms +step:1353/1680 train_time:119141ms step_avg:88.06ms +step:1354/1680 train_time:119229ms step_avg:88.06ms +step:1355/1680 train_time:119319ms step_avg:88.06ms +step:1356/1680 train_time:119408ms step_avg:88.06ms +step:1357/1680 train_time:119497ms step_avg:88.06ms +step:1358/1680 train_time:119586ms step_avg:88.06ms +step:1359/1680 train_time:119674ms step_avg:88.06ms +step:1360/1680 train_time:119763ms step_avg:88.06ms +step:1361/1680 train_time:119852ms step_avg:88.06ms +step:1362/1680 train_time:119941ms step_avg:88.06ms +step:1363/1680 train_time:120030ms step_avg:88.06ms +step:1364/1680 train_time:120119ms step_avg:88.06ms +step:1365/1680 train_time:120208ms step_avg:88.06ms +step:1366/1680 train_time:120299ms step_avg:88.07ms +step:1367/1680 train_time:120387ms step_avg:88.07ms +step:1368/1680 train_time:120477ms step_avg:88.07ms +step:1369/1680 train_time:120565ms step_avg:88.07ms +step:1370/1680 train_time:120654ms step_avg:88.07ms +step:1371/1680 train_time:120743ms step_avg:88.07ms +step:1372/1680 train_time:120832ms step_avg:88.07ms +step:1373/1680 train_time:120921ms step_avg:88.07ms +step:1374/1680 train_time:121011ms step_avg:88.07ms +step:1375/1680 train_time:121101ms step_avg:88.07ms +step:1375/1680 val_loss:3.3446 train_time:121191ms step_avg:88.14ms +step:1376/1680 train_time:121209ms step_avg:88.09ms +step:1377/1680 train_time:121282ms step_avg:88.08ms +step:1378/1680 train_time:121377ms step_avg:88.08ms +step:1379/1680 train_time:121467ms step_avg:88.08ms +step:1380/1680 train_time:121555ms step_avg:88.08ms +step:1381/1680 train_time:121643ms step_avg:88.08ms +step:1382/1680 train_time:121730ms step_avg:88.08ms +step:1383/1680 train_time:121818ms step_avg:88.08ms +step:1384/1680 train_time:121905ms step_avg:88.08ms +step:1385/1680 train_time:121993ms step_avg:88.08ms +step:1386/1680 train_time:122081ms step_avg:88.08ms +step:1387/1680 train_time:122171ms step_avg:88.08ms +step:1388/1680 train_time:122263ms step_avg:88.09ms +step:1389/1680 train_time:122355ms step_avg:88.09ms +step:1390/1680 train_time:122446ms step_avg:88.09ms +step:1391/1680 train_time:122536ms step_avg:88.09ms +step:1392/1680 train_time:122624ms step_avg:88.09ms +step:1393/1680 train_time:122712ms step_avg:88.09ms +step:1394/1680 train_time:122800ms step_avg:88.09ms +step:1395/1680 train_time:122889ms step_avg:88.09ms +step:1396/1680 train_time:122976ms step_avg:88.09ms +step:1397/1680 train_time:123065ms step_avg:88.09ms +step:1398/1680 train_time:123155ms step_avg:88.09ms +step:1399/1680 train_time:123246ms step_avg:88.10ms +step:1400/1680 train_time:123336ms step_avg:88.10ms +step:1401/1680 train_time:123425ms step_avg:88.10ms +step:1402/1680 train_time:123514ms step_avg:88.10ms +step:1403/1680 train_time:123603ms step_avg:88.10ms +step:1404/1680 train_time:123692ms step_avg:88.10ms +step:1405/1680 train_time:123781ms step_avg:88.10ms +step:1406/1680 train_time:123869ms step_avg:88.10ms +step:1407/1680 train_time:123957ms step_avg:88.10ms +step:1408/1680 train_time:124046ms step_avg:88.10ms +step:1409/1680 train_time:124135ms step_avg:88.10ms +step:1410/1680 train_time:124225ms step_avg:88.10ms +step:1411/1680 train_time:124315ms step_avg:88.10ms +step:1412/1680 train_time:124405ms step_avg:88.11ms +step:1413/1680 train_time:124495ms step_avg:88.11ms +step:1414/1680 train_time:124584ms step_avg:88.11ms +step:1415/1680 train_time:124674ms step_avg:88.11ms +step:1416/1680 train_time:124763ms step_avg:88.11ms +step:1417/1680 train_time:124852ms step_avg:88.11ms +step:1418/1680 train_time:124941ms step_avg:88.11ms +step:1419/1680 train_time:125029ms step_avg:88.11ms +step:1420/1680 train_time:125118ms step_avg:88.11ms +step:1421/1680 train_time:125207ms step_avg:88.11ms +step:1422/1680 train_time:125297ms step_avg:88.11ms +step:1423/1680 train_time:125386ms step_avg:88.11ms +step:1424/1680 train_time:125476ms step_avg:88.11ms +step:1425/1680 train_time:125565ms step_avg:88.12ms +step:1426/1680 train_time:125654ms step_avg:88.12ms +step:1427/1680 train_time:125743ms step_avg:88.12ms +step:1428/1680 train_time:125832ms step_avg:88.12ms +step:1429/1680 train_time:125921ms step_avg:88.12ms +step:1430/1680 train_time:126009ms step_avg:88.12ms +step:1431/1680 train_time:126099ms step_avg:88.12ms +step:1432/1680 train_time:126188ms step_avg:88.12ms +step:1433/1680 train_time:126277ms step_avg:88.12ms +step:1434/1680 train_time:126366ms step_avg:88.12ms +step:1435/1680 train_time:126455ms step_avg:88.12ms +step:1436/1680 train_time:126545ms step_avg:88.12ms +step:1437/1680 train_time:126634ms step_avg:88.12ms +step:1438/1680 train_time:126723ms step_avg:88.12ms +step:1439/1680 train_time:126812ms step_avg:88.13ms +step:1440/1680 train_time:126901ms step_avg:88.13ms +step:1441/1680 train_time:126990ms step_avg:88.13ms +step:1442/1680 train_time:127078ms step_avg:88.13ms +step:1443/1680 train_time:127167ms step_avg:88.13ms +step:1444/1680 train_time:127256ms step_avg:88.13ms +step:1445/1680 train_time:127346ms step_avg:88.13ms +step:1446/1680 train_time:127435ms step_avg:88.13ms +step:1447/1680 train_time:127524ms step_avg:88.13ms +step:1448/1680 train_time:127614ms step_avg:88.13ms +step:1449/1680 train_time:127704ms step_avg:88.13ms +step:1450/1680 train_time:127793ms step_avg:88.13ms +step:1451/1680 train_time:127882ms step_avg:88.13ms +step:1452/1680 train_time:127970ms step_avg:88.13ms +step:1453/1680 train_time:128059ms step_avg:88.13ms +step:1454/1680 train_time:128148ms step_avg:88.14ms +step:1455/1680 train_time:128237ms step_avg:88.14ms +step:1456/1680 train_time:128327ms step_avg:88.14ms +step:1457/1680 train_time:128415ms step_avg:88.14ms +step:1458/1680 train_time:128505ms step_avg:88.14ms +step:1459/1680 train_time:128595ms step_avg:88.14ms +step:1460/1680 train_time:128684ms step_avg:88.14ms +step:1461/1680 train_time:128774ms step_avg:88.14ms +step:1462/1680 train_time:128863ms step_avg:88.14ms +step:1463/1680 train_time:128952ms step_avg:88.14ms +step:1464/1680 train_time:129041ms step_avg:88.14ms +step:1465/1680 train_time:129129ms step_avg:88.14ms +step:1466/1680 train_time:129218ms step_avg:88.14ms +step:1467/1680 train_time:129307ms step_avg:88.14ms +step:1468/1680 train_time:129396ms step_avg:88.14ms +step:1469/1680 train_time:129486ms step_avg:88.15ms +step:1470/1680 train_time:129574ms step_avg:88.15ms +step:1471/1680 train_time:129664ms step_avg:88.15ms +step:1472/1680 train_time:129753ms step_avg:88.15ms +step:1473/1680 train_time:129843ms step_avg:88.15ms +step:1474/1680 train_time:129932ms step_avg:88.15ms +step:1475/1680 train_time:130021ms step_avg:88.15ms +step:1476/1680 train_time:130109ms step_avg:88.15ms +step:1477/1680 train_time:130198ms step_avg:88.15ms +step:1478/1680 train_time:130287ms step_avg:88.15ms +step:1479/1680 train_time:130377ms step_avg:88.15ms +step:1480/1680 train_time:130467ms step_avg:88.15ms +step:1481/1680 train_time:130556ms step_avg:88.15ms +step:1482/1680 train_time:130645ms step_avg:88.15ms +step:1483/1680 train_time:130735ms step_avg:88.16ms +step:1484/1680 train_time:130824ms step_avg:88.16ms +step:1485/1680 train_time:130914ms step_avg:88.16ms +step:1486/1680 train_time:131002ms step_avg:88.16ms +step:1487/1680 train_time:131091ms step_avg:88.16ms +step:1488/1680 train_time:131180ms step_avg:88.16ms +step:1489/1680 train_time:131269ms step_avg:88.16ms +step:1490/1680 train_time:131358ms step_avg:88.16ms +step:1491/1680 train_time:131447ms step_avg:88.16ms +step:1492/1680 train_time:131537ms step_avg:88.16ms +step:1493/1680 train_time:131626ms step_avg:88.16ms +step:1494/1680 train_time:131714ms step_avg:88.16ms +step:1495/1680 train_time:131803ms step_avg:88.16ms +step:1496/1680 train_time:131892ms step_avg:88.16ms +step:1497/1680 train_time:131981ms step_avg:88.16ms +step:1498/1680 train_time:132070ms step_avg:88.16ms +step:1499/1680 train_time:132158ms step_avg:88.16ms +step:1500/1680 train_time:132248ms step_avg:88.17ms +step:1500/1680 val_loss:3.3149 train_time:132337ms step_avg:88.22ms +step:1501/1680 train_time:132357ms step_avg:88.18ms +step:1502/1680 train_time:132432ms step_avg:88.17ms +step:1503/1680 train_time:132525ms step_avg:88.17ms +step:1504/1680 train_time:132615ms step_avg:88.17ms +step:1505/1680 train_time:132703ms step_avg:88.17ms +step:1506/1680 train_time:132791ms step_avg:88.17ms +step:1507/1680 train_time:132879ms step_avg:88.17ms +step:1508/1680 train_time:132967ms step_avg:88.17ms +step:1509/1680 train_time:133054ms step_avg:88.17ms +step:1510/1680 train_time:133143ms step_avg:88.17ms +step:1511/1680 train_time:133231ms step_avg:88.17ms +step:1512/1680 train_time:133321ms step_avg:88.18ms +step:1513/1680 train_time:133413ms step_avg:88.18ms +step:1514/1680 train_time:133504ms step_avg:88.18ms +step:1515/1680 train_time:133595ms step_avg:88.18ms +step:1516/1680 train_time:133685ms step_avg:88.18ms +step:1517/1680 train_time:133774ms step_avg:88.18ms +step:1518/1680 train_time:133862ms step_avg:88.18ms +step:1519/1680 train_time:133950ms step_avg:88.18ms +step:1520/1680 train_time:134039ms step_avg:88.18ms +step:1521/1680 train_time:134127ms step_avg:88.18ms +step:1522/1680 train_time:134216ms step_avg:88.18ms +step:1523/1680 train_time:134305ms step_avg:88.18ms +step:1524/1680 train_time:134395ms step_avg:88.19ms +step:1525/1680 train_time:134486ms step_avg:88.19ms +step:1526/1680 train_time:134576ms step_avg:88.19ms +step:1527/1680 train_time:134666ms step_avg:88.19ms +step:1528/1680 train_time:134756ms step_avg:88.19ms +step:1529/1680 train_time:134843ms step_avg:88.19ms +step:1530/1680 train_time:134932ms step_avg:88.19ms +step:1531/1680 train_time:135020ms step_avg:88.19ms +step:1532/1680 train_time:135108ms step_avg:88.19ms +step:1533/1680 train_time:135196ms step_avg:88.19ms +step:1534/1680 train_time:135286ms step_avg:88.19ms +step:1535/1680 train_time:135375ms step_avg:88.19ms +step:1536/1680 train_time:135465ms step_avg:88.19ms +step:1537/1680 train_time:135555ms step_avg:88.19ms +step:1538/1680 train_time:135644ms step_avg:88.19ms +step:1539/1680 train_time:135734ms step_avg:88.20ms +step:1540/1680 train_time:135823ms step_avg:88.20ms +step:1541/1680 train_time:135911ms step_avg:88.20ms +step:1542/1680 train_time:136000ms step_avg:88.20ms +step:1543/1680 train_time:136088ms step_avg:88.20ms +step:1544/1680 train_time:136176ms step_avg:88.20ms +step:1545/1680 train_time:136266ms step_avg:88.20ms +step:1546/1680 train_time:136357ms step_avg:88.20ms +step:1547/1680 train_time:136446ms step_avg:88.20ms +step:1548/1680 train_time:136536ms step_avg:88.20ms +step:1549/1680 train_time:136625ms step_avg:88.20ms +step:1550/1680 train_time:136715ms step_avg:88.20ms +step:1551/1680 train_time:136804ms step_avg:88.20ms +step:1552/1680 train_time:136893ms step_avg:88.20ms +step:1553/1680 train_time:136981ms step_avg:88.20ms +step:1554/1680 train_time:137070ms step_avg:88.20ms +step:1555/1680 train_time:137158ms step_avg:88.20ms +step:1556/1680 train_time:137248ms step_avg:88.21ms +step:1557/1680 train_time:137337ms step_avg:88.21ms +step:1558/1680 train_time:137426ms step_avg:88.21ms +step:1559/1680 train_time:137516ms step_avg:88.21ms +step:1560/1680 train_time:137606ms step_avg:88.21ms +step:1561/1680 train_time:137695ms step_avg:88.21ms +step:1562/1680 train_time:137785ms step_avg:88.21ms +step:1563/1680 train_time:137874ms step_avg:88.21ms +step:1564/1680 train_time:137962ms step_avg:88.21ms +step:1565/1680 train_time:138051ms step_avg:88.21ms +step:1566/1680 train_time:138140ms step_avg:88.21ms +step:1567/1680 train_time:138229ms step_avg:88.21ms +step:1568/1680 train_time:138319ms step_avg:88.21ms +step:1569/1680 train_time:138408ms step_avg:88.21ms +step:1570/1680 train_time:138497ms step_avg:88.21ms +step:1571/1680 train_time:138586ms step_avg:88.22ms +step:1572/1680 train_time:138675ms step_avg:88.22ms +step:1573/1680 train_time:138763ms step_avg:88.22ms +step:1574/1680 train_time:138853ms step_avg:88.22ms +step:1575/1680 train_time:138942ms step_avg:88.22ms +step:1576/1680 train_time:139031ms step_avg:88.22ms +step:1577/1680 train_time:139120ms step_avg:88.22ms +step:1578/1680 train_time:139208ms step_avg:88.22ms +step:1579/1680 train_time:139298ms step_avg:88.22ms +step:1580/1680 train_time:139387ms step_avg:88.22ms +step:1581/1680 train_time:139477ms step_avg:88.22ms +step:1582/1680 train_time:139566ms step_avg:88.22ms +step:1583/1680 train_time:139655ms step_avg:88.22ms +step:1584/1680 train_time:139744ms step_avg:88.22ms +step:1585/1680 train_time:139834ms step_avg:88.22ms +step:1586/1680 train_time:139923ms step_avg:88.22ms +step:1587/1680 train_time:140012ms step_avg:88.22ms +step:1588/1680 train_time:140101ms step_avg:88.22ms +step:1589/1680 train_time:140189ms step_avg:88.22ms +step:1590/1680 train_time:140278ms step_avg:88.23ms +step:1591/1680 train_time:140367ms step_avg:88.23ms +step:1592/1680 train_time:140457ms step_avg:88.23ms +step:1593/1680 train_time:140546ms step_avg:88.23ms +step:1594/1680 train_time:140636ms step_avg:88.23ms +step:1595/1680 train_time:140725ms step_avg:88.23ms +step:1596/1680 train_time:140813ms step_avg:88.23ms +step:1597/1680 train_time:140903ms step_avg:88.23ms +step:1598/1680 train_time:140992ms step_avg:88.23ms +step:1599/1680 train_time:141081ms step_avg:88.23ms +step:1600/1680 train_time:141170ms step_avg:88.23ms +step:1601/1680 train_time:141259ms step_avg:88.23ms +step:1602/1680 train_time:141348ms step_avg:88.23ms +step:1603/1680 train_time:141437ms step_avg:88.23ms +step:1604/1680 train_time:141527ms step_avg:88.23ms +step:1605/1680 train_time:141616ms step_avg:88.23ms +step:1606/1680 train_time:141704ms step_avg:88.23ms +step:1607/1680 train_time:141793ms step_avg:88.23ms +step:1608/1680 train_time:141883ms step_avg:88.24ms +step:1609/1680 train_time:141973ms step_avg:88.24ms +step:1610/1680 train_time:142061ms step_avg:88.24ms +step:1611/1680 train_time:142150ms step_avg:88.24ms +step:1612/1680 train_time:142240ms step_avg:88.24ms +step:1613/1680 train_time:142328ms step_avg:88.24ms +step:1614/1680 train_time:142418ms step_avg:88.24ms +step:1615/1680 train_time:142507ms step_avg:88.24ms +step:1616/1680 train_time:142596ms step_avg:88.24ms +step:1617/1680 train_time:142685ms step_avg:88.24ms +step:1618/1680 train_time:142774ms step_avg:88.24ms +step:1619/1680 train_time:142863ms step_avg:88.24ms +step:1620/1680 train_time:142952ms step_avg:88.24ms +step:1621/1680 train_time:143042ms step_avg:88.24ms +step:1622/1680 train_time:143131ms step_avg:88.24ms +step:1623/1680 train_time:143220ms step_avg:88.24ms +step:1624/1680 train_time:143309ms step_avg:88.24ms +step:1625/1680 train_time:143399ms step_avg:88.25ms +step:1625/1680 val_loss:3.2908 train_time:143489ms step_avg:88.30ms +step:1626/1680 train_time:143507ms step_avg:88.26ms +step:1627/1680 train_time:143581ms step_avg:88.25ms +step:1628/1680 train_time:143673ms step_avg:88.25ms +step:1629/1680 train_time:143763ms step_avg:88.25ms +step:1630/1680 train_time:143852ms step_avg:88.25ms +step:1631/1680 train_time:143942ms step_avg:88.25ms +step:1632/1680 train_time:144030ms step_avg:88.25ms +step:1633/1680 train_time:144118ms step_avg:88.25ms +step:1634/1680 train_time:144206ms step_avg:88.25ms +step:1635/1680 train_time:144294ms step_avg:88.25ms +step:1636/1680 train_time:144382ms step_avg:88.25ms +step:1637/1680 train_time:144472ms step_avg:88.25ms +step:1638/1680 train_time:144564ms step_avg:88.26ms +step:1639/1680 train_time:144655ms step_avg:88.26ms +step:1640/1680 train_time:144745ms step_avg:88.26ms +step:1641/1680 train_time:144835ms step_avg:88.26ms +step:1642/1680 train_time:144924ms step_avg:88.26ms +step:1643/1680 train_time:145012ms step_avg:88.26ms +step:1644/1680 train_time:145100ms step_avg:88.26ms +step:1645/1680 train_time:145188ms step_avg:88.26ms +step:1646/1680 train_time:145276ms step_avg:88.26ms +step:1647/1680 train_time:145365ms step_avg:88.26ms +step:1648/1680 train_time:145455ms step_avg:88.26ms +step:1649/1680 train_time:145544ms step_avg:88.26ms +step:1650/1680 train_time:145634ms step_avg:88.26ms +step:1651/1680 train_time:145725ms step_avg:88.26ms +step:1652/1680 train_time:145815ms step_avg:88.27ms +step:1653/1680 train_time:145904ms step_avg:88.27ms +step:1654/1680 train_time:145993ms step_avg:88.27ms +step:1655/1680 train_time:146083ms step_avg:88.27ms +step:1656/1680 train_time:146171ms step_avg:88.27ms +step:1657/1680 train_time:146260ms step_avg:88.27ms +step:1658/1680 train_time:146348ms step_avg:88.27ms +step:1659/1680 train_time:146438ms step_avg:88.27ms +step:1660/1680 train_time:146527ms step_avg:88.27ms +step:1661/1680 train_time:146618ms step_avg:88.27ms +step:1662/1680 train_time:146707ms step_avg:88.27ms +step:1663/1680 train_time:146797ms step_avg:88.27ms +step:1664/1680 train_time:146886ms step_avg:88.27ms +step:1665/1680 train_time:146975ms step_avg:88.27ms +step:1666/1680 train_time:147064ms step_avg:88.27ms +step:1667/1680 train_time:147153ms step_avg:88.27ms +step:1668/1680 train_time:147241ms step_avg:88.27ms +step:1669/1680 train_time:147330ms step_avg:88.27ms +step:1670/1680 train_time:147419ms step_avg:88.27ms +step:1671/1680 train_time:147508ms step_avg:88.28ms +step:1672/1680 train_time:147597ms step_avg:88.28ms +step:1673/1680 train_time:147686ms step_avg:88.28ms +step:1674/1680 train_time:147776ms step_avg:88.28ms +step:1675/1680 train_time:147865ms step_avg:88.28ms +step:1676/1680 train_time:147954ms step_avg:88.28ms +step:1677/1680 train_time:148044ms step_avg:88.28ms +step:1678/1680 train_time:148132ms step_avg:88.28ms +step:1679/1680 train_time:148221ms step_avg:88.28ms +step:1680/1680 train_time:148310ms step_avg:88.28ms +step:1680/1680 val_loss:3.2803 train_time:148400ms step_avg:88.33ms +peak memory allocated: 30760 MiB reserved: 45934 MiB diff --git a/records/092725_BF16CE/cb4e8b78-b9ab-4c83-9ff9-3fdfb8bb1b9b.txt b/records/092725_BF16CE/cb4e8b78-b9ab-4c83-9ff9-3fdfb8bb1b9b.txt new file mode 100644 index 000000000..d3d205def --- /dev/null +++ b/records/092725_BF16CE/cb4e8b78-b9ab-4c83-9ff9-3fdfb8bb1b9b.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:15:45 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 28C P0 123W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 171321 C /usr/bin/python 0MiB | +| 0 N/A N/A 171322 C /usr/bin/python 0MiB | +| 0 N/A N/A 171323 C /usr/bin/python 0MiB | +| 0 N/A N/A 171324 C /usr/bin/python 0MiB | +| 0 N/A N/A 171325 C /usr/bin/python 0MiB | +| 0 N/A N/A 171326 C /usr/bin/python 0MiB | +| 0 N/A N/A 171327 C /usr/bin/python 0MiB | +| 0 N/A N/A 171328 C /usr/bin/python 0MiB | +| 1 N/A N/A 171322 C /usr/bin/python 0MiB | +| 2 N/A N/A 171323 C /usr/bin/python 0MiB | +| 3 N/A N/A 171324 C /usr/bin/python 0MiB | +| 4 N/A N/A 171325 C /usr/bin/python 0MiB | +| 5 N/A N/A 171326 C /usr/bin/python 0MiB | +| 6 N/A N/A 171327 C /usr/bin/python 0MiB | +| 7 N/A N/A 171328 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:152ms step_avg:152.04ms +step:2/1680 train_time:172ms step_avg:85.87ms +step:3/1680 train_time:237ms step_avg:79.00ms +step:4/1680 train_time:323ms step_avg:80.87ms +step:5/1680 train_time:408ms step_avg:81.69ms +step:6/1680 train_time:495ms step_avg:82.43ms +step:7/1680 train_time:581ms step_avg:82.95ms +step:8/1680 train_time:667ms step_avg:83.36ms +step:9/1680 train_time:753ms step_avg:83.64ms +step:10/1680 train_time:839ms step_avg:83.90ms +step:11/1680 train_time:925ms step_avg:84.12ms +step:12/1680 train_time:1014ms step_avg:84.46ms +step:13/1680 train_time:1106ms step_avg:85.06ms +step:14/1680 train_time:1195ms step_avg:85.39ms +step:15/1680 train_time:1284ms step_avg:85.60ms +step:16/1680 train_time:1371ms step_avg:85.70ms +step:17/1680 train_time:1458ms step_avg:85.77ms +step:18/1680 train_time:1545ms step_avg:85.83ms +step:19/1680 train_time:1631ms step_avg:85.86ms +step:20/1680 train_time:1718ms step_avg:85.91ms +step:21/1680 train_time:1805ms step_avg:85.93ms +step:22/1680 train_time:1891ms step_avg:85.97ms +step:23/1680 train_time:1978ms step_avg:86.00ms +step:24/1680 train_time:2067ms step_avg:86.12ms +step:25/1680 train_time:2155ms step_avg:86.20ms +step:26/1680 train_time:2244ms step_avg:86.29ms +step:27/1680 train_time:2332ms step_avg:86.36ms +step:28/1680 train_time:2419ms step_avg:86.40ms +step:29/1680 train_time:2507ms step_avg:86.44ms +step:30/1680 train_time:2594ms step_avg:86.45ms +step:31/1680 train_time:2681ms step_avg:86.48ms +step:32/1680 train_time:2768ms step_avg:86.49ms +step:33/1680 train_time:2854ms step_avg:86.49ms +step:34/1680 train_time:2941ms step_avg:86.51ms +step:35/1680 train_time:3029ms step_avg:86.55ms +step:36/1680 train_time:3117ms step_avg:86.58ms +step:37/1680 train_time:3205ms step_avg:86.63ms +step:38/1680 train_time:3294ms step_avg:86.67ms +step:39/1680 train_time:3381ms step_avg:86.69ms +step:40/1680 train_time:3468ms step_avg:86.71ms +step:41/1680 train_time:3555ms step_avg:86.71ms +step:42/1680 train_time:3642ms step_avg:86.71ms +step:43/1680 train_time:3729ms step_avg:86.73ms +step:44/1680 train_time:3816ms step_avg:86.72ms +step:45/1680 train_time:3904ms step_avg:86.75ms +step:46/1680 train_time:3991ms step_avg:86.77ms +step:47/1680 train_time:4078ms step_avg:86.77ms +step:48/1680 train_time:4167ms step_avg:86.81ms +step:49/1680 train_time:4254ms step_avg:86.82ms +step:50/1680 train_time:4342ms step_avg:86.84ms +step:51/1680 train_time:4429ms step_avg:86.84ms +step:52/1680 train_time:4516ms step_avg:86.84ms +step:53/1680 train_time:4603ms step_avg:86.85ms +step:54/1680 train_time:4690ms step_avg:86.85ms +step:55/1680 train_time:4777ms step_avg:86.85ms +step:56/1680 train_time:4864ms step_avg:86.85ms +step:57/1680 train_time:4951ms step_avg:86.86ms +step:58/1680 train_time:5039ms step_avg:86.87ms +step:59/1680 train_time:5126ms step_avg:86.88ms +step:60/1680 train_time:5213ms step_avg:86.88ms +step:61/1680 train_time:5300ms step_avg:86.88ms +step:62/1680 train_time:5387ms step_avg:86.89ms +step:63/1680 train_time:5474ms step_avg:86.89ms +step:64/1680 train_time:5561ms step_avg:86.89ms +step:65/1680 train_time:5648ms step_avg:86.90ms +step:66/1680 train_time:5735ms step_avg:86.89ms +step:67/1680 train_time:5822ms step_avg:86.90ms +step:68/1680 train_time:5910ms step_avg:86.91ms +step:69/1680 train_time:5996ms step_avg:86.90ms +step:70/1680 train_time:6084ms step_avg:86.91ms +step:71/1680 train_time:6171ms step_avg:86.91ms +step:72/1680 train_time:6258ms step_avg:86.91ms +step:73/1680 train_time:6345ms step_avg:86.92ms +step:74/1680 train_time:6432ms step_avg:86.92ms +step:75/1680 train_time:6519ms step_avg:86.92ms +step:76/1680 train_time:6606ms step_avg:86.92ms +step:77/1680 train_time:6693ms step_avg:86.92ms +step:78/1680 train_time:6780ms step_avg:86.92ms +step:79/1680 train_time:6867ms step_avg:86.92ms +step:80/1680 train_time:6953ms step_avg:86.92ms +step:81/1680 train_time:7041ms step_avg:86.92ms +step:82/1680 train_time:7128ms step_avg:86.93ms +step:83/1680 train_time:7215ms step_avg:86.93ms +step:84/1680 train_time:7304ms step_avg:86.95ms +step:85/1680 train_time:7391ms step_avg:86.96ms +step:86/1680 train_time:7478ms step_avg:86.95ms +step:87/1680 train_time:7565ms step_avg:86.96ms +step:88/1680 train_time:7653ms step_avg:86.96ms +step:89/1680 train_time:7740ms step_avg:86.96ms +step:90/1680 train_time:7827ms step_avg:86.96ms +step:91/1680 train_time:7914ms step_avg:86.96ms +step:92/1680 train_time:8001ms step_avg:86.97ms +step:93/1680 train_time:8088ms step_avg:86.97ms +step:94/1680 train_time:8175ms step_avg:86.96ms +step:95/1680 train_time:8262ms step_avg:86.97ms +step:96/1680 train_time:8350ms step_avg:86.98ms +step:97/1680 train_time:8436ms step_avg:86.97ms +step:98/1680 train_time:8524ms step_avg:86.97ms +step:99/1680 train_time:8611ms step_avg:86.98ms +step:100/1680 train_time:8697ms step_avg:86.97ms +step:101/1680 train_time:8784ms step_avg:86.97ms +step:102/1680 train_time:8871ms step_avg:86.97ms +step:103/1680 train_time:8958ms step_avg:86.97ms +step:104/1680 train_time:9045ms step_avg:86.97ms +step:105/1680 train_time:9133ms step_avg:86.98ms +step:106/1680 train_time:9220ms step_avg:86.98ms +step:107/1680 train_time:9308ms step_avg:86.99ms +step:108/1680 train_time:9395ms step_avg:86.99ms +step:109/1680 train_time:9482ms step_avg:86.99ms +step:110/1680 train_time:9568ms step_avg:86.99ms +step:111/1680 train_time:9655ms step_avg:86.98ms +step:112/1680 train_time:9742ms step_avg:86.98ms +step:113/1680 train_time:9830ms step_avg:86.99ms +step:114/1680 train_time:9916ms step_avg:86.99ms +step:115/1680 train_time:10003ms step_avg:86.99ms +step:116/1680 train_time:10090ms step_avg:86.98ms +step:117/1680 train_time:10177ms step_avg:86.98ms +step:118/1680 train_time:10265ms step_avg:86.99ms +step:119/1680 train_time:10352ms step_avg:86.99ms +step:120/1680 train_time:10440ms step_avg:87.00ms +step:121/1680 train_time:10527ms step_avg:87.00ms +step:122/1680 train_time:10614ms step_avg:87.00ms +step:123/1680 train_time:10700ms step_avg:86.99ms +step:124/1680 train_time:10787ms step_avg:86.99ms +step:125/1680 train_time:10874ms step_avg:86.99ms +step:125/1680 val_loss:4.2881 train_time:10962ms step_avg:87.70ms +step:126/1680 train_time:10985ms step_avg:87.18ms +step:127/1680 train_time:11051ms step_avg:87.02ms +step:128/1680 train_time:11147ms step_avg:87.08ms +step:129/1680 train_time:11237ms step_avg:87.11ms +step:130/1680 train_time:11325ms step_avg:87.11ms +step:131/1680 train_time:11411ms step_avg:87.11ms +step:132/1680 train_time:11497ms step_avg:87.10ms +step:133/1680 train_time:11583ms step_avg:87.09ms +step:134/1680 train_time:11670ms step_avg:87.09ms +step:135/1680 train_time:11756ms step_avg:87.08ms +step:136/1680 train_time:11842ms step_avg:87.07ms +step:137/1680 train_time:11928ms step_avg:87.07ms +step:138/1680 train_time:12015ms step_avg:87.06ms +step:139/1680 train_time:12104ms step_avg:87.08ms +step:140/1680 train_time:12194ms step_avg:87.10ms +step:141/1680 train_time:12282ms step_avg:87.10ms +step:142/1680 train_time:12369ms step_avg:87.10ms +step:143/1680 train_time:12456ms step_avg:87.11ms +step:144/1680 train_time:12544ms step_avg:87.11ms +step:145/1680 train_time:12630ms step_avg:87.10ms +step:146/1680 train_time:12716ms step_avg:87.09ms +step:147/1680 train_time:12802ms step_avg:87.09ms +step:148/1680 train_time:12888ms step_avg:87.08ms +step:149/1680 train_time:12975ms step_avg:87.08ms +step:150/1680 train_time:13063ms step_avg:87.09ms +step:151/1680 train_time:13151ms step_avg:87.09ms +step:152/1680 train_time:13239ms step_avg:87.10ms +step:153/1680 train_time:13327ms step_avg:87.10ms +step:154/1680 train_time:13414ms step_avg:87.10ms +step:155/1680 train_time:13501ms step_avg:87.11ms +step:156/1680 train_time:13588ms step_avg:87.10ms +step:157/1680 train_time:13675ms step_avg:87.10ms +step:158/1680 train_time:13762ms step_avg:87.10ms +step:159/1680 train_time:13848ms step_avg:87.09ms +step:160/1680 train_time:13935ms step_avg:87.09ms +step:161/1680 train_time:14022ms step_avg:87.09ms +step:162/1680 train_time:14109ms step_avg:87.09ms +step:163/1680 train_time:14197ms step_avg:87.10ms +step:164/1680 train_time:14284ms step_avg:87.10ms +step:165/1680 train_time:14372ms step_avg:87.10ms +step:166/1680 train_time:14459ms step_avg:87.10ms +step:167/1680 train_time:14547ms step_avg:87.11ms +step:168/1680 train_time:14635ms step_avg:87.11ms +step:169/1680 train_time:14721ms step_avg:87.11ms +step:170/1680 train_time:14807ms step_avg:87.10ms +step:171/1680 train_time:14894ms step_avg:87.10ms +step:172/1680 train_time:14981ms step_avg:87.10ms +step:173/1680 train_time:15068ms step_avg:87.10ms +step:174/1680 train_time:15156ms step_avg:87.10ms +step:175/1680 train_time:15243ms step_avg:87.10ms +step:176/1680 train_time:15331ms step_avg:87.11ms +step:177/1680 train_time:15418ms step_avg:87.11ms +step:178/1680 train_time:15506ms step_avg:87.11ms +step:179/1680 train_time:15592ms step_avg:87.11ms +step:180/1680 train_time:15679ms step_avg:87.11ms +step:181/1680 train_time:15766ms step_avg:87.11ms +step:182/1680 train_time:15853ms step_avg:87.10ms +step:183/1680 train_time:15939ms step_avg:87.10ms +step:184/1680 train_time:16026ms step_avg:87.10ms +step:185/1680 train_time:16113ms step_avg:87.10ms +step:186/1680 train_time:16200ms step_avg:87.10ms +step:187/1680 train_time:16288ms step_avg:87.10ms +step:188/1680 train_time:16375ms step_avg:87.10ms +step:189/1680 train_time:16463ms step_avg:87.10ms +step:190/1680 train_time:16549ms step_avg:87.10ms +step:191/1680 train_time:16636ms step_avg:87.10ms +step:192/1680 train_time:16723ms step_avg:87.10ms +step:193/1680 train_time:16811ms step_avg:87.10ms +step:194/1680 train_time:16897ms step_avg:87.10ms +step:195/1680 train_time:16985ms step_avg:87.10ms +step:196/1680 train_time:17072ms step_avg:87.10ms +step:197/1680 train_time:17159ms step_avg:87.10ms +step:198/1680 train_time:17248ms step_avg:87.11ms +step:199/1680 train_time:17334ms step_avg:87.10ms +step:200/1680 train_time:17421ms step_avg:87.11ms +step:201/1680 train_time:17508ms step_avg:87.11ms +step:202/1680 train_time:17595ms step_avg:87.10ms +step:203/1680 train_time:17683ms step_avg:87.11ms +step:204/1680 train_time:17770ms step_avg:87.11ms +step:205/1680 train_time:17858ms step_avg:87.11ms +step:206/1680 train_time:17945ms step_avg:87.11ms +step:207/1680 train_time:18032ms step_avg:87.11ms +step:208/1680 train_time:18119ms step_avg:87.11ms +step:209/1680 train_time:18206ms step_avg:87.11ms +step:210/1680 train_time:18294ms step_avg:87.11ms +step:211/1680 train_time:18382ms step_avg:87.12ms +step:212/1680 train_time:18469ms step_avg:87.12ms +step:213/1680 train_time:18556ms step_avg:87.12ms +step:214/1680 train_time:18643ms step_avg:87.12ms +step:215/1680 train_time:18729ms step_avg:87.11ms +step:216/1680 train_time:18817ms step_avg:87.12ms +step:217/1680 train_time:18904ms step_avg:87.11ms +step:218/1680 train_time:18991ms step_avg:87.12ms +step:219/1680 train_time:19078ms step_avg:87.11ms +step:220/1680 train_time:19165ms step_avg:87.11ms +step:221/1680 train_time:19251ms step_avg:87.11ms +step:222/1680 train_time:19339ms step_avg:87.11ms +step:223/1680 train_time:19426ms step_avg:87.11ms +step:224/1680 train_time:19513ms step_avg:87.11ms +step:225/1680 train_time:19600ms step_avg:87.11ms +step:226/1680 train_time:19687ms step_avg:87.11ms +step:227/1680 train_time:19774ms step_avg:87.11ms +step:228/1680 train_time:19861ms step_avg:87.11ms +step:229/1680 train_time:19949ms step_avg:87.11ms +step:230/1680 train_time:20035ms step_avg:87.11ms +step:231/1680 train_time:20122ms step_avg:87.11ms +step:232/1680 train_time:20209ms step_avg:87.11ms +step:233/1680 train_time:20296ms step_avg:87.11ms +step:234/1680 train_time:20383ms step_avg:87.11ms +step:235/1680 train_time:20470ms step_avg:87.11ms +step:236/1680 train_time:20557ms step_avg:87.11ms +step:237/1680 train_time:20645ms step_avg:87.11ms +step:238/1680 train_time:20732ms step_avg:87.11ms +step:239/1680 train_time:20819ms step_avg:87.11ms +step:240/1680 train_time:20906ms step_avg:87.11ms +step:241/1680 train_time:20993ms step_avg:87.11ms +step:242/1680 train_time:21080ms step_avg:87.11ms +step:243/1680 train_time:21168ms step_avg:87.11ms +step:244/1680 train_time:21255ms step_avg:87.11ms +step:245/1680 train_time:21342ms step_avg:87.11ms +step:246/1680 train_time:21430ms step_avg:87.11ms +step:247/1680 train_time:21517ms step_avg:87.11ms +step:248/1680 train_time:21603ms step_avg:87.11ms +step:249/1680 train_time:21691ms step_avg:87.11ms +step:250/1680 train_time:21778ms step_avg:87.11ms +step:250/1680 val_loss:3.9626 train_time:21867ms step_avg:87.47ms +step:251/1680 train_time:21886ms step_avg:87.19ms +step:252/1680 train_time:21957ms step_avg:87.13ms +step:253/1680 train_time:22048ms step_avg:87.15ms +step:254/1680 train_time:22135ms step_avg:87.15ms +step:255/1680 train_time:22222ms step_avg:87.14ms +step:256/1680 train_time:22309ms step_avg:87.15ms +step:257/1680 train_time:22395ms step_avg:87.14ms +step:258/1680 train_time:22481ms step_avg:87.13ms +step:259/1680 train_time:22566ms step_avg:87.13ms +step:260/1680 train_time:22653ms step_avg:87.13ms +step:261/1680 train_time:22739ms step_avg:87.12ms +step:262/1680 train_time:22826ms step_avg:87.12ms +step:263/1680 train_time:22914ms step_avg:87.13ms +step:264/1680 train_time:23003ms step_avg:87.13ms +step:265/1680 train_time:23091ms step_avg:87.14ms +step:266/1680 train_time:23178ms step_avg:87.14ms +step:267/1680 train_time:23266ms step_avg:87.14ms +step:268/1680 train_time:23352ms step_avg:87.14ms +step:269/1680 train_time:23439ms step_avg:87.13ms +step:270/1680 train_time:23525ms step_avg:87.13ms +step:271/1680 train_time:23611ms step_avg:87.13ms +step:272/1680 train_time:23697ms step_avg:87.12ms +step:273/1680 train_time:23784ms step_avg:87.12ms +step:274/1680 train_time:23871ms step_avg:87.12ms +step:275/1680 train_time:23959ms step_avg:87.12ms +step:276/1680 train_time:24047ms step_avg:87.13ms +step:277/1680 train_time:24134ms step_avg:87.13ms +step:278/1680 train_time:24222ms step_avg:87.13ms +step:279/1680 train_time:24310ms step_avg:87.13ms +step:280/1680 train_time:24396ms step_avg:87.13ms +step:281/1680 train_time:24483ms step_avg:87.13ms +step:282/1680 train_time:24569ms step_avg:87.12ms +step:283/1680 train_time:24655ms step_avg:87.12ms +step:284/1680 train_time:24742ms step_avg:87.12ms +step:285/1680 train_time:24829ms step_avg:87.12ms +step:286/1680 train_time:24916ms step_avg:87.12ms +step:287/1680 train_time:25004ms step_avg:87.12ms +step:288/1680 train_time:25091ms step_avg:87.12ms +step:289/1680 train_time:25179ms step_avg:87.12ms +step:290/1680 train_time:25266ms step_avg:87.12ms +step:291/1680 train_time:25353ms step_avg:87.12ms +step:292/1680 train_time:25440ms step_avg:87.12ms +step:293/1680 train_time:25527ms step_avg:87.12ms +step:294/1680 train_time:25613ms step_avg:87.12ms +step:295/1680 train_time:25700ms step_avg:87.12ms +step:296/1680 train_time:25787ms step_avg:87.12ms +step:297/1680 train_time:25874ms step_avg:87.12ms +step:298/1680 train_time:25961ms step_avg:87.12ms +step:299/1680 train_time:26049ms step_avg:87.12ms +step:300/1680 train_time:26136ms step_avg:87.12ms +step:301/1680 train_time:26224ms step_avg:87.12ms +step:302/1680 train_time:26311ms step_avg:87.12ms +step:303/1680 train_time:26397ms step_avg:87.12ms +step:304/1680 train_time:26484ms step_avg:87.12ms +step:305/1680 train_time:26570ms step_avg:87.12ms +step:306/1680 train_time:26657ms step_avg:87.11ms +step:307/1680 train_time:26744ms step_avg:87.11ms +step:308/1680 train_time:26831ms step_avg:87.11ms +step:309/1680 train_time:26919ms step_avg:87.12ms +step:310/1680 train_time:27005ms step_avg:87.11ms +step:311/1680 train_time:27093ms step_avg:87.12ms +step:312/1680 train_time:27180ms step_avg:87.11ms +step:313/1680 train_time:27267ms step_avg:87.11ms +step:314/1680 train_time:27353ms step_avg:87.11ms +step:315/1680 train_time:27440ms step_avg:87.11ms +step:316/1680 train_time:27527ms step_avg:87.11ms +step:317/1680 train_time:27614ms step_avg:87.11ms +step:318/1680 train_time:27701ms step_avg:87.11ms +step:319/1680 train_time:27788ms step_avg:87.11ms +step:320/1680 train_time:27876ms step_avg:87.11ms +step:321/1680 train_time:27963ms step_avg:87.11ms +step:322/1680 train_time:28051ms step_avg:87.11ms +step:323/1680 train_time:28138ms step_avg:87.12ms +step:324/1680 train_time:28225ms step_avg:87.11ms +step:325/1680 train_time:28312ms step_avg:87.11ms +step:326/1680 train_time:28398ms step_avg:87.11ms +step:327/1680 train_time:28486ms step_avg:87.11ms +step:328/1680 train_time:28573ms step_avg:87.11ms +step:329/1680 train_time:28660ms step_avg:87.11ms +step:330/1680 train_time:28746ms step_avg:87.11ms +step:331/1680 train_time:28833ms step_avg:87.11ms +step:332/1680 train_time:28920ms step_avg:87.11ms +step:333/1680 train_time:29007ms step_avg:87.11ms +step:334/1680 train_time:29094ms step_avg:87.11ms +step:335/1680 train_time:29181ms step_avg:87.11ms +step:336/1680 train_time:29268ms step_avg:87.11ms +step:337/1680 train_time:29355ms step_avg:87.11ms +step:338/1680 train_time:29442ms step_avg:87.11ms +step:339/1680 train_time:29528ms step_avg:87.10ms +step:340/1680 train_time:29615ms step_avg:87.10ms +step:341/1680 train_time:29702ms step_avg:87.10ms +step:342/1680 train_time:29789ms step_avg:87.10ms +step:343/1680 train_time:29875ms step_avg:87.10ms +step:344/1680 train_time:29963ms step_avg:87.10ms +step:345/1680 train_time:30051ms step_avg:87.10ms +step:346/1680 train_time:30137ms step_avg:87.10ms +step:347/1680 train_time:30224ms step_avg:87.10ms +step:348/1680 train_time:30311ms step_avg:87.10ms +step:349/1680 train_time:30398ms step_avg:87.10ms +step:350/1680 train_time:30485ms step_avg:87.10ms +step:351/1680 train_time:30572ms step_avg:87.10ms +step:352/1680 train_time:30659ms step_avg:87.10ms +step:353/1680 train_time:30747ms step_avg:87.10ms +step:354/1680 train_time:30834ms step_avg:87.10ms +step:355/1680 train_time:30922ms step_avg:87.10ms +step:356/1680 train_time:31009ms step_avg:87.10ms +step:357/1680 train_time:31096ms step_avg:87.10ms +step:358/1680 train_time:31184ms step_avg:87.11ms +step:359/1680 train_time:31270ms step_avg:87.10ms +step:360/1680 train_time:31357ms step_avg:87.10ms +step:361/1680 train_time:31444ms step_avg:87.10ms +step:362/1680 train_time:31531ms step_avg:87.10ms +step:363/1680 train_time:31618ms step_avg:87.10ms +step:364/1680 train_time:31705ms step_avg:87.10ms +step:365/1680 train_time:31792ms step_avg:87.10ms +step:366/1680 train_time:31879ms step_avg:87.10ms +step:367/1680 train_time:31966ms step_avg:87.10ms +step:368/1680 train_time:32053ms step_avg:87.10ms +step:369/1680 train_time:32140ms step_avg:87.10ms +step:370/1680 train_time:32227ms step_avg:87.10ms +step:371/1680 train_time:32314ms step_avg:87.10ms +step:372/1680 train_time:32401ms step_avg:87.10ms +step:373/1680 train_time:32488ms step_avg:87.10ms +step:374/1680 train_time:32575ms step_avg:87.10ms +step:375/1680 train_time:32662ms step_avg:87.10ms +step:375/1680 val_loss:3.8158 train_time:32751ms step_avg:87.34ms +step:376/1680 train_time:32770ms step_avg:87.15ms +step:377/1680 train_time:32840ms step_avg:87.11ms +step:378/1680 train_time:32932ms step_avg:87.12ms +step:379/1680 train_time:33021ms step_avg:87.13ms +step:380/1680 train_time:33108ms step_avg:87.13ms +step:381/1680 train_time:33194ms step_avg:87.12ms +step:382/1680 train_time:33280ms step_avg:87.12ms +step:383/1680 train_time:33366ms step_avg:87.12ms +step:384/1680 train_time:33451ms step_avg:87.11ms +step:385/1680 train_time:33537ms step_avg:87.11ms +step:386/1680 train_time:33623ms step_avg:87.11ms +step:387/1680 train_time:33710ms step_avg:87.11ms +step:388/1680 train_time:33799ms step_avg:87.11ms +step:389/1680 train_time:33888ms step_avg:87.12ms +step:390/1680 train_time:33978ms step_avg:87.12ms +step:391/1680 train_time:34066ms step_avg:87.12ms +step:392/1680 train_time:34153ms step_avg:87.13ms +step:393/1680 train_time:34241ms step_avg:87.13ms +step:394/1680 train_time:34327ms step_avg:87.12ms +step:395/1680 train_time:34414ms step_avg:87.12ms +step:396/1680 train_time:34500ms step_avg:87.12ms +step:397/1680 train_time:34586ms step_avg:87.12ms +step:398/1680 train_time:34672ms step_avg:87.12ms +step:399/1680 train_time:34760ms step_avg:87.12ms +step:400/1680 train_time:34848ms step_avg:87.12ms +step:401/1680 train_time:34937ms step_avg:87.12ms +step:402/1680 train_time:35024ms step_avg:87.12ms +step:403/1680 train_time:35112ms step_avg:87.13ms +step:404/1680 train_time:35199ms step_avg:87.13ms +step:405/1680 train_time:35286ms step_avg:87.13ms +step:406/1680 train_time:35373ms step_avg:87.13ms +step:407/1680 train_time:35460ms step_avg:87.12ms +step:408/1680 train_time:35546ms step_avg:87.12ms +step:409/1680 train_time:35632ms step_avg:87.12ms +step:410/1680 train_time:35719ms step_avg:87.12ms +step:411/1680 train_time:35807ms step_avg:87.12ms +step:412/1680 train_time:35894ms step_avg:87.12ms +step:413/1680 train_time:35981ms step_avg:87.12ms +step:414/1680 train_time:36069ms step_avg:87.12ms +step:415/1680 train_time:36156ms step_avg:87.12ms +step:416/1680 train_time:36243ms step_avg:87.12ms +step:417/1680 train_time:36330ms step_avg:87.12ms +step:418/1680 train_time:36417ms step_avg:87.12ms +step:419/1680 train_time:36504ms step_avg:87.12ms +step:420/1680 train_time:36590ms step_avg:87.12ms +step:421/1680 train_time:36677ms step_avg:87.12ms +step:422/1680 train_time:36764ms step_avg:87.12ms +step:423/1680 train_time:36852ms step_avg:87.12ms +step:424/1680 train_time:36939ms step_avg:87.12ms +step:425/1680 train_time:37026ms step_avg:87.12ms +step:426/1680 train_time:37113ms step_avg:87.12ms +step:427/1680 train_time:37200ms step_avg:87.12ms +step:428/1680 train_time:37287ms step_avg:87.12ms +step:429/1680 train_time:37375ms step_avg:87.12ms +step:430/1680 train_time:37462ms step_avg:87.12ms +step:431/1680 train_time:37549ms step_avg:87.12ms +step:432/1680 train_time:37636ms step_avg:87.12ms +step:433/1680 train_time:37723ms step_avg:87.12ms +step:434/1680 train_time:37810ms step_avg:87.12ms +step:435/1680 train_time:37898ms step_avg:87.12ms +step:436/1680 train_time:37985ms step_avg:87.12ms +step:437/1680 train_time:38073ms step_avg:87.12ms +step:438/1680 train_time:38160ms step_avg:87.12ms +step:439/1680 train_time:38247ms step_avg:87.12ms +step:440/1680 train_time:38334ms step_avg:87.12ms +step:441/1680 train_time:38420ms step_avg:87.12ms +step:442/1680 train_time:38507ms step_avg:87.12ms +step:443/1680 train_time:38595ms step_avg:87.12ms +step:444/1680 train_time:38681ms step_avg:87.12ms +step:445/1680 train_time:38769ms step_avg:87.12ms +step:446/1680 train_time:38856ms step_avg:87.12ms +step:447/1680 train_time:38942ms step_avg:87.12ms +step:448/1680 train_time:39030ms step_avg:87.12ms +step:449/1680 train_time:39118ms step_avg:87.12ms +step:450/1680 train_time:39204ms step_avg:87.12ms +step:451/1680 train_time:39292ms step_avg:87.12ms +step:452/1680 train_time:39379ms step_avg:87.12ms +step:453/1680 train_time:39465ms step_avg:87.12ms +step:454/1680 train_time:39552ms step_avg:87.12ms +step:455/1680 train_time:39639ms step_avg:87.12ms +step:456/1680 train_time:39726ms step_avg:87.12ms +step:457/1680 train_time:39813ms step_avg:87.12ms +step:458/1680 train_time:39901ms step_avg:87.12ms +step:459/1680 train_time:39988ms step_avg:87.12ms +step:460/1680 train_time:40076ms step_avg:87.12ms +step:461/1680 train_time:40162ms step_avg:87.12ms +step:462/1680 train_time:40250ms step_avg:87.12ms +step:463/1680 train_time:40337ms step_avg:87.12ms +step:464/1680 train_time:40424ms step_avg:87.12ms +step:465/1680 train_time:40511ms step_avg:87.12ms +step:466/1680 train_time:40598ms step_avg:87.12ms +step:467/1680 train_time:40685ms step_avg:87.12ms +step:468/1680 train_time:40772ms step_avg:87.12ms +step:469/1680 train_time:40859ms step_avg:87.12ms +step:470/1680 train_time:40946ms step_avg:87.12ms +step:471/1680 train_time:41033ms step_avg:87.12ms +step:472/1680 train_time:41120ms step_avg:87.12ms +step:473/1680 train_time:41207ms step_avg:87.12ms +step:474/1680 train_time:41295ms step_avg:87.12ms +step:475/1680 train_time:41382ms step_avg:87.12ms +step:476/1680 train_time:41469ms step_avg:87.12ms +step:477/1680 train_time:41555ms step_avg:87.12ms +step:478/1680 train_time:41642ms step_avg:87.12ms +step:479/1680 train_time:41729ms step_avg:87.12ms +step:480/1680 train_time:41817ms step_avg:87.12ms +step:481/1680 train_time:41904ms step_avg:87.12ms +step:482/1680 train_time:41991ms step_avg:87.12ms +step:483/1680 train_time:42078ms step_avg:87.12ms +step:484/1680 train_time:42165ms step_avg:87.12ms +step:485/1680 train_time:42252ms step_avg:87.12ms +step:486/1680 train_time:42339ms step_avg:87.12ms +step:487/1680 train_time:42425ms step_avg:87.12ms +step:488/1680 train_time:42513ms step_avg:87.12ms +step:489/1680 train_time:42599ms step_avg:87.12ms +step:490/1680 train_time:42686ms step_avg:87.11ms +step:491/1680 train_time:42774ms step_avg:87.12ms +step:492/1680 train_time:42860ms step_avg:87.11ms +step:493/1680 train_time:42948ms step_avg:87.12ms +step:494/1680 train_time:43036ms step_avg:87.12ms +step:495/1680 train_time:43122ms step_avg:87.12ms +step:496/1680 train_time:43210ms step_avg:87.12ms +step:497/1680 train_time:43297ms step_avg:87.12ms +step:498/1680 train_time:43384ms step_avg:87.12ms +step:499/1680 train_time:43472ms step_avg:87.12ms +step:500/1680 train_time:43558ms step_avg:87.12ms +step:500/1680 val_loss:3.7135 train_time:43647ms step_avg:87.29ms +step:501/1680 train_time:43666ms step_avg:87.16ms +step:502/1680 train_time:43736ms step_avg:87.12ms +step:503/1680 train_time:43826ms step_avg:87.13ms +step:504/1680 train_time:43915ms step_avg:87.13ms +step:505/1680 train_time:44002ms step_avg:87.13ms +step:506/1680 train_time:44088ms step_avg:87.13ms +step:507/1680 train_time:44174ms step_avg:87.13ms +step:508/1680 train_time:44260ms step_avg:87.13ms +step:509/1680 train_time:44347ms step_avg:87.13ms +step:510/1680 train_time:44433ms step_avg:87.12ms +step:511/1680 train_time:44519ms step_avg:87.12ms +step:512/1680 train_time:44606ms step_avg:87.12ms +step:513/1680 train_time:44694ms step_avg:87.12ms +step:514/1680 train_time:44782ms step_avg:87.13ms +step:515/1680 train_time:44871ms step_avg:87.13ms +step:516/1680 train_time:44958ms step_avg:87.13ms +step:517/1680 train_time:45045ms step_avg:87.13ms +step:518/1680 train_time:45131ms step_avg:87.13ms +step:519/1680 train_time:45218ms step_avg:87.12ms +step:520/1680 train_time:45304ms step_avg:87.12ms +step:521/1680 train_time:45391ms step_avg:87.12ms +step:522/1680 train_time:45477ms step_avg:87.12ms +step:523/1680 train_time:45564ms step_avg:87.12ms +step:524/1680 train_time:45651ms step_avg:87.12ms +step:525/1680 train_time:45739ms step_avg:87.12ms +step:526/1680 train_time:45826ms step_avg:87.12ms +step:527/1680 train_time:45914ms step_avg:87.12ms +step:528/1680 train_time:46001ms step_avg:87.12ms +step:529/1680 train_time:46089ms step_avg:87.12ms +step:530/1680 train_time:46176ms step_avg:87.12ms +step:531/1680 train_time:46262ms step_avg:87.12ms +step:532/1680 train_time:46349ms step_avg:87.12ms +step:533/1680 train_time:46436ms step_avg:87.12ms +step:534/1680 train_time:46523ms step_avg:87.12ms +step:535/1680 train_time:46609ms step_avg:87.12ms +step:536/1680 train_time:46696ms step_avg:87.12ms +step:537/1680 train_time:46783ms step_avg:87.12ms +step:538/1680 train_time:46871ms step_avg:87.12ms +step:539/1680 train_time:46959ms step_avg:87.12ms +step:540/1680 train_time:47046ms step_avg:87.12ms +step:541/1680 train_time:47134ms step_avg:87.12ms +step:542/1680 train_time:47221ms step_avg:87.12ms +step:543/1680 train_time:47308ms step_avg:87.12ms +step:544/1680 train_time:47395ms step_avg:87.12ms +step:545/1680 train_time:47481ms step_avg:87.12ms +step:546/1680 train_time:47567ms step_avg:87.12ms +step:547/1680 train_time:47654ms step_avg:87.12ms +step:548/1680 train_time:47741ms step_avg:87.12ms +step:549/1680 train_time:47830ms step_avg:87.12ms +step:550/1680 train_time:47919ms step_avg:87.13ms +step:551/1680 train_time:48008ms step_avg:87.13ms +step:552/1680 train_time:48097ms step_avg:87.13ms +step:553/1680 train_time:48186ms step_avg:87.14ms +step:554/1680 train_time:48275ms step_avg:87.14ms +step:555/1680 train_time:48362ms step_avg:87.14ms +step:556/1680 train_time:48451ms step_avg:87.14ms +step:557/1680 train_time:48539ms step_avg:87.14ms +step:558/1680 train_time:48627ms step_avg:87.15ms +step:559/1680 train_time:48715ms step_avg:87.15ms +step:560/1680 train_time:48803ms step_avg:87.15ms +step:561/1680 train_time:48891ms step_avg:87.15ms +step:562/1680 train_time:48980ms step_avg:87.15ms +step:563/1680 train_time:49068ms step_avg:87.15ms +step:564/1680 train_time:49157ms step_avg:87.16ms +step:565/1680 train_time:49245ms step_avg:87.16ms +step:566/1680 train_time:49333ms step_avg:87.16ms +step:567/1680 train_time:49421ms step_avg:87.16ms +step:568/1680 train_time:49509ms step_avg:87.16ms +step:569/1680 train_time:49598ms step_avg:87.17ms +step:570/1680 train_time:49685ms step_avg:87.17ms +step:571/1680 train_time:49774ms step_avg:87.17ms +step:572/1680 train_time:49863ms step_avg:87.17ms +step:573/1680 train_time:49952ms step_avg:87.18ms +step:574/1680 train_time:50039ms step_avg:87.18ms +step:575/1680 train_time:50128ms step_avg:87.18ms +step:576/1680 train_time:50216ms step_avg:87.18ms +step:577/1680 train_time:50304ms step_avg:87.18ms +step:578/1680 train_time:50392ms step_avg:87.18ms +step:579/1680 train_time:50480ms step_avg:87.18ms +step:580/1680 train_time:50568ms step_avg:87.19ms +step:581/1680 train_time:50656ms step_avg:87.19ms +step:582/1680 train_time:50744ms step_avg:87.19ms +step:583/1680 train_time:50832ms step_avg:87.19ms +step:584/1680 train_time:50921ms step_avg:87.19ms +step:585/1680 train_time:51010ms step_avg:87.20ms +step:586/1680 train_time:51098ms step_avg:87.20ms +step:587/1680 train_time:51187ms step_avg:87.20ms +step:588/1680 train_time:51276ms step_avg:87.20ms +step:589/1680 train_time:51364ms step_avg:87.20ms +step:590/1680 train_time:51451ms step_avg:87.21ms +step:591/1680 train_time:51539ms step_avg:87.21ms +step:592/1680 train_time:51627ms step_avg:87.21ms +step:593/1680 train_time:51715ms step_avg:87.21ms +step:594/1680 train_time:51804ms step_avg:87.21ms +step:595/1680 train_time:51892ms step_avg:87.21ms +step:596/1680 train_time:51980ms step_avg:87.21ms +step:597/1680 train_time:52068ms step_avg:87.22ms +step:598/1680 train_time:52156ms step_avg:87.22ms +step:599/1680 train_time:52244ms step_avg:87.22ms +step:600/1680 train_time:52333ms step_avg:87.22ms +step:601/1680 train_time:52420ms step_avg:87.22ms +step:602/1680 train_time:52508ms step_avg:87.22ms +step:603/1680 train_time:52596ms step_avg:87.22ms +step:604/1680 train_time:52684ms step_avg:87.23ms +step:605/1680 train_time:52773ms step_avg:87.23ms +step:606/1680 train_time:52861ms step_avg:87.23ms +step:607/1680 train_time:52949ms step_avg:87.23ms +step:608/1680 train_time:53038ms step_avg:87.23ms +step:609/1680 train_time:53127ms step_avg:87.24ms +step:610/1680 train_time:53215ms step_avg:87.24ms +step:611/1680 train_time:53302ms step_avg:87.24ms +step:612/1680 train_time:53391ms step_avg:87.24ms +step:613/1680 train_time:53479ms step_avg:87.24ms +step:614/1680 train_time:53566ms step_avg:87.24ms +step:615/1680 train_time:53654ms step_avg:87.24ms +step:616/1680 train_time:53742ms step_avg:87.24ms +step:617/1680 train_time:53830ms step_avg:87.24ms +step:618/1680 train_time:53918ms step_avg:87.25ms +step:619/1680 train_time:54007ms step_avg:87.25ms +step:620/1680 train_time:54095ms step_avg:87.25ms +step:621/1680 train_time:54183ms step_avg:87.25ms +step:622/1680 train_time:54271ms step_avg:87.25ms +step:623/1680 train_time:54360ms step_avg:87.26ms +step:624/1680 train_time:54448ms step_avg:87.26ms +step:625/1680 train_time:54537ms step_avg:87.26ms +step:625/1680 val_loss:3.6147 train_time:54626ms step_avg:87.40ms +step:626/1680 train_time:54648ms step_avg:87.30ms +step:627/1680 train_time:54715ms step_avg:87.26ms +step:628/1680 train_time:54803ms step_avg:87.27ms +step:629/1680 train_time:54893ms step_avg:87.27ms +step:630/1680 train_time:54982ms step_avg:87.27ms +step:631/1680 train_time:55069ms step_avg:87.27ms +step:632/1680 train_time:55156ms step_avg:87.27ms +step:633/1680 train_time:55242ms step_avg:87.27ms +step:634/1680 train_time:55329ms step_avg:87.27ms +step:635/1680 train_time:55417ms step_avg:87.27ms +step:636/1680 train_time:55506ms step_avg:87.27ms +step:637/1680 train_time:55598ms step_avg:87.28ms +step:638/1680 train_time:55687ms step_avg:87.28ms +step:639/1680 train_time:55776ms step_avg:87.29ms +step:640/1680 train_time:55864ms step_avg:87.29ms +step:641/1680 train_time:55953ms step_avg:87.29ms +step:642/1680 train_time:56041ms step_avg:87.29ms +step:643/1680 train_time:56129ms step_avg:87.29ms +step:644/1680 train_time:56216ms step_avg:87.29ms +step:645/1680 train_time:56303ms step_avg:87.29ms +step:646/1680 train_time:56391ms step_avg:87.29ms +step:647/1680 train_time:56479ms step_avg:87.29ms +step:648/1680 train_time:56569ms step_avg:87.30ms +step:649/1680 train_time:56658ms step_avg:87.30ms +step:650/1680 train_time:56747ms step_avg:87.30ms +step:651/1680 train_time:56836ms step_avg:87.31ms +step:652/1680 train_time:56923ms step_avg:87.31ms +step:653/1680 train_time:57012ms step_avg:87.31ms +step:654/1680 train_time:57101ms step_avg:87.31ms +step:655/1680 train_time:57189ms step_avg:87.31ms +step:656/1680 train_time:57276ms step_avg:87.31ms +step:657/1680 train_time:57363ms step_avg:87.31ms +step:658/1680 train_time:57450ms step_avg:87.31ms +step:659/1680 train_time:57539ms step_avg:87.31ms +step:660/1680 train_time:57627ms step_avg:87.31ms +step:661/1680 train_time:57717ms step_avg:87.32ms +step:662/1680 train_time:57805ms step_avg:87.32ms +step:663/1680 train_time:57894ms step_avg:87.32ms +step:664/1680 train_time:57982ms step_avg:87.32ms +step:665/1680 train_time:58071ms step_avg:87.32ms +step:666/1680 train_time:58159ms step_avg:87.33ms +step:667/1680 train_time:58246ms step_avg:87.33ms +step:668/1680 train_time:58334ms step_avg:87.33ms +step:669/1680 train_time:58422ms step_avg:87.33ms +step:670/1680 train_time:58510ms step_avg:87.33ms +step:671/1680 train_time:58599ms step_avg:87.33ms +step:672/1680 train_time:58687ms step_avg:87.33ms +step:673/1680 train_time:58775ms step_avg:87.33ms +step:674/1680 train_time:58864ms step_avg:87.34ms +step:675/1680 train_time:58952ms step_avg:87.34ms +step:676/1680 train_time:59040ms step_avg:87.34ms +step:677/1680 train_time:59128ms step_avg:87.34ms +step:678/1680 train_time:59217ms step_avg:87.34ms +step:679/1680 train_time:59305ms step_avg:87.34ms +step:680/1680 train_time:59393ms step_avg:87.34ms +step:681/1680 train_time:59481ms step_avg:87.34ms +step:682/1680 train_time:59570ms step_avg:87.35ms +step:683/1680 train_time:59659ms step_avg:87.35ms +step:684/1680 train_time:59747ms step_avg:87.35ms +step:685/1680 train_time:59835ms step_avg:87.35ms +step:686/1680 train_time:59924ms step_avg:87.35ms +step:687/1680 train_time:60012ms step_avg:87.35ms +step:688/1680 train_time:60100ms step_avg:87.35ms +step:689/1680 train_time:60188ms step_avg:87.36ms +step:690/1680 train_time:60275ms step_avg:87.36ms +step:691/1680 train_time:60363ms step_avg:87.36ms +step:692/1680 train_time:60451ms step_avg:87.36ms +step:693/1680 train_time:60540ms step_avg:87.36ms +step:694/1680 train_time:60628ms step_avg:87.36ms +step:695/1680 train_time:60717ms step_avg:87.36ms +step:696/1680 train_time:60806ms step_avg:87.36ms +step:697/1680 train_time:60894ms step_avg:87.37ms +step:698/1680 train_time:60982ms step_avg:87.37ms +step:699/1680 train_time:61070ms step_avg:87.37ms +step:700/1680 train_time:61158ms step_avg:87.37ms +step:701/1680 train_time:61247ms step_avg:87.37ms +step:702/1680 train_time:61335ms step_avg:87.37ms +step:703/1680 train_time:61423ms step_avg:87.37ms +step:704/1680 train_time:61511ms step_avg:87.37ms +step:705/1680 train_time:61599ms step_avg:87.37ms +step:706/1680 train_time:61687ms step_avg:87.38ms +step:707/1680 train_time:61776ms step_avg:87.38ms +step:708/1680 train_time:61864ms step_avg:87.38ms +step:709/1680 train_time:61952ms step_avg:87.38ms +step:710/1680 train_time:62041ms step_avg:87.38ms +step:711/1680 train_time:62130ms step_avg:87.38ms +step:712/1680 train_time:62218ms step_avg:87.39ms +step:713/1680 train_time:62306ms step_avg:87.39ms +step:714/1680 train_time:62394ms step_avg:87.39ms +step:715/1680 train_time:62483ms step_avg:87.39ms +step:716/1680 train_time:62570ms step_avg:87.39ms +step:717/1680 train_time:62658ms step_avg:87.39ms +step:718/1680 train_time:62746ms step_avg:87.39ms +step:719/1680 train_time:62835ms step_avg:87.39ms +step:720/1680 train_time:62923ms step_avg:87.39ms +step:721/1680 train_time:63011ms step_avg:87.39ms +step:722/1680 train_time:63100ms step_avg:87.40ms +step:723/1680 train_time:63188ms step_avg:87.40ms +step:724/1680 train_time:63276ms step_avg:87.40ms +step:725/1680 train_time:63364ms step_avg:87.40ms +step:726/1680 train_time:63453ms step_avg:87.40ms +step:727/1680 train_time:63540ms step_avg:87.40ms +step:728/1680 train_time:63628ms step_avg:87.40ms +step:729/1680 train_time:63716ms step_avg:87.40ms +step:730/1680 train_time:63804ms step_avg:87.40ms +step:731/1680 train_time:63891ms step_avg:87.40ms +step:732/1680 train_time:63980ms step_avg:87.40ms +step:733/1680 train_time:64068ms step_avg:87.41ms +step:734/1680 train_time:64156ms step_avg:87.41ms +step:735/1680 train_time:64245ms step_avg:87.41ms +step:736/1680 train_time:64333ms step_avg:87.41ms +step:737/1680 train_time:64421ms step_avg:87.41ms +step:738/1680 train_time:64509ms step_avg:87.41ms +step:739/1680 train_time:64598ms step_avg:87.41ms +step:740/1680 train_time:64686ms step_avg:87.41ms +step:741/1680 train_time:64774ms step_avg:87.41ms +step:742/1680 train_time:64862ms step_avg:87.42ms +step:743/1680 train_time:64950ms step_avg:87.42ms +step:744/1680 train_time:65038ms step_avg:87.42ms +step:745/1680 train_time:65126ms step_avg:87.42ms +step:746/1680 train_time:65214ms step_avg:87.42ms +step:747/1680 train_time:65303ms step_avg:87.42ms +step:748/1680 train_time:65391ms step_avg:87.42ms +step:749/1680 train_time:65480ms step_avg:87.42ms +step:750/1680 train_time:65568ms step_avg:87.42ms +step:750/1680 val_loss:3.5637 train_time:65658ms step_avg:87.54ms +step:751/1680 train_time:65676ms step_avg:87.45ms +step:752/1680 train_time:65750ms step_avg:87.43ms +step:753/1680 train_time:65842ms step_avg:87.44ms +step:754/1680 train_time:65931ms step_avg:87.44ms +step:755/1680 train_time:66019ms step_avg:87.44ms +step:756/1680 train_time:66106ms step_avg:87.44ms +step:757/1680 train_time:66194ms step_avg:87.44ms +step:758/1680 train_time:66281ms step_avg:87.44ms +step:759/1680 train_time:66368ms step_avg:87.44ms +step:760/1680 train_time:66455ms step_avg:87.44ms +step:761/1680 train_time:66543ms step_avg:87.44ms +step:762/1680 train_time:66631ms step_avg:87.44ms +step:763/1680 train_time:66722ms step_avg:87.45ms +step:764/1680 train_time:66813ms step_avg:87.45ms +step:765/1680 train_time:66902ms step_avg:87.45ms +step:766/1680 train_time:66990ms step_avg:87.45ms +step:767/1680 train_time:67079ms step_avg:87.46ms +step:768/1680 train_time:67167ms step_avg:87.46ms +step:769/1680 train_time:67254ms step_avg:87.46ms +step:770/1680 train_time:67341ms step_avg:87.46ms +step:771/1680 train_time:67429ms step_avg:87.46ms +step:772/1680 train_time:67516ms step_avg:87.46ms +step:773/1680 train_time:67604ms step_avg:87.46ms +step:774/1680 train_time:67693ms step_avg:87.46ms +step:775/1680 train_time:67782ms step_avg:87.46ms +step:776/1680 train_time:67871ms step_avg:87.46ms +step:777/1680 train_time:67959ms step_avg:87.46ms +step:778/1680 train_time:68049ms step_avg:87.47ms +step:779/1680 train_time:68137ms step_avg:87.47ms +step:780/1680 train_time:68224ms step_avg:87.47ms +step:781/1680 train_time:68312ms step_avg:87.47ms +step:782/1680 train_time:68399ms step_avg:87.47ms +step:783/1680 train_time:68488ms step_avg:87.47ms +step:784/1680 train_time:68576ms step_avg:87.47ms +step:785/1680 train_time:68665ms step_avg:87.47ms +step:786/1680 train_time:68753ms step_avg:87.47ms +step:787/1680 train_time:68843ms step_avg:87.47ms +step:788/1680 train_time:68931ms step_avg:87.48ms +step:789/1680 train_time:69020ms step_avg:87.48ms +step:790/1680 train_time:69107ms step_avg:87.48ms +step:791/1680 train_time:69195ms step_avg:87.48ms +step:792/1680 train_time:69283ms step_avg:87.48ms +step:793/1680 train_time:69371ms step_avg:87.48ms +step:794/1680 train_time:69459ms step_avg:87.48ms +step:795/1680 train_time:69546ms step_avg:87.48ms +step:796/1680 train_time:69634ms step_avg:87.48ms +step:797/1680 train_time:69723ms step_avg:87.48ms +step:798/1680 train_time:69811ms step_avg:87.48ms +step:799/1680 train_time:69899ms step_avg:87.48ms +step:800/1680 train_time:69988ms step_avg:87.49ms +step:801/1680 train_time:70077ms step_avg:87.49ms +step:802/1680 train_time:70164ms step_avg:87.49ms +step:803/1680 train_time:70252ms step_avg:87.49ms +step:804/1680 train_time:70340ms step_avg:87.49ms +step:805/1680 train_time:70428ms step_avg:87.49ms +step:806/1680 train_time:70516ms step_avg:87.49ms +step:807/1680 train_time:70604ms step_avg:87.49ms +step:808/1680 train_time:70693ms step_avg:87.49ms +step:809/1680 train_time:70782ms step_avg:87.49ms +step:810/1680 train_time:70871ms step_avg:87.50ms +step:811/1680 train_time:70960ms step_avg:87.50ms +step:812/1680 train_time:71048ms step_avg:87.50ms +step:813/1680 train_time:71136ms step_avg:87.50ms +step:814/1680 train_time:71224ms step_avg:87.50ms +step:815/1680 train_time:71312ms step_avg:87.50ms +step:816/1680 train_time:71399ms step_avg:87.50ms +step:817/1680 train_time:71487ms step_avg:87.50ms +step:818/1680 train_time:71575ms step_avg:87.50ms +step:819/1680 train_time:71664ms step_avg:87.50ms +step:820/1680 train_time:71753ms step_avg:87.50ms +step:821/1680 train_time:71841ms step_avg:87.50ms +step:822/1680 train_time:71930ms step_avg:87.51ms +step:823/1680 train_time:72019ms step_avg:87.51ms +step:824/1680 train_time:72107ms step_avg:87.51ms +step:825/1680 train_time:72195ms step_avg:87.51ms +step:826/1680 train_time:72283ms step_avg:87.51ms +step:827/1680 train_time:72371ms step_avg:87.51ms +step:828/1680 train_time:72459ms step_avg:87.51ms +step:829/1680 train_time:72548ms step_avg:87.51ms +step:830/1680 train_time:72636ms step_avg:87.51ms +step:831/1680 train_time:72723ms step_avg:87.51ms +step:832/1680 train_time:72812ms step_avg:87.51ms +step:833/1680 train_time:72900ms step_avg:87.51ms +step:834/1680 train_time:72989ms step_avg:87.52ms +step:835/1680 train_time:73077ms step_avg:87.52ms +step:836/1680 train_time:73166ms step_avg:87.52ms +step:837/1680 train_time:73254ms step_avg:87.52ms +step:838/1680 train_time:73342ms step_avg:87.52ms +step:839/1680 train_time:73430ms step_avg:87.52ms +step:840/1680 train_time:73519ms step_avg:87.52ms +step:841/1680 train_time:73608ms step_avg:87.52ms +step:842/1680 train_time:73696ms step_avg:87.52ms +step:843/1680 train_time:73784ms step_avg:87.53ms +step:844/1680 train_time:73872ms step_avg:87.53ms +step:845/1680 train_time:73961ms step_avg:87.53ms +step:846/1680 train_time:74050ms step_avg:87.53ms +step:847/1680 train_time:74138ms step_avg:87.53ms +step:848/1680 train_time:74225ms step_avg:87.53ms +step:849/1680 train_time:74313ms step_avg:87.53ms +step:850/1680 train_time:74401ms step_avg:87.53ms +step:851/1680 train_time:74490ms step_avg:87.53ms +step:852/1680 train_time:74578ms step_avg:87.53ms +step:853/1680 train_time:74666ms step_avg:87.53ms +step:854/1680 train_time:74755ms step_avg:87.53ms +step:855/1680 train_time:74842ms step_avg:87.54ms +step:856/1680 train_time:74931ms step_avg:87.54ms +step:857/1680 train_time:75020ms step_avg:87.54ms +step:858/1680 train_time:75108ms step_avg:87.54ms +step:859/1680 train_time:75196ms step_avg:87.54ms +step:860/1680 train_time:75284ms step_avg:87.54ms +step:861/1680 train_time:75372ms step_avg:87.54ms +step:862/1680 train_time:75460ms step_avg:87.54ms +step:863/1680 train_time:75549ms step_avg:87.54ms +step:864/1680 train_time:75637ms step_avg:87.54ms +step:865/1680 train_time:75724ms step_avg:87.54ms +step:866/1680 train_time:75812ms step_avg:87.54ms +step:867/1680 train_time:75900ms step_avg:87.54ms +step:868/1680 train_time:75989ms step_avg:87.55ms +step:869/1680 train_time:76077ms step_avg:87.55ms +step:870/1680 train_time:76165ms step_avg:87.55ms +step:871/1680 train_time:76254ms step_avg:87.55ms +step:872/1680 train_time:76342ms step_avg:87.55ms +step:873/1680 train_time:76430ms step_avg:87.55ms +step:874/1680 train_time:76518ms step_avg:87.55ms +step:875/1680 train_time:76605ms step_avg:87.55ms +step:875/1680 val_loss:3.5182 train_time:76695ms step_avg:87.65ms +step:876/1680 train_time:76714ms step_avg:87.57ms +step:877/1680 train_time:76785ms step_avg:87.55ms +step:878/1680 train_time:76879ms step_avg:87.56ms +step:879/1680 train_time:76968ms step_avg:87.56ms +step:880/1680 train_time:77055ms step_avg:87.56ms +step:881/1680 train_time:77142ms step_avg:87.56ms +step:882/1680 train_time:77230ms step_avg:87.56ms +step:883/1680 train_time:77318ms step_avg:87.56ms +step:884/1680 train_time:77405ms step_avg:87.56ms +step:885/1680 train_time:77492ms step_avg:87.56ms +step:886/1680 train_time:77580ms step_avg:87.56ms +step:887/1680 train_time:77667ms step_avg:87.56ms +step:888/1680 train_time:77758ms step_avg:87.57ms +step:889/1680 train_time:77848ms step_avg:87.57ms +step:890/1680 train_time:77938ms step_avg:87.57ms +step:891/1680 train_time:78026ms step_avg:87.57ms +step:892/1680 train_time:78115ms step_avg:87.57ms +step:893/1680 train_time:78203ms step_avg:87.57ms +step:894/1680 train_time:78291ms step_avg:87.57ms +step:895/1680 train_time:78378ms step_avg:87.57ms +step:896/1680 train_time:78465ms step_avg:87.57ms +step:897/1680 train_time:78553ms step_avg:87.57ms +step:898/1680 train_time:78641ms step_avg:87.57ms +step:899/1680 train_time:78731ms step_avg:87.58ms +step:900/1680 train_time:78821ms step_avg:87.58ms +step:901/1680 train_time:78911ms step_avg:87.58ms +step:902/1680 train_time:79000ms step_avg:87.58ms +step:903/1680 train_time:79088ms step_avg:87.58ms +step:904/1680 train_time:79176ms step_avg:87.58ms +step:905/1680 train_time:79265ms step_avg:87.59ms +step:906/1680 train_time:79352ms step_avg:87.59ms +step:907/1680 train_time:79441ms step_avg:87.59ms +step:908/1680 train_time:79529ms step_avg:87.59ms +step:909/1680 train_time:79617ms step_avg:87.59ms +step:910/1680 train_time:79705ms step_avg:87.59ms +step:911/1680 train_time:79794ms step_avg:87.59ms +step:912/1680 train_time:79883ms step_avg:87.59ms +step:913/1680 train_time:79972ms step_avg:87.59ms +step:914/1680 train_time:80060ms step_avg:87.59ms +step:915/1680 train_time:80148ms step_avg:87.59ms +step:916/1680 train_time:80237ms step_avg:87.59ms +step:917/1680 train_time:80325ms step_avg:87.60ms +step:918/1680 train_time:80413ms step_avg:87.60ms +step:919/1680 train_time:80501ms step_avg:87.60ms +step:920/1680 train_time:80589ms step_avg:87.60ms +step:921/1680 train_time:80678ms step_avg:87.60ms +step:922/1680 train_time:80766ms step_avg:87.60ms +step:923/1680 train_time:80854ms step_avg:87.60ms +step:924/1680 train_time:80943ms step_avg:87.60ms +step:925/1680 train_time:81032ms step_avg:87.60ms +step:926/1680 train_time:81120ms step_avg:87.60ms +step:927/1680 train_time:81210ms step_avg:87.60ms +step:928/1680 train_time:81299ms step_avg:87.61ms +step:929/1680 train_time:81387ms step_avg:87.61ms +step:930/1680 train_time:81474ms step_avg:87.61ms +step:931/1680 train_time:81562ms step_avg:87.61ms +step:932/1680 train_time:81651ms step_avg:87.61ms +step:933/1680 train_time:81739ms step_avg:87.61ms +step:934/1680 train_time:81828ms step_avg:87.61ms +step:935/1680 train_time:81916ms step_avg:87.61ms +step:936/1680 train_time:82004ms step_avg:87.61ms +step:937/1680 train_time:82092ms step_avg:87.61ms +step:938/1680 train_time:82181ms step_avg:87.61ms +step:939/1680 train_time:82270ms step_avg:87.61ms +step:940/1680 train_time:82358ms step_avg:87.61ms +step:941/1680 train_time:82446ms step_avg:87.61ms +step:942/1680 train_time:82534ms step_avg:87.62ms +step:943/1680 train_time:82622ms step_avg:87.62ms +step:944/1680 train_time:82711ms step_avg:87.62ms +step:945/1680 train_time:82801ms step_avg:87.62ms +step:946/1680 train_time:82889ms step_avg:87.62ms +step:947/1680 train_time:82977ms step_avg:87.62ms +step:948/1680 train_time:83065ms step_avg:87.62ms +step:949/1680 train_time:83153ms step_avg:87.62ms +step:950/1680 train_time:83242ms step_avg:87.62ms +step:951/1680 train_time:83330ms step_avg:87.62ms +step:952/1680 train_time:83418ms step_avg:87.62ms +step:953/1680 train_time:83506ms step_avg:87.62ms +step:954/1680 train_time:83593ms step_avg:87.62ms +step:955/1680 train_time:83681ms step_avg:87.62ms +step:956/1680 train_time:83770ms step_avg:87.63ms +step:957/1680 train_time:83858ms step_avg:87.63ms +step:958/1680 train_time:83947ms step_avg:87.63ms +step:959/1680 train_time:84036ms step_avg:87.63ms +step:960/1680 train_time:84124ms step_avg:87.63ms +step:961/1680 train_time:84212ms step_avg:87.63ms +step:962/1680 train_time:84302ms step_avg:87.63ms +step:963/1680 train_time:84391ms step_avg:87.63ms +step:964/1680 train_time:84478ms step_avg:87.63ms +step:965/1680 train_time:84566ms step_avg:87.63ms +step:966/1680 train_time:84654ms step_avg:87.63ms +step:967/1680 train_time:84743ms step_avg:87.63ms +step:968/1680 train_time:84831ms step_avg:87.64ms +step:969/1680 train_time:84919ms step_avg:87.64ms +step:970/1680 train_time:85008ms step_avg:87.64ms +step:971/1680 train_time:85097ms step_avg:87.64ms +step:972/1680 train_time:85185ms step_avg:87.64ms +step:973/1680 train_time:85273ms step_avg:87.64ms +step:974/1680 train_time:85361ms step_avg:87.64ms +step:975/1680 train_time:85450ms step_avg:87.64ms +step:976/1680 train_time:85537ms step_avg:87.64ms +step:977/1680 train_time:85626ms step_avg:87.64ms +step:978/1680 train_time:85714ms step_avg:87.64ms +step:979/1680 train_time:85803ms step_avg:87.64ms +step:980/1680 train_time:85891ms step_avg:87.64ms +step:981/1680 train_time:85979ms step_avg:87.64ms +step:982/1680 train_time:86067ms step_avg:87.64ms +step:983/1680 train_time:86155ms step_avg:87.65ms +step:984/1680 train_time:86244ms step_avg:87.65ms +step:985/1680 train_time:86332ms step_avg:87.65ms +step:986/1680 train_time:86421ms step_avg:87.65ms +step:987/1680 train_time:86509ms step_avg:87.65ms +step:988/1680 train_time:86598ms step_avg:87.65ms +step:989/1680 train_time:86685ms step_avg:87.65ms +step:990/1680 train_time:86773ms step_avg:87.65ms +step:991/1680 train_time:86861ms step_avg:87.65ms +step:992/1680 train_time:86949ms step_avg:87.65ms +step:993/1680 train_time:87037ms step_avg:87.65ms +step:994/1680 train_time:87125ms step_avg:87.65ms +step:995/1680 train_time:87214ms step_avg:87.65ms +step:996/1680 train_time:87302ms step_avg:87.65ms +step:997/1680 train_time:87391ms step_avg:87.65ms +step:998/1680 train_time:87480ms step_avg:87.66ms +step:999/1680 train_time:87568ms step_avg:87.66ms +step:1000/1680 train_time:87655ms step_avg:87.66ms +step:1000/1680 val_loss:3.4694 train_time:87746ms step_avg:87.75ms +step:1001/1680 train_time:87764ms step_avg:87.68ms +step:1002/1680 train_time:87837ms step_avg:87.66ms +step:1003/1680 train_time:87930ms step_avg:87.67ms +step:1004/1680 train_time:88018ms step_avg:87.67ms +step:1005/1680 train_time:88105ms step_avg:87.67ms +step:1006/1680 train_time:88194ms step_avg:87.67ms +step:1007/1680 train_time:88281ms step_avg:87.67ms +step:1008/1680 train_time:88369ms step_avg:87.67ms +step:1009/1680 train_time:88456ms step_avg:87.67ms +step:1010/1680 train_time:88544ms step_avg:87.67ms +step:1011/1680 train_time:88631ms step_avg:87.67ms +step:1012/1680 train_time:88720ms step_avg:87.67ms +step:1013/1680 train_time:88810ms step_avg:87.67ms +step:1014/1680 train_time:88901ms step_avg:87.67ms +step:1015/1680 train_time:88990ms step_avg:87.67ms +step:1016/1680 train_time:89078ms step_avg:87.68ms +step:1017/1680 train_time:89166ms step_avg:87.68ms +step:1018/1680 train_time:89254ms step_avg:87.68ms +step:1019/1680 train_time:89341ms step_avg:87.68ms +step:1020/1680 train_time:89429ms step_avg:87.68ms +step:1021/1680 train_time:89516ms step_avg:87.67ms +step:1022/1680 train_time:89603ms step_avg:87.67ms +step:1023/1680 train_time:89693ms step_avg:87.68ms +step:1024/1680 train_time:89782ms step_avg:87.68ms +step:1025/1680 train_time:89872ms step_avg:87.68ms +step:1026/1680 train_time:89961ms step_avg:87.68ms +step:1027/1680 train_time:90050ms step_avg:87.68ms +step:1028/1680 train_time:90138ms step_avg:87.68ms +step:1029/1680 train_time:90225ms step_avg:87.68ms +step:1030/1680 train_time:90313ms step_avg:87.68ms +step:1031/1680 train_time:90400ms step_avg:87.68ms +step:1032/1680 train_time:90488ms step_avg:87.68ms +step:1033/1680 train_time:90575ms step_avg:87.68ms +step:1034/1680 train_time:90664ms step_avg:87.68ms +step:1035/1680 train_time:90752ms step_avg:87.68ms +step:1036/1680 train_time:90841ms step_avg:87.68ms +step:1037/1680 train_time:90931ms step_avg:87.69ms +step:1038/1680 train_time:91020ms step_avg:87.69ms +step:1039/1680 train_time:91108ms step_avg:87.69ms +step:1040/1680 train_time:91196ms step_avg:87.69ms +step:1041/1680 train_time:91285ms step_avg:87.69ms +step:1042/1680 train_time:91373ms step_avg:87.69ms +step:1043/1680 train_time:91461ms step_avg:87.69ms +step:1044/1680 train_time:91549ms step_avg:87.69ms +step:1045/1680 train_time:91636ms step_avg:87.69ms +step:1046/1680 train_time:91725ms step_avg:87.69ms +step:1047/1680 train_time:91814ms step_avg:87.69ms +step:1048/1680 train_time:91902ms step_avg:87.69ms +step:1049/1680 train_time:91992ms step_avg:87.69ms +step:1050/1680 train_time:92081ms step_avg:87.70ms +step:1051/1680 train_time:92170ms step_avg:87.70ms +step:1052/1680 train_time:92258ms step_avg:87.70ms +step:1053/1680 train_time:92346ms step_avg:87.70ms +step:1054/1680 train_time:92433ms step_avg:87.70ms +step:1055/1680 train_time:92521ms step_avg:87.70ms +step:1056/1680 train_time:92609ms step_avg:87.70ms +step:1057/1680 train_time:92697ms step_avg:87.70ms +step:1058/1680 train_time:92786ms step_avg:87.70ms +step:1059/1680 train_time:92875ms step_avg:87.70ms +step:1060/1680 train_time:92964ms step_avg:87.70ms +step:1061/1680 train_time:93053ms step_avg:87.70ms +step:1062/1680 train_time:93142ms step_avg:87.70ms +step:1063/1680 train_time:93230ms step_avg:87.70ms +step:1064/1680 train_time:93318ms step_avg:87.70ms +step:1065/1680 train_time:93406ms step_avg:87.71ms +step:1066/1680 train_time:93494ms step_avg:87.71ms +step:1067/1680 train_time:93582ms step_avg:87.71ms +step:1068/1680 train_time:93670ms step_avg:87.71ms +step:1069/1680 train_time:93758ms step_avg:87.71ms +step:1070/1680 train_time:93846ms step_avg:87.71ms +step:1071/1680 train_time:93935ms step_avg:87.71ms +step:1072/1680 train_time:94024ms step_avg:87.71ms +step:1073/1680 train_time:94113ms step_avg:87.71ms +step:1074/1680 train_time:94201ms step_avg:87.71ms +step:1075/1680 train_time:94291ms step_avg:87.71ms +step:1076/1680 train_time:94378ms step_avg:87.71ms +step:1077/1680 train_time:94467ms step_avg:87.71ms +step:1078/1680 train_time:94555ms step_avg:87.71ms +step:1079/1680 train_time:94643ms step_avg:87.71ms +step:1080/1680 train_time:94732ms step_avg:87.71ms +step:1081/1680 train_time:94820ms step_avg:87.72ms +step:1082/1680 train_time:94908ms step_avg:87.72ms +step:1083/1680 train_time:94997ms step_avg:87.72ms +step:1084/1680 train_time:95085ms step_avg:87.72ms +step:1085/1680 train_time:95173ms step_avg:87.72ms +step:1086/1680 train_time:95262ms step_avg:87.72ms +step:1087/1680 train_time:95350ms step_avg:87.72ms +step:1088/1680 train_time:95438ms step_avg:87.72ms +step:1089/1680 train_time:95527ms step_avg:87.72ms +step:1090/1680 train_time:95615ms step_avg:87.72ms +step:1091/1680 train_time:95703ms step_avg:87.72ms +step:1092/1680 train_time:95792ms step_avg:87.72ms +step:1093/1680 train_time:95881ms step_avg:87.72ms +step:1094/1680 train_time:95969ms step_avg:87.72ms +step:1095/1680 train_time:96058ms step_avg:87.72ms +step:1096/1680 train_time:96147ms step_avg:87.73ms +step:1097/1680 train_time:96235ms step_avg:87.73ms +step:1098/1680 train_time:96324ms step_avg:87.73ms +step:1099/1680 train_time:96413ms step_avg:87.73ms +step:1100/1680 train_time:96501ms step_avg:87.73ms +step:1101/1680 train_time:96590ms step_avg:87.73ms +step:1102/1680 train_time:96679ms step_avg:87.73ms +step:1103/1680 train_time:96768ms step_avg:87.73ms +step:1104/1680 train_time:96857ms step_avg:87.73ms +step:1105/1680 train_time:96947ms step_avg:87.73ms +step:1106/1680 train_time:97036ms step_avg:87.74ms +step:1107/1680 train_time:97125ms step_avg:87.74ms +step:1108/1680 train_time:97214ms step_avg:87.74ms +step:1109/1680 train_time:97303ms step_avg:87.74ms +step:1110/1680 train_time:97391ms step_avg:87.74ms +step:1111/1680 train_time:97480ms step_avg:87.74ms +step:1112/1680 train_time:97570ms step_avg:87.74ms +step:1113/1680 train_time:97658ms step_avg:87.74ms +step:1114/1680 train_time:97747ms step_avg:87.74ms +step:1115/1680 train_time:97836ms step_avg:87.75ms +step:1116/1680 train_time:97926ms step_avg:87.75ms +step:1117/1680 train_time:98015ms step_avg:87.75ms +step:1118/1680 train_time:98104ms step_avg:87.75ms +step:1119/1680 train_time:98193ms step_avg:87.75ms +step:1120/1680 train_time:98281ms step_avg:87.75ms +step:1121/1680 train_time:98370ms step_avg:87.75ms +step:1122/1680 train_time:98459ms step_avg:87.75ms +step:1123/1680 train_time:98548ms step_avg:87.75ms +step:1124/1680 train_time:98636ms step_avg:87.75ms +step:1125/1680 train_time:98726ms step_avg:87.76ms +step:1125/1680 val_loss:3.4159 train_time:98816ms step_avg:87.84ms +step:1126/1680 train_time:98834ms step_avg:87.77ms +step:1127/1680 train_time:98907ms step_avg:87.76ms +step:1128/1680 train_time:98996ms step_avg:87.76ms +step:1129/1680 train_time:99089ms step_avg:87.77ms +step:1130/1680 train_time:99178ms step_avg:87.77ms +step:1131/1680 train_time:99267ms step_avg:87.77ms +step:1132/1680 train_time:99355ms step_avg:87.77ms +step:1133/1680 train_time:99443ms step_avg:87.77ms +step:1134/1680 train_time:99531ms step_avg:87.77ms +step:1135/1680 train_time:99619ms step_avg:87.77ms +step:1136/1680 train_time:99707ms step_avg:87.77ms +step:1137/1680 train_time:99797ms step_avg:87.77ms +step:1138/1680 train_time:99887ms step_avg:87.77ms +step:1139/1680 train_time:99978ms step_avg:87.78ms +step:1140/1680 train_time:100068ms step_avg:87.78ms +step:1141/1680 train_time:100156ms step_avg:87.78ms +step:1142/1680 train_time:100245ms step_avg:87.78ms +step:1143/1680 train_time:100334ms step_avg:87.78ms +step:1144/1680 train_time:100423ms step_avg:87.78ms +step:1145/1680 train_time:100511ms step_avg:87.78ms +step:1146/1680 train_time:100599ms step_avg:87.78ms +step:1147/1680 train_time:100688ms step_avg:87.78ms +step:1148/1680 train_time:100776ms step_avg:87.78ms +step:1149/1680 train_time:100866ms step_avg:87.79ms +step:1150/1680 train_time:100956ms step_avg:87.79ms +step:1151/1680 train_time:101046ms step_avg:87.79ms +step:1152/1680 train_time:101136ms step_avg:87.79ms +step:1153/1680 train_time:101225ms step_avg:87.79ms +step:1154/1680 train_time:101313ms step_avg:87.79ms +step:1155/1680 train_time:101402ms step_avg:87.79ms +step:1156/1680 train_time:101490ms step_avg:87.79ms +step:1157/1680 train_time:101578ms step_avg:87.79ms +step:1158/1680 train_time:101667ms step_avg:87.80ms +step:1159/1680 train_time:101755ms step_avg:87.80ms +step:1160/1680 train_time:101845ms step_avg:87.80ms +step:1161/1680 train_time:101934ms step_avg:87.80ms +step:1162/1680 train_time:102024ms step_avg:87.80ms +step:1163/1680 train_time:102113ms step_avg:87.80ms +step:1164/1680 train_time:102203ms step_avg:87.80ms +step:1165/1680 train_time:102292ms step_avg:87.80ms +step:1166/1680 train_time:102380ms step_avg:87.80ms +step:1167/1680 train_time:102469ms step_avg:87.81ms +step:1168/1680 train_time:102558ms step_avg:87.81ms +step:1169/1680 train_time:102646ms step_avg:87.81ms +step:1170/1680 train_time:102734ms step_avg:87.81ms +step:1171/1680 train_time:102823ms step_avg:87.81ms +step:1172/1680 train_time:102913ms step_avg:87.81ms +step:1173/1680 train_time:103002ms step_avg:87.81ms +step:1174/1680 train_time:103092ms step_avg:87.81ms +step:1175/1680 train_time:103181ms step_avg:87.81ms +step:1176/1680 train_time:103270ms step_avg:87.81ms +step:1177/1680 train_time:103359ms step_avg:87.82ms +step:1178/1680 train_time:103448ms step_avg:87.82ms +step:1179/1680 train_time:103536ms step_avg:87.82ms +step:1180/1680 train_time:103625ms step_avg:87.82ms +step:1181/1680 train_time:103714ms step_avg:87.82ms +step:1182/1680 train_time:103803ms step_avg:87.82ms +step:1183/1680 train_time:103892ms step_avg:87.82ms +step:1184/1680 train_time:103981ms step_avg:87.82ms +step:1185/1680 train_time:104070ms step_avg:87.82ms +step:1186/1680 train_time:104159ms step_avg:87.82ms +step:1187/1680 train_time:104249ms step_avg:87.83ms +step:1188/1680 train_time:104337ms step_avg:87.83ms +step:1189/1680 train_time:104426ms step_avg:87.83ms +step:1190/1680 train_time:104515ms step_avg:87.83ms +step:1191/1680 train_time:104603ms step_avg:87.83ms +step:1192/1680 train_time:104692ms step_avg:87.83ms +step:1193/1680 train_time:104780ms step_avg:87.83ms +step:1194/1680 train_time:104869ms step_avg:87.83ms +step:1195/1680 train_time:104959ms step_avg:87.83ms +step:1196/1680 train_time:105047ms step_avg:87.83ms +step:1197/1680 train_time:105136ms step_avg:87.83ms +step:1198/1680 train_time:105226ms step_avg:87.83ms +step:1199/1680 train_time:105315ms step_avg:87.84ms +step:1200/1680 train_time:105404ms step_avg:87.84ms +step:1201/1680 train_time:105493ms step_avg:87.84ms +step:1202/1680 train_time:105582ms step_avg:87.84ms +step:1203/1680 train_time:105670ms step_avg:87.84ms +step:1204/1680 train_time:105758ms step_avg:87.84ms +step:1205/1680 train_time:105849ms step_avg:87.84ms +step:1206/1680 train_time:105938ms step_avg:87.84ms +step:1207/1680 train_time:106027ms step_avg:87.84ms +step:1208/1680 train_time:106116ms step_avg:87.84ms +step:1209/1680 train_time:106205ms step_avg:87.85ms +step:1210/1680 train_time:106293ms step_avg:87.85ms +step:1211/1680 train_time:106383ms step_avg:87.85ms +step:1212/1680 train_time:106472ms step_avg:87.85ms +step:1213/1680 train_time:106562ms step_avg:87.85ms +step:1214/1680 train_time:106650ms step_avg:87.85ms +step:1215/1680 train_time:106739ms step_avg:87.85ms +step:1216/1680 train_time:106828ms step_avg:87.85ms +step:1217/1680 train_time:106917ms step_avg:87.85ms +step:1218/1680 train_time:107007ms step_avg:87.85ms +step:1219/1680 train_time:107095ms step_avg:87.85ms +step:1220/1680 train_time:107184ms step_avg:87.86ms +step:1221/1680 train_time:107274ms step_avg:87.86ms +step:1222/1680 train_time:107362ms step_avg:87.86ms +step:1223/1680 train_time:107451ms step_avg:87.86ms +step:1224/1680 train_time:107541ms step_avg:87.86ms +step:1225/1680 train_time:107630ms step_avg:87.86ms +step:1226/1680 train_time:107720ms step_avg:87.86ms +step:1227/1680 train_time:107810ms step_avg:87.86ms +step:1228/1680 train_time:107898ms step_avg:87.86ms +step:1229/1680 train_time:107986ms step_avg:87.87ms +step:1230/1680 train_time:108075ms step_avg:87.87ms +step:1231/1680 train_time:108164ms step_avg:87.87ms +step:1232/1680 train_time:108253ms step_avg:87.87ms +step:1233/1680 train_time:108342ms step_avg:87.87ms +step:1234/1680 train_time:108431ms step_avg:87.87ms +step:1235/1680 train_time:108520ms step_avg:87.87ms +step:1236/1680 train_time:108609ms step_avg:87.87ms +step:1237/1680 train_time:108697ms step_avg:87.87ms +step:1238/1680 train_time:108786ms step_avg:87.87ms +step:1239/1680 train_time:108875ms step_avg:87.87ms +step:1240/1680 train_time:108963ms step_avg:87.87ms +step:1241/1680 train_time:109053ms step_avg:87.87ms +step:1242/1680 train_time:109142ms step_avg:87.88ms +step:1243/1680 train_time:109231ms step_avg:87.88ms +step:1244/1680 train_time:109320ms step_avg:87.88ms +step:1245/1680 train_time:109409ms step_avg:87.88ms +step:1246/1680 train_time:109498ms step_avg:87.88ms +step:1247/1680 train_time:109587ms step_avg:87.88ms +step:1248/1680 train_time:109676ms step_avg:87.88ms +step:1249/1680 train_time:109765ms step_avg:87.88ms +step:1250/1680 train_time:109855ms step_avg:87.88ms +step:1250/1680 val_loss:3.3773 train_time:109945ms step_avg:87.96ms +step:1251/1680 train_time:109966ms step_avg:87.90ms +step:1252/1680 train_time:110039ms step_avg:87.89ms +step:1253/1680 train_time:110132ms step_avg:87.89ms +step:1254/1680 train_time:110221ms step_avg:87.90ms +step:1255/1680 train_time:110309ms step_avg:87.90ms +step:1256/1680 train_time:110397ms step_avg:87.90ms +step:1257/1680 train_time:110485ms step_avg:87.90ms +step:1258/1680 train_time:110573ms step_avg:87.90ms +step:1259/1680 train_time:110661ms step_avg:87.90ms +step:1260/1680 train_time:110749ms step_avg:87.90ms +step:1261/1680 train_time:110837ms step_avg:87.90ms +step:1262/1680 train_time:110927ms step_avg:87.90ms +step:1263/1680 train_time:111017ms step_avg:87.90ms +step:1264/1680 train_time:111108ms step_avg:87.90ms +step:1265/1680 train_time:111198ms step_avg:87.90ms +step:1266/1680 train_time:111287ms step_avg:87.90ms +step:1267/1680 train_time:111377ms step_avg:87.91ms +step:1268/1680 train_time:111464ms step_avg:87.91ms +step:1269/1680 train_time:111553ms step_avg:87.91ms +step:1270/1680 train_time:111641ms step_avg:87.91ms +step:1271/1680 train_time:111729ms step_avg:87.91ms +step:1272/1680 train_time:111817ms step_avg:87.91ms +step:1273/1680 train_time:111906ms step_avg:87.91ms +step:1274/1680 train_time:111996ms step_avg:87.91ms +step:1275/1680 train_time:112086ms step_avg:87.91ms +step:1276/1680 train_time:112176ms step_avg:87.91ms +step:1277/1680 train_time:112265ms step_avg:87.91ms +step:1278/1680 train_time:112354ms step_avg:87.91ms +step:1279/1680 train_time:112442ms step_avg:87.91ms +step:1280/1680 train_time:112530ms step_avg:87.91ms +step:1281/1680 train_time:112619ms step_avg:87.91ms +step:1282/1680 train_time:112707ms step_avg:87.92ms +step:1283/1680 train_time:112795ms step_avg:87.92ms +step:1284/1680 train_time:112885ms step_avg:87.92ms +step:1285/1680 train_time:112974ms step_avg:87.92ms +step:1286/1680 train_time:113063ms step_avg:87.92ms +step:1287/1680 train_time:113154ms step_avg:87.92ms +step:1288/1680 train_time:113243ms step_avg:87.92ms +step:1289/1680 train_time:113333ms step_avg:87.92ms +step:1290/1680 train_time:113422ms step_avg:87.92ms +step:1291/1680 train_time:113511ms step_avg:87.92ms +step:1292/1680 train_time:113599ms step_avg:87.93ms +step:1293/1680 train_time:113688ms step_avg:87.93ms +step:1294/1680 train_time:113776ms step_avg:87.93ms +step:1295/1680 train_time:113866ms step_avg:87.93ms +step:1296/1680 train_time:113954ms step_avg:87.93ms +step:1297/1680 train_time:114043ms step_avg:87.93ms +step:1298/1680 train_time:114132ms step_avg:87.93ms +step:1299/1680 train_time:114221ms step_avg:87.93ms +step:1300/1680 train_time:114309ms step_avg:87.93ms +step:1301/1680 train_time:114399ms step_avg:87.93ms +step:1302/1680 train_time:114487ms step_avg:87.93ms +step:1303/1680 train_time:114575ms step_avg:87.93ms +step:1304/1680 train_time:114664ms step_avg:87.93ms +step:1305/1680 train_time:114753ms step_avg:87.93ms +step:1306/1680 train_time:114842ms step_avg:87.93ms +step:1307/1680 train_time:114932ms step_avg:87.94ms +step:1308/1680 train_time:115020ms step_avg:87.94ms +step:1309/1680 train_time:115110ms step_avg:87.94ms +step:1310/1680 train_time:115200ms step_avg:87.94ms +step:1311/1680 train_time:115289ms step_avg:87.94ms +step:1312/1680 train_time:115378ms step_avg:87.94ms +step:1313/1680 train_time:115467ms step_avg:87.94ms +step:1314/1680 train_time:115557ms step_avg:87.94ms +step:1315/1680 train_time:115645ms step_avg:87.94ms +step:1316/1680 train_time:115734ms step_avg:87.94ms +step:1317/1680 train_time:115822ms step_avg:87.94ms +step:1318/1680 train_time:115911ms step_avg:87.94ms +step:1319/1680 train_time:116000ms step_avg:87.95ms +step:1320/1680 train_time:116089ms step_avg:87.95ms +step:1321/1680 train_time:116179ms step_avg:87.95ms +step:1322/1680 train_time:116269ms step_avg:87.95ms +step:1323/1680 train_time:116359ms step_avg:87.95ms +step:1324/1680 train_time:116448ms step_avg:87.95ms +step:1325/1680 train_time:116537ms step_avg:87.95ms +step:1326/1680 train_time:116626ms step_avg:87.95ms +step:1327/1680 train_time:116715ms step_avg:87.95ms +step:1328/1680 train_time:116804ms step_avg:87.95ms +step:1329/1680 train_time:116893ms step_avg:87.96ms +step:1330/1680 train_time:116982ms step_avg:87.96ms +step:1331/1680 train_time:117072ms step_avg:87.96ms +step:1332/1680 train_time:117162ms step_avg:87.96ms +step:1333/1680 train_time:117252ms step_avg:87.96ms +step:1334/1680 train_time:117341ms step_avg:87.96ms +step:1335/1680 train_time:117430ms step_avg:87.96ms +step:1336/1680 train_time:117520ms step_avg:87.96ms +step:1337/1680 train_time:117609ms step_avg:87.96ms +step:1338/1680 train_time:117698ms step_avg:87.97ms +step:1339/1680 train_time:117787ms step_avg:87.97ms +step:1340/1680 train_time:117876ms step_avg:87.97ms +step:1341/1680 train_time:117965ms step_avg:87.97ms +step:1342/1680 train_time:118054ms step_avg:87.97ms +step:1343/1680 train_time:118143ms step_avg:87.97ms +step:1344/1680 train_time:118232ms step_avg:87.97ms +step:1345/1680 train_time:118322ms step_avg:87.97ms +step:1346/1680 train_time:118411ms step_avg:87.97ms +step:1347/1680 train_time:118501ms step_avg:87.97ms +step:1348/1680 train_time:118590ms step_avg:87.97ms +step:1349/1680 train_time:118679ms step_avg:87.98ms +step:1350/1680 train_time:118768ms step_avg:87.98ms +step:1351/1680 train_time:118858ms step_avg:87.98ms +step:1352/1680 train_time:118948ms step_avg:87.98ms +step:1353/1680 train_time:119037ms step_avg:87.98ms +step:1354/1680 train_time:119125ms step_avg:87.98ms +step:1355/1680 train_time:119214ms step_avg:87.98ms +step:1356/1680 train_time:119303ms step_avg:87.98ms +step:1357/1680 train_time:119394ms step_avg:87.98ms +step:1358/1680 train_time:119483ms step_avg:87.98ms +step:1359/1680 train_time:119572ms step_avg:87.99ms +step:1360/1680 train_time:119660ms step_avg:87.99ms +step:1361/1680 train_time:119750ms step_avg:87.99ms +step:1362/1680 train_time:119838ms step_avg:87.99ms +step:1363/1680 train_time:119928ms step_avg:87.99ms +step:1364/1680 train_time:120016ms step_avg:87.99ms +step:1365/1680 train_time:120105ms step_avg:87.99ms +step:1366/1680 train_time:120194ms step_avg:87.99ms +step:1367/1680 train_time:120284ms step_avg:87.99ms +step:1368/1680 train_time:120373ms step_avg:87.99ms +step:1369/1680 train_time:120463ms step_avg:87.99ms +step:1370/1680 train_time:120552ms step_avg:87.99ms +step:1371/1680 train_time:120641ms step_avg:87.99ms +step:1372/1680 train_time:120730ms step_avg:88.00ms +step:1373/1680 train_time:120819ms step_avg:88.00ms +step:1374/1680 train_time:120907ms step_avg:88.00ms +step:1375/1680 train_time:120997ms step_avg:88.00ms +step:1375/1680 val_loss:3.3431 train_time:121087ms step_avg:88.06ms +step:1376/1680 train_time:121105ms step_avg:88.01ms +step:1377/1680 train_time:121177ms step_avg:88.00ms +step:1378/1680 train_time:121269ms step_avg:88.00ms +step:1379/1680 train_time:121359ms step_avg:88.00ms +step:1380/1680 train_time:121447ms step_avg:88.01ms +step:1381/1680 train_time:121535ms step_avg:88.00ms +step:1382/1680 train_time:121622ms step_avg:88.00ms +step:1383/1680 train_time:121710ms step_avg:88.00ms +step:1384/1680 train_time:121797ms step_avg:88.00ms +step:1385/1680 train_time:121886ms step_avg:88.00ms +step:1386/1680 train_time:121974ms step_avg:88.00ms +step:1387/1680 train_time:122065ms step_avg:88.01ms +step:1388/1680 train_time:122154ms step_avg:88.01ms +step:1389/1680 train_time:122246ms step_avg:88.01ms +step:1390/1680 train_time:122336ms step_avg:88.01ms +step:1391/1680 train_time:122425ms step_avg:88.01ms +step:1392/1680 train_time:122514ms step_avg:88.01ms +step:1393/1680 train_time:122602ms step_avg:88.01ms +step:1394/1680 train_time:122690ms step_avg:88.01ms +step:1395/1680 train_time:122778ms step_avg:88.01ms +step:1396/1680 train_time:122866ms step_avg:88.01ms +step:1397/1680 train_time:122955ms step_avg:88.01ms +step:1398/1680 train_time:123045ms step_avg:88.01ms +step:1399/1680 train_time:123134ms step_avg:88.02ms +step:1400/1680 train_time:123224ms step_avg:88.02ms +step:1401/1680 train_time:123314ms step_avg:88.02ms +step:1402/1680 train_time:123404ms step_avg:88.02ms +step:1403/1680 train_time:123493ms step_avg:88.02ms +step:1404/1680 train_time:123583ms step_avg:88.02ms +step:1405/1680 train_time:123671ms step_avg:88.02ms +step:1406/1680 train_time:123760ms step_avg:88.02ms +step:1407/1680 train_time:123848ms step_avg:88.02ms +step:1408/1680 train_time:123936ms step_avg:88.02ms +step:1409/1680 train_time:124025ms step_avg:88.02ms +step:1410/1680 train_time:124116ms step_avg:88.03ms +step:1411/1680 train_time:124206ms step_avg:88.03ms +step:1412/1680 train_time:124298ms step_avg:88.03ms +step:1413/1680 train_time:124388ms step_avg:88.03ms +step:1414/1680 train_time:124477ms step_avg:88.03ms +step:1415/1680 train_time:124565ms step_avg:88.03ms +step:1416/1680 train_time:124654ms step_avg:88.03ms +step:1417/1680 train_time:124742ms step_avg:88.03ms +step:1418/1680 train_time:124830ms step_avg:88.03ms +step:1419/1680 train_time:124919ms step_avg:88.03ms +step:1420/1680 train_time:125008ms step_avg:88.03ms +step:1421/1680 train_time:125098ms step_avg:88.04ms +step:1422/1680 train_time:125187ms step_avg:88.04ms +step:1423/1680 train_time:125276ms step_avg:88.04ms +step:1424/1680 train_time:125366ms step_avg:88.04ms +step:1425/1680 train_time:125455ms step_avg:88.04ms +step:1426/1680 train_time:125544ms step_avg:88.04ms +step:1427/1680 train_time:125632ms step_avg:88.04ms +step:1428/1680 train_time:125721ms step_avg:88.04ms +step:1429/1680 train_time:125809ms step_avg:88.04ms +step:1430/1680 train_time:125898ms step_avg:88.04ms +step:1431/1680 train_time:125987ms step_avg:88.04ms +step:1432/1680 train_time:126077ms step_avg:88.04ms +step:1433/1680 train_time:126167ms step_avg:88.04ms +step:1434/1680 train_time:126257ms step_avg:88.05ms +step:1435/1680 train_time:126347ms step_avg:88.05ms +step:1436/1680 train_time:126436ms step_avg:88.05ms +step:1437/1680 train_time:126526ms step_avg:88.05ms +step:1438/1680 train_time:126614ms step_avg:88.05ms +step:1439/1680 train_time:126702ms step_avg:88.05ms +step:1440/1680 train_time:126791ms step_avg:88.05ms +step:1441/1680 train_time:126879ms step_avg:88.05ms +step:1442/1680 train_time:126968ms step_avg:88.05ms +step:1443/1680 train_time:127057ms step_avg:88.05ms +step:1444/1680 train_time:127146ms step_avg:88.05ms +step:1445/1680 train_time:127234ms step_avg:88.05ms +step:1446/1680 train_time:127324ms step_avg:88.05ms +step:1447/1680 train_time:127414ms step_avg:88.05ms +step:1448/1680 train_time:127504ms step_avg:88.06ms +step:1449/1680 train_time:127594ms step_avg:88.06ms +step:1450/1680 train_time:127684ms step_avg:88.06ms +step:1451/1680 train_time:127772ms step_avg:88.06ms +step:1452/1680 train_time:127861ms step_avg:88.06ms +step:1453/1680 train_time:127950ms step_avg:88.06ms +step:1454/1680 train_time:128038ms step_avg:88.06ms +step:1455/1680 train_time:128127ms step_avg:88.06ms +step:1456/1680 train_time:128216ms step_avg:88.06ms +step:1457/1680 train_time:128306ms step_avg:88.06ms +step:1458/1680 train_time:128395ms step_avg:88.06ms +step:1459/1680 train_time:128485ms step_avg:88.06ms +step:1460/1680 train_time:128575ms step_avg:88.07ms +step:1461/1680 train_time:128664ms step_avg:88.07ms +step:1462/1680 train_time:128754ms step_avg:88.07ms +step:1463/1680 train_time:128844ms step_avg:88.07ms +step:1464/1680 train_time:128933ms step_avg:88.07ms +step:1465/1680 train_time:129022ms step_avg:88.07ms +step:1466/1680 train_time:129111ms step_avg:88.07ms +step:1467/1680 train_time:129200ms step_avg:88.07ms +step:1468/1680 train_time:129289ms step_avg:88.07ms +step:1469/1680 train_time:129378ms step_avg:88.07ms +step:1470/1680 train_time:129467ms step_avg:88.07ms +step:1471/1680 train_time:129556ms step_avg:88.07ms +step:1472/1680 train_time:129646ms step_avg:88.07ms +step:1473/1680 train_time:129736ms step_avg:88.08ms +step:1474/1680 train_time:129825ms step_avg:88.08ms +step:1475/1680 train_time:129915ms step_avg:88.08ms +step:1476/1680 train_time:130004ms step_avg:88.08ms +step:1477/1680 train_time:130093ms step_avg:88.08ms +step:1478/1680 train_time:130183ms step_avg:88.08ms +step:1479/1680 train_time:130272ms step_avg:88.08ms +step:1480/1680 train_time:130361ms step_avg:88.08ms +step:1481/1680 train_time:130449ms step_avg:88.08ms +step:1482/1680 train_time:130538ms step_avg:88.08ms +step:1483/1680 train_time:130628ms step_avg:88.08ms +step:1484/1680 train_time:130717ms step_avg:88.08ms +step:1485/1680 train_time:130806ms step_avg:88.09ms +step:1486/1680 train_time:130896ms step_avg:88.09ms +step:1487/1680 train_time:130985ms step_avg:88.09ms +step:1488/1680 train_time:131073ms step_avg:88.09ms +step:1489/1680 train_time:131162ms step_avg:88.09ms +step:1490/1680 train_time:131251ms step_avg:88.09ms +step:1491/1680 train_time:131341ms step_avg:88.09ms +step:1492/1680 train_time:131429ms step_avg:88.09ms +step:1493/1680 train_time:131518ms step_avg:88.09ms +step:1494/1680 train_time:131607ms step_avg:88.09ms +step:1495/1680 train_time:131696ms step_avg:88.09ms +step:1496/1680 train_time:131785ms step_avg:88.09ms +step:1497/1680 train_time:131874ms step_avg:88.09ms +step:1498/1680 train_time:131963ms step_avg:88.09ms +step:1499/1680 train_time:132051ms step_avg:88.09ms +step:1500/1680 train_time:132140ms step_avg:88.09ms +step:1500/1680 val_loss:3.3135 train_time:132230ms step_avg:88.15ms +step:1501/1680 train_time:132249ms step_avg:88.11ms +step:1502/1680 train_time:132322ms step_avg:88.10ms +step:1503/1680 train_time:132417ms step_avg:88.10ms +step:1504/1680 train_time:132508ms step_avg:88.10ms +step:1505/1680 train_time:132596ms step_avg:88.10ms +step:1506/1680 train_time:132684ms step_avg:88.10ms +step:1507/1680 train_time:132772ms step_avg:88.10ms +step:1508/1680 train_time:132861ms step_avg:88.10ms +step:1509/1680 train_time:132949ms step_avg:88.10ms +step:1510/1680 train_time:133037ms step_avg:88.10ms +step:1511/1680 train_time:133125ms step_avg:88.10ms +step:1512/1680 train_time:133215ms step_avg:88.10ms +step:1513/1680 train_time:133305ms step_avg:88.11ms +step:1514/1680 train_time:133397ms step_avg:88.11ms +step:1515/1680 train_time:133486ms step_avg:88.11ms +step:1516/1680 train_time:133575ms step_avg:88.11ms +step:1517/1680 train_time:133665ms step_avg:88.11ms +step:1518/1680 train_time:133752ms step_avg:88.11ms +step:1519/1680 train_time:133841ms step_avg:88.11ms +step:1520/1680 train_time:133931ms step_avg:88.11ms +step:1521/1680 train_time:134019ms step_avg:88.11ms +step:1522/1680 train_time:134107ms step_avg:88.11ms +step:1523/1680 train_time:134195ms step_avg:88.11ms +step:1524/1680 train_time:134285ms step_avg:88.11ms +step:1525/1680 train_time:134375ms step_avg:88.11ms +step:1526/1680 train_time:134466ms step_avg:88.12ms +step:1527/1680 train_time:134555ms step_avg:88.12ms +step:1528/1680 train_time:134644ms step_avg:88.12ms +step:1529/1680 train_time:134734ms step_avg:88.12ms +step:1530/1680 train_time:134822ms step_avg:88.12ms +step:1531/1680 train_time:134910ms step_avg:88.12ms +step:1532/1680 train_time:134998ms step_avg:88.12ms +step:1533/1680 train_time:135087ms step_avg:88.12ms +step:1534/1680 train_time:135175ms step_avg:88.12ms +step:1535/1680 train_time:135264ms step_avg:88.12ms +step:1536/1680 train_time:135355ms step_avg:88.12ms +step:1537/1680 train_time:135445ms step_avg:88.12ms +step:1538/1680 train_time:135535ms step_avg:88.12ms +step:1539/1680 train_time:135624ms step_avg:88.12ms +step:1540/1680 train_time:135713ms step_avg:88.13ms +step:1541/1680 train_time:135802ms step_avg:88.13ms +step:1542/1680 train_time:135890ms step_avg:88.13ms +step:1543/1680 train_time:135979ms step_avg:88.13ms +step:1544/1680 train_time:136068ms step_avg:88.13ms +step:1545/1680 train_time:136157ms step_avg:88.13ms +step:1546/1680 train_time:136246ms step_avg:88.13ms +step:1547/1680 train_time:136336ms step_avg:88.13ms +step:1548/1680 train_time:136426ms step_avg:88.13ms +step:1549/1680 train_time:136515ms step_avg:88.13ms +step:1550/1680 train_time:136606ms step_avg:88.13ms +step:1551/1680 train_time:136695ms step_avg:88.13ms +step:1552/1680 train_time:136783ms step_avg:88.13ms +step:1553/1680 train_time:136871ms step_avg:88.13ms +step:1554/1680 train_time:136960ms step_avg:88.13ms +step:1555/1680 train_time:137049ms step_avg:88.13ms +step:1556/1680 train_time:137139ms step_avg:88.14ms +step:1557/1680 train_time:137229ms step_avg:88.14ms +step:1558/1680 train_time:137318ms step_avg:88.14ms +step:1559/1680 train_time:137409ms step_avg:88.14ms +step:1560/1680 train_time:137499ms step_avg:88.14ms +step:1561/1680 train_time:137589ms step_avg:88.14ms +step:1562/1680 train_time:137679ms step_avg:88.14ms +step:1563/1680 train_time:137768ms step_avg:88.14ms +step:1564/1680 train_time:137856ms step_avg:88.14ms +step:1565/1680 train_time:137945ms step_avg:88.14ms +step:1566/1680 train_time:138034ms step_avg:88.14ms +step:1567/1680 train_time:138124ms step_avg:88.15ms +step:1568/1680 train_time:138212ms step_avg:88.15ms +step:1569/1680 train_time:138301ms step_avg:88.15ms +step:1570/1680 train_time:138390ms step_avg:88.15ms +step:1571/1680 train_time:138479ms step_avg:88.15ms +step:1572/1680 train_time:138568ms step_avg:88.15ms +step:1573/1680 train_time:138658ms step_avg:88.15ms +step:1574/1680 train_time:138747ms step_avg:88.15ms +step:1575/1680 train_time:138835ms step_avg:88.15ms +step:1576/1680 train_time:138924ms step_avg:88.15ms +step:1577/1680 train_time:139012ms step_avg:88.15ms +step:1578/1680 train_time:139101ms step_avg:88.15ms +step:1579/1680 train_time:139190ms step_avg:88.15ms +step:1580/1680 train_time:139279ms step_avg:88.15ms +step:1581/1680 train_time:139368ms step_avg:88.15ms +step:1582/1680 train_time:139458ms step_avg:88.15ms +step:1583/1680 train_time:139548ms step_avg:88.15ms +step:1584/1680 train_time:139637ms step_avg:88.15ms +step:1585/1680 train_time:139727ms step_avg:88.16ms +step:1586/1680 train_time:139815ms step_avg:88.16ms +step:1587/1680 train_time:139904ms step_avg:88.16ms +step:1588/1680 train_time:139993ms step_avg:88.16ms +step:1589/1680 train_time:140083ms step_avg:88.16ms +step:1590/1680 train_time:140171ms step_avg:88.16ms +step:1591/1680 train_time:140260ms step_avg:88.16ms +step:1592/1680 train_time:140349ms step_avg:88.16ms +step:1593/1680 train_time:140438ms step_avg:88.16ms +step:1594/1680 train_time:140528ms step_avg:88.16ms +step:1595/1680 train_time:140619ms step_avg:88.16ms +step:1596/1680 train_time:140708ms step_avg:88.16ms +step:1597/1680 train_time:140797ms step_avg:88.16ms +step:1598/1680 train_time:140886ms step_avg:88.16ms +step:1599/1680 train_time:140975ms step_avg:88.16ms +step:1600/1680 train_time:141064ms step_avg:88.17ms +step:1601/1680 train_time:141154ms step_avg:88.17ms +step:1602/1680 train_time:141242ms step_avg:88.17ms +step:1603/1680 train_time:141331ms step_avg:88.17ms +step:1604/1680 train_time:141420ms step_avg:88.17ms +step:1605/1680 train_time:141509ms step_avg:88.17ms +step:1606/1680 train_time:141599ms step_avg:88.17ms +step:1607/1680 train_time:141689ms step_avg:88.17ms +step:1608/1680 train_time:141777ms step_avg:88.17ms +step:1609/1680 train_time:141867ms step_avg:88.17ms +step:1610/1680 train_time:141956ms step_avg:88.17ms +step:1611/1680 train_time:142046ms step_avg:88.17ms +step:1612/1680 train_time:142135ms step_avg:88.17ms +step:1613/1680 train_time:142225ms step_avg:88.17ms +step:1614/1680 train_time:142314ms step_avg:88.17ms +step:1615/1680 train_time:142403ms step_avg:88.18ms +step:1616/1680 train_time:142493ms step_avg:88.18ms +step:1617/1680 train_time:142582ms step_avg:88.18ms +step:1618/1680 train_time:142672ms step_avg:88.18ms +step:1619/1680 train_time:142762ms step_avg:88.18ms +step:1620/1680 train_time:142851ms step_avg:88.18ms +step:1621/1680 train_time:142940ms step_avg:88.18ms +step:1622/1680 train_time:143030ms step_avg:88.18ms +step:1623/1680 train_time:143119ms step_avg:88.18ms +step:1624/1680 train_time:143208ms step_avg:88.18ms +step:1625/1680 train_time:143296ms step_avg:88.18ms +step:1625/1680 val_loss:3.2898 train_time:143387ms step_avg:88.24ms +step:1626/1680 train_time:143406ms step_avg:88.20ms +step:1627/1680 train_time:143479ms step_avg:88.19ms +step:1628/1680 train_time:143573ms step_avg:88.19ms +step:1629/1680 train_time:143662ms step_avg:88.19ms +step:1630/1680 train_time:143751ms step_avg:88.19ms +step:1631/1680 train_time:143840ms step_avg:88.19ms +step:1632/1680 train_time:143929ms step_avg:88.19ms +step:1633/1680 train_time:144017ms step_avg:88.19ms +step:1634/1680 train_time:144104ms step_avg:88.19ms +step:1635/1680 train_time:144192ms step_avg:88.19ms +step:1636/1680 train_time:144281ms step_avg:88.19ms +step:1637/1680 train_time:144371ms step_avg:88.19ms +step:1638/1680 train_time:144462ms step_avg:88.19ms +step:1639/1680 train_time:144553ms step_avg:88.20ms +step:1640/1680 train_time:144644ms step_avg:88.20ms +step:1641/1680 train_time:144733ms step_avg:88.20ms +step:1642/1680 train_time:144822ms step_avg:88.20ms +step:1643/1680 train_time:144911ms step_avg:88.20ms +step:1644/1680 train_time:144999ms step_avg:88.20ms +step:1645/1680 train_time:145088ms step_avg:88.20ms +step:1646/1680 train_time:145177ms step_avg:88.20ms +step:1647/1680 train_time:145266ms step_avg:88.20ms +step:1648/1680 train_time:145355ms step_avg:88.20ms +step:1649/1680 train_time:145444ms step_avg:88.20ms +step:1650/1680 train_time:145534ms step_avg:88.20ms +step:1651/1680 train_time:145623ms step_avg:88.20ms +step:1652/1680 train_time:145712ms step_avg:88.20ms +step:1653/1680 train_time:145801ms step_avg:88.20ms +step:1654/1680 train_time:145890ms step_avg:88.20ms +step:1655/1680 train_time:145980ms step_avg:88.21ms +step:1656/1680 train_time:146069ms step_avg:88.21ms +step:1657/1680 train_time:146157ms step_avg:88.21ms +step:1658/1680 train_time:146245ms step_avg:88.21ms +step:1659/1680 train_time:146334ms step_avg:88.21ms +step:1660/1680 train_time:146423ms step_avg:88.21ms +step:1661/1680 train_time:146513ms step_avg:88.21ms +step:1662/1680 train_time:146603ms step_avg:88.21ms +step:1663/1680 train_time:146692ms step_avg:88.21ms +step:1664/1680 train_time:146781ms step_avg:88.21ms +step:1665/1680 train_time:146870ms step_avg:88.21ms +step:1666/1680 train_time:146959ms step_avg:88.21ms +step:1667/1680 train_time:147048ms step_avg:88.21ms +step:1668/1680 train_time:147136ms step_avg:88.21ms +step:1669/1680 train_time:147224ms step_avg:88.21ms +step:1670/1680 train_time:147314ms step_avg:88.21ms +step:1671/1680 train_time:147403ms step_avg:88.21ms +step:1672/1680 train_time:147492ms step_avg:88.21ms +step:1673/1680 train_time:147582ms step_avg:88.21ms +step:1674/1680 train_time:147672ms step_avg:88.21ms +step:1675/1680 train_time:147761ms step_avg:88.22ms +step:1676/1680 train_time:147850ms step_avg:88.22ms +step:1677/1680 train_time:147939ms step_avg:88.22ms +step:1678/1680 train_time:148027ms step_avg:88.22ms +step:1679/1680 train_time:148116ms step_avg:88.22ms +step:1680/1680 train_time:148205ms step_avg:88.22ms +step:1680/1680 val_loss:3.2789 train_time:148296ms step_avg:88.27ms +peak memory allocated: 30760 MiB reserved: 46014 MiB diff --git a/records/092725_BF16CE/d89c0dc1-c0ce-4346-a405-af9e88ed79bc.txt b/records/092725_BF16CE/d89c0dc1-c0ce-4346-a405-af9e88ed79bc.txt new file mode 100644 index 000000000..b372cd8ba --- /dev/null +++ b/records/092725_BF16CE/d89c0dc1-c0ce-4346-a405-af9e88ed79bc.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:42:08 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 123W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 24C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 30C P0 123W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 28C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 161605 C /usr/bin/python 0MiB | +| 0 N/A N/A 161606 C /usr/bin/python 0MiB | +| 0 N/A N/A 161607 C /usr/bin/python 0MiB | +| 0 N/A N/A 161608 C /usr/bin/python 0MiB | +| 0 N/A N/A 161609 C /usr/bin/python 0MiB | +| 0 N/A N/A 161610 C /usr/bin/python 0MiB | +| 0 N/A N/A 161611 C /usr/bin/python 0MiB | +| 0 N/A N/A 161612 C /usr/bin/python 0MiB | +| 1 N/A N/A 161606 C /usr/bin/python 0MiB | +| 2 N/A N/A 161607 C /usr/bin/python 0MiB | +| 3 N/A N/A 161608 C /usr/bin/python 0MiB | +| 4 N/A N/A 161609 C /usr/bin/python 0MiB | +| 5 N/A N/A 161610 C /usr/bin/python 0MiB | +| 6 N/A N/A 161611 C /usr/bin/python 0MiB | +| 7 N/A N/A 161612 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:152ms step_avg:151.82ms +step:2/1680 train_time:171ms step_avg:85.56ms +step:3/1680 train_time:235ms step_avg:78.39ms +step:4/1680 train_time:320ms step_avg:80.03ms +step:5/1680 train_time:406ms step_avg:81.14ms +step:6/1680 train_time:492ms step_avg:81.95ms +step:7/1680 train_time:578ms step_avg:82.58ms +step:8/1680 train_time:664ms step_avg:83.05ms +step:9/1680 train_time:751ms step_avg:83.41ms +step:10/1680 train_time:837ms step_avg:83.72ms +step:11/1680 train_time:924ms step_avg:83.96ms +step:12/1680 train_time:1013ms step_avg:84.39ms +step:13/1680 train_time:1106ms step_avg:85.07ms +step:14/1680 train_time:1195ms step_avg:85.34ms +step:15/1680 train_time:1282ms step_avg:85.47ms +step:16/1680 train_time:1369ms step_avg:85.56ms +step:17/1680 train_time:1456ms step_avg:85.66ms +step:18/1680 train_time:1543ms step_avg:85.70ms +step:19/1680 train_time:1630ms step_avg:85.76ms +step:20/1680 train_time:1717ms step_avg:85.83ms +step:21/1680 train_time:1804ms step_avg:85.88ms +step:22/1680 train_time:1890ms step_avg:85.91ms +step:23/1680 train_time:1978ms step_avg:86.01ms +step:24/1680 train_time:2067ms step_avg:86.14ms +step:25/1680 train_time:2156ms step_avg:86.24ms +step:26/1680 train_time:2245ms step_avg:86.34ms +step:27/1680 train_time:2333ms step_avg:86.40ms +step:28/1680 train_time:2420ms step_avg:86.43ms +step:29/1680 train_time:2507ms step_avg:86.46ms +step:30/1680 train_time:2594ms step_avg:86.48ms +step:31/1680 train_time:2681ms step_avg:86.47ms +step:32/1680 train_time:2767ms step_avg:86.48ms +step:33/1680 train_time:2854ms step_avg:86.48ms +step:34/1680 train_time:2941ms step_avg:86.50ms +step:35/1680 train_time:3029ms step_avg:86.54ms +step:36/1680 train_time:3117ms step_avg:86.59ms +step:37/1680 train_time:3207ms step_avg:86.66ms +step:38/1680 train_time:3294ms step_avg:86.70ms +step:39/1680 train_time:3382ms step_avg:86.73ms +step:40/1680 train_time:3470ms step_avg:86.75ms +step:41/1680 train_time:3557ms step_avg:86.76ms +step:42/1680 train_time:3645ms step_avg:86.79ms +step:43/1680 train_time:3732ms step_avg:86.79ms +step:44/1680 train_time:3820ms step_avg:86.82ms +step:45/1680 train_time:3907ms step_avg:86.83ms +step:46/1680 train_time:3995ms step_avg:86.85ms +step:47/1680 train_time:4083ms step_avg:86.86ms +step:48/1680 train_time:4170ms step_avg:86.88ms +step:49/1680 train_time:4258ms step_avg:86.91ms +step:50/1680 train_time:4346ms step_avg:86.92ms +step:51/1680 train_time:4434ms step_avg:86.94ms +step:52/1680 train_time:4521ms step_avg:86.94ms +step:53/1680 train_time:4608ms step_avg:86.95ms +step:54/1680 train_time:4696ms step_avg:86.96ms +step:55/1680 train_time:4783ms step_avg:86.96ms +step:56/1680 train_time:4870ms step_avg:86.96ms +step:57/1680 train_time:4957ms step_avg:86.97ms +step:58/1680 train_time:5045ms step_avg:86.99ms +step:59/1680 train_time:5133ms step_avg:87.01ms +step:60/1680 train_time:5221ms step_avg:87.02ms +step:61/1680 train_time:5308ms step_avg:87.02ms +step:62/1680 train_time:5396ms step_avg:87.03ms +step:63/1680 train_time:5483ms step_avg:87.03ms +step:64/1680 train_time:5570ms step_avg:87.03ms +step:65/1680 train_time:5658ms step_avg:87.04ms +step:66/1680 train_time:5745ms step_avg:87.05ms +step:67/1680 train_time:5832ms step_avg:87.04ms +step:68/1680 train_time:5919ms step_avg:87.05ms +step:69/1680 train_time:6007ms step_avg:87.06ms +step:70/1680 train_time:6094ms step_avg:87.05ms +step:71/1680 train_time:6181ms step_avg:87.06ms +step:72/1680 train_time:6269ms step_avg:87.06ms +step:73/1680 train_time:6356ms step_avg:87.07ms +step:74/1680 train_time:6444ms step_avg:87.08ms +step:75/1680 train_time:6532ms step_avg:87.09ms +step:76/1680 train_time:6619ms step_avg:87.10ms +step:77/1680 train_time:6706ms step_avg:87.09ms +step:78/1680 train_time:6793ms step_avg:87.09ms +step:79/1680 train_time:6880ms step_avg:87.09ms +step:80/1680 train_time:6968ms step_avg:87.09ms +step:81/1680 train_time:7054ms step_avg:87.09ms +step:82/1680 train_time:7142ms step_avg:87.09ms +step:83/1680 train_time:7229ms step_avg:87.09ms +step:84/1680 train_time:7317ms step_avg:87.10ms +step:85/1680 train_time:7405ms step_avg:87.11ms +step:86/1680 train_time:7492ms step_avg:87.12ms +step:87/1680 train_time:7579ms step_avg:87.12ms +step:88/1680 train_time:7667ms step_avg:87.12ms +step:89/1680 train_time:7754ms step_avg:87.13ms +step:90/1680 train_time:7841ms step_avg:87.12ms +step:91/1680 train_time:7928ms step_avg:87.12ms +step:92/1680 train_time:8016ms step_avg:87.13ms +step:93/1680 train_time:8104ms step_avg:87.14ms +step:94/1680 train_time:8191ms step_avg:87.14ms +step:95/1680 train_time:8279ms step_avg:87.15ms +step:96/1680 train_time:8367ms step_avg:87.15ms +step:97/1680 train_time:8453ms step_avg:87.15ms +step:98/1680 train_time:8541ms step_avg:87.15ms +step:99/1680 train_time:8628ms step_avg:87.15ms +step:100/1680 train_time:8715ms step_avg:87.15ms +step:101/1680 train_time:8801ms step_avg:87.14ms +step:102/1680 train_time:8888ms step_avg:87.14ms +step:103/1680 train_time:8976ms step_avg:87.14ms +step:104/1680 train_time:9063ms step_avg:87.15ms +step:105/1680 train_time:9151ms step_avg:87.15ms +step:106/1680 train_time:9239ms step_avg:87.16ms +step:107/1680 train_time:9327ms step_avg:87.17ms +step:108/1680 train_time:9414ms step_avg:87.17ms +step:109/1680 train_time:9502ms step_avg:87.17ms +step:110/1680 train_time:9589ms step_avg:87.17ms +step:111/1680 train_time:9677ms step_avg:87.18ms +step:112/1680 train_time:9765ms step_avg:87.19ms +step:113/1680 train_time:9851ms step_avg:87.18ms +step:114/1680 train_time:9939ms step_avg:87.18ms +step:115/1680 train_time:10026ms step_avg:87.19ms +step:116/1680 train_time:10114ms step_avg:87.19ms +step:117/1680 train_time:10201ms step_avg:87.19ms +step:118/1680 train_time:10288ms step_avg:87.19ms +step:119/1680 train_time:10376ms step_avg:87.19ms +step:120/1680 train_time:10463ms step_avg:87.19ms +step:121/1680 train_time:10550ms step_avg:87.19ms +step:122/1680 train_time:10637ms step_avg:87.19ms +step:123/1680 train_time:10724ms step_avg:87.19ms +step:124/1680 train_time:10811ms step_avg:87.19ms +step:125/1680 train_time:10899ms step_avg:87.19ms +step:125/1680 val_loss:4.3047 train_time:10987ms step_avg:87.90ms +step:126/1680 train_time:11006ms step_avg:87.35ms +step:127/1680 train_time:11076ms step_avg:87.22ms +step:128/1680 train_time:11171ms step_avg:87.27ms +step:129/1680 train_time:11263ms step_avg:87.31ms +step:130/1680 train_time:11349ms step_avg:87.30ms +step:131/1680 train_time:11436ms step_avg:87.30ms +step:132/1680 train_time:11522ms step_avg:87.29ms +step:133/1680 train_time:11608ms step_avg:87.28ms +step:134/1680 train_time:11693ms step_avg:87.26ms +step:135/1680 train_time:11780ms step_avg:87.26ms +step:136/1680 train_time:11866ms step_avg:87.25ms +step:137/1680 train_time:11953ms step_avg:87.25ms +step:138/1680 train_time:12041ms step_avg:87.25ms +step:139/1680 train_time:12130ms step_avg:87.27ms +step:140/1680 train_time:12219ms step_avg:87.28ms +step:141/1680 train_time:12308ms step_avg:87.29ms +step:142/1680 train_time:12395ms step_avg:87.29ms +step:143/1680 train_time:12482ms step_avg:87.29ms +step:144/1680 train_time:12569ms step_avg:87.28ms +step:145/1680 train_time:12656ms step_avg:87.28ms +step:146/1680 train_time:12743ms step_avg:87.28ms +step:147/1680 train_time:12829ms step_avg:87.27ms +step:148/1680 train_time:12915ms step_avg:87.26ms +step:149/1680 train_time:13002ms step_avg:87.26ms +step:150/1680 train_time:13090ms step_avg:87.27ms +step:151/1680 train_time:13179ms step_avg:87.28ms +step:152/1680 train_time:13267ms step_avg:87.28ms +step:153/1680 train_time:13355ms step_avg:87.29ms +step:154/1680 train_time:13442ms step_avg:87.29ms +step:155/1680 train_time:13530ms step_avg:87.29ms +step:156/1680 train_time:13617ms step_avg:87.29ms +step:157/1680 train_time:13704ms step_avg:87.29ms +step:158/1680 train_time:13791ms step_avg:87.29ms +step:159/1680 train_time:13877ms step_avg:87.28ms +step:160/1680 train_time:13964ms step_avg:87.28ms +step:161/1680 train_time:14051ms step_avg:87.27ms +step:162/1680 train_time:14139ms step_avg:87.28ms +step:163/1680 train_time:14226ms step_avg:87.28ms +step:164/1680 train_time:14314ms step_avg:87.28ms +step:165/1680 train_time:14401ms step_avg:87.28ms +step:166/1680 train_time:14488ms step_avg:87.28ms +step:167/1680 train_time:14576ms step_avg:87.28ms +step:168/1680 train_time:14664ms step_avg:87.28ms +step:169/1680 train_time:14750ms step_avg:87.28ms +step:170/1680 train_time:14837ms step_avg:87.28ms +step:171/1680 train_time:14923ms step_avg:87.27ms +step:172/1680 train_time:15011ms step_avg:87.27ms +step:173/1680 train_time:15099ms step_avg:87.28ms +step:174/1680 train_time:15186ms step_avg:87.28ms +step:175/1680 train_time:15273ms step_avg:87.28ms +step:176/1680 train_time:15362ms step_avg:87.28ms +step:177/1680 train_time:15449ms step_avg:87.28ms +step:178/1680 train_time:15536ms step_avg:87.28ms +step:179/1680 train_time:15623ms step_avg:87.28ms +step:180/1680 train_time:15710ms step_avg:87.28ms +step:181/1680 train_time:15796ms step_avg:87.27ms +step:182/1680 train_time:15883ms step_avg:87.27ms +step:183/1680 train_time:15970ms step_avg:87.27ms +step:184/1680 train_time:16058ms step_avg:87.27ms +step:185/1680 train_time:16146ms step_avg:87.27ms +step:186/1680 train_time:16233ms step_avg:87.28ms +step:187/1680 train_time:16321ms step_avg:87.28ms +step:188/1680 train_time:16408ms step_avg:87.28ms +step:189/1680 train_time:16495ms step_avg:87.28ms +step:190/1680 train_time:16582ms step_avg:87.28ms +step:191/1680 train_time:16669ms step_avg:87.27ms +step:192/1680 train_time:16756ms step_avg:87.27ms +step:193/1680 train_time:16843ms step_avg:87.27ms +step:194/1680 train_time:16930ms step_avg:87.27ms +step:195/1680 train_time:17017ms step_avg:87.27ms +step:196/1680 train_time:17104ms step_avg:87.27ms +step:197/1680 train_time:17191ms step_avg:87.27ms +step:198/1680 train_time:17279ms step_avg:87.27ms +step:199/1680 train_time:17366ms step_avg:87.27ms +step:200/1680 train_time:17454ms step_avg:87.27ms +step:201/1680 train_time:17541ms step_avg:87.27ms +step:202/1680 train_time:17629ms step_avg:87.27ms +step:203/1680 train_time:17716ms step_avg:87.27ms +step:204/1680 train_time:17803ms step_avg:87.27ms +step:205/1680 train_time:17890ms step_avg:87.27ms +step:206/1680 train_time:17977ms step_avg:87.27ms +step:207/1680 train_time:18064ms step_avg:87.27ms +step:208/1680 train_time:18151ms step_avg:87.27ms +step:209/1680 train_time:18238ms step_avg:87.27ms +step:210/1680 train_time:18326ms step_avg:87.26ms +step:211/1680 train_time:18414ms step_avg:87.27ms +step:212/1680 train_time:18501ms step_avg:87.27ms +step:213/1680 train_time:18589ms step_avg:87.27ms +step:214/1680 train_time:18676ms step_avg:87.27ms +step:215/1680 train_time:18763ms step_avg:87.27ms +step:216/1680 train_time:18850ms step_avg:87.27ms +step:217/1680 train_time:18938ms step_avg:87.27ms +step:218/1680 train_time:19024ms step_avg:87.27ms +step:219/1680 train_time:19112ms step_avg:87.27ms +step:220/1680 train_time:19200ms step_avg:87.27ms +step:221/1680 train_time:19287ms step_avg:87.27ms +step:222/1680 train_time:19374ms step_avg:87.27ms +step:223/1680 train_time:19462ms step_avg:87.27ms +step:224/1680 train_time:19549ms step_avg:87.27ms +step:225/1680 train_time:19636ms step_avg:87.27ms +step:226/1680 train_time:19722ms step_avg:87.27ms +step:227/1680 train_time:19809ms step_avg:87.27ms +step:228/1680 train_time:19897ms step_avg:87.27ms +step:229/1680 train_time:19984ms step_avg:87.27ms +step:230/1680 train_time:20071ms step_avg:87.27ms +step:231/1680 train_time:20159ms step_avg:87.27ms +step:232/1680 train_time:20246ms step_avg:87.27ms +step:233/1680 train_time:20333ms step_avg:87.27ms +step:234/1680 train_time:20421ms step_avg:87.27ms +step:235/1680 train_time:20509ms step_avg:87.27ms +step:236/1680 train_time:20596ms step_avg:87.27ms +step:237/1680 train_time:20682ms step_avg:87.27ms +step:238/1680 train_time:20770ms step_avg:87.27ms +step:239/1680 train_time:20857ms step_avg:87.27ms +step:240/1680 train_time:20944ms step_avg:87.27ms +step:241/1680 train_time:21032ms step_avg:87.27ms +step:242/1680 train_time:21119ms step_avg:87.27ms +step:243/1680 train_time:21206ms step_avg:87.27ms +step:244/1680 train_time:21293ms step_avg:87.27ms +step:245/1680 train_time:21382ms step_avg:87.27ms +step:246/1680 train_time:21469ms step_avg:87.27ms +step:247/1680 train_time:21557ms step_avg:87.27ms +step:248/1680 train_time:21643ms step_avg:87.27ms +step:249/1680 train_time:21731ms step_avg:87.27ms +step:250/1680 train_time:21818ms step_avg:87.27ms +step:250/1680 val_loss:3.9650 train_time:21906ms step_avg:87.62ms +step:251/1680 train_time:21924ms step_avg:87.35ms +step:252/1680 train_time:21995ms step_avg:87.28ms +step:253/1680 train_time:22087ms step_avg:87.30ms +step:254/1680 train_time:22175ms step_avg:87.30ms +step:255/1680 train_time:22262ms step_avg:87.30ms +step:256/1680 train_time:22349ms step_avg:87.30ms +step:257/1680 train_time:22435ms step_avg:87.29ms +step:258/1680 train_time:22520ms step_avg:87.29ms +step:259/1680 train_time:22606ms step_avg:87.28ms +step:260/1680 train_time:22693ms step_avg:87.28ms +step:261/1680 train_time:22778ms step_avg:87.27ms +step:262/1680 train_time:22866ms step_avg:87.27ms +step:263/1680 train_time:22956ms step_avg:87.28ms +step:264/1680 train_time:23045ms step_avg:87.29ms +step:265/1680 train_time:23134ms step_avg:87.30ms +step:266/1680 train_time:23221ms step_avg:87.30ms +step:267/1680 train_time:23309ms step_avg:87.30ms +step:268/1680 train_time:23396ms step_avg:87.30ms +step:269/1680 train_time:23482ms step_avg:87.29ms +step:270/1680 train_time:23569ms step_avg:87.29ms +step:271/1680 train_time:23655ms step_avg:87.29ms +step:272/1680 train_time:23741ms step_avg:87.28ms +step:273/1680 train_time:23829ms step_avg:87.28ms +step:274/1680 train_time:23917ms step_avg:87.29ms +step:275/1680 train_time:24005ms step_avg:87.29ms +step:276/1680 train_time:24095ms step_avg:87.30ms +step:277/1680 train_time:24182ms step_avg:87.30ms +step:278/1680 train_time:24269ms step_avg:87.30ms +step:279/1680 train_time:24357ms step_avg:87.30ms +step:280/1680 train_time:24443ms step_avg:87.30ms +step:281/1680 train_time:24530ms step_avg:87.30ms +step:282/1680 train_time:24616ms step_avg:87.29ms +step:283/1680 train_time:24703ms step_avg:87.29ms +step:284/1680 train_time:24790ms step_avg:87.29ms +step:285/1680 train_time:24877ms step_avg:87.29ms +step:286/1680 train_time:24965ms step_avg:87.29ms +step:287/1680 train_time:25055ms step_avg:87.30ms +step:288/1680 train_time:25142ms step_avg:87.30ms +step:289/1680 train_time:25230ms step_avg:87.30ms +step:290/1680 train_time:25317ms step_avg:87.30ms +step:291/1680 train_time:25404ms step_avg:87.30ms +step:292/1680 train_time:25491ms step_avg:87.30ms +step:293/1680 train_time:25577ms step_avg:87.29ms +step:294/1680 train_time:25664ms step_avg:87.29ms +step:295/1680 train_time:25750ms step_avg:87.29ms +step:296/1680 train_time:25838ms step_avg:87.29ms +step:297/1680 train_time:25925ms step_avg:87.29ms +step:298/1680 train_time:26014ms step_avg:87.29ms +step:299/1680 train_time:26101ms step_avg:87.29ms +step:300/1680 train_time:26189ms step_avg:87.30ms +step:301/1680 train_time:26277ms step_avg:87.30ms +step:302/1680 train_time:26364ms step_avg:87.30ms +step:303/1680 train_time:26452ms step_avg:87.30ms +step:304/1680 train_time:26538ms step_avg:87.30ms +step:305/1680 train_time:26625ms step_avg:87.29ms +step:306/1680 train_time:26712ms step_avg:87.29ms +step:307/1680 train_time:26799ms step_avg:87.29ms +step:308/1680 train_time:26887ms step_avg:87.29ms +step:309/1680 train_time:26974ms step_avg:87.30ms +step:310/1680 train_time:27061ms step_avg:87.29ms +step:311/1680 train_time:27150ms step_avg:87.30ms +step:312/1680 train_time:27238ms step_avg:87.30ms +step:313/1680 train_time:27325ms step_avg:87.30ms +step:314/1680 train_time:27413ms step_avg:87.30ms +step:315/1680 train_time:27500ms step_avg:87.30ms +step:316/1680 train_time:27587ms step_avg:87.30ms +step:317/1680 train_time:27674ms step_avg:87.30ms +step:318/1680 train_time:27760ms step_avg:87.30ms +step:319/1680 train_time:27847ms step_avg:87.30ms +step:320/1680 train_time:27935ms step_avg:87.30ms +step:321/1680 train_time:28021ms step_avg:87.29ms +step:322/1680 train_time:28109ms step_avg:87.30ms +step:323/1680 train_time:28197ms step_avg:87.30ms +step:324/1680 train_time:28285ms step_avg:87.30ms +step:325/1680 train_time:28373ms step_avg:87.30ms +step:326/1680 train_time:28460ms step_avg:87.30ms +step:327/1680 train_time:28547ms step_avg:87.30ms +step:328/1680 train_time:28634ms step_avg:87.30ms +step:329/1680 train_time:28721ms step_avg:87.30ms +step:330/1680 train_time:28808ms step_avg:87.30ms +step:331/1680 train_time:28895ms step_avg:87.30ms +step:332/1680 train_time:28983ms step_avg:87.30ms +step:333/1680 train_time:29069ms step_avg:87.30ms +step:334/1680 train_time:29158ms step_avg:87.30ms +step:335/1680 train_time:29245ms step_avg:87.30ms +step:336/1680 train_time:29333ms step_avg:87.30ms +step:337/1680 train_time:29420ms step_avg:87.30ms +step:338/1680 train_time:29507ms step_avg:87.30ms +step:339/1680 train_time:29594ms step_avg:87.30ms +step:340/1680 train_time:29681ms step_avg:87.30ms +step:341/1680 train_time:29769ms step_avg:87.30ms +step:342/1680 train_time:29855ms step_avg:87.30ms +step:343/1680 train_time:29943ms step_avg:87.30ms +step:344/1680 train_time:30030ms step_avg:87.30ms +step:345/1680 train_time:30118ms step_avg:87.30ms +step:346/1680 train_time:30205ms step_avg:87.30ms +step:347/1680 train_time:30292ms step_avg:87.30ms +step:348/1680 train_time:30379ms step_avg:87.30ms +step:349/1680 train_time:30467ms step_avg:87.30ms +step:350/1680 train_time:30554ms step_avg:87.30ms +step:351/1680 train_time:30641ms step_avg:87.30ms +step:352/1680 train_time:30729ms step_avg:87.30ms +step:353/1680 train_time:30816ms step_avg:87.30ms +step:354/1680 train_time:30903ms step_avg:87.30ms +step:355/1680 train_time:30991ms step_avg:87.30ms +step:356/1680 train_time:31078ms step_avg:87.30ms +step:357/1680 train_time:31165ms step_avg:87.30ms +step:358/1680 train_time:31252ms step_avg:87.30ms +step:359/1680 train_time:31339ms step_avg:87.29ms +step:360/1680 train_time:31426ms step_avg:87.29ms +step:361/1680 train_time:31513ms step_avg:87.29ms +step:362/1680 train_time:31600ms step_avg:87.29ms +step:363/1680 train_time:31688ms step_avg:87.29ms +step:364/1680 train_time:31775ms step_avg:87.29ms +step:365/1680 train_time:31862ms step_avg:87.29ms +step:366/1680 train_time:31950ms step_avg:87.29ms +step:367/1680 train_time:32037ms step_avg:87.30ms +step:368/1680 train_time:32124ms step_avg:87.29ms +step:369/1680 train_time:32212ms step_avg:87.30ms +step:370/1680 train_time:32299ms step_avg:87.30ms +step:371/1680 train_time:32387ms step_avg:87.30ms +step:372/1680 train_time:32474ms step_avg:87.30ms +step:373/1680 train_time:32562ms step_avg:87.30ms +step:374/1680 train_time:32649ms step_avg:87.30ms +step:375/1680 train_time:32737ms step_avg:87.30ms +step:375/1680 val_loss:3.8161 train_time:32825ms step_avg:87.53ms +step:376/1680 train_time:32846ms step_avg:87.36ms +step:377/1680 train_time:32915ms step_avg:87.31ms +step:378/1680 train_time:33007ms step_avg:87.32ms +step:379/1680 train_time:33097ms step_avg:87.33ms +step:380/1680 train_time:33183ms step_avg:87.32ms +step:381/1680 train_time:33270ms step_avg:87.32ms +step:382/1680 train_time:33356ms step_avg:87.32ms +step:383/1680 train_time:33442ms step_avg:87.32ms +step:384/1680 train_time:33528ms step_avg:87.31ms +step:385/1680 train_time:33615ms step_avg:87.31ms +step:386/1680 train_time:33701ms step_avg:87.31ms +step:387/1680 train_time:33788ms step_avg:87.31ms +step:388/1680 train_time:33877ms step_avg:87.31ms +step:389/1680 train_time:33967ms step_avg:87.32ms +step:390/1680 train_time:34055ms step_avg:87.32ms +step:391/1680 train_time:34143ms step_avg:87.32ms +step:392/1680 train_time:34230ms step_avg:87.32ms +step:393/1680 train_time:34317ms step_avg:87.32ms +step:394/1680 train_time:34403ms step_avg:87.32ms +step:395/1680 train_time:34490ms step_avg:87.32ms +step:396/1680 train_time:34577ms step_avg:87.32ms +step:397/1680 train_time:34664ms step_avg:87.31ms +step:398/1680 train_time:34750ms step_avg:87.31ms +step:399/1680 train_time:34838ms step_avg:87.31ms +step:400/1680 train_time:34927ms step_avg:87.32ms +step:401/1680 train_time:35015ms step_avg:87.32ms +step:402/1680 train_time:35103ms step_avg:87.32ms +step:403/1680 train_time:35190ms step_avg:87.32ms +step:404/1680 train_time:35277ms step_avg:87.32ms +step:405/1680 train_time:35365ms step_avg:87.32ms +step:406/1680 train_time:35451ms step_avg:87.32ms +step:407/1680 train_time:35538ms step_avg:87.32ms +step:408/1680 train_time:35624ms step_avg:87.31ms +step:409/1680 train_time:35712ms step_avg:87.31ms +step:410/1680 train_time:35800ms step_avg:87.32ms +step:411/1680 train_time:35888ms step_avg:87.32ms +step:412/1680 train_time:35976ms step_avg:87.32ms +step:413/1680 train_time:36063ms step_avg:87.32ms +step:414/1680 train_time:36151ms step_avg:87.32ms +step:415/1680 train_time:36238ms step_avg:87.32ms +step:416/1680 train_time:36325ms step_avg:87.32ms +step:417/1680 train_time:36412ms step_avg:87.32ms +step:418/1680 train_time:36498ms step_avg:87.32ms +step:419/1680 train_time:36585ms step_avg:87.31ms +step:420/1680 train_time:36672ms step_avg:87.31ms +step:421/1680 train_time:36759ms step_avg:87.31ms +step:422/1680 train_time:36846ms step_avg:87.31ms +step:423/1680 train_time:36933ms step_avg:87.31ms +step:424/1680 train_time:37021ms step_avg:87.31ms +step:425/1680 train_time:37109ms step_avg:87.32ms +step:426/1680 train_time:37197ms step_avg:87.32ms +step:427/1680 train_time:37284ms step_avg:87.32ms +step:428/1680 train_time:37372ms step_avg:87.32ms +step:429/1680 train_time:37458ms step_avg:87.31ms +step:430/1680 train_time:37545ms step_avg:87.31ms +step:431/1680 train_time:37632ms step_avg:87.31ms +step:432/1680 train_time:37720ms step_avg:87.31ms +step:433/1680 train_time:37807ms step_avg:87.31ms +step:434/1680 train_time:37894ms step_avg:87.31ms +step:435/1680 train_time:37982ms step_avg:87.31ms +step:436/1680 train_time:38069ms step_avg:87.32ms +step:437/1680 train_time:38157ms step_avg:87.32ms +step:438/1680 train_time:38244ms step_avg:87.32ms +step:439/1680 train_time:38332ms step_avg:87.32ms +step:440/1680 train_time:38420ms step_avg:87.32ms +step:441/1680 train_time:38506ms step_avg:87.32ms +step:442/1680 train_time:38593ms step_avg:87.31ms +step:443/1680 train_time:38680ms step_avg:87.31ms +step:444/1680 train_time:38767ms step_avg:87.31ms +step:445/1680 train_time:38855ms step_avg:87.31ms +step:446/1680 train_time:38942ms step_avg:87.31ms +step:447/1680 train_time:39029ms step_avg:87.31ms +step:448/1680 train_time:39117ms step_avg:87.31ms +step:449/1680 train_time:39204ms step_avg:87.31ms +step:450/1680 train_time:39292ms step_avg:87.32ms +step:451/1680 train_time:39380ms step_avg:87.32ms +step:452/1680 train_time:39467ms step_avg:87.32ms +step:453/1680 train_time:39554ms step_avg:87.32ms +step:454/1680 train_time:39641ms step_avg:87.31ms +step:455/1680 train_time:39728ms step_avg:87.32ms +step:456/1680 train_time:39816ms step_avg:87.32ms +step:457/1680 train_time:39902ms step_avg:87.31ms +step:458/1680 train_time:39990ms step_avg:87.31ms +step:459/1680 train_time:40077ms step_avg:87.31ms +step:460/1680 train_time:40165ms step_avg:87.31ms +step:461/1680 train_time:40253ms step_avg:87.32ms +step:462/1680 train_time:40340ms step_avg:87.32ms +step:463/1680 train_time:40428ms step_avg:87.32ms +step:464/1680 train_time:40515ms step_avg:87.32ms +step:465/1680 train_time:40602ms step_avg:87.32ms +step:466/1680 train_time:40689ms step_avg:87.32ms +step:467/1680 train_time:40776ms step_avg:87.32ms +step:468/1680 train_time:40863ms step_avg:87.31ms +step:469/1680 train_time:40951ms step_avg:87.32ms +step:470/1680 train_time:41038ms step_avg:87.31ms +step:471/1680 train_time:41125ms step_avg:87.31ms +step:472/1680 train_time:41213ms step_avg:87.32ms +step:473/1680 train_time:41300ms step_avg:87.31ms +step:474/1680 train_time:41388ms step_avg:87.32ms +step:475/1680 train_time:41475ms step_avg:87.32ms +step:476/1680 train_time:41562ms step_avg:87.31ms +step:477/1680 train_time:41649ms step_avg:87.31ms +step:478/1680 train_time:41736ms step_avg:87.31ms +step:479/1680 train_time:41823ms step_avg:87.31ms +step:480/1680 train_time:41911ms step_avg:87.31ms +step:481/1680 train_time:41998ms step_avg:87.31ms +step:482/1680 train_time:42085ms step_avg:87.31ms +step:483/1680 train_time:42173ms step_avg:87.31ms +step:484/1680 train_time:42259ms step_avg:87.31ms +step:485/1680 train_time:42347ms step_avg:87.31ms +step:486/1680 train_time:42434ms step_avg:87.31ms +step:487/1680 train_time:42521ms step_avg:87.31ms +step:488/1680 train_time:42608ms step_avg:87.31ms +step:489/1680 train_time:42695ms step_avg:87.31ms +step:490/1680 train_time:42782ms step_avg:87.31ms +step:491/1680 train_time:42869ms step_avg:87.31ms +step:492/1680 train_time:42957ms step_avg:87.31ms +step:493/1680 train_time:43044ms step_avg:87.31ms +step:494/1680 train_time:43132ms step_avg:87.31ms +step:495/1680 train_time:43220ms step_avg:87.31ms +step:496/1680 train_time:43307ms step_avg:87.31ms +step:497/1680 train_time:43394ms step_avg:87.31ms +step:498/1680 train_time:43481ms step_avg:87.31ms +step:499/1680 train_time:43568ms step_avg:87.31ms +step:500/1680 train_time:43655ms step_avg:87.31ms +step:500/1680 val_loss:3.7146 train_time:43743ms step_avg:87.49ms +step:501/1680 train_time:43762ms step_avg:87.35ms +step:502/1680 train_time:43834ms step_avg:87.32ms +step:503/1680 train_time:43925ms step_avg:87.33ms +step:504/1680 train_time:44013ms step_avg:87.33ms +step:505/1680 train_time:44099ms step_avg:87.33ms +step:506/1680 train_time:44186ms step_avg:87.32ms +step:507/1680 train_time:44272ms step_avg:87.32ms +step:508/1680 train_time:44358ms step_avg:87.32ms +step:509/1680 train_time:44444ms step_avg:87.32ms +step:510/1680 train_time:44531ms step_avg:87.32ms +step:511/1680 train_time:44617ms step_avg:87.31ms +step:512/1680 train_time:44704ms step_avg:87.31ms +step:513/1680 train_time:44794ms step_avg:87.32ms +step:514/1680 train_time:44883ms step_avg:87.32ms +step:515/1680 train_time:44972ms step_avg:87.32ms +step:516/1680 train_time:45059ms step_avg:87.32ms +step:517/1680 train_time:45146ms step_avg:87.32ms +step:518/1680 train_time:45233ms step_avg:87.32ms +step:519/1680 train_time:45319ms step_avg:87.32ms +step:520/1680 train_time:45406ms step_avg:87.32ms +step:521/1680 train_time:45493ms step_avg:87.32ms +step:522/1680 train_time:45579ms step_avg:87.32ms +step:523/1680 train_time:45666ms step_avg:87.32ms +step:524/1680 train_time:45754ms step_avg:87.32ms +step:525/1680 train_time:45843ms step_avg:87.32ms +step:526/1680 train_time:45931ms step_avg:87.32ms +step:527/1680 train_time:46018ms step_avg:87.32ms +step:528/1680 train_time:46106ms step_avg:87.32ms +step:529/1680 train_time:46193ms step_avg:87.32ms +step:530/1680 train_time:46280ms step_avg:87.32ms +step:531/1680 train_time:46367ms step_avg:87.32ms +step:532/1680 train_time:46454ms step_avg:87.32ms +step:533/1680 train_time:46540ms step_avg:87.32ms +step:534/1680 train_time:46626ms step_avg:87.31ms +step:535/1680 train_time:46713ms step_avg:87.31ms +step:536/1680 train_time:46801ms step_avg:87.32ms +step:537/1680 train_time:46890ms step_avg:87.32ms +step:538/1680 train_time:46977ms step_avg:87.32ms +step:539/1680 train_time:47065ms step_avg:87.32ms +step:540/1680 train_time:47153ms step_avg:87.32ms +step:541/1680 train_time:47240ms step_avg:87.32ms +step:542/1680 train_time:47327ms step_avg:87.32ms +step:543/1680 train_time:47415ms step_avg:87.32ms +step:544/1680 train_time:47501ms step_avg:87.32ms +step:545/1680 train_time:47587ms step_avg:87.32ms +step:546/1680 train_time:47674ms step_avg:87.31ms +step:547/1680 train_time:47761ms step_avg:87.32ms +step:548/1680 train_time:47850ms step_avg:87.32ms +step:549/1680 train_time:47938ms step_avg:87.32ms +step:550/1680 train_time:48028ms step_avg:87.32ms +step:551/1680 train_time:48116ms step_avg:87.32ms +step:552/1680 train_time:48204ms step_avg:87.33ms +step:553/1680 train_time:48292ms step_avg:87.33ms +step:554/1680 train_time:48380ms step_avg:87.33ms +step:555/1680 train_time:48468ms step_avg:87.33ms +step:556/1680 train_time:48557ms step_avg:87.33ms +step:557/1680 train_time:48645ms step_avg:87.33ms +step:558/1680 train_time:48734ms step_avg:87.34ms +step:559/1680 train_time:48822ms step_avg:87.34ms +step:560/1680 train_time:48911ms step_avg:87.34ms +step:561/1680 train_time:48999ms step_avg:87.34ms +step:562/1680 train_time:49088ms step_avg:87.34ms +step:563/1680 train_time:49177ms step_avg:87.35ms +step:564/1680 train_time:49265ms step_avg:87.35ms +step:565/1680 train_time:49354ms step_avg:87.35ms +step:566/1680 train_time:49442ms step_avg:87.35ms +step:567/1680 train_time:49531ms step_avg:87.36ms +step:568/1680 train_time:49619ms step_avg:87.36ms +step:569/1680 train_time:49708ms step_avg:87.36ms +step:570/1680 train_time:49796ms step_avg:87.36ms +step:571/1680 train_time:49885ms step_avg:87.36ms +step:572/1680 train_time:49973ms step_avg:87.37ms +step:573/1680 train_time:50061ms step_avg:87.37ms +step:574/1680 train_time:50149ms step_avg:87.37ms +step:575/1680 train_time:50237ms step_avg:87.37ms +step:576/1680 train_time:50325ms step_avg:87.37ms +step:577/1680 train_time:50414ms step_avg:87.37ms +step:578/1680 train_time:50502ms step_avg:87.37ms +step:579/1680 train_time:50590ms step_avg:87.37ms +step:580/1680 train_time:50678ms step_avg:87.38ms +step:581/1680 train_time:50766ms step_avg:87.38ms +step:582/1680 train_time:50855ms step_avg:87.38ms +step:583/1680 train_time:50944ms step_avg:87.38ms +step:584/1680 train_time:51033ms step_avg:87.39ms +step:585/1680 train_time:51121ms step_avg:87.39ms +step:586/1680 train_time:51209ms step_avg:87.39ms +step:587/1680 train_time:51298ms step_avg:87.39ms +step:588/1680 train_time:51386ms step_avg:87.39ms +step:589/1680 train_time:51474ms step_avg:87.39ms +step:590/1680 train_time:51564ms step_avg:87.40ms +step:591/1680 train_time:51652ms step_avg:87.40ms +step:592/1680 train_time:51740ms step_avg:87.40ms +step:593/1680 train_time:51829ms step_avg:87.40ms +step:594/1680 train_time:51917ms step_avg:87.40ms +step:595/1680 train_time:52006ms step_avg:87.40ms +step:596/1680 train_time:52095ms step_avg:87.41ms +step:597/1680 train_time:52183ms step_avg:87.41ms +step:598/1680 train_time:52271ms step_avg:87.41ms +step:599/1680 train_time:52359ms step_avg:87.41ms +step:600/1680 train_time:52448ms step_avg:87.41ms +step:601/1680 train_time:52536ms step_avg:87.41ms +step:602/1680 train_time:52625ms step_avg:87.42ms +step:603/1680 train_time:52713ms step_avg:87.42ms +step:604/1680 train_time:52801ms step_avg:87.42ms +step:605/1680 train_time:52889ms step_avg:87.42ms +step:606/1680 train_time:52978ms step_avg:87.42ms +step:607/1680 train_time:53066ms step_avg:87.42ms +step:608/1680 train_time:53155ms step_avg:87.43ms +step:609/1680 train_time:53243ms step_avg:87.43ms +step:610/1680 train_time:53330ms step_avg:87.43ms +step:611/1680 train_time:53418ms step_avg:87.43ms +step:612/1680 train_time:53507ms step_avg:87.43ms +step:613/1680 train_time:53596ms step_avg:87.43ms +step:614/1680 train_time:53685ms step_avg:87.43ms +step:615/1680 train_time:53773ms step_avg:87.44ms +step:616/1680 train_time:53861ms step_avg:87.44ms +step:617/1680 train_time:53950ms step_avg:87.44ms +step:618/1680 train_time:54039ms step_avg:87.44ms +step:619/1680 train_time:54126ms step_avg:87.44ms +step:620/1680 train_time:54215ms step_avg:87.44ms +step:621/1680 train_time:54303ms step_avg:87.44ms +step:622/1680 train_time:54391ms step_avg:87.45ms +step:623/1680 train_time:54479ms step_avg:87.45ms +step:624/1680 train_time:54568ms step_avg:87.45ms +step:625/1680 train_time:54656ms step_avg:87.45ms +step:625/1680 val_loss:3.6165 train_time:54747ms step_avg:87.59ms +step:626/1680 train_time:54767ms step_avg:87.49ms +step:627/1680 train_time:54835ms step_avg:87.46ms +step:628/1680 train_time:54924ms step_avg:87.46ms +step:629/1680 train_time:55014ms step_avg:87.46ms +step:630/1680 train_time:55103ms step_avg:87.46ms +step:631/1680 train_time:55190ms step_avg:87.46ms +step:632/1680 train_time:55277ms step_avg:87.46ms +step:633/1680 train_time:55365ms step_avg:87.47ms +step:634/1680 train_time:55453ms step_avg:87.46ms +step:635/1680 train_time:55541ms step_avg:87.47ms +step:636/1680 train_time:55629ms step_avg:87.47ms +step:637/1680 train_time:55722ms step_avg:87.48ms +step:638/1680 train_time:55812ms step_avg:87.48ms +step:639/1680 train_time:55901ms step_avg:87.48ms +step:640/1680 train_time:55990ms step_avg:87.48ms +step:641/1680 train_time:56079ms step_avg:87.49ms +step:642/1680 train_time:56168ms step_avg:87.49ms +step:643/1680 train_time:56255ms step_avg:87.49ms +step:644/1680 train_time:56344ms step_avg:87.49ms +step:645/1680 train_time:56431ms step_avg:87.49ms +step:646/1680 train_time:56519ms step_avg:87.49ms +step:647/1680 train_time:56607ms step_avg:87.49ms +step:648/1680 train_time:56696ms step_avg:87.49ms +step:649/1680 train_time:56785ms step_avg:87.50ms +step:650/1680 train_time:56875ms step_avg:87.50ms +step:651/1680 train_time:56963ms step_avg:87.50ms +step:652/1680 train_time:57052ms step_avg:87.50ms +step:653/1680 train_time:57140ms step_avg:87.50ms +step:654/1680 train_time:57228ms step_avg:87.51ms +step:655/1680 train_time:57316ms step_avg:87.50ms +step:656/1680 train_time:57404ms step_avg:87.51ms +step:657/1680 train_time:57492ms step_avg:87.51ms +step:658/1680 train_time:57580ms step_avg:87.51ms +step:659/1680 train_time:57668ms step_avg:87.51ms +step:660/1680 train_time:57757ms step_avg:87.51ms +step:661/1680 train_time:57846ms step_avg:87.51ms +step:662/1680 train_time:57934ms step_avg:87.51ms +step:663/1680 train_time:58024ms step_avg:87.52ms +step:664/1680 train_time:58112ms step_avg:87.52ms +step:665/1680 train_time:58200ms step_avg:87.52ms +step:666/1680 train_time:58289ms step_avg:87.52ms +step:667/1680 train_time:58377ms step_avg:87.52ms +step:668/1680 train_time:58464ms step_avg:87.52ms +step:669/1680 train_time:58552ms step_avg:87.52ms +step:670/1680 train_time:58641ms step_avg:87.52ms +step:671/1680 train_time:58729ms step_avg:87.53ms +step:672/1680 train_time:58818ms step_avg:87.53ms +step:673/1680 train_time:58908ms step_avg:87.53ms +step:674/1680 train_time:58996ms step_avg:87.53ms +step:675/1680 train_time:59084ms step_avg:87.53ms +step:676/1680 train_time:59173ms step_avg:87.53ms +step:677/1680 train_time:59261ms step_avg:87.53ms +step:678/1680 train_time:59349ms step_avg:87.54ms +step:679/1680 train_time:59437ms step_avg:87.54ms +step:680/1680 train_time:59525ms step_avg:87.54ms +step:681/1680 train_time:59612ms step_avg:87.54ms +step:682/1680 train_time:59701ms step_avg:87.54ms +step:683/1680 train_time:59791ms step_avg:87.54ms +step:684/1680 train_time:59880ms step_avg:87.54ms +step:685/1680 train_time:59969ms step_avg:87.55ms +step:686/1680 train_time:60057ms step_avg:87.55ms +step:687/1680 train_time:60145ms step_avg:87.55ms +step:688/1680 train_time:60233ms step_avg:87.55ms +step:689/1680 train_time:60321ms step_avg:87.55ms +step:690/1680 train_time:60409ms step_avg:87.55ms +step:691/1680 train_time:60497ms step_avg:87.55ms +step:692/1680 train_time:60586ms step_avg:87.55ms +step:693/1680 train_time:60675ms step_avg:87.55ms +step:694/1680 train_time:60763ms step_avg:87.55ms +step:695/1680 train_time:60851ms step_avg:87.56ms +step:696/1680 train_time:60940ms step_avg:87.56ms +step:697/1680 train_time:61028ms step_avg:87.56ms +step:698/1680 train_time:61116ms step_avg:87.56ms +step:699/1680 train_time:61204ms step_avg:87.56ms +step:700/1680 train_time:61292ms step_avg:87.56ms +step:701/1680 train_time:61380ms step_avg:87.56ms +step:702/1680 train_time:61469ms step_avg:87.56ms +step:703/1680 train_time:61557ms step_avg:87.56ms +step:704/1680 train_time:61646ms step_avg:87.57ms +step:705/1680 train_time:61734ms step_avg:87.57ms +step:706/1680 train_time:61823ms step_avg:87.57ms +step:707/1680 train_time:61911ms step_avg:87.57ms +step:708/1680 train_time:61999ms step_avg:87.57ms +step:709/1680 train_time:62088ms step_avg:87.57ms +step:710/1680 train_time:62176ms step_avg:87.57ms +step:711/1680 train_time:62264ms step_avg:87.57ms +step:712/1680 train_time:62352ms step_avg:87.57ms +step:713/1680 train_time:62441ms step_avg:87.57ms +step:714/1680 train_time:62529ms step_avg:87.58ms +step:715/1680 train_time:62617ms step_avg:87.58ms +step:716/1680 train_time:62705ms step_avg:87.58ms +step:717/1680 train_time:62793ms step_avg:87.58ms +step:718/1680 train_time:62881ms step_avg:87.58ms +step:719/1680 train_time:62969ms step_avg:87.58ms +step:720/1680 train_time:63058ms step_avg:87.58ms +step:721/1680 train_time:63146ms step_avg:87.58ms +step:722/1680 train_time:63234ms step_avg:87.58ms +step:723/1680 train_time:63322ms step_avg:87.58ms +step:724/1680 train_time:63410ms step_avg:87.58ms +step:725/1680 train_time:63499ms step_avg:87.58ms +step:726/1680 train_time:63587ms step_avg:87.59ms +step:727/1680 train_time:63675ms step_avg:87.59ms +step:728/1680 train_time:63764ms step_avg:87.59ms +step:729/1680 train_time:63853ms step_avg:87.59ms +step:730/1680 train_time:63942ms step_avg:87.59ms +step:731/1680 train_time:64031ms step_avg:87.59ms +step:732/1680 train_time:64120ms step_avg:87.60ms +step:733/1680 train_time:64209ms step_avg:87.60ms +step:734/1680 train_time:64296ms step_avg:87.60ms +step:735/1680 train_time:64386ms step_avg:87.60ms +step:736/1680 train_time:64474ms step_avg:87.60ms +step:737/1680 train_time:64563ms step_avg:87.60ms +step:738/1680 train_time:64650ms step_avg:87.60ms +step:739/1680 train_time:64739ms step_avg:87.60ms +step:740/1680 train_time:64826ms step_avg:87.60ms +step:741/1680 train_time:64915ms step_avg:87.60ms +step:742/1680 train_time:65004ms step_avg:87.61ms +step:743/1680 train_time:65093ms step_avg:87.61ms +step:744/1680 train_time:65181ms step_avg:87.61ms +step:745/1680 train_time:65269ms step_avg:87.61ms +step:746/1680 train_time:65358ms step_avg:87.61ms +step:747/1680 train_time:65447ms step_avg:87.61ms +step:748/1680 train_time:65535ms step_avg:87.61ms +step:749/1680 train_time:65624ms step_avg:87.62ms +step:750/1680 train_time:65712ms step_avg:87.62ms +step:750/1680 val_loss:3.5642 train_time:65801ms step_avg:87.74ms +step:751/1680 train_time:65820ms step_avg:87.64ms +step:752/1680 train_time:65892ms step_avg:87.62ms +step:753/1680 train_time:65985ms step_avg:87.63ms +step:754/1680 train_time:66075ms step_avg:87.63ms +step:755/1680 train_time:66162ms step_avg:87.63ms +step:756/1680 train_time:66249ms step_avg:87.63ms +step:757/1680 train_time:66336ms step_avg:87.63ms +step:758/1680 train_time:66424ms step_avg:87.63ms +step:759/1680 train_time:66511ms step_avg:87.63ms +step:760/1680 train_time:66598ms step_avg:87.63ms +step:761/1680 train_time:66685ms step_avg:87.63ms +step:762/1680 train_time:66773ms step_avg:87.63ms +step:763/1680 train_time:66863ms step_avg:87.63ms +step:764/1680 train_time:66954ms step_avg:87.64ms +step:765/1680 train_time:67045ms step_avg:87.64ms +step:766/1680 train_time:67133ms step_avg:87.64ms +step:767/1680 train_time:67221ms step_avg:87.64ms +step:768/1680 train_time:67309ms step_avg:87.64ms +step:769/1680 train_time:67398ms step_avg:87.64ms +step:770/1680 train_time:67485ms step_avg:87.64ms +step:771/1680 train_time:67573ms step_avg:87.64ms +step:772/1680 train_time:67660ms step_avg:87.64ms +step:773/1680 train_time:67748ms step_avg:87.64ms +step:774/1680 train_time:67837ms step_avg:87.64ms +step:775/1680 train_time:67925ms step_avg:87.65ms +step:776/1680 train_time:68015ms step_avg:87.65ms +step:777/1680 train_time:68104ms step_avg:87.65ms +step:778/1680 train_time:68192ms step_avg:87.65ms +step:779/1680 train_time:68280ms step_avg:87.65ms +step:780/1680 train_time:68368ms step_avg:87.65ms +step:781/1680 train_time:68456ms step_avg:87.65ms +step:782/1680 train_time:68544ms step_avg:87.65ms +step:783/1680 train_time:68633ms step_avg:87.65ms +step:784/1680 train_time:68721ms step_avg:87.65ms +step:785/1680 train_time:68810ms step_avg:87.66ms +step:786/1680 train_time:68900ms step_avg:87.66ms +step:787/1680 train_time:68989ms step_avg:87.66ms +step:788/1680 train_time:69077ms step_avg:87.66ms +step:789/1680 train_time:69165ms step_avg:87.66ms +step:790/1680 train_time:69254ms step_avg:87.66ms +step:791/1680 train_time:69341ms step_avg:87.66ms +step:792/1680 train_time:69429ms step_avg:87.66ms +step:793/1680 train_time:69517ms step_avg:87.66ms +step:794/1680 train_time:69605ms step_avg:87.66ms +step:795/1680 train_time:69693ms step_avg:87.66ms +step:796/1680 train_time:69781ms step_avg:87.67ms +step:797/1680 train_time:69870ms step_avg:87.67ms +step:798/1680 train_time:69959ms step_avg:87.67ms +step:799/1680 train_time:70048ms step_avg:87.67ms +step:800/1680 train_time:70136ms step_avg:87.67ms +step:801/1680 train_time:70225ms step_avg:87.67ms +step:802/1680 train_time:70313ms step_avg:87.67ms +step:803/1680 train_time:70401ms step_avg:87.67ms +step:804/1680 train_time:70489ms step_avg:87.67ms +step:805/1680 train_time:70577ms step_avg:87.67ms +step:806/1680 train_time:70665ms step_avg:87.67ms +step:807/1680 train_time:70753ms step_avg:87.67ms +step:808/1680 train_time:70841ms step_avg:87.67ms +step:809/1680 train_time:70931ms step_avg:87.68ms +step:810/1680 train_time:71019ms step_avg:87.68ms +step:811/1680 train_time:71107ms step_avg:87.68ms +step:812/1680 train_time:71196ms step_avg:87.68ms +step:813/1680 train_time:71284ms step_avg:87.68ms +step:814/1680 train_time:71372ms step_avg:87.68ms +step:815/1680 train_time:71460ms step_avg:87.68ms +step:816/1680 train_time:71547ms step_avg:87.68ms +step:817/1680 train_time:71636ms step_avg:87.68ms +step:818/1680 train_time:71724ms step_avg:87.68ms +step:819/1680 train_time:71813ms step_avg:87.68ms +step:820/1680 train_time:71901ms step_avg:87.68ms +step:821/1680 train_time:71989ms step_avg:87.68ms +step:822/1680 train_time:72077ms step_avg:87.69ms +step:823/1680 train_time:72165ms step_avg:87.69ms +step:824/1680 train_time:72254ms step_avg:87.69ms +step:825/1680 train_time:72342ms step_avg:87.69ms +step:826/1680 train_time:72431ms step_avg:87.69ms +step:827/1680 train_time:72519ms step_avg:87.69ms +step:828/1680 train_time:72608ms step_avg:87.69ms +step:829/1680 train_time:72696ms step_avg:87.69ms +step:830/1680 train_time:72783ms step_avg:87.69ms +step:831/1680 train_time:72871ms step_avg:87.69ms +step:832/1680 train_time:72960ms step_avg:87.69ms +step:833/1680 train_time:73048ms step_avg:87.69ms +step:834/1680 train_time:73137ms step_avg:87.69ms +step:835/1680 train_time:73226ms step_avg:87.70ms +step:836/1680 train_time:73313ms step_avg:87.70ms +step:837/1680 train_time:73402ms step_avg:87.70ms +step:838/1680 train_time:73490ms step_avg:87.70ms +step:839/1680 train_time:73579ms step_avg:87.70ms +step:840/1680 train_time:73667ms step_avg:87.70ms +step:841/1680 train_time:73756ms step_avg:87.70ms +step:842/1680 train_time:73844ms step_avg:87.70ms +step:843/1680 train_time:73933ms step_avg:87.70ms +step:844/1680 train_time:74021ms step_avg:87.70ms +step:845/1680 train_time:74109ms step_avg:87.70ms +step:846/1680 train_time:74198ms step_avg:87.70ms +step:847/1680 train_time:74285ms step_avg:87.70ms +step:848/1680 train_time:74373ms step_avg:87.70ms +step:849/1680 train_time:74461ms step_avg:87.70ms +step:850/1680 train_time:74550ms step_avg:87.71ms +step:851/1680 train_time:74638ms step_avg:87.71ms +step:852/1680 train_time:74727ms step_avg:87.71ms +step:853/1680 train_time:74816ms step_avg:87.71ms +step:854/1680 train_time:74904ms step_avg:87.71ms +step:855/1680 train_time:74993ms step_avg:87.71ms +step:856/1680 train_time:75081ms step_avg:87.71ms +step:857/1680 train_time:75170ms step_avg:87.71ms +step:858/1680 train_time:75258ms step_avg:87.71ms +step:859/1680 train_time:75347ms step_avg:87.71ms +step:860/1680 train_time:75436ms step_avg:87.72ms +step:861/1680 train_time:75525ms step_avg:87.72ms +step:862/1680 train_time:75612ms step_avg:87.72ms +step:863/1680 train_time:75700ms step_avg:87.72ms +step:864/1680 train_time:75788ms step_avg:87.72ms +step:865/1680 train_time:75876ms step_avg:87.72ms +step:866/1680 train_time:75964ms step_avg:87.72ms +step:867/1680 train_time:76052ms step_avg:87.72ms +step:868/1680 train_time:76140ms step_avg:87.72ms +step:869/1680 train_time:76229ms step_avg:87.72ms +step:870/1680 train_time:76317ms step_avg:87.72ms +step:871/1680 train_time:76406ms step_avg:87.72ms +step:872/1680 train_time:76494ms step_avg:87.72ms +step:873/1680 train_time:76582ms step_avg:87.72ms +step:874/1680 train_time:76671ms step_avg:87.72ms +step:875/1680 train_time:76759ms step_avg:87.72ms +step:875/1680 val_loss:3.5185 train_time:76849ms step_avg:87.83ms +step:876/1680 train_time:76867ms step_avg:87.75ms +step:877/1680 train_time:76940ms step_avg:87.73ms +step:878/1680 train_time:77034ms step_avg:87.74ms +step:879/1680 train_time:77124ms step_avg:87.74ms +step:880/1680 train_time:77212ms step_avg:87.74ms +step:881/1680 train_time:77299ms step_avg:87.74ms +step:882/1680 train_time:77385ms step_avg:87.74ms +step:883/1680 train_time:77472ms step_avg:87.74ms +step:884/1680 train_time:77560ms step_avg:87.74ms +step:885/1680 train_time:77648ms step_avg:87.74ms +step:886/1680 train_time:77735ms step_avg:87.74ms +step:887/1680 train_time:77824ms step_avg:87.74ms +step:888/1680 train_time:77914ms step_avg:87.74ms +step:889/1680 train_time:78005ms step_avg:87.74ms +step:890/1680 train_time:78094ms step_avg:87.75ms +step:891/1680 train_time:78184ms step_avg:87.75ms +step:892/1680 train_time:78272ms step_avg:87.75ms +step:893/1680 train_time:78360ms step_avg:87.75ms +step:894/1680 train_time:78448ms step_avg:87.75ms +step:895/1680 train_time:78535ms step_avg:87.75ms +step:896/1680 train_time:78622ms step_avg:87.75ms +step:897/1680 train_time:78709ms step_avg:87.75ms +step:898/1680 train_time:78797ms step_avg:87.75ms +step:899/1680 train_time:78886ms step_avg:87.75ms +step:900/1680 train_time:78975ms step_avg:87.75ms +step:901/1680 train_time:79065ms step_avg:87.75ms +step:902/1680 train_time:79153ms step_avg:87.75ms +step:903/1680 train_time:79242ms step_avg:87.75ms +step:904/1680 train_time:79330ms step_avg:87.75ms +step:905/1680 train_time:79418ms step_avg:87.75ms +step:906/1680 train_time:79506ms step_avg:87.75ms +step:907/1680 train_time:79593ms step_avg:87.75ms +step:908/1680 train_time:79681ms step_avg:87.75ms +step:909/1680 train_time:79769ms step_avg:87.75ms +step:910/1680 train_time:79858ms step_avg:87.76ms +step:911/1680 train_time:79948ms step_avg:87.76ms +step:912/1680 train_time:80037ms step_avg:87.76ms +step:913/1680 train_time:80125ms step_avg:87.76ms +step:914/1680 train_time:80214ms step_avg:87.76ms +step:915/1680 train_time:80302ms step_avg:87.76ms +step:916/1680 train_time:80389ms step_avg:87.76ms +step:917/1680 train_time:80477ms step_avg:87.76ms +step:918/1680 train_time:80566ms step_avg:87.76ms +step:919/1680 train_time:80653ms step_avg:87.76ms +step:920/1680 train_time:80742ms step_avg:87.76ms +step:921/1680 train_time:80830ms step_avg:87.76ms +step:922/1680 train_time:80919ms step_avg:87.77ms +step:923/1680 train_time:81009ms step_avg:87.77ms +step:924/1680 train_time:81098ms step_avg:87.77ms +step:925/1680 train_time:81187ms step_avg:87.77ms +step:926/1680 train_time:81275ms step_avg:87.77ms +step:927/1680 train_time:81364ms step_avg:87.77ms +step:928/1680 train_time:81452ms step_avg:87.77ms +step:929/1680 train_time:81540ms step_avg:87.77ms +step:930/1680 train_time:81628ms step_avg:87.77ms +step:931/1680 train_time:81716ms step_avg:87.77ms +step:932/1680 train_time:81804ms step_avg:87.77ms +step:933/1680 train_time:81892ms step_avg:87.77ms +step:934/1680 train_time:81981ms step_avg:87.77ms +step:935/1680 train_time:82069ms step_avg:87.77ms +step:936/1680 train_time:82158ms step_avg:87.78ms +step:937/1680 train_time:82247ms step_avg:87.78ms +step:938/1680 train_time:82334ms step_avg:87.78ms +step:939/1680 train_time:82421ms step_avg:87.78ms +step:940/1680 train_time:82510ms step_avg:87.78ms +step:941/1680 train_time:82598ms step_avg:87.78ms +step:942/1680 train_time:82687ms step_avg:87.78ms +step:943/1680 train_time:82775ms step_avg:87.78ms +step:944/1680 train_time:82863ms step_avg:87.78ms +step:945/1680 train_time:82952ms step_avg:87.78ms +step:946/1680 train_time:83041ms step_avg:87.78ms +step:947/1680 train_time:83129ms step_avg:87.78ms +step:948/1680 train_time:83218ms step_avg:87.78ms +step:949/1680 train_time:83306ms step_avg:87.78ms +step:950/1680 train_time:83395ms step_avg:87.78ms +step:951/1680 train_time:83482ms step_avg:87.78ms +step:952/1680 train_time:83571ms step_avg:87.78ms +step:953/1680 train_time:83659ms step_avg:87.79ms +step:954/1680 train_time:83747ms step_avg:87.79ms +step:955/1680 train_time:83836ms step_avg:87.79ms +step:956/1680 train_time:83924ms step_avg:87.79ms +step:957/1680 train_time:84012ms step_avg:87.79ms +step:958/1680 train_time:84100ms step_avg:87.79ms +step:959/1680 train_time:84188ms step_avg:87.79ms +step:960/1680 train_time:84277ms step_avg:87.79ms +step:961/1680 train_time:84365ms step_avg:87.79ms +step:962/1680 train_time:84453ms step_avg:87.79ms +step:963/1680 train_time:84541ms step_avg:87.79ms +step:964/1680 train_time:84629ms step_avg:87.79ms +step:965/1680 train_time:84717ms step_avg:87.79ms +step:966/1680 train_time:84805ms step_avg:87.79ms +step:967/1680 train_time:84894ms step_avg:87.79ms +step:968/1680 train_time:84982ms step_avg:87.79ms +step:969/1680 train_time:85071ms step_avg:87.79ms +step:970/1680 train_time:85159ms step_avg:87.79ms +step:971/1680 train_time:85248ms step_avg:87.79ms +step:972/1680 train_time:85337ms step_avg:87.80ms +step:973/1680 train_time:85425ms step_avg:87.80ms +step:974/1680 train_time:85513ms step_avg:87.80ms +step:975/1680 train_time:85601ms step_avg:87.80ms +step:976/1680 train_time:85689ms step_avg:87.80ms +step:977/1680 train_time:85777ms step_avg:87.80ms +step:978/1680 train_time:85866ms step_avg:87.80ms +step:979/1680 train_time:85954ms step_avg:87.80ms +step:980/1680 train_time:86043ms step_avg:87.80ms +step:981/1680 train_time:86131ms step_avg:87.80ms +step:982/1680 train_time:86220ms step_avg:87.80ms +step:983/1680 train_time:86309ms step_avg:87.80ms +step:984/1680 train_time:86398ms step_avg:87.80ms +step:985/1680 train_time:86486ms step_avg:87.80ms +step:986/1680 train_time:86574ms step_avg:87.80ms +step:987/1680 train_time:86662ms step_avg:87.80ms +step:988/1680 train_time:86751ms step_avg:87.80ms +step:989/1680 train_time:86838ms step_avg:87.80ms +step:990/1680 train_time:86927ms step_avg:87.80ms +step:991/1680 train_time:87015ms step_avg:87.80ms +step:992/1680 train_time:87103ms step_avg:87.81ms +step:993/1680 train_time:87192ms step_avg:87.81ms +step:994/1680 train_time:87280ms step_avg:87.81ms +step:995/1680 train_time:87369ms step_avg:87.81ms +step:996/1680 train_time:87457ms step_avg:87.81ms +step:997/1680 train_time:87547ms step_avg:87.81ms +step:998/1680 train_time:87635ms step_avg:87.81ms +step:999/1680 train_time:87723ms step_avg:87.81ms +step:1000/1680 train_time:87811ms step_avg:87.81ms +step:1000/1680 val_loss:3.4690 train_time:87901ms step_avg:87.90ms +step:1001/1680 train_time:87919ms step_avg:87.83ms +step:1002/1680 train_time:87993ms step_avg:87.82ms +step:1003/1680 train_time:88085ms step_avg:87.82ms +step:1004/1680 train_time:88174ms step_avg:87.82ms +step:1005/1680 train_time:88262ms step_avg:87.82ms +step:1006/1680 train_time:88349ms step_avg:87.82ms +step:1007/1680 train_time:88436ms step_avg:87.82ms +step:1008/1680 train_time:88524ms step_avg:87.82ms +step:1009/1680 train_time:88611ms step_avg:87.82ms +step:1010/1680 train_time:88698ms step_avg:87.82ms +step:1011/1680 train_time:88785ms step_avg:87.82ms +step:1012/1680 train_time:88873ms step_avg:87.82ms +step:1013/1680 train_time:88963ms step_avg:87.82ms +step:1014/1680 train_time:89053ms step_avg:87.82ms +step:1015/1680 train_time:89142ms step_avg:87.82ms +step:1016/1680 train_time:89230ms step_avg:87.83ms +step:1017/1680 train_time:89319ms step_avg:87.83ms +step:1018/1680 train_time:89407ms step_avg:87.83ms +step:1019/1680 train_time:89495ms step_avg:87.83ms +step:1020/1680 train_time:89582ms step_avg:87.83ms +step:1021/1680 train_time:89669ms step_avg:87.82ms +step:1022/1680 train_time:89757ms step_avg:87.82ms +step:1023/1680 train_time:89845ms step_avg:87.82ms +step:1024/1680 train_time:89934ms step_avg:87.83ms +step:1025/1680 train_time:90023ms step_avg:87.83ms +step:1026/1680 train_time:90111ms step_avg:87.83ms +step:1027/1680 train_time:90200ms step_avg:87.83ms +step:1028/1680 train_time:90288ms step_avg:87.83ms +step:1029/1680 train_time:90377ms step_avg:87.83ms +step:1030/1680 train_time:90465ms step_avg:87.83ms +step:1031/1680 train_time:90553ms step_avg:87.83ms +step:1032/1680 train_time:90641ms step_avg:87.83ms +step:1033/1680 train_time:90728ms step_avg:87.83ms +step:1034/1680 train_time:90817ms step_avg:87.83ms +step:1035/1680 train_time:90906ms step_avg:87.83ms +step:1036/1680 train_time:90996ms step_avg:87.83ms +step:1037/1680 train_time:91084ms step_avg:87.83ms +step:1038/1680 train_time:91173ms step_avg:87.84ms +step:1039/1680 train_time:91261ms step_avg:87.84ms +step:1040/1680 train_time:91349ms step_avg:87.84ms +step:1041/1680 train_time:91438ms step_avg:87.84ms +step:1042/1680 train_time:91526ms step_avg:87.84ms +step:1043/1680 train_time:91614ms step_avg:87.84ms +step:1044/1680 train_time:91702ms step_avg:87.84ms +step:1045/1680 train_time:91790ms step_avg:87.84ms +step:1046/1680 train_time:91878ms step_avg:87.84ms +step:1047/1680 train_time:91967ms step_avg:87.84ms +step:1048/1680 train_time:92056ms step_avg:87.84ms +step:1049/1680 train_time:92145ms step_avg:87.84ms +step:1050/1680 train_time:92234ms step_avg:87.84ms +step:1051/1680 train_time:92322ms step_avg:87.84ms +step:1052/1680 train_time:92410ms step_avg:87.84ms +step:1053/1680 train_time:92498ms step_avg:87.84ms +step:1054/1680 train_time:92586ms step_avg:87.84ms +step:1055/1680 train_time:92674ms step_avg:87.84ms +step:1056/1680 train_time:92762ms step_avg:87.84ms +step:1057/1680 train_time:92850ms step_avg:87.84ms +step:1058/1680 train_time:92939ms step_avg:87.84ms +step:1059/1680 train_time:93027ms step_avg:87.84ms +step:1060/1680 train_time:93116ms step_avg:87.84ms +step:1061/1680 train_time:93204ms step_avg:87.85ms +step:1062/1680 train_time:93293ms step_avg:87.85ms +step:1063/1680 train_time:93381ms step_avg:87.85ms +step:1064/1680 train_time:93469ms step_avg:87.85ms +step:1065/1680 train_time:93557ms step_avg:87.85ms +step:1066/1680 train_time:93646ms step_avg:87.85ms +step:1067/1680 train_time:93734ms step_avg:87.85ms +step:1068/1680 train_time:93822ms step_avg:87.85ms +step:1069/1680 train_time:93911ms step_avg:87.85ms +step:1070/1680 train_time:94000ms step_avg:87.85ms +step:1071/1680 train_time:94088ms step_avg:87.85ms +step:1072/1680 train_time:94176ms step_avg:87.85ms +step:1073/1680 train_time:94265ms step_avg:87.85ms +step:1074/1680 train_time:94353ms step_avg:87.85ms +step:1075/1680 train_time:94441ms step_avg:87.85ms +step:1076/1680 train_time:94529ms step_avg:87.85ms +step:1077/1680 train_time:94617ms step_avg:87.85ms +step:1078/1680 train_time:94706ms step_avg:87.85ms +step:1079/1680 train_time:94794ms step_avg:87.85ms +step:1080/1680 train_time:94882ms step_avg:87.85ms +step:1081/1680 train_time:94971ms step_avg:87.85ms +step:1082/1680 train_time:95059ms step_avg:87.85ms +step:1083/1680 train_time:95148ms step_avg:87.86ms +step:1084/1680 train_time:95236ms step_avg:87.86ms +step:1085/1680 train_time:95325ms step_avg:87.86ms +step:1086/1680 train_time:95413ms step_avg:87.86ms +step:1087/1680 train_time:95501ms step_avg:87.86ms +step:1088/1680 train_time:95590ms step_avg:87.86ms +step:1089/1680 train_time:95677ms step_avg:87.86ms +step:1090/1680 train_time:95765ms step_avg:87.86ms +step:1091/1680 train_time:95854ms step_avg:87.86ms +step:1092/1680 train_time:95942ms step_avg:87.86ms +step:1093/1680 train_time:96031ms step_avg:87.86ms +step:1094/1680 train_time:96119ms step_avg:87.86ms +step:1095/1680 train_time:96208ms step_avg:87.86ms +step:1096/1680 train_time:96297ms step_avg:87.86ms +step:1097/1680 train_time:96386ms step_avg:87.86ms +step:1098/1680 train_time:96475ms step_avg:87.86ms +step:1099/1680 train_time:96564ms step_avg:87.87ms +step:1100/1680 train_time:96653ms step_avg:87.87ms +step:1101/1680 train_time:96742ms step_avg:87.87ms +step:1102/1680 train_time:96832ms step_avg:87.87ms +step:1103/1680 train_time:96921ms step_avg:87.87ms +step:1104/1680 train_time:97010ms step_avg:87.87ms +step:1105/1680 train_time:97099ms step_avg:87.87ms +step:1106/1680 train_time:97189ms step_avg:87.87ms +step:1107/1680 train_time:97278ms step_avg:87.88ms +step:1108/1680 train_time:97367ms step_avg:87.88ms +step:1109/1680 train_time:97455ms step_avg:87.88ms +step:1110/1680 train_time:97545ms step_avg:87.88ms +step:1111/1680 train_time:97635ms step_avg:87.88ms +step:1112/1680 train_time:97724ms step_avg:87.88ms +step:1113/1680 train_time:97813ms step_avg:87.88ms +step:1114/1680 train_time:97902ms step_avg:87.88ms +step:1115/1680 train_time:97991ms step_avg:87.88ms +step:1116/1680 train_time:98080ms step_avg:87.89ms +step:1117/1680 train_time:98169ms step_avg:87.89ms +step:1118/1680 train_time:98258ms step_avg:87.89ms +step:1119/1680 train_time:98347ms step_avg:87.89ms +step:1120/1680 train_time:98437ms step_avg:87.89ms +step:1121/1680 train_time:98526ms step_avg:87.89ms +step:1122/1680 train_time:98616ms step_avg:87.89ms +step:1123/1680 train_time:98705ms step_avg:87.89ms +step:1124/1680 train_time:98795ms step_avg:87.90ms +step:1125/1680 train_time:98884ms step_avg:87.90ms +step:1125/1680 val_loss:3.4153 train_time:98974ms step_avg:87.98ms +step:1126/1680 train_time:98994ms step_avg:87.92ms +step:1127/1680 train_time:99064ms step_avg:87.90ms +step:1128/1680 train_time:99157ms step_avg:87.90ms +step:1129/1680 train_time:99248ms step_avg:87.91ms +step:1130/1680 train_time:99337ms step_avg:87.91ms +step:1131/1680 train_time:99424ms step_avg:87.91ms +step:1132/1680 train_time:99513ms step_avg:87.91ms +step:1133/1680 train_time:99601ms step_avg:87.91ms +step:1134/1680 train_time:99689ms step_avg:87.91ms +step:1135/1680 train_time:99778ms step_avg:87.91ms +step:1136/1680 train_time:99867ms step_avg:87.91ms +step:1137/1680 train_time:99957ms step_avg:87.91ms +step:1138/1680 train_time:100047ms step_avg:87.91ms +step:1139/1680 train_time:100137ms step_avg:87.92ms +step:1140/1680 train_time:100228ms step_avg:87.92ms +step:1141/1680 train_time:100317ms step_avg:87.92ms +step:1142/1680 train_time:100406ms step_avg:87.92ms +step:1143/1680 train_time:100495ms step_avg:87.92ms +step:1144/1680 train_time:100584ms step_avg:87.92ms +step:1145/1680 train_time:100672ms step_avg:87.92ms +step:1146/1680 train_time:100760ms step_avg:87.92ms +step:1147/1680 train_time:100848ms step_avg:87.92ms +step:1148/1680 train_time:100938ms step_avg:87.93ms +step:1149/1680 train_time:101028ms step_avg:87.93ms +step:1150/1680 train_time:101118ms step_avg:87.93ms +step:1151/1680 train_time:101208ms step_avg:87.93ms +step:1152/1680 train_time:101298ms step_avg:87.93ms +step:1153/1680 train_time:101387ms step_avg:87.93ms +step:1154/1680 train_time:101476ms step_avg:87.93ms +step:1155/1680 train_time:101565ms step_avg:87.94ms +step:1156/1680 train_time:101654ms step_avg:87.94ms +step:1157/1680 train_time:101742ms step_avg:87.94ms +step:1158/1680 train_time:101830ms step_avg:87.94ms +step:1159/1680 train_time:101919ms step_avg:87.94ms +step:1160/1680 train_time:102008ms step_avg:87.94ms +step:1161/1680 train_time:102099ms step_avg:87.94ms +step:1162/1680 train_time:102189ms step_avg:87.94ms +step:1163/1680 train_time:102279ms step_avg:87.94ms +step:1164/1680 train_time:102369ms step_avg:87.95ms +step:1165/1680 train_time:102458ms step_avg:87.95ms +step:1166/1680 train_time:102547ms step_avg:87.95ms +step:1167/1680 train_time:102636ms step_avg:87.95ms +step:1168/1680 train_time:102725ms step_avg:87.95ms +step:1169/1680 train_time:102814ms step_avg:87.95ms +step:1170/1680 train_time:102903ms step_avg:87.95ms +step:1171/1680 train_time:102992ms step_avg:87.95ms +step:1172/1680 train_time:103081ms step_avg:87.95ms +step:1173/1680 train_time:103170ms step_avg:87.95ms +step:1174/1680 train_time:103259ms step_avg:87.95ms +step:1175/1680 train_time:103348ms step_avg:87.96ms +step:1176/1680 train_time:103437ms step_avg:87.96ms +step:1177/1680 train_time:103526ms step_avg:87.96ms +step:1178/1680 train_time:103615ms step_avg:87.96ms +step:1179/1680 train_time:103703ms step_avg:87.96ms +step:1180/1680 train_time:103791ms step_avg:87.96ms +step:1181/1680 train_time:103880ms step_avg:87.96ms +step:1182/1680 train_time:103970ms step_avg:87.96ms +step:1183/1680 train_time:104059ms step_avg:87.96ms +step:1184/1680 train_time:104149ms step_avg:87.96ms +step:1185/1680 train_time:104238ms step_avg:87.96ms +step:1186/1680 train_time:104326ms step_avg:87.96ms +step:1187/1680 train_time:104416ms step_avg:87.97ms +step:1188/1680 train_time:104504ms step_avg:87.97ms +step:1189/1680 train_time:104594ms step_avg:87.97ms +step:1190/1680 train_time:104684ms step_avg:87.97ms +step:1191/1680 train_time:104773ms step_avg:87.97ms +step:1192/1680 train_time:104863ms step_avg:87.97ms +step:1193/1680 train_time:104951ms step_avg:87.97ms +step:1194/1680 train_time:105040ms step_avg:87.97ms +step:1195/1680 train_time:105129ms step_avg:87.97ms +step:1196/1680 train_time:105218ms step_avg:87.97ms +step:1197/1680 train_time:105306ms step_avg:87.98ms +step:1198/1680 train_time:105396ms step_avg:87.98ms +step:1199/1680 train_time:105485ms step_avg:87.98ms +step:1200/1680 train_time:105575ms step_avg:87.98ms +step:1201/1680 train_time:105664ms step_avg:87.98ms +step:1202/1680 train_time:105752ms step_avg:87.98ms +step:1203/1680 train_time:105842ms step_avg:87.98ms +step:1204/1680 train_time:105930ms step_avg:87.98ms +step:1205/1680 train_time:106020ms step_avg:87.98ms +step:1206/1680 train_time:106109ms step_avg:87.98ms +step:1207/1680 train_time:106198ms step_avg:87.98ms +step:1208/1680 train_time:106287ms step_avg:87.99ms +step:1209/1680 train_time:106376ms step_avg:87.99ms +step:1210/1680 train_time:106465ms step_avg:87.99ms +step:1211/1680 train_time:106554ms step_avg:87.99ms +step:1212/1680 train_time:106643ms step_avg:87.99ms +step:1213/1680 train_time:106732ms step_avg:87.99ms +step:1214/1680 train_time:106821ms step_avg:87.99ms +step:1215/1680 train_time:106910ms step_avg:87.99ms +step:1216/1680 train_time:106999ms step_avg:87.99ms +step:1217/1680 train_time:107088ms step_avg:87.99ms +step:1218/1680 train_time:107176ms step_avg:87.99ms +step:1219/1680 train_time:107266ms step_avg:88.00ms +step:1220/1680 train_time:107356ms step_avg:88.00ms +step:1221/1680 train_time:107445ms step_avg:88.00ms +step:1222/1680 train_time:107535ms step_avg:88.00ms +step:1223/1680 train_time:107623ms step_avg:88.00ms +step:1224/1680 train_time:107711ms step_avg:88.00ms +step:1225/1680 train_time:107801ms step_avg:88.00ms +step:1226/1680 train_time:107890ms step_avg:88.00ms +step:1227/1680 train_time:107979ms step_avg:88.00ms +step:1228/1680 train_time:108069ms step_avg:88.00ms +step:1229/1680 train_time:108157ms step_avg:88.00ms +step:1230/1680 train_time:108247ms step_avg:88.01ms +step:1231/1680 train_time:108336ms step_avg:88.01ms +step:1232/1680 train_time:108425ms step_avg:88.01ms +step:1233/1680 train_time:108514ms step_avg:88.01ms +step:1234/1680 train_time:108603ms step_avg:88.01ms +step:1235/1680 train_time:108692ms step_avg:88.01ms +step:1236/1680 train_time:108782ms step_avg:88.01ms +step:1237/1680 train_time:108871ms step_avg:88.01ms +step:1238/1680 train_time:108962ms step_avg:88.01ms +step:1239/1680 train_time:109051ms step_avg:88.02ms +step:1240/1680 train_time:109139ms step_avg:88.02ms +step:1241/1680 train_time:109228ms step_avg:88.02ms +step:1242/1680 train_time:109317ms step_avg:88.02ms +step:1243/1680 train_time:109406ms step_avg:88.02ms +step:1244/1680 train_time:109495ms step_avg:88.02ms +step:1245/1680 train_time:109586ms step_avg:88.02ms +step:1246/1680 train_time:109675ms step_avg:88.02ms +step:1247/1680 train_time:109764ms step_avg:88.02ms +step:1248/1680 train_time:109854ms step_avg:88.02ms +step:1249/1680 train_time:109943ms step_avg:88.02ms +step:1250/1680 train_time:110031ms step_avg:88.03ms +step:1250/1680 val_loss:3.3777 train_time:110122ms step_avg:88.10ms +step:1251/1680 train_time:110140ms step_avg:88.04ms +step:1252/1680 train_time:110214ms step_avg:88.03ms +step:1253/1680 train_time:110309ms step_avg:88.04ms +step:1254/1680 train_time:110400ms step_avg:88.04ms +step:1255/1680 train_time:110488ms step_avg:88.04ms +step:1256/1680 train_time:110576ms step_avg:88.04ms +step:1257/1680 train_time:110664ms step_avg:88.04ms +step:1258/1680 train_time:110752ms step_avg:88.04ms +step:1259/1680 train_time:110840ms step_avg:88.04ms +step:1260/1680 train_time:110928ms step_avg:88.04ms +step:1261/1680 train_time:111016ms step_avg:88.04ms +step:1262/1680 train_time:111106ms step_avg:88.04ms +step:1263/1680 train_time:111197ms step_avg:88.04ms +step:1264/1680 train_time:111290ms step_avg:88.05ms +step:1265/1680 train_time:111380ms step_avg:88.05ms +step:1266/1680 train_time:111470ms step_avg:88.05ms +step:1267/1680 train_time:111559ms step_avg:88.05ms +step:1268/1680 train_time:111648ms step_avg:88.05ms +step:1269/1680 train_time:111736ms step_avg:88.05ms +step:1270/1680 train_time:111824ms step_avg:88.05ms +step:1271/1680 train_time:111911ms step_avg:88.05ms +step:1272/1680 train_time:112000ms step_avg:88.05ms +step:1273/1680 train_time:112090ms step_avg:88.05ms +step:1274/1680 train_time:112180ms step_avg:88.05ms +step:1275/1680 train_time:112270ms step_avg:88.06ms +step:1276/1680 train_time:112360ms step_avg:88.06ms +step:1277/1680 train_time:112449ms step_avg:88.06ms +step:1278/1680 train_time:112539ms step_avg:88.06ms +step:1279/1680 train_time:112627ms step_avg:88.06ms +step:1280/1680 train_time:112716ms step_avg:88.06ms +step:1281/1680 train_time:112805ms step_avg:88.06ms +step:1282/1680 train_time:112893ms step_avg:88.06ms +step:1283/1680 train_time:112982ms step_avg:88.06ms +step:1284/1680 train_time:113071ms step_avg:88.06ms +step:1285/1680 train_time:113160ms step_avg:88.06ms +step:1286/1680 train_time:113250ms step_avg:88.06ms +step:1287/1680 train_time:113340ms step_avg:88.07ms +step:1288/1680 train_time:113430ms step_avg:88.07ms +step:1289/1680 train_time:113520ms step_avg:88.07ms +step:1290/1680 train_time:113609ms step_avg:88.07ms +step:1291/1680 train_time:113698ms step_avg:88.07ms +step:1292/1680 train_time:113787ms step_avg:88.07ms +step:1293/1680 train_time:113875ms step_avg:88.07ms +step:1294/1680 train_time:113963ms step_avg:88.07ms +step:1295/1680 train_time:114052ms step_avg:88.07ms +step:1296/1680 train_time:114141ms step_avg:88.07ms +step:1297/1680 train_time:114232ms step_avg:88.07ms +step:1298/1680 train_time:114321ms step_avg:88.07ms +step:1299/1680 train_time:114410ms step_avg:88.08ms +step:1300/1680 train_time:114499ms step_avg:88.08ms +step:1301/1680 train_time:114588ms step_avg:88.08ms +step:1302/1680 train_time:114678ms step_avg:88.08ms +step:1303/1680 train_time:114767ms step_avg:88.08ms +step:1304/1680 train_time:114856ms step_avg:88.08ms +step:1305/1680 train_time:114946ms step_avg:88.08ms +step:1306/1680 train_time:115034ms step_avg:88.08ms +step:1307/1680 train_time:115123ms step_avg:88.08ms +step:1308/1680 train_time:115212ms step_avg:88.08ms +step:1309/1680 train_time:115301ms step_avg:88.08ms +step:1310/1680 train_time:115391ms step_avg:88.08ms +step:1311/1680 train_time:115480ms step_avg:88.09ms +step:1312/1680 train_time:115569ms step_avg:88.09ms +step:1313/1680 train_time:115658ms step_avg:88.09ms +step:1314/1680 train_time:115748ms step_avg:88.09ms +step:1315/1680 train_time:115837ms step_avg:88.09ms +step:1316/1680 train_time:115927ms step_avg:88.09ms +step:1317/1680 train_time:116015ms step_avg:88.09ms +step:1318/1680 train_time:116103ms step_avg:88.09ms +step:1319/1680 train_time:116192ms step_avg:88.09ms +step:1320/1680 train_time:116281ms step_avg:88.09ms +step:1321/1680 train_time:116370ms step_avg:88.09ms +step:1322/1680 train_time:116460ms step_avg:88.09ms +step:1323/1680 train_time:116549ms step_avg:88.09ms +step:1324/1680 train_time:116639ms step_avg:88.10ms +step:1325/1680 train_time:116728ms step_avg:88.10ms +step:1326/1680 train_time:116817ms step_avg:88.10ms +step:1327/1680 train_time:116907ms step_avg:88.10ms +step:1328/1680 train_time:116996ms step_avg:88.10ms +step:1329/1680 train_time:117086ms step_avg:88.10ms +step:1330/1680 train_time:117174ms step_avg:88.10ms +step:1331/1680 train_time:117264ms step_avg:88.10ms +step:1332/1680 train_time:117352ms step_avg:88.10ms +step:1333/1680 train_time:117441ms step_avg:88.10ms +step:1334/1680 train_time:117529ms step_avg:88.10ms +step:1335/1680 train_time:117619ms step_avg:88.10ms +step:1336/1680 train_time:117708ms step_avg:88.10ms +step:1337/1680 train_time:117796ms step_avg:88.10ms +step:1338/1680 train_time:117885ms step_avg:88.11ms +step:1339/1680 train_time:117974ms step_avg:88.11ms +step:1340/1680 train_time:118064ms step_avg:88.11ms +step:1341/1680 train_time:118153ms step_avg:88.11ms +step:1342/1680 train_time:118242ms step_avg:88.11ms +step:1343/1680 train_time:118331ms step_avg:88.11ms +step:1344/1680 train_time:118420ms step_avg:88.11ms +step:1345/1680 train_time:118510ms step_avg:88.11ms +step:1346/1680 train_time:118599ms step_avg:88.11ms +step:1347/1680 train_time:118688ms step_avg:88.11ms +step:1348/1680 train_time:118777ms step_avg:88.11ms +step:1349/1680 train_time:118867ms step_avg:88.11ms +step:1350/1680 train_time:118957ms step_avg:88.12ms +step:1351/1680 train_time:119047ms step_avg:88.12ms +step:1352/1680 train_time:119136ms step_avg:88.12ms +step:1353/1680 train_time:119226ms step_avg:88.12ms +step:1354/1680 train_time:119315ms step_avg:88.12ms +step:1355/1680 train_time:119404ms step_avg:88.12ms +step:1356/1680 train_time:119493ms step_avg:88.12ms +step:1357/1680 train_time:119582ms step_avg:88.12ms +step:1358/1680 train_time:119671ms step_avg:88.12ms +step:1359/1680 train_time:119760ms step_avg:88.12ms +step:1360/1680 train_time:119850ms step_avg:88.13ms +step:1361/1680 train_time:119939ms step_avg:88.13ms +step:1362/1680 train_time:120027ms step_avg:88.13ms +step:1363/1680 train_time:120117ms step_avg:88.13ms +step:1364/1680 train_time:120206ms step_avg:88.13ms +step:1365/1680 train_time:120295ms step_avg:88.13ms +step:1366/1680 train_time:120384ms step_avg:88.13ms +step:1367/1680 train_time:120473ms step_avg:88.13ms +step:1368/1680 train_time:120562ms step_avg:88.13ms +step:1369/1680 train_time:120651ms step_avg:88.13ms +step:1370/1680 train_time:120740ms step_avg:88.13ms +step:1371/1680 train_time:120829ms step_avg:88.13ms +step:1372/1680 train_time:120919ms step_avg:88.13ms +step:1373/1680 train_time:121009ms step_avg:88.13ms +step:1374/1680 train_time:121098ms step_avg:88.14ms +step:1375/1680 train_time:121187ms step_avg:88.14ms +step:1375/1680 val_loss:3.3433 train_time:121277ms step_avg:88.20ms +step:1376/1680 train_time:121296ms step_avg:88.15ms +step:1377/1680 train_time:121370ms step_avg:88.14ms +step:1378/1680 train_time:121462ms step_avg:88.14ms +step:1379/1680 train_time:121552ms step_avg:88.15ms +step:1380/1680 train_time:121640ms step_avg:88.14ms +step:1381/1680 train_time:121727ms step_avg:88.14ms +step:1382/1680 train_time:121815ms step_avg:88.14ms +step:1383/1680 train_time:121904ms step_avg:88.14ms +step:1384/1680 train_time:121991ms step_avg:88.14ms +step:1385/1680 train_time:122080ms step_avg:88.14ms +step:1386/1680 train_time:122168ms step_avg:88.14ms +step:1387/1680 train_time:122260ms step_avg:88.15ms +step:1388/1680 train_time:122352ms step_avg:88.15ms +step:1389/1680 train_time:122444ms step_avg:88.15ms +step:1390/1680 train_time:122536ms step_avg:88.16ms +step:1391/1680 train_time:122625ms step_avg:88.16ms +step:1392/1680 train_time:122713ms step_avg:88.16ms +step:1393/1680 train_time:122801ms step_avg:88.16ms +step:1394/1680 train_time:122889ms step_avg:88.16ms +step:1395/1680 train_time:122978ms step_avg:88.16ms +step:1396/1680 train_time:123066ms step_avg:88.16ms +step:1397/1680 train_time:123155ms step_avg:88.16ms +step:1398/1680 train_time:123244ms step_avg:88.16ms +step:1399/1680 train_time:123333ms step_avg:88.16ms +step:1400/1680 train_time:123423ms step_avg:88.16ms +step:1401/1680 train_time:123513ms step_avg:88.16ms +step:1402/1680 train_time:123603ms step_avg:88.16ms +step:1403/1680 train_time:123692ms step_avg:88.16ms +step:1404/1680 train_time:123780ms step_avg:88.16ms +step:1405/1680 train_time:123868ms step_avg:88.16ms +step:1406/1680 train_time:123956ms step_avg:88.16ms +step:1407/1680 train_time:124045ms step_avg:88.16ms +step:1408/1680 train_time:124134ms step_avg:88.16ms +step:1409/1680 train_time:124224ms step_avg:88.16ms +step:1410/1680 train_time:124313ms step_avg:88.17ms +step:1411/1680 train_time:124402ms step_avg:88.17ms +step:1412/1680 train_time:124492ms step_avg:88.17ms +step:1413/1680 train_time:124581ms step_avg:88.17ms +step:1414/1680 train_time:124670ms step_avg:88.17ms +step:1415/1680 train_time:124759ms step_avg:88.17ms +step:1416/1680 train_time:124848ms step_avg:88.17ms +step:1417/1680 train_time:124937ms step_avg:88.17ms +step:1418/1680 train_time:125026ms step_avg:88.17ms +step:1419/1680 train_time:125114ms step_avg:88.17ms +step:1420/1680 train_time:125204ms step_avg:88.17ms +step:1421/1680 train_time:125293ms step_avg:88.17ms +step:1422/1680 train_time:125383ms step_avg:88.17ms +step:1423/1680 train_time:125472ms step_avg:88.17ms +step:1424/1680 train_time:125562ms step_avg:88.18ms +step:1425/1680 train_time:125650ms step_avg:88.18ms +step:1426/1680 train_time:125740ms step_avg:88.18ms +step:1427/1680 train_time:125828ms step_avg:88.18ms +step:1428/1680 train_time:125917ms step_avg:88.18ms +step:1429/1680 train_time:126006ms step_avg:88.18ms +step:1430/1680 train_time:126095ms step_avg:88.18ms +step:1431/1680 train_time:126184ms step_avg:88.18ms +step:1432/1680 train_time:126272ms step_avg:88.18ms +step:1433/1680 train_time:126362ms step_avg:88.18ms +step:1434/1680 train_time:126451ms step_avg:88.18ms +step:1435/1680 train_time:126542ms step_avg:88.18ms +step:1436/1680 train_time:126630ms step_avg:88.18ms +step:1437/1680 train_time:126720ms step_avg:88.18ms +step:1438/1680 train_time:126809ms step_avg:88.18ms +step:1439/1680 train_time:126897ms step_avg:88.18ms +step:1440/1680 train_time:126986ms step_avg:88.18ms +step:1441/1680 train_time:127075ms step_avg:88.19ms +step:1442/1680 train_time:127164ms step_avg:88.19ms +step:1443/1680 train_time:127253ms step_avg:88.19ms +step:1444/1680 train_time:127342ms step_avg:88.19ms +step:1445/1680 train_time:127431ms step_avg:88.19ms +step:1446/1680 train_time:127520ms step_avg:88.19ms +step:1447/1680 train_time:127609ms step_avg:88.19ms +step:1448/1680 train_time:127699ms step_avg:88.19ms +step:1449/1680 train_time:127787ms step_avg:88.19ms +step:1450/1680 train_time:127876ms step_avg:88.19ms +step:1451/1680 train_time:127966ms step_avg:88.19ms +step:1452/1680 train_time:128055ms step_avg:88.19ms +step:1453/1680 train_time:128144ms step_avg:88.19ms +step:1454/1680 train_time:128233ms step_avg:88.19ms +step:1455/1680 train_time:128323ms step_avg:88.19ms +step:1456/1680 train_time:128412ms step_avg:88.19ms +step:1457/1680 train_time:128501ms step_avg:88.20ms +step:1458/1680 train_time:128591ms step_avg:88.20ms +step:1459/1680 train_time:128681ms step_avg:88.20ms +step:1460/1680 train_time:128769ms step_avg:88.20ms +step:1461/1680 train_time:128859ms step_avg:88.20ms +step:1462/1680 train_time:128948ms step_avg:88.20ms +step:1463/1680 train_time:129037ms step_avg:88.20ms +step:1464/1680 train_time:129126ms step_avg:88.20ms +step:1465/1680 train_time:129215ms step_avg:88.20ms +step:1466/1680 train_time:129304ms step_avg:88.20ms +step:1467/1680 train_time:129393ms step_avg:88.20ms +step:1468/1680 train_time:129481ms step_avg:88.20ms +step:1469/1680 train_time:129571ms step_avg:88.20ms +step:1470/1680 train_time:129660ms step_avg:88.20ms +step:1471/1680 train_time:129750ms step_avg:88.21ms +step:1472/1680 train_time:129840ms step_avg:88.21ms +step:1473/1680 train_time:129929ms step_avg:88.21ms +step:1474/1680 train_time:130018ms step_avg:88.21ms +step:1475/1680 train_time:130107ms step_avg:88.21ms +step:1476/1680 train_time:130196ms step_avg:88.21ms +step:1477/1680 train_time:130286ms step_avg:88.21ms +step:1478/1680 train_time:130375ms step_avg:88.21ms +step:1479/1680 train_time:130463ms step_avg:88.21ms +step:1480/1680 train_time:130552ms step_avg:88.21ms +step:1481/1680 train_time:130643ms step_avg:88.21ms +step:1482/1680 train_time:130732ms step_avg:88.21ms +step:1483/1680 train_time:130822ms step_avg:88.21ms +step:1484/1680 train_time:130911ms step_avg:88.21ms +step:1485/1680 train_time:131001ms step_avg:88.22ms +step:1486/1680 train_time:131090ms step_avg:88.22ms +step:1487/1680 train_time:131178ms step_avg:88.22ms +step:1488/1680 train_time:131268ms step_avg:88.22ms +step:1489/1680 train_time:131357ms step_avg:88.22ms +step:1490/1680 train_time:131447ms step_avg:88.22ms +step:1491/1680 train_time:131536ms step_avg:88.22ms +step:1492/1680 train_time:131625ms step_avg:88.22ms +step:1493/1680 train_time:131714ms step_avg:88.22ms +step:1494/1680 train_time:131803ms step_avg:88.22ms +step:1495/1680 train_time:131892ms step_avg:88.22ms +step:1496/1680 train_time:131982ms step_avg:88.22ms +step:1497/1680 train_time:132071ms step_avg:88.22ms +step:1498/1680 train_time:132160ms step_avg:88.22ms +step:1499/1680 train_time:132251ms step_avg:88.23ms +step:1500/1680 train_time:132340ms step_avg:88.23ms +step:1500/1680 val_loss:3.3137 train_time:132431ms step_avg:88.29ms +step:1501/1680 train_time:132449ms step_avg:88.24ms +step:1502/1680 train_time:132524ms step_avg:88.23ms +step:1503/1680 train_time:132618ms step_avg:88.24ms +step:1504/1680 train_time:132707ms step_avg:88.24ms +step:1505/1680 train_time:132796ms step_avg:88.24ms +step:1506/1680 train_time:132883ms step_avg:88.24ms +step:1507/1680 train_time:132972ms step_avg:88.24ms +step:1508/1680 train_time:133059ms step_avg:88.24ms +step:1509/1680 train_time:133147ms step_avg:88.24ms +step:1510/1680 train_time:133236ms step_avg:88.24ms +step:1511/1680 train_time:133324ms step_avg:88.24ms +step:1512/1680 train_time:133415ms step_avg:88.24ms +step:1513/1680 train_time:133506ms step_avg:88.24ms +step:1514/1680 train_time:133597ms step_avg:88.24ms +step:1515/1680 train_time:133688ms step_avg:88.24ms +step:1516/1680 train_time:133777ms step_avg:88.24ms +step:1517/1680 train_time:133866ms step_avg:88.24ms +step:1518/1680 train_time:133953ms step_avg:88.24ms +step:1519/1680 train_time:134042ms step_avg:88.24ms +step:1520/1680 train_time:134130ms step_avg:88.24ms +step:1521/1680 train_time:134218ms step_avg:88.24ms +step:1522/1680 train_time:134306ms step_avg:88.24ms +step:1523/1680 train_time:134395ms step_avg:88.24ms +step:1524/1680 train_time:134486ms step_avg:88.25ms +step:1525/1680 train_time:134577ms step_avg:88.25ms +step:1526/1680 train_time:134668ms step_avg:88.25ms +step:1527/1680 train_time:134757ms step_avg:88.25ms +step:1528/1680 train_time:134846ms step_avg:88.25ms +step:1529/1680 train_time:134935ms step_avg:88.25ms +step:1530/1680 train_time:135023ms step_avg:88.25ms +step:1531/1680 train_time:135111ms step_avg:88.25ms +step:1532/1680 train_time:135200ms step_avg:88.25ms +step:1533/1680 train_time:135289ms step_avg:88.25ms +step:1534/1680 train_time:135378ms step_avg:88.25ms +step:1535/1680 train_time:135467ms step_avg:88.25ms +step:1536/1680 train_time:135558ms step_avg:88.25ms +step:1537/1680 train_time:135649ms step_avg:88.26ms +step:1538/1680 train_time:135739ms step_avg:88.26ms +step:1539/1680 train_time:135828ms step_avg:88.26ms +step:1540/1680 train_time:135917ms step_avg:88.26ms +step:1541/1680 train_time:136006ms step_avg:88.26ms +step:1542/1680 train_time:136094ms step_avg:88.26ms +step:1543/1680 train_time:136184ms step_avg:88.26ms +step:1544/1680 train_time:136272ms step_avg:88.26ms +step:1545/1680 train_time:136361ms step_avg:88.26ms +step:1546/1680 train_time:136450ms step_avg:88.26ms +step:1547/1680 train_time:136540ms step_avg:88.26ms +step:1548/1680 train_time:136629ms step_avg:88.26ms +step:1549/1680 train_time:136719ms step_avg:88.26ms +step:1550/1680 train_time:136808ms step_avg:88.26ms +step:1551/1680 train_time:136898ms step_avg:88.26ms +step:1552/1680 train_time:136988ms step_avg:88.27ms +step:1553/1680 train_time:137076ms step_avg:88.27ms +step:1554/1680 train_time:137164ms step_avg:88.27ms +step:1555/1680 train_time:137253ms step_avg:88.27ms +step:1556/1680 train_time:137343ms step_avg:88.27ms +step:1557/1680 train_time:137431ms step_avg:88.27ms +step:1558/1680 train_time:137522ms step_avg:88.27ms +step:1559/1680 train_time:137611ms step_avg:88.27ms +step:1560/1680 train_time:137702ms step_avg:88.27ms +step:1561/1680 train_time:137791ms step_avg:88.27ms +step:1562/1680 train_time:137880ms step_avg:88.27ms +step:1563/1680 train_time:137969ms step_avg:88.27ms +step:1564/1680 train_time:138058ms step_avg:88.27ms +step:1565/1680 train_time:138146ms step_avg:88.27ms +step:1566/1680 train_time:138235ms step_avg:88.27ms +step:1567/1680 train_time:138324ms step_avg:88.27ms +step:1568/1680 train_time:138413ms step_avg:88.27ms +step:1569/1680 train_time:138501ms step_avg:88.27ms +step:1570/1680 train_time:138590ms step_avg:88.27ms +step:1571/1680 train_time:138680ms step_avg:88.27ms +step:1572/1680 train_time:138768ms step_avg:88.28ms +step:1573/1680 train_time:138858ms step_avg:88.28ms +step:1574/1680 train_time:138948ms step_avg:88.28ms +step:1575/1680 train_time:139037ms step_avg:88.28ms +step:1576/1680 train_time:139125ms step_avg:88.28ms +step:1577/1680 train_time:139214ms step_avg:88.28ms +step:1578/1680 train_time:139303ms step_avg:88.28ms +step:1579/1680 train_time:139392ms step_avg:88.28ms +step:1580/1680 train_time:139482ms step_avg:88.28ms +step:1581/1680 train_time:139571ms step_avg:88.28ms +step:1582/1680 train_time:139662ms step_avg:88.28ms +step:1583/1680 train_time:139751ms step_avg:88.28ms +step:1584/1680 train_time:139840ms step_avg:88.28ms +step:1585/1680 train_time:139929ms step_avg:88.28ms +step:1586/1680 train_time:140019ms step_avg:88.28ms +step:1587/1680 train_time:140108ms step_avg:88.28ms +step:1588/1680 train_time:140197ms step_avg:88.29ms +step:1589/1680 train_time:140287ms step_avg:88.29ms +step:1590/1680 train_time:140375ms step_avg:88.29ms +step:1591/1680 train_time:140464ms step_avg:88.29ms +step:1592/1680 train_time:140552ms step_avg:88.29ms +step:1593/1680 train_time:140642ms step_avg:88.29ms +step:1594/1680 train_time:140731ms step_avg:88.29ms +step:1595/1680 train_time:140820ms step_avg:88.29ms +step:1596/1680 train_time:140909ms step_avg:88.29ms +step:1597/1680 train_time:140999ms step_avg:88.29ms +step:1598/1680 train_time:141088ms step_avg:88.29ms +step:1599/1680 train_time:141178ms step_avg:88.29ms +step:1600/1680 train_time:141267ms step_avg:88.29ms +step:1601/1680 train_time:141356ms step_avg:88.29ms +step:1602/1680 train_time:141445ms step_avg:88.29ms +step:1603/1680 train_time:141532ms step_avg:88.29ms +step:1604/1680 train_time:141623ms step_avg:88.29ms +step:1605/1680 train_time:141712ms step_avg:88.29ms +step:1606/1680 train_time:141801ms step_avg:88.29ms +step:1607/1680 train_time:141889ms step_avg:88.29ms +step:1608/1680 train_time:141978ms step_avg:88.29ms +step:1609/1680 train_time:142068ms step_avg:88.30ms +step:1610/1680 train_time:142158ms step_avg:88.30ms +step:1611/1680 train_time:142247ms step_avg:88.30ms +step:1612/1680 train_time:142336ms step_avg:88.30ms +step:1613/1680 train_time:142426ms step_avg:88.30ms +step:1614/1680 train_time:142515ms step_avg:88.30ms +step:1615/1680 train_time:142605ms step_avg:88.30ms +step:1616/1680 train_time:142694ms step_avg:88.30ms +step:1617/1680 train_time:142784ms step_avg:88.30ms +step:1618/1680 train_time:142872ms step_avg:88.30ms +step:1619/1680 train_time:142961ms step_avg:88.30ms +step:1620/1680 train_time:143050ms step_avg:88.30ms +step:1621/1680 train_time:143140ms step_avg:88.30ms +step:1622/1680 train_time:143230ms step_avg:88.30ms +step:1623/1680 train_time:143320ms step_avg:88.31ms +step:1624/1680 train_time:143409ms step_avg:88.31ms +step:1625/1680 train_time:143498ms step_avg:88.31ms +step:1625/1680 val_loss:3.2902 train_time:143588ms step_avg:88.36ms +step:1626/1680 train_time:143606ms step_avg:88.32ms +step:1627/1680 train_time:143680ms step_avg:88.31ms +step:1628/1680 train_time:143773ms step_avg:88.31ms +step:1629/1680 train_time:143864ms step_avg:88.31ms +step:1630/1680 train_time:143953ms step_avg:88.31ms +step:1631/1680 train_time:144042ms step_avg:88.32ms +step:1632/1680 train_time:144130ms step_avg:88.32ms +step:1633/1680 train_time:144218ms step_avg:88.32ms +step:1634/1680 train_time:144307ms step_avg:88.32ms +step:1635/1680 train_time:144395ms step_avg:88.32ms +step:1636/1680 train_time:144484ms step_avg:88.32ms +step:1637/1680 train_time:144574ms step_avg:88.32ms +step:1638/1680 train_time:144665ms step_avg:88.32ms +step:1639/1680 train_time:144755ms step_avg:88.32ms +step:1640/1680 train_time:144845ms step_avg:88.32ms +step:1641/1680 train_time:144937ms step_avg:88.32ms +step:1642/1680 train_time:145026ms step_avg:88.32ms +step:1643/1680 train_time:145114ms step_avg:88.32ms +step:1644/1680 train_time:145203ms step_avg:88.32ms +step:1645/1680 train_time:145291ms step_avg:88.32ms +step:1646/1680 train_time:145379ms step_avg:88.32ms +step:1647/1680 train_time:145467ms step_avg:88.32ms +step:1648/1680 train_time:145556ms step_avg:88.32ms +step:1649/1680 train_time:145647ms step_avg:88.32ms +step:1650/1680 train_time:145736ms step_avg:88.33ms +step:1651/1680 train_time:145826ms step_avg:88.33ms +step:1652/1680 train_time:145917ms step_avg:88.33ms +step:1653/1680 train_time:146006ms step_avg:88.33ms +step:1654/1680 train_time:146094ms step_avg:88.33ms +step:1655/1680 train_time:146183ms step_avg:88.33ms +step:1656/1680 train_time:146271ms step_avg:88.33ms +step:1657/1680 train_time:146361ms step_avg:88.33ms +step:1658/1680 train_time:146449ms step_avg:88.33ms +step:1659/1680 train_time:146538ms step_avg:88.33ms +step:1660/1680 train_time:146628ms step_avg:88.33ms +step:1661/1680 train_time:146717ms step_avg:88.33ms +step:1662/1680 train_time:146807ms step_avg:88.33ms +step:1663/1680 train_time:146896ms step_avg:88.33ms +step:1664/1680 train_time:146986ms step_avg:88.33ms +step:1665/1680 train_time:147075ms step_avg:88.33ms +step:1666/1680 train_time:147164ms step_avg:88.33ms +step:1667/1680 train_time:147253ms step_avg:88.33ms +step:1668/1680 train_time:147342ms step_avg:88.33ms +step:1669/1680 train_time:147431ms step_avg:88.33ms +step:1670/1680 train_time:147521ms step_avg:88.34ms +step:1671/1680 train_time:147610ms step_avg:88.34ms +step:1672/1680 train_time:147699ms step_avg:88.34ms +step:1673/1680 train_time:147789ms step_avg:88.34ms +step:1674/1680 train_time:147879ms step_avg:88.34ms +step:1675/1680 train_time:147969ms step_avg:88.34ms +step:1676/1680 train_time:148059ms step_avg:88.34ms +step:1677/1680 train_time:148149ms step_avg:88.34ms +step:1678/1680 train_time:148238ms step_avg:88.34ms +step:1679/1680 train_time:148327ms step_avg:88.34ms +step:1680/1680 train_time:148416ms step_avg:88.34ms +step:1680/1680 val_loss:3.2805 train_time:148507ms step_avg:88.40ms +peak memory allocated: 30760 MiB reserved: 45934 MiB diff --git a/records/092725_BF16CE/e6622691-5ab5-4066-995d-41dada989dab.txt b/records/092725_BF16CE/e6622691-5ab5-4066-995d-41dada989dab.txt new file mode 100644 index 000000000..820cb3fa5 --- /dev/null +++ b/records/092725_BF16CE/e6622691-5ab5-4066-995d-41dada989dab.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:37:57 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 24C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 28C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 29C P0 123W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 28C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 26C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 160405 C /usr/bin/python 0MiB | +| 0 N/A N/A 160406 C /usr/bin/python 0MiB | +| 0 N/A N/A 160407 C /usr/bin/python 0MiB | +| 0 N/A N/A 160408 C /usr/bin/python 0MiB | +| 0 N/A N/A 160409 C /usr/bin/python 0MiB | +| 0 N/A N/A 160410 C /usr/bin/python 0MiB | +| 0 N/A N/A 160411 C /usr/bin/python 0MiB | +| 0 N/A N/A 160412 C /usr/bin/python 0MiB | +| 1 N/A N/A 160406 C /usr/bin/python 0MiB | +| 2 N/A N/A 160407 C /usr/bin/python 0MiB | +| 3 N/A N/A 160408 C /usr/bin/python 0MiB | +| 4 N/A N/A 160409 C /usr/bin/python 0MiB | +| 5 N/A N/A 160410 C /usr/bin/python 0MiB | +| 6 N/A N/A 160411 C /usr/bin/python 0MiB | +| 7 N/A N/A 160412 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:149ms step_avg:149.48ms +step:2/1680 train_time:169ms step_avg:84.48ms +step:3/1680 train_time:236ms step_avg:78.63ms +step:4/1680 train_time:321ms step_avg:80.22ms +step:5/1680 train_time:407ms step_avg:81.38ms +step:6/1680 train_time:493ms step_avg:82.16ms +step:7/1680 train_time:579ms step_avg:82.71ms +step:8/1680 train_time:665ms step_avg:83.17ms +step:9/1680 train_time:752ms step_avg:83.53ms +step:10/1680 train_time:838ms step_avg:83.82ms +step:11/1680 train_time:925ms step_avg:84.05ms +step:12/1680 train_time:1012ms step_avg:84.34ms +step:13/1680 train_time:1104ms step_avg:84.95ms +step:14/1680 train_time:1193ms step_avg:85.23ms +step:15/1680 train_time:1281ms step_avg:85.39ms +step:16/1680 train_time:1369ms step_avg:85.55ms +step:17/1680 train_time:1455ms step_avg:85.61ms +step:18/1680 train_time:1542ms step_avg:85.64ms +step:19/1680 train_time:1628ms step_avg:85.68ms +step:20/1680 train_time:1715ms step_avg:85.73ms +step:21/1680 train_time:1801ms step_avg:85.78ms +step:22/1680 train_time:1888ms step_avg:85.82ms +step:23/1680 train_time:1976ms step_avg:85.92ms +step:24/1680 train_time:2065ms step_avg:86.06ms +step:25/1680 train_time:2154ms step_avg:86.17ms +step:26/1680 train_time:2243ms step_avg:86.26ms +step:27/1680 train_time:2330ms step_avg:86.30ms +step:28/1680 train_time:2418ms step_avg:86.35ms +step:29/1680 train_time:2505ms step_avg:86.38ms +step:30/1680 train_time:2592ms step_avg:86.41ms +step:31/1680 train_time:2679ms step_avg:86.42ms +step:32/1680 train_time:2765ms step_avg:86.41ms +step:33/1680 train_time:2852ms step_avg:86.44ms +step:34/1680 train_time:2939ms step_avg:86.45ms +step:35/1680 train_time:3028ms step_avg:86.51ms +step:36/1680 train_time:3117ms step_avg:86.57ms +step:37/1680 train_time:3205ms step_avg:86.63ms +step:38/1680 train_time:3294ms step_avg:86.67ms +step:39/1680 train_time:3381ms step_avg:86.68ms +step:40/1680 train_time:3468ms step_avg:86.71ms +step:41/1680 train_time:3556ms step_avg:86.73ms +step:42/1680 train_time:3643ms step_avg:86.74ms +step:43/1680 train_time:3729ms step_avg:86.73ms +step:44/1680 train_time:3817ms step_avg:86.74ms +step:45/1680 train_time:3905ms step_avg:86.78ms +step:46/1680 train_time:3993ms step_avg:86.80ms +step:47/1680 train_time:4082ms step_avg:86.85ms +step:48/1680 train_time:4170ms step_avg:86.88ms +step:49/1680 train_time:4258ms step_avg:86.91ms +step:50/1680 train_time:4346ms step_avg:86.93ms +step:51/1680 train_time:4434ms step_avg:86.94ms +step:52/1680 train_time:4521ms step_avg:86.94ms +step:53/1680 train_time:4608ms step_avg:86.94ms +step:54/1680 train_time:4695ms step_avg:86.94ms +step:55/1680 train_time:4781ms step_avg:86.93ms +step:56/1680 train_time:4869ms step_avg:86.94ms +step:57/1680 train_time:4956ms step_avg:86.95ms +step:58/1680 train_time:5045ms step_avg:86.98ms +step:59/1680 train_time:5132ms step_avg:86.99ms +step:60/1680 train_time:5221ms step_avg:87.01ms +step:61/1680 train_time:5308ms step_avg:87.01ms +step:62/1680 train_time:5395ms step_avg:87.02ms +step:63/1680 train_time:5483ms step_avg:87.04ms +step:64/1680 train_time:5570ms step_avg:87.03ms +step:65/1680 train_time:5657ms step_avg:87.03ms +step:66/1680 train_time:5743ms step_avg:87.02ms +step:67/1680 train_time:5830ms step_avg:87.02ms +step:68/1680 train_time:5918ms step_avg:87.03ms +step:69/1680 train_time:6006ms step_avg:87.04ms +step:70/1680 train_time:6093ms step_avg:87.04ms +step:71/1680 train_time:6181ms step_avg:87.05ms +step:72/1680 train_time:6267ms step_avg:87.05ms +step:73/1680 train_time:6355ms step_avg:87.05ms +step:74/1680 train_time:6442ms step_avg:87.06ms +step:75/1680 train_time:6530ms step_avg:87.06ms +step:76/1680 train_time:6617ms step_avg:87.06ms +step:77/1680 train_time:6704ms step_avg:87.06ms +step:78/1680 train_time:6790ms step_avg:87.05ms +step:79/1680 train_time:6878ms step_avg:87.06ms +step:80/1680 train_time:6965ms step_avg:87.07ms +step:81/1680 train_time:7052ms step_avg:87.07ms +step:82/1680 train_time:7140ms step_avg:87.07ms +step:83/1680 train_time:7228ms step_avg:87.08ms +step:84/1680 train_time:7315ms step_avg:87.09ms +step:85/1680 train_time:7402ms step_avg:87.09ms +step:86/1680 train_time:7489ms step_avg:87.09ms +step:87/1680 train_time:7577ms step_avg:87.09ms +step:88/1680 train_time:7664ms step_avg:87.09ms +step:89/1680 train_time:7751ms step_avg:87.09ms +step:90/1680 train_time:7838ms step_avg:87.09ms +step:91/1680 train_time:7926ms step_avg:87.10ms +step:92/1680 train_time:8013ms step_avg:87.10ms +step:93/1680 train_time:8100ms step_avg:87.10ms +step:94/1680 train_time:8187ms step_avg:87.09ms +step:95/1680 train_time:8274ms step_avg:87.09ms +step:96/1680 train_time:8361ms step_avg:87.09ms +step:97/1680 train_time:8448ms step_avg:87.09ms +step:98/1680 train_time:8535ms step_avg:87.09ms +step:99/1680 train_time:8623ms step_avg:87.10ms +step:100/1680 train_time:8710ms step_avg:87.10ms +step:101/1680 train_time:8799ms step_avg:87.11ms +step:102/1680 train_time:8885ms step_avg:87.11ms +step:103/1680 train_time:8972ms step_avg:87.11ms +step:104/1680 train_time:9059ms step_avg:87.11ms +step:105/1680 train_time:9147ms step_avg:87.11ms +step:106/1680 train_time:9234ms step_avg:87.12ms +step:107/1680 train_time:9321ms step_avg:87.11ms +step:108/1680 train_time:9408ms step_avg:87.11ms +step:109/1680 train_time:9497ms step_avg:87.12ms +step:110/1680 train_time:9584ms step_avg:87.13ms +step:111/1680 train_time:9671ms step_avg:87.13ms +step:112/1680 train_time:9759ms step_avg:87.13ms +step:113/1680 train_time:9845ms step_avg:87.13ms +step:114/1680 train_time:9933ms step_avg:87.13ms +step:115/1680 train_time:10020ms step_avg:87.13ms +step:116/1680 train_time:10107ms step_avg:87.13ms +step:117/1680 train_time:10195ms step_avg:87.14ms +step:118/1680 train_time:10282ms step_avg:87.14ms +step:119/1680 train_time:10369ms step_avg:87.14ms +step:120/1680 train_time:10456ms step_avg:87.14ms +step:121/1680 train_time:10543ms step_avg:87.13ms +step:122/1680 train_time:10630ms step_avg:87.13ms +step:123/1680 train_time:10717ms step_avg:87.13ms +step:124/1680 train_time:10805ms step_avg:87.13ms +step:125/1680 train_time:10891ms step_avg:87.13ms +step:125/1680 val_loss:4.3060 train_time:10980ms step_avg:87.84ms +step:126/1680 train_time:11004ms step_avg:87.34ms +step:127/1680 train_time:11069ms step_avg:87.16ms +step:128/1680 train_time:11163ms step_avg:87.21ms +step:129/1680 train_time:11253ms step_avg:87.24ms +step:130/1680 train_time:11341ms step_avg:87.24ms +step:131/1680 train_time:11427ms step_avg:87.23ms +step:132/1680 train_time:11514ms step_avg:87.23ms +step:133/1680 train_time:11600ms step_avg:87.22ms +step:134/1680 train_time:11686ms step_avg:87.21ms +step:135/1680 train_time:11772ms step_avg:87.20ms +step:136/1680 train_time:11858ms step_avg:87.19ms +step:137/1680 train_time:11945ms step_avg:87.19ms +step:138/1680 train_time:12033ms step_avg:87.19ms +step:139/1680 train_time:12123ms step_avg:87.21ms +step:140/1680 train_time:12212ms step_avg:87.23ms +step:141/1680 train_time:12300ms step_avg:87.24ms +step:142/1680 train_time:12388ms step_avg:87.24ms +step:143/1680 train_time:12475ms step_avg:87.23ms +step:144/1680 train_time:12561ms step_avg:87.23ms +step:145/1680 train_time:12647ms step_avg:87.22ms +step:146/1680 train_time:12733ms step_avg:87.21ms +step:147/1680 train_time:12819ms step_avg:87.20ms +step:148/1680 train_time:12905ms step_avg:87.20ms +step:149/1680 train_time:12993ms step_avg:87.20ms +step:150/1680 train_time:13081ms step_avg:87.20ms +step:151/1680 train_time:13170ms step_avg:87.22ms +step:152/1680 train_time:13258ms step_avg:87.22ms +step:153/1680 train_time:13346ms step_avg:87.23ms +step:154/1680 train_time:13433ms step_avg:87.23ms +step:155/1680 train_time:13520ms step_avg:87.23ms +step:156/1680 train_time:13607ms step_avg:87.22ms +step:157/1680 train_time:13693ms step_avg:87.22ms +step:158/1680 train_time:13780ms step_avg:87.21ms +step:159/1680 train_time:13866ms step_avg:87.21ms +step:160/1680 train_time:13953ms step_avg:87.21ms +step:161/1680 train_time:14041ms step_avg:87.21ms +step:162/1680 train_time:14129ms step_avg:87.22ms +step:163/1680 train_time:14217ms step_avg:87.22ms +step:164/1680 train_time:14305ms step_avg:87.23ms +step:165/1680 train_time:14393ms step_avg:87.23ms +step:166/1680 train_time:14481ms step_avg:87.23ms +step:167/1680 train_time:14567ms step_avg:87.23ms +step:168/1680 train_time:14654ms step_avg:87.23ms +step:169/1680 train_time:14741ms step_avg:87.23ms +step:170/1680 train_time:14827ms step_avg:87.22ms +step:171/1680 train_time:14914ms step_avg:87.22ms +step:172/1680 train_time:15001ms step_avg:87.22ms +step:173/1680 train_time:15089ms step_avg:87.22ms +step:174/1680 train_time:15176ms step_avg:87.22ms +step:175/1680 train_time:15264ms step_avg:87.22ms +step:176/1680 train_time:15352ms step_avg:87.23ms +step:177/1680 train_time:15439ms step_avg:87.23ms +step:178/1680 train_time:15526ms step_avg:87.23ms +step:179/1680 train_time:15613ms step_avg:87.22ms +step:180/1680 train_time:15700ms step_avg:87.22ms +step:181/1680 train_time:15786ms step_avg:87.22ms +step:182/1680 train_time:15874ms step_avg:87.22ms +step:183/1680 train_time:15961ms step_avg:87.22ms +step:184/1680 train_time:16048ms step_avg:87.22ms +step:185/1680 train_time:16135ms step_avg:87.22ms +step:186/1680 train_time:16222ms step_avg:87.22ms +step:187/1680 train_time:16310ms step_avg:87.22ms +step:188/1680 train_time:16397ms step_avg:87.22ms +step:189/1680 train_time:16486ms step_avg:87.22ms +step:190/1680 train_time:16573ms step_avg:87.23ms +step:191/1680 train_time:16660ms step_avg:87.22ms +step:192/1680 train_time:16746ms step_avg:87.22ms +step:193/1680 train_time:16833ms step_avg:87.22ms +step:194/1680 train_time:16920ms step_avg:87.22ms +step:195/1680 train_time:17007ms step_avg:87.22ms +step:196/1680 train_time:17095ms step_avg:87.22ms +step:197/1680 train_time:17183ms step_avg:87.22ms +step:198/1680 train_time:17270ms step_avg:87.22ms +step:199/1680 train_time:17357ms step_avg:87.22ms +step:200/1680 train_time:17445ms step_avg:87.22ms +step:201/1680 train_time:17533ms step_avg:87.23ms +step:202/1680 train_time:17620ms step_avg:87.23ms +step:203/1680 train_time:17706ms step_avg:87.22ms +step:204/1680 train_time:17794ms step_avg:87.22ms +step:205/1680 train_time:17881ms step_avg:87.22ms +step:206/1680 train_time:17968ms step_avg:87.22ms +step:207/1680 train_time:18056ms step_avg:87.23ms +step:208/1680 train_time:18143ms step_avg:87.23ms +step:209/1680 train_time:18230ms step_avg:87.23ms +step:210/1680 train_time:18317ms step_avg:87.23ms +step:211/1680 train_time:18404ms step_avg:87.22ms +step:212/1680 train_time:18492ms step_avg:87.22ms +step:213/1680 train_time:18579ms step_avg:87.23ms +step:214/1680 train_time:18666ms step_avg:87.22ms +step:215/1680 train_time:18753ms step_avg:87.22ms +step:216/1680 train_time:18841ms step_avg:87.23ms +step:217/1680 train_time:18928ms step_avg:87.22ms +step:218/1680 train_time:19015ms step_avg:87.22ms +step:219/1680 train_time:19102ms step_avg:87.22ms +step:220/1680 train_time:19189ms step_avg:87.22ms +step:221/1680 train_time:19277ms step_avg:87.23ms +step:222/1680 train_time:19364ms step_avg:87.22ms +step:223/1680 train_time:19451ms step_avg:87.23ms +step:224/1680 train_time:19539ms step_avg:87.23ms +step:225/1680 train_time:19626ms step_avg:87.23ms +step:226/1680 train_time:19713ms step_avg:87.23ms +step:227/1680 train_time:19801ms step_avg:87.23ms +step:228/1680 train_time:19888ms step_avg:87.23ms +step:229/1680 train_time:19975ms step_avg:87.23ms +step:230/1680 train_time:20062ms step_avg:87.22ms +step:231/1680 train_time:20149ms step_avg:87.23ms +step:232/1680 train_time:20236ms step_avg:87.23ms +step:233/1680 train_time:20323ms step_avg:87.23ms +step:234/1680 train_time:20411ms step_avg:87.22ms +step:235/1680 train_time:20498ms step_avg:87.22ms +step:236/1680 train_time:20585ms step_avg:87.22ms +step:237/1680 train_time:20672ms step_avg:87.22ms +step:238/1680 train_time:20759ms step_avg:87.22ms +step:239/1680 train_time:20846ms step_avg:87.22ms +step:240/1680 train_time:20933ms step_avg:87.22ms +step:241/1680 train_time:21020ms step_avg:87.22ms +step:242/1680 train_time:21107ms step_avg:87.22ms +step:243/1680 train_time:21195ms step_avg:87.22ms +step:244/1680 train_time:21282ms step_avg:87.22ms +step:245/1680 train_time:21369ms step_avg:87.22ms +step:246/1680 train_time:21457ms step_avg:87.22ms +step:247/1680 train_time:21544ms step_avg:87.22ms +step:248/1680 train_time:21631ms step_avg:87.22ms +step:249/1680 train_time:21718ms step_avg:87.22ms +step:250/1680 train_time:21805ms step_avg:87.22ms +step:250/1680 val_loss:3.9809 train_time:21894ms step_avg:87.58ms +step:251/1680 train_time:21917ms step_avg:87.32ms +step:252/1680 train_time:21984ms step_avg:87.24ms +step:253/1680 train_time:22072ms step_avg:87.24ms +step:254/1680 train_time:22160ms step_avg:87.24ms +step:255/1680 train_time:22247ms step_avg:87.24ms +step:256/1680 train_time:22333ms step_avg:87.24ms +step:257/1680 train_time:22420ms step_avg:87.24ms +step:258/1680 train_time:22506ms step_avg:87.23ms +step:259/1680 train_time:22592ms step_avg:87.23ms +step:260/1680 train_time:22679ms step_avg:87.23ms +step:261/1680 train_time:22765ms step_avg:87.22ms +step:262/1680 train_time:22853ms step_avg:87.23ms +step:263/1680 train_time:22944ms step_avg:87.24ms +step:264/1680 train_time:23031ms step_avg:87.24ms +step:265/1680 train_time:23119ms step_avg:87.24ms +step:266/1680 train_time:23208ms step_avg:87.25ms +step:267/1680 train_time:23294ms step_avg:87.24ms +step:268/1680 train_time:23381ms step_avg:87.24ms +step:269/1680 train_time:23468ms step_avg:87.24ms +step:270/1680 train_time:23555ms step_avg:87.24ms +step:271/1680 train_time:23641ms step_avg:87.24ms +step:272/1680 train_time:23728ms step_avg:87.23ms +step:273/1680 train_time:23815ms step_avg:87.24ms +step:274/1680 train_time:23903ms step_avg:87.24ms +step:275/1680 train_time:23991ms step_avg:87.24ms +step:276/1680 train_time:24079ms step_avg:87.24ms +step:277/1680 train_time:24167ms step_avg:87.24ms +step:278/1680 train_time:24254ms step_avg:87.25ms +step:279/1680 train_time:24341ms step_avg:87.24ms +step:280/1680 train_time:24428ms step_avg:87.24ms +step:281/1680 train_time:24515ms step_avg:87.24ms +step:282/1680 train_time:24602ms step_avg:87.24ms +step:283/1680 train_time:24689ms step_avg:87.24ms +step:284/1680 train_time:24776ms step_avg:87.24ms +step:285/1680 train_time:24864ms step_avg:87.24ms +step:286/1680 train_time:24953ms step_avg:87.25ms +step:287/1680 train_time:25040ms step_avg:87.25ms +step:288/1680 train_time:25127ms step_avg:87.25ms +step:289/1680 train_time:25216ms step_avg:87.25ms +step:290/1680 train_time:25304ms step_avg:87.25ms +step:291/1680 train_time:25390ms step_avg:87.25ms +step:292/1680 train_time:25477ms step_avg:87.25ms +step:293/1680 train_time:25564ms step_avg:87.25ms +step:294/1680 train_time:25650ms step_avg:87.25ms +step:295/1680 train_time:25737ms step_avg:87.24ms +step:296/1680 train_time:25824ms step_avg:87.24ms +step:297/1680 train_time:25911ms step_avg:87.24ms +step:298/1680 train_time:25999ms step_avg:87.25ms +step:299/1680 train_time:26087ms step_avg:87.25ms +step:300/1680 train_time:26174ms step_avg:87.25ms +step:301/1680 train_time:26262ms step_avg:87.25ms +step:302/1680 train_time:26350ms step_avg:87.25ms +step:303/1680 train_time:26437ms step_avg:87.25ms +step:304/1680 train_time:26525ms step_avg:87.25ms +step:305/1680 train_time:26611ms step_avg:87.25ms +step:306/1680 train_time:26698ms step_avg:87.25ms +step:307/1680 train_time:26785ms step_avg:87.25ms +step:308/1680 train_time:26872ms step_avg:87.25ms +step:309/1680 train_time:26959ms step_avg:87.25ms +step:310/1680 train_time:27047ms step_avg:87.25ms +step:311/1680 train_time:27135ms step_avg:87.25ms +step:312/1680 train_time:27223ms step_avg:87.25ms +step:313/1680 train_time:27309ms step_avg:87.25ms +step:314/1680 train_time:27397ms step_avg:87.25ms +step:315/1680 train_time:27484ms step_avg:87.25ms +step:316/1680 train_time:27571ms step_avg:87.25ms +step:317/1680 train_time:27658ms step_avg:87.25ms +step:318/1680 train_time:27746ms step_avg:87.25ms +step:319/1680 train_time:27833ms step_avg:87.25ms +step:320/1680 train_time:27920ms step_avg:87.25ms +step:321/1680 train_time:28008ms step_avg:87.25ms +step:322/1680 train_time:28095ms step_avg:87.25ms +step:323/1680 train_time:28182ms step_avg:87.25ms +step:324/1680 train_time:28269ms step_avg:87.25ms +step:325/1680 train_time:28357ms step_avg:87.25ms +step:326/1680 train_time:28445ms step_avg:87.26ms +step:327/1680 train_time:28532ms step_avg:87.25ms +step:328/1680 train_time:28619ms step_avg:87.25ms +step:329/1680 train_time:28706ms step_avg:87.25ms +step:330/1680 train_time:28793ms step_avg:87.25ms +step:331/1680 train_time:28880ms step_avg:87.25ms +step:332/1680 train_time:28968ms step_avg:87.25ms +step:333/1680 train_time:29055ms step_avg:87.25ms +step:334/1680 train_time:29143ms step_avg:87.25ms +step:335/1680 train_time:29230ms step_avg:87.25ms +step:336/1680 train_time:29318ms step_avg:87.26ms +step:337/1680 train_time:29406ms step_avg:87.26ms +step:338/1680 train_time:29492ms step_avg:87.26ms +step:339/1680 train_time:29579ms step_avg:87.26ms +step:340/1680 train_time:29667ms step_avg:87.26ms +step:341/1680 train_time:29754ms step_avg:87.25ms +step:342/1680 train_time:29841ms step_avg:87.25ms +step:343/1680 train_time:29928ms step_avg:87.25ms +step:344/1680 train_time:30015ms step_avg:87.25ms +step:345/1680 train_time:30103ms step_avg:87.25ms +step:346/1680 train_time:30190ms step_avg:87.26ms +step:347/1680 train_time:30278ms step_avg:87.26ms +step:348/1680 train_time:30365ms step_avg:87.26ms +step:349/1680 train_time:30452ms step_avg:87.26ms +step:350/1680 train_time:30539ms step_avg:87.26ms +step:351/1680 train_time:30626ms step_avg:87.25ms +step:352/1680 train_time:30714ms step_avg:87.25ms +step:353/1680 train_time:30801ms step_avg:87.25ms +step:354/1680 train_time:30888ms step_avg:87.25ms +step:355/1680 train_time:30975ms step_avg:87.25ms +step:356/1680 train_time:31063ms step_avg:87.26ms +step:357/1680 train_time:31150ms step_avg:87.26ms +step:358/1680 train_time:31238ms step_avg:87.26ms +step:359/1680 train_time:31325ms step_avg:87.26ms +step:360/1680 train_time:31412ms step_avg:87.26ms +step:361/1680 train_time:31500ms step_avg:87.26ms +step:362/1680 train_time:31587ms step_avg:87.26ms +step:363/1680 train_time:31674ms step_avg:87.26ms +step:364/1680 train_time:31761ms step_avg:87.26ms +step:365/1680 train_time:31848ms step_avg:87.25ms +step:366/1680 train_time:31935ms step_avg:87.25ms +step:367/1680 train_time:32022ms step_avg:87.25ms +step:368/1680 train_time:32109ms step_avg:87.25ms +step:369/1680 train_time:32197ms step_avg:87.25ms +step:370/1680 train_time:32284ms step_avg:87.25ms +step:371/1680 train_time:32371ms step_avg:87.25ms +step:372/1680 train_time:32458ms step_avg:87.25ms +step:373/1680 train_time:32545ms step_avg:87.25ms +step:374/1680 train_time:32633ms step_avg:87.25ms +step:375/1680 train_time:32720ms step_avg:87.25ms +step:375/1680 val_loss:3.8227 train_time:32809ms step_avg:87.49ms +step:376/1680 train_time:32830ms step_avg:87.31ms +step:377/1680 train_time:32901ms step_avg:87.27ms +step:378/1680 train_time:32990ms step_avg:87.27ms +step:379/1680 train_time:33077ms step_avg:87.27ms +step:380/1680 train_time:33163ms step_avg:87.27ms +step:381/1680 train_time:33250ms step_avg:87.27ms +step:382/1680 train_time:33336ms step_avg:87.27ms +step:383/1680 train_time:33423ms step_avg:87.27ms +step:384/1680 train_time:33509ms step_avg:87.26ms +step:385/1680 train_time:33595ms step_avg:87.26ms +step:386/1680 train_time:33682ms step_avg:87.26ms +step:387/1680 train_time:33770ms step_avg:87.26ms +step:388/1680 train_time:33859ms step_avg:87.27ms +step:389/1680 train_time:33948ms step_avg:87.27ms +step:390/1680 train_time:34037ms step_avg:87.27ms +step:391/1680 train_time:34124ms step_avg:87.27ms +step:392/1680 train_time:34211ms step_avg:87.27ms +step:393/1680 train_time:34299ms step_avg:87.27ms +step:394/1680 train_time:34385ms step_avg:87.27ms +step:395/1680 train_time:34471ms step_avg:87.27ms +step:396/1680 train_time:34558ms step_avg:87.27ms +step:397/1680 train_time:34645ms step_avg:87.27ms +step:398/1680 train_time:34731ms step_avg:87.26ms +step:399/1680 train_time:34819ms step_avg:87.26ms +step:400/1680 train_time:34907ms step_avg:87.27ms +step:401/1680 train_time:34995ms step_avg:87.27ms +step:402/1680 train_time:35084ms step_avg:87.27ms +step:403/1680 train_time:35171ms step_avg:87.27ms +step:404/1680 train_time:35259ms step_avg:87.27ms +step:405/1680 train_time:35345ms step_avg:87.27ms +step:406/1680 train_time:35432ms step_avg:87.27ms +step:407/1680 train_time:35518ms step_avg:87.27ms +step:408/1680 train_time:35605ms step_avg:87.27ms +step:409/1680 train_time:35692ms step_avg:87.27ms +step:410/1680 train_time:35779ms step_avg:87.27ms +step:411/1680 train_time:35868ms step_avg:87.27ms +step:412/1680 train_time:35956ms step_avg:87.27ms +step:413/1680 train_time:36043ms step_avg:87.27ms +step:414/1680 train_time:36131ms step_avg:87.27ms +step:415/1680 train_time:36219ms step_avg:87.27ms +step:416/1680 train_time:36306ms step_avg:87.27ms +step:417/1680 train_time:36393ms step_avg:87.27ms +step:418/1680 train_time:36480ms step_avg:87.27ms +step:419/1680 train_time:36567ms step_avg:87.27ms +step:420/1680 train_time:36653ms step_avg:87.27ms +step:421/1680 train_time:36740ms step_avg:87.27ms +step:422/1680 train_time:36828ms step_avg:87.27ms +step:423/1680 train_time:36916ms step_avg:87.27ms +step:424/1680 train_time:37003ms step_avg:87.27ms +step:425/1680 train_time:37091ms step_avg:87.27ms +step:426/1680 train_time:37178ms step_avg:87.27ms +step:427/1680 train_time:37265ms step_avg:87.27ms +step:428/1680 train_time:37353ms step_avg:87.27ms +step:429/1680 train_time:37440ms step_avg:87.27ms +step:430/1680 train_time:37527ms step_avg:87.27ms +step:431/1680 train_time:37614ms step_avg:87.27ms +step:432/1680 train_time:37700ms step_avg:87.27ms +step:433/1680 train_time:37787ms step_avg:87.27ms +step:434/1680 train_time:37875ms step_avg:87.27ms +step:435/1680 train_time:37962ms step_avg:87.27ms +step:436/1680 train_time:38050ms step_avg:87.27ms +step:437/1680 train_time:38138ms step_avg:87.27ms +step:438/1680 train_time:38225ms step_avg:87.27ms +step:439/1680 train_time:38312ms step_avg:87.27ms +step:440/1680 train_time:38399ms step_avg:87.27ms +step:441/1680 train_time:38487ms step_avg:87.27ms +step:442/1680 train_time:38574ms step_avg:87.27ms +step:443/1680 train_time:38661ms step_avg:87.27ms +step:444/1680 train_time:38748ms step_avg:87.27ms +step:445/1680 train_time:38835ms step_avg:87.27ms +step:446/1680 train_time:38922ms step_avg:87.27ms +step:447/1680 train_time:39010ms step_avg:87.27ms +step:448/1680 train_time:39098ms step_avg:87.27ms +step:449/1680 train_time:39185ms step_avg:87.27ms +step:450/1680 train_time:39272ms step_avg:87.27ms +step:451/1680 train_time:39360ms step_avg:87.27ms +step:452/1680 train_time:39447ms step_avg:87.27ms +step:453/1680 train_time:39535ms step_avg:87.27ms +step:454/1680 train_time:39621ms step_avg:87.27ms +step:455/1680 train_time:39709ms step_avg:87.27ms +step:456/1680 train_time:39796ms step_avg:87.27ms +step:457/1680 train_time:39883ms step_avg:87.27ms +step:458/1680 train_time:39970ms step_avg:87.27ms +step:459/1680 train_time:40058ms step_avg:87.27ms +step:460/1680 train_time:40145ms step_avg:87.27ms +step:461/1680 train_time:40232ms step_avg:87.27ms +step:462/1680 train_time:40319ms step_avg:87.27ms +step:463/1680 train_time:40407ms step_avg:87.27ms +step:464/1680 train_time:40495ms step_avg:87.27ms +step:465/1680 train_time:40581ms step_avg:87.27ms +step:466/1680 train_time:40668ms step_avg:87.27ms +step:467/1680 train_time:40755ms step_avg:87.27ms +step:468/1680 train_time:40843ms step_avg:87.27ms +step:469/1680 train_time:40930ms step_avg:87.27ms +step:470/1680 train_time:41018ms step_avg:87.27ms +step:471/1680 train_time:41105ms step_avg:87.27ms +step:472/1680 train_time:41192ms step_avg:87.27ms +step:473/1680 train_time:41280ms step_avg:87.27ms +step:474/1680 train_time:41367ms step_avg:87.27ms +step:475/1680 train_time:41454ms step_avg:87.27ms +step:476/1680 train_time:41541ms step_avg:87.27ms +step:477/1680 train_time:41629ms step_avg:87.27ms +step:478/1680 train_time:41716ms step_avg:87.27ms +step:479/1680 train_time:41804ms step_avg:87.27ms +step:480/1680 train_time:41890ms step_avg:87.27ms +step:481/1680 train_time:41978ms step_avg:87.27ms +step:482/1680 train_time:42066ms step_avg:87.27ms +step:483/1680 train_time:42153ms step_avg:87.27ms +step:484/1680 train_time:42240ms step_avg:87.27ms +step:485/1680 train_time:42328ms step_avg:87.27ms +step:486/1680 train_time:42415ms step_avg:87.27ms +step:487/1680 train_time:42502ms step_avg:87.27ms +step:488/1680 train_time:42589ms step_avg:87.27ms +step:489/1680 train_time:42676ms step_avg:87.27ms +step:490/1680 train_time:42763ms step_avg:87.27ms +step:491/1680 train_time:42850ms step_avg:87.27ms +step:492/1680 train_time:42937ms step_avg:87.27ms +step:493/1680 train_time:43024ms step_avg:87.27ms +step:494/1680 train_time:43112ms step_avg:87.27ms +step:495/1680 train_time:43199ms step_avg:87.27ms +step:496/1680 train_time:43287ms step_avg:87.27ms +step:497/1680 train_time:43374ms step_avg:87.27ms +step:498/1680 train_time:43461ms step_avg:87.27ms +step:499/1680 train_time:43548ms step_avg:87.27ms +step:500/1680 train_time:43635ms step_avg:87.27ms +step:500/1680 val_loss:3.7186 train_time:43724ms step_avg:87.45ms +step:501/1680 train_time:43745ms step_avg:87.31ms +step:502/1680 train_time:43812ms step_avg:87.27ms +step:503/1680 train_time:43904ms step_avg:87.28ms +step:504/1680 train_time:43993ms step_avg:87.29ms +step:505/1680 train_time:44079ms step_avg:87.29ms +step:506/1680 train_time:44165ms step_avg:87.28ms +step:507/1680 train_time:44251ms step_avg:87.28ms +step:508/1680 train_time:44338ms step_avg:87.28ms +step:509/1680 train_time:44425ms step_avg:87.28ms +step:510/1680 train_time:44512ms step_avg:87.28ms +step:511/1680 train_time:44598ms step_avg:87.28ms +step:512/1680 train_time:44686ms step_avg:87.28ms +step:513/1680 train_time:44774ms step_avg:87.28ms +step:514/1680 train_time:44864ms step_avg:87.28ms +step:515/1680 train_time:44952ms step_avg:87.29ms +step:516/1680 train_time:45040ms step_avg:87.29ms +step:517/1680 train_time:45126ms step_avg:87.29ms +step:518/1680 train_time:45214ms step_avg:87.29ms +step:519/1680 train_time:45301ms step_avg:87.28ms +step:520/1680 train_time:45387ms step_avg:87.28ms +step:521/1680 train_time:45474ms step_avg:87.28ms +step:522/1680 train_time:45560ms step_avg:87.28ms +step:523/1680 train_time:45648ms step_avg:87.28ms +step:524/1680 train_time:45735ms step_avg:87.28ms +step:525/1680 train_time:45824ms step_avg:87.28ms +step:526/1680 train_time:45911ms step_avg:87.28ms +step:527/1680 train_time:45999ms step_avg:87.28ms +step:528/1680 train_time:46086ms step_avg:87.28ms +step:529/1680 train_time:46174ms step_avg:87.28ms +step:530/1680 train_time:46261ms step_avg:87.29ms +step:531/1680 train_time:46348ms step_avg:87.28ms +step:532/1680 train_time:46435ms step_avg:87.28ms +step:533/1680 train_time:46521ms step_avg:87.28ms +step:534/1680 train_time:46608ms step_avg:87.28ms +step:535/1680 train_time:46695ms step_avg:87.28ms +step:536/1680 train_time:46783ms step_avg:87.28ms +step:537/1680 train_time:46871ms step_avg:87.28ms +step:538/1680 train_time:46959ms step_avg:87.28ms +step:539/1680 train_time:47047ms step_avg:87.29ms +step:540/1680 train_time:47134ms step_avg:87.29ms +step:541/1680 train_time:47222ms step_avg:87.29ms +step:542/1680 train_time:47309ms step_avg:87.29ms +step:543/1680 train_time:47395ms step_avg:87.28ms +step:544/1680 train_time:47483ms step_avg:87.28ms +step:545/1680 train_time:47570ms step_avg:87.28ms +step:546/1680 train_time:47657ms step_avg:87.28ms +step:547/1680 train_time:47744ms step_avg:87.28ms +step:548/1680 train_time:47832ms step_avg:87.29ms +step:549/1680 train_time:47921ms step_avg:87.29ms +step:550/1680 train_time:48010ms step_avg:87.29ms +step:551/1680 train_time:48098ms step_avg:87.29ms +step:552/1680 train_time:48187ms step_avg:87.30ms +step:553/1680 train_time:48276ms step_avg:87.30ms +step:554/1680 train_time:48364ms step_avg:87.30ms +step:555/1680 train_time:48452ms step_avg:87.30ms +step:556/1680 train_time:48539ms step_avg:87.30ms +step:557/1680 train_time:48628ms step_avg:87.30ms +step:558/1680 train_time:48716ms step_avg:87.30ms +step:559/1680 train_time:48804ms step_avg:87.31ms +step:560/1680 train_time:48893ms step_avg:87.31ms +step:561/1680 train_time:48982ms step_avg:87.31ms +step:562/1680 train_time:49071ms step_avg:87.32ms +step:563/1680 train_time:49160ms step_avg:87.32ms +step:564/1680 train_time:49248ms step_avg:87.32ms +step:565/1680 train_time:49337ms step_avg:87.32ms +step:566/1680 train_time:49427ms step_avg:87.33ms +step:567/1680 train_time:49515ms step_avg:87.33ms +step:568/1680 train_time:49603ms step_avg:87.33ms +step:569/1680 train_time:49692ms step_avg:87.33ms +step:570/1680 train_time:49781ms step_avg:87.33ms +step:571/1680 train_time:49869ms step_avg:87.34ms +step:572/1680 train_time:49958ms step_avg:87.34ms +step:573/1680 train_time:50047ms step_avg:87.34ms +step:574/1680 train_time:50136ms step_avg:87.34ms +step:575/1680 train_time:50225ms step_avg:87.35ms +step:576/1680 train_time:50313ms step_avg:87.35ms +step:577/1680 train_time:50402ms step_avg:87.35ms +step:578/1680 train_time:50490ms step_avg:87.35ms +step:579/1680 train_time:50580ms step_avg:87.36ms +step:580/1680 train_time:50668ms step_avg:87.36ms +step:581/1680 train_time:50756ms step_avg:87.36ms +step:582/1680 train_time:50845ms step_avg:87.36ms +step:583/1680 train_time:50934ms step_avg:87.37ms +step:584/1680 train_time:51023ms step_avg:87.37ms +step:585/1680 train_time:51112ms step_avg:87.37ms +step:586/1680 train_time:51200ms step_avg:87.37ms +step:587/1680 train_time:51288ms step_avg:87.37ms +step:588/1680 train_time:51377ms step_avg:87.38ms +step:589/1680 train_time:51465ms step_avg:87.38ms +step:590/1680 train_time:51555ms step_avg:87.38ms +step:591/1680 train_time:51643ms step_avg:87.38ms +step:592/1680 train_time:51732ms step_avg:87.39ms +step:593/1680 train_time:51820ms step_avg:87.39ms +step:594/1680 train_time:51909ms step_avg:87.39ms +step:595/1680 train_time:51997ms step_avg:87.39ms +step:596/1680 train_time:52086ms step_avg:87.39ms +step:597/1680 train_time:52174ms step_avg:87.39ms +step:598/1680 train_time:52264ms step_avg:87.40ms +step:599/1680 train_time:52353ms step_avg:87.40ms +step:600/1680 train_time:52441ms step_avg:87.40ms +step:601/1680 train_time:52530ms step_avg:87.40ms +step:602/1680 train_time:52618ms step_avg:87.41ms +step:603/1680 train_time:52707ms step_avg:87.41ms +step:604/1680 train_time:52795ms step_avg:87.41ms +step:605/1680 train_time:52884ms step_avg:87.41ms +step:606/1680 train_time:52971ms step_avg:87.41ms +step:607/1680 train_time:53060ms step_avg:87.41ms +step:608/1680 train_time:53147ms step_avg:87.41ms +step:609/1680 train_time:53236ms step_avg:87.42ms +step:610/1680 train_time:53325ms step_avg:87.42ms +step:611/1680 train_time:53413ms step_avg:87.42ms +step:612/1680 train_time:53501ms step_avg:87.42ms +step:613/1680 train_time:53590ms step_avg:87.42ms +step:614/1680 train_time:53679ms step_avg:87.43ms +step:615/1680 train_time:53768ms step_avg:87.43ms +step:616/1680 train_time:53856ms step_avg:87.43ms +step:617/1680 train_time:53945ms step_avg:87.43ms +step:618/1680 train_time:54034ms step_avg:87.43ms +step:619/1680 train_time:54122ms step_avg:87.43ms +step:620/1680 train_time:54210ms step_avg:87.44ms +step:621/1680 train_time:54298ms step_avg:87.44ms +step:622/1680 train_time:54387ms step_avg:87.44ms +step:623/1680 train_time:54475ms step_avg:87.44ms +step:624/1680 train_time:54565ms step_avg:87.44ms +step:625/1680 train_time:54653ms step_avg:87.44ms +step:625/1680 val_loss:3.6169 train_time:54743ms step_avg:87.59ms +step:626/1680 train_time:54765ms step_avg:87.48ms +step:627/1680 train_time:54834ms step_avg:87.45ms +step:628/1680 train_time:54923ms step_avg:87.46ms +step:629/1680 train_time:55016ms step_avg:87.47ms +step:630/1680 train_time:55107ms step_avg:87.47ms +step:631/1680 train_time:55194ms step_avg:87.47ms +step:632/1680 train_time:55281ms step_avg:87.47ms +step:633/1680 train_time:55368ms step_avg:87.47ms +step:634/1680 train_time:55455ms step_avg:87.47ms +step:635/1680 train_time:55541ms step_avg:87.47ms +step:636/1680 train_time:55631ms step_avg:87.47ms +step:637/1680 train_time:55724ms step_avg:87.48ms +step:638/1680 train_time:55814ms step_avg:87.48ms +step:639/1680 train_time:55904ms step_avg:87.49ms +step:640/1680 train_time:55993ms step_avg:87.49ms +step:641/1680 train_time:56081ms step_avg:87.49ms +step:642/1680 train_time:56169ms step_avg:87.49ms +step:643/1680 train_time:56258ms step_avg:87.49ms +step:644/1680 train_time:56345ms step_avg:87.49ms +step:645/1680 train_time:56433ms step_avg:87.49ms +step:646/1680 train_time:56520ms step_avg:87.49ms +step:647/1680 train_time:56608ms step_avg:87.49ms +step:648/1680 train_time:56697ms step_avg:87.50ms +step:649/1680 train_time:56786ms step_avg:87.50ms +step:650/1680 train_time:56875ms step_avg:87.50ms +step:651/1680 train_time:56964ms step_avg:87.50ms +step:652/1680 train_time:57053ms step_avg:87.50ms +step:653/1680 train_time:57141ms step_avg:87.51ms +step:654/1680 train_time:57229ms step_avg:87.51ms +step:655/1680 train_time:57317ms step_avg:87.51ms +step:656/1680 train_time:57405ms step_avg:87.51ms +step:657/1680 train_time:57493ms step_avg:87.51ms +step:658/1680 train_time:57581ms step_avg:87.51ms +step:659/1680 train_time:57671ms step_avg:87.51ms +step:660/1680 train_time:57761ms step_avg:87.52ms +step:661/1680 train_time:57850ms step_avg:87.52ms +step:662/1680 train_time:57938ms step_avg:87.52ms +step:663/1680 train_time:58028ms step_avg:87.52ms +step:664/1680 train_time:58115ms step_avg:87.52ms +step:665/1680 train_time:58203ms step_avg:87.52ms +step:666/1680 train_time:58292ms step_avg:87.52ms +step:667/1680 train_time:58379ms step_avg:87.52ms +step:668/1680 train_time:58467ms step_avg:87.53ms +step:669/1680 train_time:58556ms step_avg:87.53ms +step:670/1680 train_time:58643ms step_avg:87.53ms +step:671/1680 train_time:58732ms step_avg:87.53ms +step:672/1680 train_time:58821ms step_avg:87.53ms +step:673/1680 train_time:58909ms step_avg:87.53ms +step:674/1680 train_time:58999ms step_avg:87.54ms +step:675/1680 train_time:59087ms step_avg:87.54ms +step:676/1680 train_time:59175ms step_avg:87.54ms +step:677/1680 train_time:59263ms step_avg:87.54ms +step:678/1680 train_time:59352ms step_avg:87.54ms +step:679/1680 train_time:59439ms step_avg:87.54ms +step:680/1680 train_time:59527ms step_avg:87.54ms +step:681/1680 train_time:59615ms step_avg:87.54ms +step:682/1680 train_time:59703ms step_avg:87.54ms +step:683/1680 train_time:59792ms step_avg:87.54ms +step:684/1680 train_time:59881ms step_avg:87.55ms +step:685/1680 train_time:59970ms step_avg:87.55ms +step:686/1680 train_time:60059ms step_avg:87.55ms +step:687/1680 train_time:60146ms step_avg:87.55ms +step:688/1680 train_time:60234ms step_avg:87.55ms +step:689/1680 train_time:60322ms step_avg:87.55ms +step:690/1680 train_time:60411ms step_avg:87.55ms +step:691/1680 train_time:60498ms step_avg:87.55ms +step:692/1680 train_time:60586ms step_avg:87.55ms +step:693/1680 train_time:60674ms step_avg:87.55ms +step:694/1680 train_time:60762ms step_avg:87.55ms +step:695/1680 train_time:60851ms step_avg:87.55ms +step:696/1680 train_time:60939ms step_avg:87.56ms +step:697/1680 train_time:61028ms step_avg:87.56ms +step:698/1680 train_time:61117ms step_avg:87.56ms +step:699/1680 train_time:61204ms step_avg:87.56ms +step:700/1680 train_time:61294ms step_avg:87.56ms +step:701/1680 train_time:61382ms step_avg:87.56ms +step:702/1680 train_time:61470ms step_avg:87.56ms +step:703/1680 train_time:61559ms step_avg:87.57ms +step:704/1680 train_time:61647ms step_avg:87.57ms +step:705/1680 train_time:61735ms step_avg:87.57ms +step:706/1680 train_time:61823ms step_avg:87.57ms +step:707/1680 train_time:61913ms step_avg:87.57ms +step:708/1680 train_time:62003ms step_avg:87.57ms +step:709/1680 train_time:62091ms step_avg:87.58ms +step:710/1680 train_time:62179ms step_avg:87.58ms +step:711/1680 train_time:62267ms step_avg:87.58ms +step:712/1680 train_time:62356ms step_avg:87.58ms +step:713/1680 train_time:62444ms step_avg:87.58ms +step:714/1680 train_time:62533ms step_avg:87.58ms +step:715/1680 train_time:62621ms step_avg:87.58ms +step:716/1680 train_time:62710ms step_avg:87.58ms +step:717/1680 train_time:62799ms step_avg:87.59ms +step:718/1680 train_time:62887ms step_avg:87.59ms +step:719/1680 train_time:62976ms step_avg:87.59ms +step:720/1680 train_time:63065ms step_avg:87.59ms +step:721/1680 train_time:63153ms step_avg:87.59ms +step:722/1680 train_time:63241ms step_avg:87.59ms +step:723/1680 train_time:63330ms step_avg:87.59ms +step:724/1680 train_time:63418ms step_avg:87.59ms +step:725/1680 train_time:63506ms step_avg:87.59ms +step:726/1680 train_time:63594ms step_avg:87.60ms +step:727/1680 train_time:63682ms step_avg:87.60ms +step:728/1680 train_time:63771ms step_avg:87.60ms +step:729/1680 train_time:63859ms step_avg:87.60ms +step:730/1680 train_time:63947ms step_avg:87.60ms +step:731/1680 train_time:64036ms step_avg:87.60ms +step:732/1680 train_time:64123ms step_avg:87.60ms +step:733/1680 train_time:64212ms step_avg:87.60ms +step:734/1680 train_time:64301ms step_avg:87.60ms +step:735/1680 train_time:64389ms step_avg:87.60ms +step:736/1680 train_time:64477ms step_avg:87.60ms +step:737/1680 train_time:64565ms step_avg:87.61ms +step:738/1680 train_time:64653ms step_avg:87.61ms +step:739/1680 train_time:64741ms step_avg:87.61ms +step:740/1680 train_time:64830ms step_avg:87.61ms +step:741/1680 train_time:64919ms step_avg:87.61ms +step:742/1680 train_time:65007ms step_avg:87.61ms +step:743/1680 train_time:65096ms step_avg:87.61ms +step:744/1680 train_time:65185ms step_avg:87.61ms +step:745/1680 train_time:65273ms step_avg:87.61ms +step:746/1680 train_time:65361ms step_avg:87.62ms +step:747/1680 train_time:65449ms step_avg:87.62ms +step:748/1680 train_time:65538ms step_avg:87.62ms +step:749/1680 train_time:65626ms step_avg:87.62ms +step:750/1680 train_time:65714ms step_avg:87.62ms +step:750/1680 val_loss:3.5644 train_time:65804ms step_avg:87.74ms +step:751/1680 train_time:65823ms step_avg:87.65ms +step:752/1680 train_time:65895ms step_avg:87.63ms +step:753/1680 train_time:65989ms step_avg:87.63ms +step:754/1680 train_time:66078ms step_avg:87.64ms +step:755/1680 train_time:66166ms step_avg:87.64ms +step:756/1680 train_time:66253ms step_avg:87.64ms +step:757/1680 train_time:66341ms step_avg:87.64ms +step:758/1680 train_time:66428ms step_avg:87.64ms +step:759/1680 train_time:66516ms step_avg:87.64ms +step:760/1680 train_time:66604ms step_avg:87.64ms +step:761/1680 train_time:66692ms step_avg:87.64ms +step:762/1680 train_time:66781ms step_avg:87.64ms +step:763/1680 train_time:66870ms step_avg:87.64ms +step:764/1680 train_time:66960ms step_avg:87.64ms +step:765/1680 train_time:67050ms step_avg:87.65ms +step:766/1680 train_time:67139ms step_avg:87.65ms +step:767/1680 train_time:67228ms step_avg:87.65ms +step:768/1680 train_time:67315ms step_avg:87.65ms +step:769/1680 train_time:67403ms step_avg:87.65ms +step:770/1680 train_time:67490ms step_avg:87.65ms +step:771/1680 train_time:67577ms step_avg:87.65ms +step:772/1680 train_time:67665ms step_avg:87.65ms +step:773/1680 train_time:67754ms step_avg:87.65ms +step:774/1680 train_time:67843ms step_avg:87.65ms +step:775/1680 train_time:67932ms step_avg:87.65ms +step:776/1680 train_time:68022ms step_avg:87.66ms +step:777/1680 train_time:68111ms step_avg:87.66ms +step:778/1680 train_time:68199ms step_avg:87.66ms +step:779/1680 train_time:68287ms step_avg:87.66ms +step:780/1680 train_time:68375ms step_avg:87.66ms +step:781/1680 train_time:68463ms step_avg:87.66ms +step:782/1680 train_time:68551ms step_avg:87.66ms +step:783/1680 train_time:68638ms step_avg:87.66ms +step:784/1680 train_time:68726ms step_avg:87.66ms +step:785/1680 train_time:68816ms step_avg:87.66ms +step:786/1680 train_time:68904ms step_avg:87.66ms +step:787/1680 train_time:68994ms step_avg:87.67ms +step:788/1680 train_time:69083ms step_avg:87.67ms +step:789/1680 train_time:69171ms step_avg:87.67ms +step:790/1680 train_time:69259ms step_avg:87.67ms +step:791/1680 train_time:69348ms step_avg:87.67ms +step:792/1680 train_time:69436ms step_avg:87.67ms +step:793/1680 train_time:69524ms step_avg:87.67ms +step:794/1680 train_time:69613ms step_avg:87.67ms +step:795/1680 train_time:69700ms step_avg:87.67ms +step:796/1680 train_time:69789ms step_avg:87.67ms +step:797/1680 train_time:69878ms step_avg:87.68ms +step:798/1680 train_time:69968ms step_avg:87.68ms +step:799/1680 train_time:70057ms step_avg:87.68ms +step:800/1680 train_time:70146ms step_avg:87.68ms +step:801/1680 train_time:70234ms step_avg:87.68ms +step:802/1680 train_time:70322ms step_avg:87.68ms +step:803/1680 train_time:70410ms step_avg:87.68ms +step:804/1680 train_time:70498ms step_avg:87.68ms +step:805/1680 train_time:70586ms step_avg:87.68ms +step:806/1680 train_time:70674ms step_avg:87.69ms +step:807/1680 train_time:70762ms step_avg:87.68ms +step:808/1680 train_time:70850ms step_avg:87.69ms +step:809/1680 train_time:70939ms step_avg:87.69ms +step:810/1680 train_time:71027ms step_avg:87.69ms +step:811/1680 train_time:71116ms step_avg:87.69ms +step:812/1680 train_time:71205ms step_avg:87.69ms +step:813/1680 train_time:71294ms step_avg:87.69ms +step:814/1680 train_time:71381ms step_avg:87.69ms +step:815/1680 train_time:71470ms step_avg:87.69ms +step:816/1680 train_time:71558ms step_avg:87.69ms +step:817/1680 train_time:71647ms step_avg:87.69ms +step:818/1680 train_time:71735ms step_avg:87.70ms +step:819/1680 train_time:71823ms step_avg:87.70ms +step:820/1680 train_time:71911ms step_avg:87.70ms +step:821/1680 train_time:72000ms step_avg:87.70ms +step:822/1680 train_time:72088ms step_avg:87.70ms +step:823/1680 train_time:72177ms step_avg:87.70ms +step:824/1680 train_time:72266ms step_avg:87.70ms +step:825/1680 train_time:72355ms step_avg:87.70ms +step:826/1680 train_time:72443ms step_avg:87.70ms +step:827/1680 train_time:72532ms step_avg:87.71ms +step:828/1680 train_time:72620ms step_avg:87.71ms +step:829/1680 train_time:72709ms step_avg:87.71ms +step:830/1680 train_time:72797ms step_avg:87.71ms +step:831/1680 train_time:72886ms step_avg:87.71ms +step:832/1680 train_time:72974ms step_avg:87.71ms +step:833/1680 train_time:73062ms step_avg:87.71ms +step:834/1680 train_time:73151ms step_avg:87.71ms +step:835/1680 train_time:73239ms step_avg:87.71ms +step:836/1680 train_time:73328ms step_avg:87.71ms +step:837/1680 train_time:73417ms step_avg:87.71ms +step:838/1680 train_time:73505ms step_avg:87.71ms +step:839/1680 train_time:73592ms step_avg:87.71ms +step:840/1680 train_time:73681ms step_avg:87.71ms +step:841/1680 train_time:73769ms step_avg:87.72ms +step:842/1680 train_time:73858ms step_avg:87.72ms +step:843/1680 train_time:73947ms step_avg:87.72ms +step:844/1680 train_time:74035ms step_avg:87.72ms +step:845/1680 train_time:74124ms step_avg:87.72ms +step:846/1680 train_time:74212ms step_avg:87.72ms +step:847/1680 train_time:74300ms step_avg:87.72ms +step:848/1680 train_time:74390ms step_avg:87.72ms +step:849/1680 train_time:74478ms step_avg:87.72ms +step:850/1680 train_time:74567ms step_avg:87.73ms +step:851/1680 train_time:74656ms step_avg:87.73ms +step:852/1680 train_time:74744ms step_avg:87.73ms +step:853/1680 train_time:74832ms step_avg:87.73ms +step:854/1680 train_time:74921ms step_avg:87.73ms +step:855/1680 train_time:75010ms step_avg:87.73ms +step:856/1680 train_time:75098ms step_avg:87.73ms +step:857/1680 train_time:75187ms step_avg:87.73ms +step:858/1680 train_time:75276ms step_avg:87.73ms +step:859/1680 train_time:75365ms step_avg:87.74ms +step:860/1680 train_time:75453ms step_avg:87.74ms +step:861/1680 train_time:75541ms step_avg:87.74ms +step:862/1680 train_time:75629ms step_avg:87.74ms +step:863/1680 train_time:75718ms step_avg:87.74ms +step:864/1680 train_time:75806ms step_avg:87.74ms +step:865/1680 train_time:75895ms step_avg:87.74ms +step:866/1680 train_time:75984ms step_avg:87.74ms +step:867/1680 train_time:76072ms step_avg:87.74ms +step:868/1680 train_time:76161ms step_avg:87.74ms +step:869/1680 train_time:76249ms step_avg:87.74ms +step:870/1680 train_time:76338ms step_avg:87.74ms +step:871/1680 train_time:76426ms step_avg:87.75ms +step:872/1680 train_time:76514ms step_avg:87.75ms +step:873/1680 train_time:76602ms step_avg:87.75ms +step:874/1680 train_time:76690ms step_avg:87.75ms +step:875/1680 train_time:76778ms step_avg:87.75ms +step:875/1680 val_loss:3.5198 train_time:76869ms step_avg:87.85ms +step:876/1680 train_time:76889ms step_avg:87.77ms +step:877/1680 train_time:76961ms step_avg:87.75ms +step:878/1680 train_time:77052ms step_avg:87.76ms +step:879/1680 train_time:77140ms step_avg:87.76ms +step:880/1680 train_time:77228ms step_avg:87.76ms +step:881/1680 train_time:77317ms step_avg:87.76ms +step:882/1680 train_time:77404ms step_avg:87.76ms +step:883/1680 train_time:77491ms step_avg:87.76ms +step:884/1680 train_time:77579ms step_avg:87.76ms +step:885/1680 train_time:77668ms step_avg:87.76ms +step:886/1680 train_time:77755ms step_avg:87.76ms +step:887/1680 train_time:77843ms step_avg:87.76ms +step:888/1680 train_time:77933ms step_avg:87.76ms +step:889/1680 train_time:78024ms step_avg:87.77ms +step:890/1680 train_time:78113ms step_avg:87.77ms +step:891/1680 train_time:78202ms step_avg:87.77ms +step:892/1680 train_time:78291ms step_avg:87.77ms +step:893/1680 train_time:78378ms step_avg:87.77ms +step:894/1680 train_time:78467ms step_avg:87.77ms +step:895/1680 train_time:78554ms step_avg:87.77ms +step:896/1680 train_time:78642ms step_avg:87.77ms +step:897/1680 train_time:78730ms step_avg:87.77ms +step:898/1680 train_time:78818ms step_avg:87.77ms +step:899/1680 train_time:78908ms step_avg:87.77ms +step:900/1680 train_time:78997ms step_avg:87.77ms +step:901/1680 train_time:79086ms step_avg:87.78ms +step:902/1680 train_time:79174ms step_avg:87.78ms +step:903/1680 train_time:79264ms step_avg:87.78ms +step:904/1680 train_time:79352ms step_avg:87.78ms +step:905/1680 train_time:79441ms step_avg:87.78ms +step:906/1680 train_time:79530ms step_avg:87.78ms +step:907/1680 train_time:79617ms step_avg:87.78ms +step:908/1680 train_time:79706ms step_avg:87.78ms +step:909/1680 train_time:79793ms step_avg:87.78ms +step:910/1680 train_time:79882ms step_avg:87.78ms +step:911/1680 train_time:79971ms step_avg:87.78ms +step:912/1680 train_time:80059ms step_avg:87.78ms +step:913/1680 train_time:80148ms step_avg:87.78ms +step:914/1680 train_time:80236ms step_avg:87.79ms +step:915/1680 train_time:80324ms step_avg:87.79ms +step:916/1680 train_time:80412ms step_avg:87.79ms +step:917/1680 train_time:80501ms step_avg:87.79ms +step:918/1680 train_time:80589ms step_avg:87.79ms +step:919/1680 train_time:80677ms step_avg:87.79ms +step:920/1680 train_time:80766ms step_avg:87.79ms +step:921/1680 train_time:80853ms step_avg:87.79ms +step:922/1680 train_time:80942ms step_avg:87.79ms +step:923/1680 train_time:81030ms step_avg:87.79ms +step:924/1680 train_time:81119ms step_avg:87.79ms +step:925/1680 train_time:81208ms step_avg:87.79ms +step:926/1680 train_time:81297ms step_avg:87.79ms +step:927/1680 train_time:81386ms step_avg:87.79ms +step:928/1680 train_time:81474ms step_avg:87.80ms +step:929/1680 train_time:81563ms step_avg:87.80ms +step:930/1680 train_time:81652ms step_avg:87.80ms +step:931/1680 train_time:81741ms step_avg:87.80ms +step:932/1680 train_time:81829ms step_avg:87.80ms +step:933/1680 train_time:81917ms step_avg:87.80ms +step:934/1680 train_time:82006ms step_avg:87.80ms +step:935/1680 train_time:82093ms step_avg:87.80ms +step:936/1680 train_time:82183ms step_avg:87.80ms +step:937/1680 train_time:82272ms step_avg:87.80ms +step:938/1680 train_time:82360ms step_avg:87.80ms +step:939/1680 train_time:82449ms step_avg:87.80ms +step:940/1680 train_time:82537ms step_avg:87.81ms +step:941/1680 train_time:82626ms step_avg:87.81ms +step:942/1680 train_time:82714ms step_avg:87.81ms +step:943/1680 train_time:82802ms step_avg:87.81ms +step:944/1680 train_time:82891ms step_avg:87.81ms +step:945/1680 train_time:82979ms step_avg:87.81ms +step:946/1680 train_time:83067ms step_avg:87.81ms +step:947/1680 train_time:83156ms step_avg:87.81ms +step:948/1680 train_time:83245ms step_avg:87.81ms +step:949/1680 train_time:83333ms step_avg:87.81ms +step:950/1680 train_time:83421ms step_avg:87.81ms +step:951/1680 train_time:83510ms step_avg:87.81ms +step:952/1680 train_time:83598ms step_avg:87.81ms +step:953/1680 train_time:83687ms step_avg:87.81ms +step:954/1680 train_time:83775ms step_avg:87.81ms +step:955/1680 train_time:83864ms step_avg:87.82ms +step:956/1680 train_time:83952ms step_avg:87.82ms +step:957/1680 train_time:84041ms step_avg:87.82ms +step:958/1680 train_time:84130ms step_avg:87.82ms +step:959/1680 train_time:84218ms step_avg:87.82ms +step:960/1680 train_time:84307ms step_avg:87.82ms +step:961/1680 train_time:84394ms step_avg:87.82ms +step:962/1680 train_time:84483ms step_avg:87.82ms +step:963/1680 train_time:84572ms step_avg:87.82ms +step:964/1680 train_time:84661ms step_avg:87.82ms +step:965/1680 train_time:84749ms step_avg:87.82ms +step:966/1680 train_time:84838ms step_avg:87.82ms +step:967/1680 train_time:84926ms step_avg:87.82ms +step:968/1680 train_time:85014ms step_avg:87.82ms +step:969/1680 train_time:85103ms step_avg:87.83ms +step:970/1680 train_time:85191ms step_avg:87.83ms +step:971/1680 train_time:85279ms step_avg:87.83ms +step:972/1680 train_time:85368ms step_avg:87.83ms +step:973/1680 train_time:85456ms step_avg:87.83ms +step:974/1680 train_time:85544ms step_avg:87.83ms +step:975/1680 train_time:85632ms step_avg:87.83ms +step:976/1680 train_time:85720ms step_avg:87.83ms +step:977/1680 train_time:85808ms step_avg:87.83ms +step:978/1680 train_time:85897ms step_avg:87.83ms +step:979/1680 train_time:85985ms step_avg:87.83ms +step:980/1680 train_time:86073ms step_avg:87.83ms +step:981/1680 train_time:86162ms step_avg:87.83ms +step:982/1680 train_time:86250ms step_avg:87.83ms +step:983/1680 train_time:86338ms step_avg:87.83ms +step:984/1680 train_time:86427ms step_avg:87.83ms +step:985/1680 train_time:86515ms step_avg:87.83ms +step:986/1680 train_time:86604ms step_avg:87.83ms +step:987/1680 train_time:86692ms step_avg:87.83ms +step:988/1680 train_time:86780ms step_avg:87.83ms +step:989/1680 train_time:86869ms step_avg:87.84ms +step:990/1680 train_time:86957ms step_avg:87.84ms +step:991/1680 train_time:87045ms step_avg:87.84ms +step:992/1680 train_time:87132ms step_avg:87.84ms +step:993/1680 train_time:87221ms step_avg:87.84ms +step:994/1680 train_time:87310ms step_avg:87.84ms +step:995/1680 train_time:87398ms step_avg:87.84ms +step:996/1680 train_time:87487ms step_avg:87.84ms +step:997/1680 train_time:87575ms step_avg:87.84ms +step:998/1680 train_time:87665ms step_avg:87.84ms +step:999/1680 train_time:87753ms step_avg:87.84ms +step:1000/1680 train_time:87841ms step_avg:87.84ms +step:1000/1680 val_loss:3.4705 train_time:87931ms step_avg:87.93ms +step:1001/1680 train_time:87950ms step_avg:87.86ms +step:1002/1680 train_time:88026ms step_avg:87.85ms +step:1003/1680 train_time:88115ms step_avg:87.85ms +step:1004/1680 train_time:88205ms step_avg:87.85ms +step:1005/1680 train_time:88292ms step_avg:87.85ms +step:1006/1680 train_time:88379ms step_avg:87.85ms +step:1007/1680 train_time:88466ms step_avg:87.85ms +step:1008/1680 train_time:88554ms step_avg:87.85ms +step:1009/1680 train_time:88641ms step_avg:87.85ms +step:1010/1680 train_time:88729ms step_avg:87.85ms +step:1011/1680 train_time:88816ms step_avg:87.85ms +step:1012/1680 train_time:88906ms step_avg:87.85ms +step:1013/1680 train_time:88996ms step_avg:87.85ms +step:1014/1680 train_time:89086ms step_avg:87.86ms +step:1015/1680 train_time:89175ms step_avg:87.86ms +step:1016/1680 train_time:89266ms step_avg:87.86ms +step:1017/1680 train_time:89354ms step_avg:87.86ms +step:1018/1680 train_time:89441ms step_avg:87.86ms +step:1019/1680 train_time:89529ms step_avg:87.86ms +step:1020/1680 train_time:89617ms step_avg:87.86ms +step:1021/1680 train_time:89705ms step_avg:87.86ms +step:1022/1680 train_time:89792ms step_avg:87.86ms +step:1023/1680 train_time:89881ms step_avg:87.86ms +step:1024/1680 train_time:89970ms step_avg:87.86ms +step:1025/1680 train_time:90059ms step_avg:87.86ms +step:1026/1680 train_time:90148ms step_avg:87.86ms +step:1027/1680 train_time:90237ms step_avg:87.86ms +step:1028/1680 train_time:90325ms step_avg:87.86ms +step:1029/1680 train_time:90413ms step_avg:87.86ms +step:1030/1680 train_time:90500ms step_avg:87.86ms +step:1031/1680 train_time:90588ms step_avg:87.86ms +step:1032/1680 train_time:90676ms step_avg:87.86ms +step:1033/1680 train_time:90764ms step_avg:87.86ms +step:1034/1680 train_time:90852ms step_avg:87.86ms +step:1035/1680 train_time:90941ms step_avg:87.87ms +step:1036/1680 train_time:91029ms step_avg:87.87ms +step:1037/1680 train_time:91118ms step_avg:87.87ms +step:1038/1680 train_time:91207ms step_avg:87.87ms +step:1039/1680 train_time:91296ms step_avg:87.87ms +step:1040/1680 train_time:91384ms step_avg:87.87ms +step:1041/1680 train_time:91472ms step_avg:87.87ms +step:1042/1680 train_time:91560ms step_avg:87.87ms +step:1043/1680 train_time:91649ms step_avg:87.87ms +step:1044/1680 train_time:91737ms step_avg:87.87ms +step:1045/1680 train_time:91825ms step_avg:87.87ms +step:1046/1680 train_time:91913ms step_avg:87.87ms +step:1047/1680 train_time:92002ms step_avg:87.87ms +step:1048/1680 train_time:92091ms step_avg:87.87ms +step:1049/1680 train_time:92180ms step_avg:87.87ms +step:1050/1680 train_time:92269ms step_avg:87.87ms +step:1051/1680 train_time:92357ms step_avg:87.88ms +step:1052/1680 train_time:92445ms step_avg:87.88ms +step:1053/1680 train_time:92533ms step_avg:87.88ms +step:1054/1680 train_time:92621ms step_avg:87.88ms +step:1055/1680 train_time:92709ms step_avg:87.88ms +step:1056/1680 train_time:92797ms step_avg:87.88ms +step:1057/1680 train_time:92885ms step_avg:87.88ms +step:1058/1680 train_time:92974ms step_avg:87.88ms +step:1059/1680 train_time:93062ms step_avg:87.88ms +step:1060/1680 train_time:93151ms step_avg:87.88ms +step:1061/1680 train_time:93241ms step_avg:87.88ms +step:1062/1680 train_time:93330ms step_avg:87.88ms +step:1063/1680 train_time:93418ms step_avg:87.88ms +step:1064/1680 train_time:93507ms step_avg:87.88ms +step:1065/1680 train_time:93595ms step_avg:87.88ms +step:1066/1680 train_time:93683ms step_avg:87.88ms +step:1067/1680 train_time:93771ms step_avg:87.88ms +step:1068/1680 train_time:93859ms step_avg:87.88ms +step:1069/1680 train_time:93948ms step_avg:87.88ms +step:1070/1680 train_time:94037ms step_avg:87.88ms +step:1071/1680 train_time:94126ms step_avg:87.89ms +step:1072/1680 train_time:94214ms step_avg:87.89ms +step:1073/1680 train_time:94303ms step_avg:87.89ms +step:1074/1680 train_time:94391ms step_avg:87.89ms +step:1075/1680 train_time:94480ms step_avg:87.89ms +step:1076/1680 train_time:94568ms step_avg:87.89ms +step:1077/1680 train_time:94657ms step_avg:87.89ms +step:1078/1680 train_time:94745ms step_avg:87.89ms +step:1079/1680 train_time:94833ms step_avg:87.89ms +step:1080/1680 train_time:94922ms step_avg:87.89ms +step:1081/1680 train_time:95010ms step_avg:87.89ms +step:1082/1680 train_time:95098ms step_avg:87.89ms +step:1083/1680 train_time:95187ms step_avg:87.89ms +step:1084/1680 train_time:95276ms step_avg:87.89ms +step:1085/1680 train_time:95365ms step_avg:87.89ms +step:1086/1680 train_time:95453ms step_avg:87.89ms +step:1087/1680 train_time:95541ms step_avg:87.89ms +step:1088/1680 train_time:95629ms step_avg:87.89ms +step:1089/1680 train_time:95717ms step_avg:87.89ms +step:1090/1680 train_time:95805ms step_avg:87.89ms +step:1091/1680 train_time:95894ms step_avg:87.90ms +step:1092/1680 train_time:95982ms step_avg:87.90ms +step:1093/1680 train_time:96070ms step_avg:87.90ms +step:1094/1680 train_time:96160ms step_avg:87.90ms +step:1095/1680 train_time:96249ms step_avg:87.90ms +step:1096/1680 train_time:96339ms step_avg:87.90ms +step:1097/1680 train_time:96427ms step_avg:87.90ms +step:1098/1680 train_time:96517ms step_avg:87.90ms +step:1099/1680 train_time:96606ms step_avg:87.90ms +step:1100/1680 train_time:96695ms step_avg:87.90ms +step:1101/1680 train_time:96783ms step_avg:87.91ms +step:1102/1680 train_time:96872ms step_avg:87.91ms +step:1103/1680 train_time:96961ms step_avg:87.91ms +step:1104/1680 train_time:97050ms step_avg:87.91ms +step:1105/1680 train_time:97139ms step_avg:87.91ms +step:1106/1680 train_time:97228ms step_avg:87.91ms +step:1107/1680 train_time:97318ms step_avg:87.91ms +step:1108/1680 train_time:97407ms step_avg:87.91ms +step:1109/1680 train_time:97497ms step_avg:87.91ms +step:1110/1680 train_time:97586ms step_avg:87.92ms +step:1111/1680 train_time:97675ms step_avg:87.92ms +step:1112/1680 train_time:97764ms step_avg:87.92ms +step:1113/1680 train_time:97853ms step_avg:87.92ms +step:1114/1680 train_time:97941ms step_avg:87.92ms +step:1115/1680 train_time:98030ms step_avg:87.92ms +step:1116/1680 train_time:98121ms step_avg:87.92ms +step:1117/1680 train_time:98210ms step_avg:87.92ms +step:1118/1680 train_time:98299ms step_avg:87.92ms +step:1119/1680 train_time:98388ms step_avg:87.92ms +step:1120/1680 train_time:98476ms step_avg:87.93ms +step:1121/1680 train_time:98566ms step_avg:87.93ms +step:1122/1680 train_time:98656ms step_avg:87.93ms +step:1123/1680 train_time:98745ms step_avg:87.93ms +step:1124/1680 train_time:98833ms step_avg:87.93ms +step:1125/1680 train_time:98922ms step_avg:87.93ms +step:1125/1680 val_loss:3.4164 train_time:99012ms step_avg:88.01ms +step:1126/1680 train_time:99032ms step_avg:87.95ms +step:1127/1680 train_time:99104ms step_avg:87.94ms +step:1128/1680 train_time:99194ms step_avg:87.94ms +step:1129/1680 train_time:99284ms step_avg:87.94ms +step:1130/1680 train_time:99373ms step_avg:87.94ms +step:1131/1680 train_time:99461ms step_avg:87.94ms +step:1132/1680 train_time:99549ms step_avg:87.94ms +step:1133/1680 train_time:99637ms step_avg:87.94ms +step:1134/1680 train_time:99725ms step_avg:87.94ms +step:1135/1680 train_time:99813ms step_avg:87.94ms +step:1136/1680 train_time:99903ms step_avg:87.94ms +step:1137/1680 train_time:99996ms step_avg:87.95ms +step:1138/1680 train_time:100086ms step_avg:87.95ms +step:1139/1680 train_time:100177ms step_avg:87.95ms +step:1140/1680 train_time:100267ms step_avg:87.95ms +step:1141/1680 train_time:100355ms step_avg:87.95ms +step:1142/1680 train_time:100443ms step_avg:87.95ms +step:1143/1680 train_time:100531ms step_avg:87.95ms +step:1144/1680 train_time:100619ms step_avg:87.95ms +step:1145/1680 train_time:100708ms step_avg:87.95ms +step:1146/1680 train_time:100796ms step_avg:87.95ms +step:1147/1680 train_time:100884ms step_avg:87.96ms +step:1148/1680 train_time:100975ms step_avg:87.96ms +step:1149/1680 train_time:101066ms step_avg:87.96ms +step:1150/1680 train_time:101156ms step_avg:87.96ms +step:1151/1680 train_time:101246ms step_avg:87.96ms +step:1152/1680 train_time:101335ms step_avg:87.96ms +step:1153/1680 train_time:101423ms step_avg:87.96ms +step:1154/1680 train_time:101512ms step_avg:87.97ms +step:1155/1680 train_time:101600ms step_avg:87.97ms +step:1156/1680 train_time:101688ms step_avg:87.97ms +step:1157/1680 train_time:101777ms step_avg:87.97ms +step:1158/1680 train_time:101867ms step_avg:87.97ms +step:1159/1680 train_time:101957ms step_avg:87.97ms +step:1160/1680 train_time:102048ms step_avg:87.97ms +step:1161/1680 train_time:102139ms step_avg:87.98ms +step:1162/1680 train_time:102228ms step_avg:87.98ms +step:1163/1680 train_time:102317ms step_avg:87.98ms +step:1164/1680 train_time:102407ms step_avg:87.98ms +step:1165/1680 train_time:102495ms step_avg:87.98ms +step:1166/1680 train_time:102584ms step_avg:87.98ms +step:1167/1680 train_time:102672ms step_avg:87.98ms +step:1168/1680 train_time:102761ms step_avg:87.98ms +step:1169/1680 train_time:102851ms step_avg:87.98ms +step:1170/1680 train_time:102940ms step_avg:87.98ms +step:1171/1680 train_time:103030ms step_avg:87.98ms +step:1172/1680 train_time:103120ms step_avg:87.99ms +step:1173/1680 train_time:103211ms step_avg:87.99ms +step:1174/1680 train_time:103300ms step_avg:87.99ms +step:1175/1680 train_time:103389ms step_avg:87.99ms +step:1176/1680 train_time:103478ms step_avg:87.99ms +step:1177/1680 train_time:103566ms step_avg:87.99ms +step:1178/1680 train_time:103654ms step_avg:87.99ms +step:1179/1680 train_time:103743ms step_avg:87.99ms +step:1180/1680 train_time:103832ms step_avg:87.99ms +step:1181/1680 train_time:103921ms step_avg:87.99ms +step:1182/1680 train_time:104011ms step_avg:88.00ms +step:1183/1680 train_time:104100ms step_avg:88.00ms +step:1184/1680 train_time:104190ms step_avg:88.00ms +step:1185/1680 train_time:104280ms step_avg:88.00ms +step:1186/1680 train_time:104369ms step_avg:88.00ms +step:1187/1680 train_time:104458ms step_avg:88.00ms +step:1188/1680 train_time:104547ms step_avg:88.00ms +step:1189/1680 train_time:104636ms step_avg:88.00ms +step:1190/1680 train_time:104726ms step_avg:88.00ms +step:1191/1680 train_time:104815ms step_avg:88.01ms +step:1192/1680 train_time:104905ms step_avg:88.01ms +step:1193/1680 train_time:104994ms step_avg:88.01ms +step:1194/1680 train_time:105083ms step_avg:88.01ms +step:1195/1680 train_time:105173ms step_avg:88.01ms +step:1196/1680 train_time:105262ms step_avg:88.01ms +step:1197/1680 train_time:105351ms step_avg:88.01ms +step:1198/1680 train_time:105440ms step_avg:88.01ms +step:1199/1680 train_time:105529ms step_avg:88.01ms +step:1200/1680 train_time:105618ms step_avg:88.01ms +step:1201/1680 train_time:105707ms step_avg:88.02ms +step:1202/1680 train_time:105796ms step_avg:88.02ms +step:1203/1680 train_time:105885ms step_avg:88.02ms +step:1204/1680 train_time:105974ms step_avg:88.02ms +step:1205/1680 train_time:106063ms step_avg:88.02ms +step:1206/1680 train_time:106152ms step_avg:88.02ms +step:1207/1680 train_time:106241ms step_avg:88.02ms +step:1208/1680 train_time:106331ms step_avg:88.02ms +step:1209/1680 train_time:106419ms step_avg:88.02ms +step:1210/1680 train_time:106508ms step_avg:88.02ms +step:1211/1680 train_time:106597ms step_avg:88.02ms +step:1212/1680 train_time:106686ms step_avg:88.02ms +step:1213/1680 train_time:106775ms step_avg:88.03ms +step:1214/1680 train_time:106864ms step_avg:88.03ms +step:1215/1680 train_time:106953ms step_avg:88.03ms +step:1216/1680 train_time:107042ms step_avg:88.03ms +step:1217/1680 train_time:107131ms step_avg:88.03ms +step:1218/1680 train_time:107221ms step_avg:88.03ms +step:1219/1680 train_time:107310ms step_avg:88.03ms +step:1220/1680 train_time:107398ms step_avg:88.03ms +step:1221/1680 train_time:107488ms step_avg:88.03ms +step:1222/1680 train_time:107576ms step_avg:88.03ms +step:1223/1680 train_time:107666ms step_avg:88.03ms +step:1224/1680 train_time:107756ms step_avg:88.04ms +step:1225/1680 train_time:107845ms step_avg:88.04ms +step:1226/1680 train_time:107934ms step_avg:88.04ms +step:1227/1680 train_time:108023ms step_avg:88.04ms +step:1228/1680 train_time:108112ms step_avg:88.04ms +step:1229/1680 train_time:108201ms step_avg:88.04ms +step:1230/1680 train_time:108290ms step_avg:88.04ms +step:1231/1680 train_time:108379ms step_avg:88.04ms +step:1232/1680 train_time:108468ms step_avg:88.04ms +step:1233/1680 train_time:108557ms step_avg:88.04ms +step:1234/1680 train_time:108646ms step_avg:88.04ms +step:1235/1680 train_time:108735ms step_avg:88.04ms +step:1236/1680 train_time:108824ms step_avg:88.05ms +step:1237/1680 train_time:108914ms step_avg:88.05ms +step:1238/1680 train_time:109003ms step_avg:88.05ms +step:1239/1680 train_time:109092ms step_avg:88.05ms +step:1240/1680 train_time:109181ms step_avg:88.05ms +step:1241/1680 train_time:109271ms step_avg:88.05ms +step:1242/1680 train_time:109360ms step_avg:88.05ms +step:1243/1680 train_time:109449ms step_avg:88.05ms +step:1244/1680 train_time:109539ms step_avg:88.05ms +step:1245/1680 train_time:109628ms step_avg:88.05ms +step:1246/1680 train_time:109717ms step_avg:88.06ms +step:1247/1680 train_time:109806ms step_avg:88.06ms +step:1248/1680 train_time:109896ms step_avg:88.06ms +step:1249/1680 train_time:109984ms step_avg:88.06ms +step:1250/1680 train_time:110073ms step_avg:88.06ms +step:1250/1680 val_loss:3.3778 train_time:110163ms step_avg:88.13ms +step:1251/1680 train_time:110183ms step_avg:88.08ms +step:1252/1680 train_time:110255ms step_avg:88.06ms +step:1253/1680 train_time:110348ms step_avg:88.07ms +step:1254/1680 train_time:110437ms step_avg:88.07ms +step:1255/1680 train_time:110526ms step_avg:88.07ms +step:1256/1680 train_time:110614ms step_avg:88.07ms +step:1257/1680 train_time:110702ms step_avg:88.07ms +step:1258/1680 train_time:110791ms step_avg:88.07ms +step:1259/1680 train_time:110879ms step_avg:88.07ms +step:1260/1680 train_time:110967ms step_avg:88.07ms +step:1261/1680 train_time:111056ms step_avg:88.07ms +step:1262/1680 train_time:111145ms step_avg:88.07ms +step:1263/1680 train_time:111236ms step_avg:88.07ms +step:1264/1680 train_time:111327ms step_avg:88.07ms +step:1265/1680 train_time:111417ms step_avg:88.08ms +step:1266/1680 train_time:111507ms step_avg:88.08ms +step:1267/1680 train_time:111597ms step_avg:88.08ms +step:1268/1680 train_time:111684ms step_avg:88.08ms +step:1269/1680 train_time:111773ms step_avg:88.08ms +step:1270/1680 train_time:111861ms step_avg:88.08ms +step:1271/1680 train_time:111949ms step_avg:88.08ms +step:1272/1680 train_time:112038ms step_avg:88.08ms +step:1273/1680 train_time:112127ms step_avg:88.08ms +step:1274/1680 train_time:112217ms step_avg:88.08ms +step:1275/1680 train_time:112307ms step_avg:88.08ms +step:1276/1680 train_time:112398ms step_avg:88.09ms +step:1277/1680 train_time:112487ms step_avg:88.09ms +step:1278/1680 train_time:112577ms step_avg:88.09ms +step:1279/1680 train_time:112666ms step_avg:88.09ms +step:1280/1680 train_time:112754ms step_avg:88.09ms +step:1281/1680 train_time:112843ms step_avg:88.09ms +step:1282/1680 train_time:112932ms step_avg:88.09ms +step:1283/1680 train_time:113021ms step_avg:88.09ms +step:1284/1680 train_time:113110ms step_avg:88.09ms +step:1285/1680 train_time:113200ms step_avg:88.09ms +step:1286/1680 train_time:113290ms step_avg:88.09ms +step:1287/1680 train_time:113379ms step_avg:88.10ms +step:1288/1680 train_time:113469ms step_avg:88.10ms +step:1289/1680 train_time:113558ms step_avg:88.10ms +step:1290/1680 train_time:113647ms step_avg:88.10ms +step:1291/1680 train_time:113736ms step_avg:88.10ms +step:1292/1680 train_time:113825ms step_avg:88.10ms +step:1293/1680 train_time:113914ms step_avg:88.10ms +step:1294/1680 train_time:114002ms step_avg:88.10ms +step:1295/1680 train_time:114091ms step_avg:88.10ms +step:1296/1680 train_time:114180ms step_avg:88.10ms +step:1297/1680 train_time:114269ms step_avg:88.10ms +step:1298/1680 train_time:114359ms step_avg:88.10ms +step:1299/1680 train_time:114448ms step_avg:88.10ms +step:1300/1680 train_time:114538ms step_avg:88.11ms +step:1301/1680 train_time:114626ms step_avg:88.11ms +step:1302/1680 train_time:114715ms step_avg:88.11ms +step:1303/1680 train_time:114804ms step_avg:88.11ms +step:1304/1680 train_time:114893ms step_avg:88.11ms +step:1305/1680 train_time:114983ms step_avg:88.11ms +step:1306/1680 train_time:115072ms step_avg:88.11ms +step:1307/1680 train_time:115161ms step_avg:88.11ms +step:1308/1680 train_time:115250ms step_avg:88.11ms +step:1309/1680 train_time:115339ms step_avg:88.11ms +step:1310/1680 train_time:115429ms step_avg:88.11ms +step:1311/1680 train_time:115519ms step_avg:88.11ms +step:1312/1680 train_time:115607ms step_avg:88.12ms +step:1313/1680 train_time:115696ms step_avg:88.12ms +step:1314/1680 train_time:115785ms step_avg:88.12ms +step:1315/1680 train_time:115873ms step_avg:88.12ms +step:1316/1680 train_time:115963ms step_avg:88.12ms +step:1317/1680 train_time:116052ms step_avg:88.12ms +step:1318/1680 train_time:116142ms step_avg:88.12ms +step:1319/1680 train_time:116231ms step_avg:88.12ms +step:1320/1680 train_time:116321ms step_avg:88.12ms +step:1321/1680 train_time:116410ms step_avg:88.12ms +step:1322/1680 train_time:116500ms step_avg:88.12ms +step:1323/1680 train_time:116589ms step_avg:88.12ms +step:1324/1680 train_time:116678ms step_avg:88.13ms +step:1325/1680 train_time:116767ms step_avg:88.13ms +step:1326/1680 train_time:116856ms step_avg:88.13ms +step:1327/1680 train_time:116944ms step_avg:88.13ms +step:1328/1680 train_time:117034ms step_avg:88.13ms +step:1329/1680 train_time:117124ms step_avg:88.13ms +step:1330/1680 train_time:117214ms step_avg:88.13ms +step:1331/1680 train_time:117303ms step_avg:88.13ms +step:1332/1680 train_time:117392ms step_avg:88.13ms +step:1333/1680 train_time:117482ms step_avg:88.13ms +step:1334/1680 train_time:117570ms step_avg:88.13ms +step:1335/1680 train_time:117660ms step_avg:88.13ms +step:1336/1680 train_time:117749ms step_avg:88.14ms +step:1337/1680 train_time:117837ms step_avg:88.14ms +step:1338/1680 train_time:117926ms step_avg:88.14ms +step:1339/1680 train_time:118016ms step_avg:88.14ms +step:1340/1680 train_time:118105ms step_avg:88.14ms +step:1341/1680 train_time:118193ms step_avg:88.14ms +step:1342/1680 train_time:118282ms step_avg:88.14ms +step:1343/1680 train_time:118372ms step_avg:88.14ms +step:1344/1680 train_time:118461ms step_avg:88.14ms +step:1345/1680 train_time:118550ms step_avg:88.14ms +step:1346/1680 train_time:118639ms step_avg:88.14ms +step:1347/1680 train_time:118728ms step_avg:88.14ms +step:1348/1680 train_time:118817ms step_avg:88.14ms +step:1349/1680 train_time:118905ms step_avg:88.14ms +step:1350/1680 train_time:118995ms step_avg:88.14ms +step:1351/1680 train_time:119085ms step_avg:88.15ms +step:1352/1680 train_time:119174ms step_avg:88.15ms +step:1353/1680 train_time:119264ms step_avg:88.15ms +step:1354/1680 train_time:119353ms step_avg:88.15ms +step:1355/1680 train_time:119442ms step_avg:88.15ms +step:1356/1680 train_time:119531ms step_avg:88.15ms +step:1357/1680 train_time:119620ms step_avg:88.15ms +step:1358/1680 train_time:119709ms step_avg:88.15ms +step:1359/1680 train_time:119799ms step_avg:88.15ms +step:1360/1680 train_time:119888ms step_avg:88.15ms +step:1361/1680 train_time:119977ms step_avg:88.15ms +step:1362/1680 train_time:120065ms step_avg:88.15ms +step:1363/1680 train_time:120155ms step_avg:88.15ms +step:1364/1680 train_time:120244ms step_avg:88.16ms +step:1365/1680 train_time:120333ms step_avg:88.16ms +step:1366/1680 train_time:120422ms step_avg:88.16ms +step:1367/1680 train_time:120512ms step_avg:88.16ms +step:1368/1680 train_time:120601ms step_avg:88.16ms +step:1369/1680 train_time:120691ms step_avg:88.16ms +step:1370/1680 train_time:120779ms step_avg:88.16ms +step:1371/1680 train_time:120868ms step_avg:88.16ms +step:1372/1680 train_time:120957ms step_avg:88.16ms +step:1373/1680 train_time:121046ms step_avg:88.16ms +step:1374/1680 train_time:121135ms step_avg:88.16ms +step:1375/1680 train_time:121224ms step_avg:88.16ms +step:1375/1680 val_loss:3.3432 train_time:121315ms step_avg:88.23ms +step:1376/1680 train_time:121333ms step_avg:88.18ms +step:1377/1680 train_time:121406ms step_avg:88.17ms +step:1378/1680 train_time:121500ms step_avg:88.17ms +step:1379/1680 train_time:121590ms step_avg:88.17ms +step:1380/1680 train_time:121679ms step_avg:88.17ms +step:1381/1680 train_time:121767ms step_avg:88.17ms +step:1382/1680 train_time:121856ms step_avg:88.17ms +step:1383/1680 train_time:121945ms step_avg:88.17ms +step:1384/1680 train_time:122033ms step_avg:88.17ms +step:1385/1680 train_time:122121ms step_avg:88.17ms +step:1386/1680 train_time:122210ms step_avg:88.17ms +step:1387/1680 train_time:122300ms step_avg:88.18ms +step:1388/1680 train_time:122390ms step_avg:88.18ms +step:1389/1680 train_time:122481ms step_avg:88.18ms +step:1390/1680 train_time:122571ms step_avg:88.18ms +step:1391/1680 train_time:122660ms step_avg:88.18ms +step:1392/1680 train_time:122748ms step_avg:88.18ms +step:1393/1680 train_time:122838ms step_avg:88.18ms +step:1394/1680 train_time:122928ms step_avg:88.18ms +step:1395/1680 train_time:123016ms step_avg:88.18ms +step:1396/1680 train_time:123105ms step_avg:88.18ms +step:1397/1680 train_time:123193ms step_avg:88.18ms +step:1398/1680 train_time:123282ms step_avg:88.18ms +step:1399/1680 train_time:123372ms step_avg:88.19ms +step:1400/1680 train_time:123463ms step_avg:88.19ms +step:1401/1680 train_time:123553ms step_avg:88.19ms +step:1402/1680 train_time:123642ms step_avg:88.19ms +step:1403/1680 train_time:123731ms step_avg:88.19ms +step:1404/1680 train_time:123821ms step_avg:88.19ms +step:1405/1680 train_time:123909ms step_avg:88.19ms +step:1406/1680 train_time:123997ms step_avg:88.19ms +step:1407/1680 train_time:124086ms step_avg:88.19ms +step:1408/1680 train_time:124176ms step_avg:88.19ms +step:1409/1680 train_time:124266ms step_avg:88.19ms +step:1410/1680 train_time:124355ms step_avg:88.20ms +step:1411/1680 train_time:124446ms step_avg:88.20ms +step:1412/1680 train_time:124535ms step_avg:88.20ms +step:1413/1680 train_time:124624ms step_avg:88.20ms +step:1414/1680 train_time:124713ms step_avg:88.20ms +step:1415/1680 train_time:124802ms step_avg:88.20ms +step:1416/1680 train_time:124891ms step_avg:88.20ms +step:1417/1680 train_time:124980ms step_avg:88.20ms +step:1418/1680 train_time:125068ms step_avg:88.20ms +step:1419/1680 train_time:125157ms step_avg:88.20ms +step:1420/1680 train_time:125246ms step_avg:88.20ms +step:1421/1680 train_time:125335ms step_avg:88.20ms +step:1422/1680 train_time:125426ms step_avg:88.20ms +step:1423/1680 train_time:125516ms step_avg:88.21ms +step:1424/1680 train_time:125605ms step_avg:88.21ms +step:1425/1680 train_time:125694ms step_avg:88.21ms +step:1426/1680 train_time:125784ms step_avg:88.21ms +step:1427/1680 train_time:125872ms step_avg:88.21ms +step:1428/1680 train_time:125962ms step_avg:88.21ms +step:1429/1680 train_time:126050ms step_avg:88.21ms +step:1430/1680 train_time:126140ms step_avg:88.21ms +step:1431/1680 train_time:126229ms step_avg:88.21ms +step:1432/1680 train_time:126318ms step_avg:88.21ms +step:1433/1680 train_time:126408ms step_avg:88.21ms +step:1434/1680 train_time:126498ms step_avg:88.21ms +step:1435/1680 train_time:126587ms step_avg:88.21ms +step:1436/1680 train_time:126676ms step_avg:88.21ms +step:1437/1680 train_time:126766ms step_avg:88.22ms +step:1438/1680 train_time:126855ms step_avg:88.22ms +step:1439/1680 train_time:126943ms step_avg:88.22ms +step:1440/1680 train_time:127032ms step_avg:88.22ms +step:1441/1680 train_time:127122ms step_avg:88.22ms +step:1442/1680 train_time:127211ms step_avg:88.22ms +step:1443/1680 train_time:127300ms step_avg:88.22ms +step:1444/1680 train_time:127388ms step_avg:88.22ms +step:1445/1680 train_time:127478ms step_avg:88.22ms +step:1446/1680 train_time:127568ms step_avg:88.22ms +step:1447/1680 train_time:127658ms step_avg:88.22ms +step:1448/1680 train_time:127747ms step_avg:88.22ms +step:1449/1680 train_time:127837ms step_avg:88.22ms +step:1450/1680 train_time:127926ms step_avg:88.22ms +step:1451/1680 train_time:128015ms step_avg:88.23ms +step:1452/1680 train_time:128104ms step_avg:88.23ms +step:1453/1680 train_time:128192ms step_avg:88.23ms +step:1454/1680 train_time:128282ms step_avg:88.23ms +step:1455/1680 train_time:128370ms step_avg:88.23ms +step:1456/1680 train_time:128460ms step_avg:88.23ms +step:1457/1680 train_time:128549ms step_avg:88.23ms +step:1458/1680 train_time:128639ms step_avg:88.23ms +step:1459/1680 train_time:128728ms step_avg:88.23ms +step:1460/1680 train_time:128818ms step_avg:88.23ms +step:1461/1680 train_time:128908ms step_avg:88.23ms +step:1462/1680 train_time:128997ms step_avg:88.23ms +step:1463/1680 train_time:129085ms step_avg:88.23ms +step:1464/1680 train_time:129175ms step_avg:88.23ms +step:1465/1680 train_time:129263ms step_avg:88.23ms +step:1466/1680 train_time:129353ms step_avg:88.24ms +step:1467/1680 train_time:129442ms step_avg:88.24ms +step:1468/1680 train_time:129531ms step_avg:88.24ms +step:1469/1680 train_time:129621ms step_avg:88.24ms +step:1470/1680 train_time:129710ms step_avg:88.24ms +step:1471/1680 train_time:129799ms step_avg:88.24ms +step:1472/1680 train_time:129888ms step_avg:88.24ms +step:1473/1680 train_time:129977ms step_avg:88.24ms +step:1474/1680 train_time:130066ms step_avg:88.24ms +step:1475/1680 train_time:130155ms step_avg:88.24ms +step:1476/1680 train_time:130245ms step_avg:88.24ms +step:1477/1680 train_time:130335ms step_avg:88.24ms +step:1478/1680 train_time:130424ms step_avg:88.24ms +step:1479/1680 train_time:130513ms step_avg:88.24ms +step:1480/1680 train_time:130603ms step_avg:88.25ms +step:1481/1680 train_time:130693ms step_avg:88.25ms +step:1482/1680 train_time:130782ms step_avg:88.25ms +step:1483/1680 train_time:130871ms step_avg:88.25ms +step:1484/1680 train_time:130960ms step_avg:88.25ms +step:1485/1680 train_time:131050ms step_avg:88.25ms +step:1486/1680 train_time:131139ms step_avg:88.25ms +step:1487/1680 train_time:131228ms step_avg:88.25ms +step:1488/1680 train_time:131317ms step_avg:88.25ms +step:1489/1680 train_time:131406ms step_avg:88.25ms +step:1490/1680 train_time:131494ms step_avg:88.25ms +step:1491/1680 train_time:131584ms step_avg:88.25ms +step:1492/1680 train_time:131673ms step_avg:88.25ms +step:1493/1680 train_time:131762ms step_avg:88.25ms +step:1494/1680 train_time:131852ms step_avg:88.25ms +step:1495/1680 train_time:131940ms step_avg:88.25ms +step:1496/1680 train_time:132030ms step_avg:88.26ms +step:1497/1680 train_time:132119ms step_avg:88.26ms +step:1498/1680 train_time:132208ms step_avg:88.26ms +step:1499/1680 train_time:132298ms step_avg:88.26ms +step:1500/1680 train_time:132387ms step_avg:88.26ms +step:1500/1680 val_loss:3.3141 train_time:132478ms step_avg:88.32ms +step:1501/1680 train_time:132497ms step_avg:88.27ms +step:1502/1680 train_time:132570ms step_avg:88.26ms +step:1503/1680 train_time:132661ms step_avg:88.26ms +step:1504/1680 train_time:132750ms step_avg:88.26ms +step:1505/1680 train_time:132839ms step_avg:88.27ms +step:1506/1680 train_time:132927ms step_avg:88.27ms +step:1507/1680 train_time:133015ms step_avg:88.27ms +step:1508/1680 train_time:133104ms step_avg:88.27ms +step:1509/1680 train_time:133191ms step_avg:88.26ms +step:1510/1680 train_time:133280ms step_avg:88.27ms +step:1511/1680 train_time:133369ms step_avg:88.27ms +step:1512/1680 train_time:133459ms step_avg:88.27ms +step:1513/1680 train_time:133550ms step_avg:88.27ms +step:1514/1680 train_time:133643ms step_avg:88.27ms +step:1515/1680 train_time:133732ms step_avg:88.27ms +step:1516/1680 train_time:133822ms step_avg:88.27ms +step:1517/1680 train_time:133911ms step_avg:88.27ms +step:1518/1680 train_time:133999ms step_avg:88.27ms +step:1519/1680 train_time:134088ms step_avg:88.27ms +step:1520/1680 train_time:134176ms step_avg:88.27ms +step:1521/1680 train_time:134264ms step_avg:88.27ms +step:1522/1680 train_time:134353ms step_avg:88.27ms +step:1523/1680 train_time:134442ms step_avg:88.27ms +step:1524/1680 train_time:134532ms step_avg:88.28ms +step:1525/1680 train_time:134623ms step_avg:88.28ms +step:1526/1680 train_time:134713ms step_avg:88.28ms +step:1527/1680 train_time:134803ms step_avg:88.28ms +step:1528/1680 train_time:134892ms step_avg:88.28ms +step:1529/1680 train_time:134981ms step_avg:88.28ms +step:1530/1680 train_time:135071ms step_avg:88.28ms +step:1531/1680 train_time:135160ms step_avg:88.28ms +step:1532/1680 train_time:135248ms step_avg:88.28ms +step:1533/1680 train_time:135337ms step_avg:88.28ms +step:1534/1680 train_time:135426ms step_avg:88.28ms +step:1535/1680 train_time:135515ms step_avg:88.28ms +step:1536/1680 train_time:135606ms step_avg:88.29ms +step:1537/1680 train_time:135696ms step_avg:88.29ms +step:1538/1680 train_time:135785ms step_avg:88.29ms +step:1539/1680 train_time:135875ms step_avg:88.29ms +step:1540/1680 train_time:135964ms step_avg:88.29ms +step:1541/1680 train_time:136053ms step_avg:88.29ms +step:1542/1680 train_time:136142ms step_avg:88.29ms +step:1543/1680 train_time:136231ms step_avg:88.29ms +step:1544/1680 train_time:136319ms step_avg:88.29ms +step:1545/1680 train_time:136407ms step_avg:88.29ms +step:1546/1680 train_time:136497ms step_avg:88.29ms +step:1547/1680 train_time:136586ms step_avg:88.29ms +step:1548/1680 train_time:136676ms step_avg:88.29ms +step:1549/1680 train_time:136766ms step_avg:88.29ms +step:1550/1680 train_time:136855ms step_avg:88.29ms +step:1551/1680 train_time:136945ms step_avg:88.29ms +step:1552/1680 train_time:137036ms step_avg:88.30ms +step:1553/1680 train_time:137124ms step_avg:88.30ms +step:1554/1680 train_time:137214ms step_avg:88.30ms +step:1555/1680 train_time:137303ms step_avg:88.30ms +step:1556/1680 train_time:137393ms step_avg:88.30ms +step:1557/1680 train_time:137483ms step_avg:88.30ms +step:1558/1680 train_time:137572ms step_avg:88.30ms +step:1559/1680 train_time:137662ms step_avg:88.30ms +step:1560/1680 train_time:137752ms step_avg:88.30ms +step:1561/1680 train_time:137841ms step_avg:88.30ms +step:1562/1680 train_time:137931ms step_avg:88.30ms +step:1563/1680 train_time:138022ms step_avg:88.31ms +step:1564/1680 train_time:138112ms step_avg:88.31ms +step:1565/1680 train_time:138201ms step_avg:88.31ms +step:1566/1680 train_time:138290ms step_avg:88.31ms +step:1567/1680 train_time:138379ms step_avg:88.31ms +step:1568/1680 train_time:138470ms step_avg:88.31ms +step:1569/1680 train_time:138559ms step_avg:88.31ms +step:1570/1680 train_time:138648ms step_avg:88.31ms +step:1571/1680 train_time:138736ms step_avg:88.31ms +step:1572/1680 train_time:138825ms step_avg:88.31ms +step:1573/1680 train_time:138914ms step_avg:88.31ms +step:1574/1680 train_time:139003ms step_avg:88.31ms +step:1575/1680 train_time:139092ms step_avg:88.31ms +step:1576/1680 train_time:139183ms step_avg:88.31ms +step:1577/1680 train_time:139272ms step_avg:88.31ms +step:1578/1680 train_time:139361ms step_avg:88.32ms +step:1579/1680 train_time:139451ms step_avg:88.32ms +step:1580/1680 train_time:139540ms step_avg:88.32ms +step:1581/1680 train_time:139629ms step_avg:88.32ms +step:1582/1680 train_time:139719ms step_avg:88.32ms +step:1583/1680 train_time:139808ms step_avg:88.32ms +step:1584/1680 train_time:139897ms step_avg:88.32ms +step:1585/1680 train_time:139987ms step_avg:88.32ms +step:1586/1680 train_time:140076ms step_avg:88.32ms +step:1587/1680 train_time:140166ms step_avg:88.32ms +step:1588/1680 train_time:140256ms step_avg:88.32ms +step:1589/1680 train_time:140345ms step_avg:88.32ms +step:1590/1680 train_time:140434ms step_avg:88.32ms +step:1591/1680 train_time:140524ms step_avg:88.32ms +step:1592/1680 train_time:140613ms step_avg:88.32ms +step:1593/1680 train_time:140702ms step_avg:88.33ms +step:1594/1680 train_time:140792ms step_avg:88.33ms +step:1595/1680 train_time:140881ms step_avg:88.33ms +step:1596/1680 train_time:140971ms step_avg:88.33ms +step:1597/1680 train_time:141061ms step_avg:88.33ms +step:1598/1680 train_time:141150ms step_avg:88.33ms +step:1599/1680 train_time:141239ms step_avg:88.33ms +step:1600/1680 train_time:141328ms step_avg:88.33ms +step:1601/1680 train_time:141418ms step_avg:88.33ms +step:1602/1680 train_time:141507ms step_avg:88.33ms +step:1603/1680 train_time:141595ms step_avg:88.33ms +step:1604/1680 train_time:141684ms step_avg:88.33ms +step:1605/1680 train_time:141774ms step_avg:88.33ms +step:1606/1680 train_time:141864ms step_avg:88.33ms +step:1607/1680 train_time:141952ms step_avg:88.33ms +step:1608/1680 train_time:142041ms step_avg:88.33ms +step:1609/1680 train_time:142131ms step_avg:88.33ms +step:1610/1680 train_time:142220ms step_avg:88.34ms +step:1611/1680 train_time:142310ms step_avg:88.34ms +step:1612/1680 train_time:142399ms step_avg:88.34ms +step:1613/1680 train_time:142489ms step_avg:88.34ms +step:1614/1680 train_time:142579ms step_avg:88.34ms +step:1615/1680 train_time:142668ms step_avg:88.34ms +step:1616/1680 train_time:142757ms step_avg:88.34ms +step:1617/1680 train_time:142846ms step_avg:88.34ms +step:1618/1680 train_time:142936ms step_avg:88.34ms +step:1619/1680 train_time:143025ms step_avg:88.34ms +step:1620/1680 train_time:143113ms step_avg:88.34ms +step:1621/1680 train_time:143202ms step_avg:88.34ms +step:1622/1680 train_time:143291ms step_avg:88.34ms +step:1623/1680 train_time:143381ms step_avg:88.34ms +step:1624/1680 train_time:143470ms step_avg:88.34ms +step:1625/1680 train_time:143560ms step_avg:88.34ms +step:1625/1680 val_loss:3.2901 train_time:143650ms step_avg:88.40ms +step:1626/1680 train_time:143668ms step_avg:88.36ms +step:1627/1680 train_time:143742ms step_avg:88.35ms +step:1628/1680 train_time:143834ms step_avg:88.35ms +step:1629/1680 train_time:143924ms step_avg:88.35ms +step:1630/1680 train_time:144012ms step_avg:88.35ms +step:1631/1680 train_time:144100ms step_avg:88.35ms +step:1632/1680 train_time:144188ms step_avg:88.35ms +step:1633/1680 train_time:144275ms step_avg:88.35ms +step:1634/1680 train_time:144364ms step_avg:88.35ms +step:1635/1680 train_time:144453ms step_avg:88.35ms +step:1636/1680 train_time:144542ms step_avg:88.35ms +step:1637/1680 train_time:144633ms step_avg:88.35ms +step:1638/1680 train_time:144725ms step_avg:88.35ms +step:1639/1680 train_time:144816ms step_avg:88.36ms +step:1640/1680 train_time:144906ms step_avg:88.36ms +step:1641/1680 train_time:144995ms step_avg:88.36ms +step:1642/1680 train_time:145084ms step_avg:88.36ms +step:1643/1680 train_time:145173ms step_avg:88.36ms +step:1644/1680 train_time:145261ms step_avg:88.36ms +step:1645/1680 train_time:145349ms step_avg:88.36ms +step:1646/1680 train_time:145437ms step_avg:88.36ms +step:1647/1680 train_time:145527ms step_avg:88.36ms +step:1648/1680 train_time:145617ms step_avg:88.36ms +step:1649/1680 train_time:145707ms step_avg:88.36ms +step:1650/1680 train_time:145796ms step_avg:88.36ms +step:1651/1680 train_time:145886ms step_avg:88.36ms +step:1652/1680 train_time:145975ms step_avg:88.36ms +step:1653/1680 train_time:146064ms step_avg:88.36ms +step:1654/1680 train_time:146153ms step_avg:88.36ms +step:1655/1680 train_time:146242ms step_avg:88.36ms +step:1656/1680 train_time:146331ms step_avg:88.36ms +step:1657/1680 train_time:146419ms step_avg:88.36ms +step:1658/1680 train_time:146507ms step_avg:88.36ms +step:1659/1680 train_time:146597ms step_avg:88.36ms +step:1660/1680 train_time:146688ms step_avg:88.37ms +step:1661/1680 train_time:146777ms step_avg:88.37ms +step:1662/1680 train_time:146867ms step_avg:88.37ms +step:1663/1680 train_time:146957ms step_avg:88.37ms +step:1664/1680 train_time:147046ms step_avg:88.37ms +step:1665/1680 train_time:147135ms step_avg:88.37ms +step:1666/1680 train_time:147224ms step_avg:88.37ms +step:1667/1680 train_time:147313ms step_avg:88.37ms +step:1668/1680 train_time:147401ms step_avg:88.37ms +step:1669/1680 train_time:147490ms step_avg:88.37ms +step:1670/1680 train_time:147579ms step_avg:88.37ms +step:1671/1680 train_time:147668ms step_avg:88.37ms +step:1672/1680 train_time:147757ms step_avg:88.37ms +step:1673/1680 train_time:147847ms step_avg:88.37ms +step:1674/1680 train_time:147937ms step_avg:88.37ms +step:1675/1680 train_time:148026ms step_avg:88.37ms +step:1676/1680 train_time:148115ms step_avg:88.37ms +step:1677/1680 train_time:148204ms step_avg:88.37ms +step:1678/1680 train_time:148292ms step_avg:88.37ms +step:1679/1680 train_time:148381ms step_avg:88.37ms +step:1680/1680 train_time:148470ms step_avg:88.37ms +step:1680/1680 val_loss:3.2796 train_time:148560ms step_avg:88.43ms +peak memory allocated: 30760 MiB reserved: 46034 MiB diff --git a/records/092725_BF16CE/efa9ba5e-7c95-4d47-8873-ad23d1f28e80.txt b/records/092725_BF16CE/efa9ba5e-7c95-4d47-8873-ad23d1f28e80.txt new file mode 100644 index 000000000..91425b94a --- /dev/null +++ b/records/092725_BF16CE/efa9ba5e-7c95-4d47-8873-ad23d1f28e80.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:17:03 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 25C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 28C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 154404 C /usr/bin/python 0MiB | +| 0 N/A N/A 154405 C /usr/bin/python 0MiB | +| 0 N/A N/A 154406 C /usr/bin/python 0MiB | +| 0 N/A N/A 154407 C /usr/bin/python 0MiB | +| 0 N/A N/A 154408 C /usr/bin/python 0MiB | +| 0 N/A N/A 154409 C /usr/bin/python 0MiB | +| 0 N/A N/A 154410 C /usr/bin/python 0MiB | +| 0 N/A N/A 154411 C /usr/bin/python 0MiB | +| 1 N/A N/A 154405 C /usr/bin/python 0MiB | +| 2 N/A N/A 154406 C /usr/bin/python 0MiB | +| 3 N/A N/A 154407 C /usr/bin/python 0MiB | +| 4 N/A N/A 154408 C /usr/bin/python 0MiB | +| 5 N/A N/A 154409 C /usr/bin/python 0MiB | +| 6 N/A N/A 154410 C /usr/bin/python 0MiB | +| 7 N/A N/A 154411 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:138ms step_avg:137.66ms +step:2/1680 train_time:158ms step_avg:78.93ms +step:3/1680 train_time:222ms step_avg:74.06ms +step:4/1680 train_time:307ms step_avg:76.82ms +step:5/1680 train_time:393ms step_avg:78.65ms +step:6/1680 train_time:479ms step_avg:79.84ms +step:7/1680 train_time:565ms step_avg:80.70ms +step:8/1680 train_time:651ms step_avg:81.42ms +step:9/1680 train_time:737ms step_avg:81.93ms +step:10/1680 train_time:823ms step_avg:82.34ms +step:11/1680 train_time:910ms step_avg:82.71ms +step:12/1680 train_time:997ms step_avg:83.12ms +step:13/1680 train_time:1088ms step_avg:83.71ms +step:14/1680 train_time:1178ms step_avg:84.16ms +step:15/1680 train_time:1266ms step_avg:84.42ms +step:16/1680 train_time:1354ms step_avg:84.62ms +step:17/1680 train_time:1442ms step_avg:84.80ms +step:18/1680 train_time:1529ms step_avg:84.94ms +step:19/1680 train_time:1615ms step_avg:85.02ms +step:20/1680 train_time:1702ms step_avg:85.11ms +step:21/1680 train_time:1789ms step_avg:85.19ms +step:22/1680 train_time:1876ms step_avg:85.26ms +step:23/1680 train_time:1963ms step_avg:85.33ms +step:24/1680 train_time:2052ms step_avg:85.51ms +step:25/1680 train_time:2141ms step_avg:85.64ms +step:26/1680 train_time:2230ms step_avg:85.77ms +step:27/1680 train_time:2317ms step_avg:85.83ms +step:28/1680 train_time:2405ms step_avg:85.88ms +step:29/1680 train_time:2492ms step_avg:85.93ms +step:30/1680 train_time:2579ms step_avg:85.98ms +step:31/1680 train_time:2666ms step_avg:86.00ms +step:32/1680 train_time:2753ms step_avg:86.02ms +step:33/1680 train_time:2839ms step_avg:86.04ms +step:34/1680 train_time:2926ms step_avg:86.07ms +step:35/1680 train_time:3014ms step_avg:86.11ms +step:36/1680 train_time:3102ms step_avg:86.16ms +step:37/1680 train_time:3190ms step_avg:86.22ms +step:38/1680 train_time:3278ms step_avg:86.26ms +step:39/1680 train_time:3365ms step_avg:86.29ms +step:40/1680 train_time:3453ms step_avg:86.33ms +step:41/1680 train_time:3540ms step_avg:86.35ms +step:42/1680 train_time:3628ms step_avg:86.39ms +step:43/1680 train_time:3715ms step_avg:86.40ms +step:44/1680 train_time:3802ms step_avg:86.41ms +step:45/1680 train_time:3889ms step_avg:86.42ms +step:46/1680 train_time:3976ms step_avg:86.43ms +step:47/1680 train_time:4063ms step_avg:86.46ms +step:48/1680 train_time:4152ms step_avg:86.50ms +step:49/1680 train_time:4240ms step_avg:86.54ms +step:50/1680 train_time:4329ms step_avg:86.58ms +step:51/1680 train_time:4419ms step_avg:86.64ms +step:52/1680 train_time:4506ms step_avg:86.65ms +step:53/1680 train_time:4593ms step_avg:86.66ms +step:54/1680 train_time:4680ms step_avg:86.67ms +step:55/1680 train_time:4768ms step_avg:86.68ms +step:56/1680 train_time:4854ms step_avg:86.68ms +step:57/1680 train_time:4942ms step_avg:86.69ms +step:58/1680 train_time:5029ms step_avg:86.71ms +step:59/1680 train_time:5116ms step_avg:86.71ms +step:60/1680 train_time:5204ms step_avg:86.73ms +step:61/1680 train_time:5291ms step_avg:86.75ms +step:62/1680 train_time:5379ms step_avg:86.76ms +step:63/1680 train_time:5467ms step_avg:86.77ms +step:64/1680 train_time:5553ms step_avg:86.77ms +step:65/1680 train_time:5641ms step_avg:86.78ms +step:66/1680 train_time:5728ms step_avg:86.79ms +step:67/1680 train_time:5816ms step_avg:86.80ms +step:68/1680 train_time:5903ms step_avg:86.80ms +step:69/1680 train_time:5989ms step_avg:86.80ms +step:70/1680 train_time:6077ms step_avg:86.81ms +step:71/1680 train_time:6164ms step_avg:86.82ms +step:72/1680 train_time:6252ms step_avg:86.83ms +step:73/1680 train_time:6340ms step_avg:86.84ms +step:74/1680 train_time:6427ms step_avg:86.86ms +step:75/1680 train_time:6515ms step_avg:86.87ms +step:76/1680 train_time:6602ms step_avg:86.87ms +step:77/1680 train_time:6689ms step_avg:86.88ms +step:78/1680 train_time:6777ms step_avg:86.89ms +step:79/1680 train_time:6864ms step_avg:86.88ms +step:80/1680 train_time:6951ms step_avg:86.89ms +step:81/1680 train_time:7038ms step_avg:86.89ms +step:82/1680 train_time:7125ms step_avg:86.89ms +step:83/1680 train_time:7212ms step_avg:86.89ms +step:84/1680 train_time:7300ms step_avg:86.91ms +step:85/1680 train_time:7388ms step_avg:86.91ms +step:86/1680 train_time:7474ms step_avg:86.91ms +step:87/1680 train_time:7562ms step_avg:86.92ms +step:88/1680 train_time:7650ms step_avg:86.93ms +step:89/1680 train_time:7738ms step_avg:86.94ms +step:90/1680 train_time:7825ms step_avg:86.94ms +step:91/1680 train_time:7912ms step_avg:86.94ms +step:92/1680 train_time:7999ms step_avg:86.94ms +step:93/1680 train_time:8086ms step_avg:86.94ms +step:94/1680 train_time:8173ms step_avg:86.94ms +step:95/1680 train_time:8260ms step_avg:86.95ms +step:96/1680 train_time:8347ms step_avg:86.95ms +step:97/1680 train_time:8434ms step_avg:86.95ms +step:98/1680 train_time:8521ms step_avg:86.95ms +step:99/1680 train_time:8609ms step_avg:86.96ms +step:100/1680 train_time:8696ms step_avg:86.96ms +step:101/1680 train_time:8783ms step_avg:86.96ms +step:102/1680 train_time:8871ms step_avg:86.97ms +step:103/1680 train_time:8958ms step_avg:86.97ms +step:104/1680 train_time:9045ms step_avg:86.97ms +step:105/1680 train_time:9133ms step_avg:86.98ms +step:106/1680 train_time:9220ms step_avg:86.98ms +step:107/1680 train_time:9307ms step_avg:86.98ms +step:108/1680 train_time:9396ms step_avg:87.00ms +step:109/1680 train_time:9483ms step_avg:87.00ms +step:110/1680 train_time:9570ms step_avg:87.00ms +step:111/1680 train_time:9657ms step_avg:87.00ms +step:112/1680 train_time:9745ms step_avg:87.01ms +step:113/1680 train_time:9833ms step_avg:87.01ms +step:114/1680 train_time:9920ms step_avg:87.01ms +step:115/1680 train_time:10006ms step_avg:87.01ms +step:116/1680 train_time:10094ms step_avg:87.02ms +step:117/1680 train_time:10181ms step_avg:87.01ms +step:118/1680 train_time:10268ms step_avg:87.02ms +step:119/1680 train_time:10354ms step_avg:87.01ms +step:120/1680 train_time:10442ms step_avg:87.01ms +step:121/1680 train_time:10529ms step_avg:87.02ms +step:122/1680 train_time:10616ms step_avg:87.01ms +step:123/1680 train_time:10703ms step_avg:87.02ms +step:124/1680 train_time:10791ms step_avg:87.02ms +step:125/1680 train_time:10878ms step_avg:87.02ms +step:125/1680 val_loss:4.2984 train_time:10967ms step_avg:87.74ms +step:126/1680 train_time:10986ms step_avg:87.19ms +step:127/1680 train_time:11059ms step_avg:87.08ms +step:128/1680 train_time:11155ms step_avg:87.15ms +step:129/1680 train_time:11248ms step_avg:87.19ms +step:130/1680 train_time:11335ms step_avg:87.19ms +step:131/1680 train_time:11421ms step_avg:87.19ms +step:132/1680 train_time:11507ms step_avg:87.18ms +step:133/1680 train_time:11593ms step_avg:87.17ms +step:134/1680 train_time:11679ms step_avg:87.16ms +step:135/1680 train_time:11766ms step_avg:87.15ms +step:136/1680 train_time:11852ms step_avg:87.15ms +step:137/1680 train_time:11938ms step_avg:87.14ms +step:138/1680 train_time:12025ms step_avg:87.14ms +step:139/1680 train_time:12115ms step_avg:87.16ms +step:140/1680 train_time:12204ms step_avg:87.17ms +step:141/1680 train_time:12295ms step_avg:87.20ms +step:142/1680 train_time:12382ms step_avg:87.20ms +step:143/1680 train_time:12469ms step_avg:87.20ms +step:144/1680 train_time:12556ms step_avg:87.20ms +step:145/1680 train_time:12643ms step_avg:87.19ms +step:146/1680 train_time:12729ms step_avg:87.18ms +step:147/1680 train_time:12815ms step_avg:87.18ms +step:148/1680 train_time:12901ms step_avg:87.17ms +step:149/1680 train_time:12988ms step_avg:87.17ms +step:150/1680 train_time:13076ms step_avg:87.17ms +step:151/1680 train_time:13165ms step_avg:87.18ms +step:152/1680 train_time:13255ms step_avg:87.20ms +step:153/1680 train_time:13342ms step_avg:87.20ms +step:154/1680 train_time:13429ms step_avg:87.20ms +step:155/1680 train_time:13516ms step_avg:87.20ms +step:156/1680 train_time:13603ms step_avg:87.20ms +step:157/1680 train_time:13690ms step_avg:87.19ms +step:158/1680 train_time:13777ms step_avg:87.19ms +step:159/1680 train_time:13863ms step_avg:87.19ms +step:160/1680 train_time:13950ms step_avg:87.19ms +step:161/1680 train_time:14037ms step_avg:87.19ms +step:162/1680 train_time:14124ms step_avg:87.19ms +step:163/1680 train_time:14213ms step_avg:87.19ms +step:164/1680 train_time:14301ms step_avg:87.20ms +step:165/1680 train_time:14388ms step_avg:87.20ms +step:166/1680 train_time:14476ms step_avg:87.20ms +step:167/1680 train_time:14562ms step_avg:87.20ms +step:168/1680 train_time:14649ms step_avg:87.19ms +step:169/1680 train_time:14736ms step_avg:87.19ms +step:170/1680 train_time:14823ms step_avg:87.19ms +step:171/1680 train_time:14910ms step_avg:87.19ms +step:172/1680 train_time:14997ms step_avg:87.19ms +step:173/1680 train_time:15084ms step_avg:87.19ms +step:174/1680 train_time:15172ms step_avg:87.19ms +step:175/1680 train_time:15260ms step_avg:87.20ms +step:176/1680 train_time:15347ms step_avg:87.20ms +step:177/1680 train_time:15435ms step_avg:87.20ms +step:178/1680 train_time:15522ms step_avg:87.20ms +step:179/1680 train_time:15609ms step_avg:87.20ms +step:180/1680 train_time:15696ms step_avg:87.20ms +step:181/1680 train_time:15783ms step_avg:87.20ms +step:182/1680 train_time:15871ms step_avg:87.20ms +step:183/1680 train_time:15958ms step_avg:87.20ms +step:184/1680 train_time:16045ms step_avg:87.20ms +step:185/1680 train_time:16132ms step_avg:87.20ms +step:186/1680 train_time:16219ms step_avg:87.20ms +step:187/1680 train_time:16307ms step_avg:87.20ms +step:188/1680 train_time:16394ms step_avg:87.20ms +step:189/1680 train_time:16481ms step_avg:87.20ms +step:190/1680 train_time:16569ms step_avg:87.20ms +step:191/1680 train_time:16656ms step_avg:87.21ms +step:192/1680 train_time:16743ms step_avg:87.20ms +step:193/1680 train_time:16830ms step_avg:87.20ms +step:194/1680 train_time:16917ms step_avg:87.20ms +step:195/1680 train_time:17004ms step_avg:87.20ms +step:196/1680 train_time:17091ms step_avg:87.20ms +step:197/1680 train_time:17179ms step_avg:87.20ms +step:198/1680 train_time:17267ms step_avg:87.20ms +step:199/1680 train_time:17354ms step_avg:87.21ms +step:200/1680 train_time:17441ms step_avg:87.20ms +step:201/1680 train_time:17528ms step_avg:87.20ms +step:202/1680 train_time:17615ms step_avg:87.20ms +step:203/1680 train_time:17702ms step_avg:87.20ms +step:204/1680 train_time:17789ms step_avg:87.20ms +step:205/1680 train_time:17876ms step_avg:87.20ms +step:206/1680 train_time:17963ms step_avg:87.20ms +step:207/1680 train_time:18050ms step_avg:87.20ms +step:208/1680 train_time:18137ms step_avg:87.20ms +step:209/1680 train_time:18225ms step_avg:87.20ms +step:210/1680 train_time:18313ms step_avg:87.20ms +step:211/1680 train_time:18400ms step_avg:87.21ms +step:212/1680 train_time:18488ms step_avg:87.21ms +step:213/1680 train_time:18576ms step_avg:87.21ms +step:214/1680 train_time:18662ms step_avg:87.21ms +step:215/1680 train_time:18749ms step_avg:87.20ms +step:216/1680 train_time:18836ms step_avg:87.20ms +step:217/1680 train_time:18924ms step_avg:87.21ms +step:218/1680 train_time:19011ms step_avg:87.21ms +step:219/1680 train_time:19098ms step_avg:87.21ms +step:220/1680 train_time:19185ms step_avg:87.20ms +step:221/1680 train_time:19272ms step_avg:87.21ms +step:222/1680 train_time:19359ms step_avg:87.20ms +step:223/1680 train_time:19446ms step_avg:87.20ms +step:224/1680 train_time:19534ms step_avg:87.20ms +step:225/1680 train_time:19621ms step_avg:87.20ms +step:226/1680 train_time:19708ms step_avg:87.20ms +step:227/1680 train_time:19794ms step_avg:87.20ms +step:228/1680 train_time:19881ms step_avg:87.20ms +step:229/1680 train_time:19968ms step_avg:87.20ms +step:230/1680 train_time:20056ms step_avg:87.20ms +step:231/1680 train_time:20143ms step_avg:87.20ms +step:232/1680 train_time:20231ms step_avg:87.20ms +step:233/1680 train_time:20318ms step_avg:87.20ms +step:234/1680 train_time:20405ms step_avg:87.20ms +step:235/1680 train_time:20492ms step_avg:87.20ms +step:236/1680 train_time:20579ms step_avg:87.20ms +step:237/1680 train_time:20666ms step_avg:87.20ms +step:238/1680 train_time:20754ms step_avg:87.20ms +step:239/1680 train_time:20841ms step_avg:87.20ms +step:240/1680 train_time:20928ms step_avg:87.20ms +step:241/1680 train_time:21015ms step_avg:87.20ms +step:242/1680 train_time:21102ms step_avg:87.20ms +step:243/1680 train_time:21189ms step_avg:87.20ms +step:244/1680 train_time:21276ms step_avg:87.20ms +step:245/1680 train_time:21363ms step_avg:87.20ms +step:246/1680 train_time:21451ms step_avg:87.20ms +step:247/1680 train_time:21538ms step_avg:87.20ms +step:248/1680 train_time:21625ms step_avg:87.20ms +step:249/1680 train_time:21712ms step_avg:87.20ms +step:250/1680 train_time:21799ms step_avg:87.20ms +step:250/1680 val_loss:3.9788 train_time:21888ms step_avg:87.55ms +step:251/1680 train_time:21908ms step_avg:87.28ms +step:252/1680 train_time:21979ms step_avg:87.22ms +step:253/1680 train_time:22071ms step_avg:87.24ms +step:254/1680 train_time:22160ms step_avg:87.24ms +step:255/1680 train_time:22247ms step_avg:87.24ms +step:256/1680 train_time:22334ms step_avg:87.24ms +step:257/1680 train_time:22420ms step_avg:87.24ms +step:258/1680 train_time:22506ms step_avg:87.23ms +step:259/1680 train_time:22592ms step_avg:87.23ms +step:260/1680 train_time:22679ms step_avg:87.23ms +step:261/1680 train_time:22765ms step_avg:87.22ms +step:262/1680 train_time:22852ms step_avg:87.22ms +step:263/1680 train_time:22941ms step_avg:87.23ms +step:264/1680 train_time:23032ms step_avg:87.24ms +step:265/1680 train_time:23121ms step_avg:87.25ms +step:266/1680 train_time:23209ms step_avg:87.25ms +step:267/1680 train_time:23295ms step_avg:87.25ms +step:268/1680 train_time:23382ms step_avg:87.25ms +step:269/1680 train_time:23468ms step_avg:87.24ms +step:270/1680 train_time:23555ms step_avg:87.24ms +step:271/1680 train_time:23641ms step_avg:87.24ms +step:272/1680 train_time:23728ms step_avg:87.24ms +step:273/1680 train_time:23815ms step_avg:87.23ms +step:274/1680 train_time:23902ms step_avg:87.24ms +step:275/1680 train_time:23990ms step_avg:87.24ms +step:276/1680 train_time:24079ms step_avg:87.24ms +step:277/1680 train_time:24168ms step_avg:87.25ms +step:278/1680 train_time:24254ms step_avg:87.25ms +step:279/1680 train_time:24341ms step_avg:87.25ms +step:280/1680 train_time:24428ms step_avg:87.24ms +step:281/1680 train_time:24515ms step_avg:87.24ms +step:282/1680 train_time:24602ms step_avg:87.24ms +step:283/1680 train_time:24688ms step_avg:87.24ms +step:284/1680 train_time:24775ms step_avg:87.24ms +step:285/1680 train_time:24863ms step_avg:87.24ms +step:286/1680 train_time:24950ms step_avg:87.24ms +step:287/1680 train_time:25038ms step_avg:87.24ms +step:288/1680 train_time:25126ms step_avg:87.24ms +step:289/1680 train_time:25214ms step_avg:87.25ms +step:290/1680 train_time:25302ms step_avg:87.25ms +step:291/1680 train_time:25389ms step_avg:87.25ms +step:292/1680 train_time:25476ms step_avg:87.25ms +step:293/1680 train_time:25562ms step_avg:87.24ms +step:294/1680 train_time:25648ms step_avg:87.24ms +step:295/1680 train_time:25735ms step_avg:87.24ms +step:296/1680 train_time:25822ms step_avg:87.24ms +step:297/1680 train_time:25909ms step_avg:87.24ms +step:298/1680 train_time:25996ms step_avg:87.24ms +step:299/1680 train_time:26084ms step_avg:87.24ms +step:300/1680 train_time:26171ms step_avg:87.24ms +step:301/1680 train_time:26259ms step_avg:87.24ms +step:302/1680 train_time:26346ms step_avg:87.24ms +step:303/1680 train_time:26434ms step_avg:87.24ms +step:304/1680 train_time:26521ms step_avg:87.24ms +step:305/1680 train_time:26608ms step_avg:87.24ms +step:306/1680 train_time:26695ms step_avg:87.24ms +step:307/1680 train_time:26782ms step_avg:87.24ms +step:308/1680 train_time:26868ms step_avg:87.23ms +step:309/1680 train_time:26955ms step_avg:87.23ms +step:310/1680 train_time:27043ms step_avg:87.24ms +step:311/1680 train_time:27130ms step_avg:87.24ms +step:312/1680 train_time:27217ms step_avg:87.23ms +step:313/1680 train_time:27305ms step_avg:87.24ms +step:314/1680 train_time:27392ms step_avg:87.24ms +step:315/1680 train_time:27480ms step_avg:87.24ms +step:316/1680 train_time:27567ms step_avg:87.24ms +step:317/1680 train_time:27654ms step_avg:87.24ms +step:318/1680 train_time:27741ms step_avg:87.23ms +step:319/1680 train_time:27828ms step_avg:87.23ms +step:320/1680 train_time:27915ms step_avg:87.23ms +step:321/1680 train_time:28003ms step_avg:87.24ms +step:322/1680 train_time:28090ms step_avg:87.24ms +step:323/1680 train_time:28178ms step_avg:87.24ms +step:324/1680 train_time:28265ms step_avg:87.24ms +step:325/1680 train_time:28352ms step_avg:87.24ms +step:326/1680 train_time:28440ms step_avg:87.24ms +step:327/1680 train_time:28527ms step_avg:87.24ms +step:328/1680 train_time:28614ms step_avg:87.24ms +step:329/1680 train_time:28701ms step_avg:87.24ms +step:330/1680 train_time:28789ms step_avg:87.24ms +step:331/1680 train_time:28876ms step_avg:87.24ms +step:332/1680 train_time:28964ms step_avg:87.24ms +step:333/1680 train_time:29051ms step_avg:87.24ms +step:334/1680 train_time:29138ms step_avg:87.24ms +step:335/1680 train_time:29225ms step_avg:87.24ms +step:336/1680 train_time:29312ms step_avg:87.24ms +step:337/1680 train_time:29400ms step_avg:87.24ms +step:338/1680 train_time:29487ms step_avg:87.24ms +step:339/1680 train_time:29574ms step_avg:87.24ms +step:340/1680 train_time:29661ms step_avg:87.24ms +step:341/1680 train_time:29748ms step_avg:87.24ms +step:342/1680 train_time:29835ms step_avg:87.24ms +step:343/1680 train_time:29922ms step_avg:87.24ms +step:344/1680 train_time:30010ms step_avg:87.24ms +step:345/1680 train_time:30097ms step_avg:87.24ms +step:346/1680 train_time:30184ms step_avg:87.24ms +step:347/1680 train_time:30271ms step_avg:87.24ms +step:348/1680 train_time:30359ms step_avg:87.24ms +step:349/1680 train_time:30446ms step_avg:87.24ms +step:350/1680 train_time:30534ms step_avg:87.24ms +step:351/1680 train_time:30621ms step_avg:87.24ms +step:352/1680 train_time:30708ms step_avg:87.24ms +step:353/1680 train_time:30795ms step_avg:87.24ms +step:354/1680 train_time:30882ms step_avg:87.24ms +step:355/1680 train_time:30970ms step_avg:87.24ms +step:356/1680 train_time:31057ms step_avg:87.24ms +step:357/1680 train_time:31144ms step_avg:87.24ms +step:358/1680 train_time:31231ms step_avg:87.24ms +step:359/1680 train_time:31318ms step_avg:87.24ms +step:360/1680 train_time:31405ms step_avg:87.24ms +step:361/1680 train_time:31494ms step_avg:87.24ms +step:362/1680 train_time:31580ms step_avg:87.24ms +step:363/1680 train_time:31667ms step_avg:87.24ms +step:364/1680 train_time:31754ms step_avg:87.24ms +step:365/1680 train_time:31841ms step_avg:87.24ms +step:366/1680 train_time:31929ms step_avg:87.24ms +step:367/1680 train_time:32016ms step_avg:87.24ms +step:368/1680 train_time:32104ms step_avg:87.24ms +step:369/1680 train_time:32191ms step_avg:87.24ms +step:370/1680 train_time:32278ms step_avg:87.24ms +step:371/1680 train_time:32365ms step_avg:87.24ms +step:372/1680 train_time:32452ms step_avg:87.24ms +step:373/1680 train_time:32540ms step_avg:87.24ms +step:374/1680 train_time:32627ms step_avg:87.24ms +step:375/1680 train_time:32714ms step_avg:87.24ms +step:375/1680 val_loss:3.8188 train_time:32802ms step_avg:87.47ms +step:376/1680 train_time:32821ms step_avg:87.29ms +step:377/1680 train_time:32892ms step_avg:87.25ms +step:378/1680 train_time:32982ms step_avg:87.25ms +step:379/1680 train_time:33070ms step_avg:87.25ms +step:380/1680 train_time:33156ms step_avg:87.25ms +step:381/1680 train_time:33244ms step_avg:87.25ms +step:382/1680 train_time:33330ms step_avg:87.25ms +step:383/1680 train_time:33416ms step_avg:87.25ms +step:384/1680 train_time:33503ms step_avg:87.25ms +step:385/1680 train_time:33589ms step_avg:87.24ms +step:386/1680 train_time:33675ms step_avg:87.24ms +step:387/1680 train_time:33763ms step_avg:87.24ms +step:388/1680 train_time:33852ms step_avg:87.25ms +step:389/1680 train_time:33941ms step_avg:87.25ms +step:390/1680 train_time:34029ms step_avg:87.26ms +step:391/1680 train_time:34117ms step_avg:87.26ms +step:392/1680 train_time:34204ms step_avg:87.26ms +step:393/1680 train_time:34291ms step_avg:87.25ms +step:394/1680 train_time:34377ms step_avg:87.25ms +step:395/1680 train_time:34464ms step_avg:87.25ms +step:396/1680 train_time:34550ms step_avg:87.25ms +step:397/1680 train_time:34637ms step_avg:87.25ms +step:398/1680 train_time:34724ms step_avg:87.25ms +step:399/1680 train_time:34812ms step_avg:87.25ms +step:400/1680 train_time:34901ms step_avg:87.25ms +step:401/1680 train_time:34988ms step_avg:87.25ms +step:402/1680 train_time:35076ms step_avg:87.25ms +step:403/1680 train_time:35163ms step_avg:87.25ms +step:404/1680 train_time:35250ms step_avg:87.25ms +step:405/1680 train_time:35337ms step_avg:87.25ms +step:406/1680 train_time:35424ms step_avg:87.25ms +step:407/1680 train_time:35511ms step_avg:87.25ms +step:408/1680 train_time:35597ms step_avg:87.25ms +step:409/1680 train_time:35684ms step_avg:87.25ms +step:410/1680 train_time:35771ms step_avg:87.25ms +step:411/1680 train_time:35858ms step_avg:87.25ms +step:412/1680 train_time:35947ms step_avg:87.25ms +step:413/1680 train_time:36034ms step_avg:87.25ms +step:414/1680 train_time:36122ms step_avg:87.25ms +step:415/1680 train_time:36210ms step_avg:87.25ms +step:416/1680 train_time:36297ms step_avg:87.25ms +step:417/1680 train_time:36384ms step_avg:87.25ms +step:418/1680 train_time:36471ms step_avg:87.25ms +step:419/1680 train_time:36558ms step_avg:87.25ms +step:420/1680 train_time:36645ms step_avg:87.25ms +step:421/1680 train_time:36732ms step_avg:87.25ms +step:422/1680 train_time:36819ms step_avg:87.25ms +step:423/1680 train_time:36907ms step_avg:87.25ms +step:424/1680 train_time:36995ms step_avg:87.25ms +step:425/1680 train_time:37082ms step_avg:87.25ms +step:426/1680 train_time:37170ms step_avg:87.25ms +step:427/1680 train_time:37257ms step_avg:87.25ms +step:428/1680 train_time:37344ms step_avg:87.25ms +step:429/1680 train_time:37431ms step_avg:87.25ms +step:430/1680 train_time:37518ms step_avg:87.25ms +step:431/1680 train_time:37605ms step_avg:87.25ms +step:432/1680 train_time:37693ms step_avg:87.25ms +step:433/1680 train_time:37780ms step_avg:87.25ms +step:434/1680 train_time:37867ms step_avg:87.25ms +step:435/1680 train_time:37954ms step_avg:87.25ms +step:436/1680 train_time:38042ms step_avg:87.25ms +step:437/1680 train_time:38130ms step_avg:87.25ms +step:438/1680 train_time:38217ms step_avg:87.25ms +step:439/1680 train_time:38304ms step_avg:87.25ms +step:440/1680 train_time:38392ms step_avg:87.25ms +step:441/1680 train_time:38479ms step_avg:87.25ms +step:442/1680 train_time:38566ms step_avg:87.25ms +step:443/1680 train_time:38653ms step_avg:87.25ms +step:444/1680 train_time:38740ms step_avg:87.25ms +step:445/1680 train_time:38827ms step_avg:87.25ms +step:446/1680 train_time:38914ms step_avg:87.25ms +step:447/1680 train_time:39002ms step_avg:87.25ms +step:448/1680 train_time:39090ms step_avg:87.25ms +step:449/1680 train_time:39177ms step_avg:87.25ms +step:450/1680 train_time:39265ms step_avg:87.26ms +step:451/1680 train_time:39352ms step_avg:87.26ms +step:452/1680 train_time:39440ms step_avg:87.26ms +step:453/1680 train_time:39527ms step_avg:87.26ms +step:454/1680 train_time:39614ms step_avg:87.26ms +step:455/1680 train_time:39701ms step_avg:87.26ms +step:456/1680 train_time:39789ms step_avg:87.26ms +step:457/1680 train_time:39876ms step_avg:87.26ms +step:458/1680 train_time:39963ms step_avg:87.26ms +step:459/1680 train_time:40050ms step_avg:87.26ms +step:460/1680 train_time:40138ms step_avg:87.26ms +step:461/1680 train_time:40225ms step_avg:87.26ms +step:462/1680 train_time:40312ms step_avg:87.26ms +step:463/1680 train_time:40399ms step_avg:87.25ms +step:464/1680 train_time:40486ms step_avg:87.26ms +step:465/1680 train_time:40574ms step_avg:87.26ms +step:466/1680 train_time:40661ms step_avg:87.26ms +step:467/1680 train_time:40748ms step_avg:87.26ms +step:468/1680 train_time:40835ms step_avg:87.26ms +step:469/1680 train_time:40923ms step_avg:87.26ms +step:470/1680 train_time:41010ms step_avg:87.26ms +step:471/1680 train_time:41097ms step_avg:87.26ms +step:472/1680 train_time:41186ms step_avg:87.26ms +step:473/1680 train_time:41273ms step_avg:87.26ms +step:474/1680 train_time:41360ms step_avg:87.26ms +step:475/1680 train_time:41448ms step_avg:87.26ms +step:476/1680 train_time:41535ms step_avg:87.26ms +step:477/1680 train_time:41622ms step_avg:87.26ms +step:478/1680 train_time:41710ms step_avg:87.26ms +step:479/1680 train_time:41796ms step_avg:87.26ms +step:480/1680 train_time:41884ms step_avg:87.26ms +step:481/1680 train_time:41971ms step_avg:87.26ms +step:482/1680 train_time:42059ms step_avg:87.26ms +step:483/1680 train_time:42146ms step_avg:87.26ms +step:484/1680 train_time:42234ms step_avg:87.26ms +step:485/1680 train_time:42321ms step_avg:87.26ms +step:486/1680 train_time:42408ms step_avg:87.26ms +step:487/1680 train_time:42495ms step_avg:87.26ms +step:488/1680 train_time:42582ms step_avg:87.26ms +step:489/1680 train_time:42670ms step_avg:87.26ms +step:490/1680 train_time:42757ms step_avg:87.26ms +step:491/1680 train_time:42844ms step_avg:87.26ms +step:492/1680 train_time:42932ms step_avg:87.26ms +step:493/1680 train_time:43019ms step_avg:87.26ms +step:494/1680 train_time:43107ms step_avg:87.26ms +step:495/1680 train_time:43194ms step_avg:87.26ms +step:496/1680 train_time:43281ms step_avg:87.26ms +step:497/1680 train_time:43368ms step_avg:87.26ms +step:498/1680 train_time:43455ms step_avg:87.26ms +step:499/1680 train_time:43542ms step_avg:87.26ms +step:500/1680 train_time:43630ms step_avg:87.26ms +step:500/1680 val_loss:3.7174 train_time:43719ms step_avg:87.44ms +step:501/1680 train_time:43737ms step_avg:87.30ms +step:502/1680 train_time:43809ms step_avg:87.27ms +step:503/1680 train_time:43901ms step_avg:87.28ms +step:504/1680 train_time:43990ms step_avg:87.28ms +step:505/1680 train_time:44077ms step_avg:87.28ms +step:506/1680 train_time:44164ms step_avg:87.28ms +step:507/1680 train_time:44251ms step_avg:87.28ms +step:508/1680 train_time:44337ms step_avg:87.28ms +step:509/1680 train_time:44423ms step_avg:87.28ms +step:510/1680 train_time:44509ms step_avg:87.27ms +step:511/1680 train_time:44596ms step_avg:87.27ms +step:512/1680 train_time:44682ms step_avg:87.27ms +step:513/1680 train_time:44771ms step_avg:87.27ms +step:514/1680 train_time:44861ms step_avg:87.28ms +step:515/1680 train_time:44950ms step_avg:87.28ms +step:516/1680 train_time:45037ms step_avg:87.28ms +step:517/1680 train_time:45124ms step_avg:87.28ms +step:518/1680 train_time:45211ms step_avg:87.28ms +step:519/1680 train_time:45298ms step_avg:87.28ms +step:520/1680 train_time:45385ms step_avg:87.28ms +step:521/1680 train_time:45472ms step_avg:87.28ms +step:522/1680 train_time:45559ms step_avg:87.28ms +step:523/1680 train_time:45645ms step_avg:87.28ms +step:524/1680 train_time:45733ms step_avg:87.28ms +step:525/1680 train_time:45821ms step_avg:87.28ms +step:526/1680 train_time:45909ms step_avg:87.28ms +step:527/1680 train_time:45996ms step_avg:87.28ms +step:528/1680 train_time:46084ms step_avg:87.28ms +step:529/1680 train_time:46171ms step_avg:87.28ms +step:530/1680 train_time:46259ms step_avg:87.28ms +step:531/1680 train_time:46345ms step_avg:87.28ms +step:532/1680 train_time:46432ms step_avg:87.28ms +step:533/1680 train_time:46518ms step_avg:87.28ms +step:534/1680 train_time:46606ms step_avg:87.28ms +step:535/1680 train_time:46693ms step_avg:87.28ms +step:536/1680 train_time:46780ms step_avg:87.28ms +step:537/1680 train_time:46868ms step_avg:87.28ms +step:538/1680 train_time:46956ms step_avg:87.28ms +step:539/1680 train_time:47043ms step_avg:87.28ms +step:540/1680 train_time:47131ms step_avg:87.28ms +step:541/1680 train_time:47218ms step_avg:87.28ms +step:542/1680 train_time:47305ms step_avg:87.28ms +step:543/1680 train_time:47392ms step_avg:87.28ms +step:544/1680 train_time:47479ms step_avg:87.28ms +step:545/1680 train_time:47566ms step_avg:87.28ms +step:546/1680 train_time:47654ms step_avg:87.28ms +step:547/1680 train_time:47742ms step_avg:87.28ms +step:548/1680 train_time:47830ms step_avg:87.28ms +step:549/1680 train_time:47918ms step_avg:87.28ms +step:550/1680 train_time:48007ms step_avg:87.29ms +step:551/1680 train_time:48096ms step_avg:87.29ms +step:552/1680 train_time:48184ms step_avg:87.29ms +step:553/1680 train_time:48273ms step_avg:87.29ms +step:554/1680 train_time:48362ms step_avg:87.30ms +step:555/1680 train_time:48449ms step_avg:87.30ms +step:556/1680 train_time:48537ms step_avg:87.30ms +step:557/1680 train_time:48626ms step_avg:87.30ms +step:558/1680 train_time:48714ms step_avg:87.30ms +step:559/1680 train_time:48803ms step_avg:87.30ms +step:560/1680 train_time:48892ms step_avg:87.31ms +step:561/1680 train_time:48981ms step_avg:87.31ms +step:562/1680 train_time:49070ms step_avg:87.31ms +step:563/1680 train_time:49158ms step_avg:87.31ms +step:564/1680 train_time:49247ms step_avg:87.32ms +step:565/1680 train_time:49335ms step_avg:87.32ms +step:566/1680 train_time:49423ms step_avg:87.32ms +step:567/1680 train_time:49511ms step_avg:87.32ms +step:568/1680 train_time:49599ms step_avg:87.32ms +step:569/1680 train_time:49688ms step_avg:87.33ms +step:570/1680 train_time:49776ms step_avg:87.33ms +step:571/1680 train_time:49865ms step_avg:87.33ms +step:572/1680 train_time:49954ms step_avg:87.33ms +step:573/1680 train_time:50043ms step_avg:87.33ms +step:574/1680 train_time:50131ms step_avg:87.34ms +step:575/1680 train_time:50219ms step_avg:87.34ms +step:576/1680 train_time:50308ms step_avg:87.34ms +step:577/1680 train_time:50396ms step_avg:87.34ms +step:578/1680 train_time:50484ms step_avg:87.34ms +step:579/1680 train_time:50572ms step_avg:87.34ms +step:580/1680 train_time:50660ms step_avg:87.35ms +step:581/1680 train_time:50748ms step_avg:87.35ms +step:582/1680 train_time:50836ms step_avg:87.35ms +step:583/1680 train_time:50925ms step_avg:87.35ms +step:584/1680 train_time:51014ms step_avg:87.35ms +step:585/1680 train_time:51103ms step_avg:87.35ms +step:586/1680 train_time:51192ms step_avg:87.36ms +step:587/1680 train_time:51280ms step_avg:87.36ms +step:588/1680 train_time:51369ms step_avg:87.36ms +step:589/1680 train_time:51457ms step_avg:87.36ms +step:590/1680 train_time:51545ms step_avg:87.36ms +step:591/1680 train_time:51634ms step_avg:87.37ms +step:592/1680 train_time:51722ms step_avg:87.37ms +step:593/1680 train_time:51810ms step_avg:87.37ms +step:594/1680 train_time:51898ms step_avg:87.37ms +step:595/1680 train_time:51986ms step_avg:87.37ms +step:596/1680 train_time:52074ms step_avg:87.37ms +step:597/1680 train_time:52163ms step_avg:87.38ms +step:598/1680 train_time:52252ms step_avg:87.38ms +step:599/1680 train_time:52341ms step_avg:87.38ms +step:600/1680 train_time:52430ms step_avg:87.38ms +step:601/1680 train_time:52518ms step_avg:87.38ms +step:602/1680 train_time:52607ms step_avg:87.39ms +step:603/1680 train_time:52695ms step_avg:87.39ms +step:604/1680 train_time:52784ms step_avg:87.39ms +step:605/1680 train_time:52873ms step_avg:87.39ms +step:606/1680 train_time:52961ms step_avg:87.39ms +step:607/1680 train_time:53050ms step_avg:87.40ms +step:608/1680 train_time:53138ms step_avg:87.40ms +step:609/1680 train_time:53227ms step_avg:87.40ms +step:610/1680 train_time:53315ms step_avg:87.40ms +step:611/1680 train_time:53403ms step_avg:87.40ms +step:612/1680 train_time:53493ms step_avg:87.41ms +step:613/1680 train_time:53581ms step_avg:87.41ms +step:614/1680 train_time:53669ms step_avg:87.41ms +step:615/1680 train_time:53757ms step_avg:87.41ms +step:616/1680 train_time:53846ms step_avg:87.41ms +step:617/1680 train_time:53935ms step_avg:87.41ms +step:618/1680 train_time:54023ms step_avg:87.42ms +step:619/1680 train_time:54112ms step_avg:87.42ms +step:620/1680 train_time:54200ms step_avg:87.42ms +step:621/1680 train_time:54288ms step_avg:87.42ms +step:622/1680 train_time:54377ms step_avg:87.42ms +step:623/1680 train_time:54465ms step_avg:87.42ms +step:624/1680 train_time:54554ms step_avg:87.43ms +step:625/1680 train_time:54642ms step_avg:87.43ms +step:625/1680 val_loss:3.6182 train_time:54732ms step_avg:87.57ms +step:626/1680 train_time:54759ms step_avg:87.47ms +step:627/1680 train_time:54820ms step_avg:87.43ms +step:628/1680 train_time:54909ms step_avg:87.44ms +step:629/1680 train_time:55000ms step_avg:87.44ms +step:630/1680 train_time:55089ms step_avg:87.44ms +step:631/1680 train_time:55176ms step_avg:87.44ms +step:632/1680 train_time:55263ms step_avg:87.44ms +step:633/1680 train_time:55350ms step_avg:87.44ms +step:634/1680 train_time:55437ms step_avg:87.44ms +step:635/1680 train_time:55524ms step_avg:87.44ms +step:636/1680 train_time:55612ms step_avg:87.44ms +step:637/1680 train_time:55706ms step_avg:87.45ms +step:638/1680 train_time:55796ms step_avg:87.45ms +step:639/1680 train_time:55885ms step_avg:87.46ms +step:640/1680 train_time:55973ms step_avg:87.46ms +step:641/1680 train_time:56061ms step_avg:87.46ms +step:642/1680 train_time:56149ms step_avg:87.46ms +step:643/1680 train_time:56237ms step_avg:87.46ms +step:644/1680 train_time:56325ms step_avg:87.46ms +step:645/1680 train_time:56412ms step_avg:87.46ms +step:646/1680 train_time:56500ms step_avg:87.46ms +step:647/1680 train_time:56589ms step_avg:87.46ms +step:648/1680 train_time:56679ms step_avg:87.47ms +step:649/1680 train_time:56767ms step_avg:87.47ms +step:650/1680 train_time:56856ms step_avg:87.47ms +step:651/1680 train_time:56945ms step_avg:87.47ms +step:652/1680 train_time:57033ms step_avg:87.47ms +step:653/1680 train_time:57121ms step_avg:87.47ms +step:654/1680 train_time:57209ms step_avg:87.48ms +step:655/1680 train_time:57297ms step_avg:87.48ms +step:656/1680 train_time:57384ms step_avg:87.48ms +step:657/1680 train_time:57472ms step_avg:87.48ms +step:658/1680 train_time:57561ms step_avg:87.48ms +step:659/1680 train_time:57649ms step_avg:87.48ms +step:660/1680 train_time:57738ms step_avg:87.48ms +step:661/1680 train_time:57826ms step_avg:87.48ms +step:662/1680 train_time:57915ms step_avg:87.48ms +step:663/1680 train_time:58004ms step_avg:87.49ms +step:664/1680 train_time:58092ms step_avg:87.49ms +step:665/1680 train_time:58180ms step_avg:87.49ms +step:666/1680 train_time:58268ms step_avg:87.49ms +step:667/1680 train_time:58355ms step_avg:87.49ms +step:668/1680 train_time:58444ms step_avg:87.49ms +step:669/1680 train_time:58532ms step_avg:87.49ms +step:670/1680 train_time:58620ms step_avg:87.49ms +step:671/1680 train_time:58709ms step_avg:87.49ms +step:672/1680 train_time:58797ms step_avg:87.50ms +step:673/1680 train_time:58887ms step_avg:87.50ms +step:674/1680 train_time:58976ms step_avg:87.50ms +step:675/1680 train_time:59064ms step_avg:87.50ms +step:676/1680 train_time:59152ms step_avg:87.50ms +step:677/1680 train_time:59241ms step_avg:87.50ms +step:678/1680 train_time:59329ms step_avg:87.51ms +step:679/1680 train_time:59417ms step_avg:87.51ms +step:680/1680 train_time:59506ms step_avg:87.51ms +step:681/1680 train_time:59594ms step_avg:87.51ms +step:682/1680 train_time:59682ms step_avg:87.51ms +step:683/1680 train_time:59770ms step_avg:87.51ms +step:684/1680 train_time:59860ms step_avg:87.51ms +step:685/1680 train_time:59949ms step_avg:87.52ms +step:686/1680 train_time:60038ms step_avg:87.52ms +step:687/1680 train_time:60126ms step_avg:87.52ms +step:688/1680 train_time:60213ms step_avg:87.52ms +step:689/1680 train_time:60302ms step_avg:87.52ms +step:690/1680 train_time:60390ms step_avg:87.52ms +step:691/1680 train_time:60478ms step_avg:87.52ms +step:692/1680 train_time:60567ms step_avg:87.52ms +step:693/1680 train_time:60656ms step_avg:87.53ms +step:694/1680 train_time:60743ms step_avg:87.53ms +step:695/1680 train_time:60832ms step_avg:87.53ms +step:696/1680 train_time:60920ms step_avg:87.53ms +step:697/1680 train_time:61009ms step_avg:87.53ms +step:698/1680 train_time:61097ms step_avg:87.53ms +step:699/1680 train_time:61185ms step_avg:87.53ms +step:700/1680 train_time:61274ms step_avg:87.53ms +step:701/1680 train_time:61362ms step_avg:87.53ms +step:702/1680 train_time:61450ms step_avg:87.54ms +step:703/1680 train_time:61538ms step_avg:87.54ms +step:704/1680 train_time:61626ms step_avg:87.54ms +step:705/1680 train_time:61714ms step_avg:87.54ms +step:706/1680 train_time:61803ms step_avg:87.54ms +step:707/1680 train_time:61891ms step_avg:87.54ms +step:708/1680 train_time:61980ms step_avg:87.54ms +step:709/1680 train_time:62068ms step_avg:87.54ms +step:710/1680 train_time:62157ms step_avg:87.55ms +step:711/1680 train_time:62246ms step_avg:87.55ms +step:712/1680 train_time:62334ms step_avg:87.55ms +step:713/1680 train_time:62422ms step_avg:87.55ms +step:714/1680 train_time:62510ms step_avg:87.55ms +step:715/1680 train_time:62599ms step_avg:87.55ms +step:716/1680 train_time:62687ms step_avg:87.55ms +step:717/1680 train_time:62776ms step_avg:87.55ms +step:718/1680 train_time:62865ms step_avg:87.56ms +step:719/1680 train_time:62952ms step_avg:87.56ms +step:720/1680 train_time:63041ms step_avg:87.56ms +step:721/1680 train_time:63130ms step_avg:87.56ms +step:722/1680 train_time:63218ms step_avg:87.56ms +step:723/1680 train_time:63306ms step_avg:87.56ms +step:724/1680 train_time:63394ms step_avg:87.56ms +step:725/1680 train_time:63482ms step_avg:87.56ms +step:726/1680 train_time:63570ms step_avg:87.56ms +step:727/1680 train_time:63659ms step_avg:87.56ms +step:728/1680 train_time:63748ms step_avg:87.57ms +step:729/1680 train_time:63837ms step_avg:87.57ms +step:730/1680 train_time:63925ms step_avg:87.57ms +step:731/1680 train_time:64013ms step_avg:87.57ms +step:732/1680 train_time:64102ms step_avg:87.57ms +step:733/1680 train_time:64190ms step_avg:87.57ms +step:734/1680 train_time:64278ms step_avg:87.57ms +step:735/1680 train_time:64366ms step_avg:87.57ms +step:736/1680 train_time:64454ms step_avg:87.57ms +step:737/1680 train_time:64542ms step_avg:87.57ms +step:738/1680 train_time:64630ms step_avg:87.57ms +step:739/1680 train_time:64718ms step_avg:87.58ms +step:740/1680 train_time:64807ms step_avg:87.58ms +step:741/1680 train_time:64895ms step_avg:87.58ms +step:742/1680 train_time:64984ms step_avg:87.58ms +step:743/1680 train_time:65072ms step_avg:87.58ms +step:744/1680 train_time:65161ms step_avg:87.58ms +step:745/1680 train_time:65249ms step_avg:87.58ms +step:746/1680 train_time:65338ms step_avg:87.58ms +step:747/1680 train_time:65427ms step_avg:87.59ms +step:748/1680 train_time:65515ms step_avg:87.59ms +step:749/1680 train_time:65604ms step_avg:87.59ms +step:750/1680 train_time:65692ms step_avg:87.59ms +step:750/1680 val_loss:3.5644 train_time:65782ms step_avg:87.71ms +step:751/1680 train_time:65805ms step_avg:87.62ms +step:752/1680 train_time:65872ms step_avg:87.60ms +step:753/1680 train_time:65964ms step_avg:87.60ms +step:754/1680 train_time:66054ms step_avg:87.61ms +step:755/1680 train_time:66143ms step_avg:87.61ms +step:756/1680 train_time:66230ms step_avg:87.61ms +step:757/1680 train_time:66318ms step_avg:87.61ms +step:758/1680 train_time:66405ms step_avg:87.61ms +step:759/1680 train_time:66492ms step_avg:87.60ms +step:760/1680 train_time:66580ms step_avg:87.60ms +step:761/1680 train_time:66667ms step_avg:87.60ms +step:762/1680 train_time:66756ms step_avg:87.61ms +step:763/1680 train_time:66846ms step_avg:87.61ms +step:764/1680 train_time:66937ms step_avg:87.61ms +step:765/1680 train_time:67028ms step_avg:87.62ms +step:766/1680 train_time:67117ms step_avg:87.62ms +step:767/1680 train_time:67206ms step_avg:87.62ms +step:768/1680 train_time:67294ms step_avg:87.62ms +step:769/1680 train_time:67381ms step_avg:87.62ms +step:770/1680 train_time:67469ms step_avg:87.62ms +step:771/1680 train_time:67556ms step_avg:87.62ms +step:772/1680 train_time:67643ms step_avg:87.62ms +step:773/1680 train_time:67731ms step_avg:87.62ms +step:774/1680 train_time:67821ms step_avg:87.62ms +step:775/1680 train_time:67910ms step_avg:87.63ms +step:776/1680 train_time:68000ms step_avg:87.63ms +step:777/1680 train_time:68088ms step_avg:87.63ms +step:778/1680 train_time:68177ms step_avg:87.63ms +step:779/1680 train_time:68264ms step_avg:87.63ms +step:780/1680 train_time:68352ms step_avg:87.63ms +step:781/1680 train_time:68440ms step_avg:87.63ms +step:782/1680 train_time:68527ms step_avg:87.63ms +step:783/1680 train_time:68616ms step_avg:87.63ms +step:784/1680 train_time:68704ms step_avg:87.63ms +step:785/1680 train_time:68792ms step_avg:87.63ms +step:786/1680 train_time:68881ms step_avg:87.64ms +step:787/1680 train_time:68970ms step_avg:87.64ms +step:788/1680 train_time:69059ms step_avg:87.64ms +step:789/1680 train_time:69147ms step_avg:87.64ms +step:790/1680 train_time:69236ms step_avg:87.64ms +step:791/1680 train_time:69324ms step_avg:87.64ms +step:792/1680 train_time:69412ms step_avg:87.64ms +step:793/1680 train_time:69501ms step_avg:87.64ms +step:794/1680 train_time:69588ms step_avg:87.64ms +step:795/1680 train_time:69676ms step_avg:87.64ms +step:796/1680 train_time:69765ms step_avg:87.64ms +step:797/1680 train_time:69854ms step_avg:87.65ms +step:798/1680 train_time:69943ms step_avg:87.65ms +step:799/1680 train_time:70031ms step_avg:87.65ms +step:800/1680 train_time:70120ms step_avg:87.65ms +step:801/1680 train_time:70208ms step_avg:87.65ms +step:802/1680 train_time:70297ms step_avg:87.65ms +step:803/1680 train_time:70385ms step_avg:87.65ms +step:804/1680 train_time:70473ms step_avg:87.65ms +step:805/1680 train_time:70562ms step_avg:87.65ms +step:806/1680 train_time:70650ms step_avg:87.65ms +step:807/1680 train_time:70739ms step_avg:87.66ms +step:808/1680 train_time:70827ms step_avg:87.66ms +step:809/1680 train_time:70915ms step_avg:87.66ms +step:810/1680 train_time:71004ms step_avg:87.66ms +step:811/1680 train_time:71093ms step_avg:87.66ms +step:812/1680 train_time:71181ms step_avg:87.66ms +step:813/1680 train_time:71270ms step_avg:87.66ms +step:814/1680 train_time:71358ms step_avg:87.66ms +step:815/1680 train_time:71446ms step_avg:87.66ms +step:816/1680 train_time:71535ms step_avg:87.66ms +step:817/1680 train_time:71623ms step_avg:87.67ms +step:818/1680 train_time:71711ms step_avg:87.67ms +step:819/1680 train_time:71800ms step_avg:87.67ms +step:820/1680 train_time:71888ms step_avg:87.67ms +step:821/1680 train_time:71977ms step_avg:87.67ms +step:822/1680 train_time:72066ms step_avg:87.67ms +step:823/1680 train_time:72155ms step_avg:87.67ms +step:824/1680 train_time:72244ms step_avg:87.67ms +step:825/1680 train_time:72332ms step_avg:87.68ms +step:826/1680 train_time:72420ms step_avg:87.68ms +step:827/1680 train_time:72508ms step_avg:87.68ms +step:828/1680 train_time:72596ms step_avg:87.68ms +step:829/1680 train_time:72684ms step_avg:87.68ms +step:830/1680 train_time:72772ms step_avg:87.68ms +step:831/1680 train_time:72860ms step_avg:87.68ms +step:832/1680 train_time:72948ms step_avg:87.68ms +step:833/1680 train_time:73037ms step_avg:87.68ms +step:834/1680 train_time:73126ms step_avg:87.68ms +step:835/1680 train_time:73214ms step_avg:87.68ms +step:836/1680 train_time:73304ms step_avg:87.68ms +step:837/1680 train_time:73392ms step_avg:87.68ms +step:838/1680 train_time:73480ms step_avg:87.69ms +step:839/1680 train_time:73568ms step_avg:87.69ms +step:840/1680 train_time:73656ms step_avg:87.69ms +step:841/1680 train_time:73745ms step_avg:87.69ms +step:842/1680 train_time:73833ms step_avg:87.69ms +step:843/1680 train_time:73921ms step_avg:87.69ms +step:844/1680 train_time:74009ms step_avg:87.69ms +step:845/1680 train_time:74098ms step_avg:87.69ms +step:846/1680 train_time:74186ms step_avg:87.69ms +step:847/1680 train_time:74274ms step_avg:87.69ms +step:848/1680 train_time:74364ms step_avg:87.69ms +step:849/1680 train_time:74452ms step_avg:87.69ms +step:850/1680 train_time:74541ms step_avg:87.70ms +step:851/1680 train_time:74628ms step_avg:87.69ms +step:852/1680 train_time:74716ms step_avg:87.70ms +step:853/1680 train_time:74805ms step_avg:87.70ms +step:854/1680 train_time:74894ms step_avg:87.70ms +step:855/1680 train_time:74982ms step_avg:87.70ms +step:856/1680 train_time:75070ms step_avg:87.70ms +step:857/1680 train_time:75158ms step_avg:87.70ms +step:858/1680 train_time:75246ms step_avg:87.70ms +step:859/1680 train_time:75335ms step_avg:87.70ms +step:860/1680 train_time:75423ms step_avg:87.70ms +step:861/1680 train_time:75511ms step_avg:87.70ms +step:862/1680 train_time:75600ms step_avg:87.70ms +step:863/1680 train_time:75688ms step_avg:87.70ms +step:864/1680 train_time:75776ms step_avg:87.70ms +step:865/1680 train_time:75864ms step_avg:87.70ms +step:866/1680 train_time:75952ms step_avg:87.70ms +step:867/1680 train_time:76041ms step_avg:87.71ms +step:868/1680 train_time:76129ms step_avg:87.71ms +step:869/1680 train_time:76217ms step_avg:87.71ms +step:870/1680 train_time:76306ms step_avg:87.71ms +step:871/1680 train_time:76395ms step_avg:87.71ms +step:872/1680 train_time:76483ms step_avg:87.71ms +step:873/1680 train_time:76572ms step_avg:87.71ms +step:874/1680 train_time:76660ms step_avg:87.71ms +step:875/1680 train_time:76748ms step_avg:87.71ms +step:875/1680 val_loss:3.5182 train_time:76838ms step_avg:87.81ms +step:876/1680 train_time:76856ms step_avg:87.74ms +step:877/1680 train_time:76931ms step_avg:87.72ms +step:878/1680 train_time:77022ms step_avg:87.72ms +step:879/1680 train_time:77113ms step_avg:87.73ms +step:880/1680 train_time:77201ms step_avg:87.73ms +step:881/1680 train_time:77289ms step_avg:87.73ms +step:882/1680 train_time:77375ms step_avg:87.73ms +step:883/1680 train_time:77463ms step_avg:87.73ms +step:884/1680 train_time:77551ms step_avg:87.73ms +step:885/1680 train_time:77639ms step_avg:87.73ms +step:886/1680 train_time:77727ms step_avg:87.73ms +step:887/1680 train_time:77817ms step_avg:87.73ms +step:888/1680 train_time:77906ms step_avg:87.73ms +step:889/1680 train_time:77997ms step_avg:87.74ms +step:890/1680 train_time:78087ms step_avg:87.74ms +step:891/1680 train_time:78177ms step_avg:87.74ms +step:892/1680 train_time:78264ms step_avg:87.74ms +step:893/1680 train_time:78352ms step_avg:87.74ms +step:894/1680 train_time:78439ms step_avg:87.74ms +step:895/1680 train_time:78527ms step_avg:87.74ms +step:896/1680 train_time:78614ms step_avg:87.74ms +step:897/1680 train_time:78702ms step_avg:87.74ms +step:898/1680 train_time:78791ms step_avg:87.74ms +step:899/1680 train_time:78879ms step_avg:87.74ms +step:900/1680 train_time:78970ms step_avg:87.74ms +step:901/1680 train_time:79060ms step_avg:87.75ms +step:902/1680 train_time:79149ms step_avg:87.75ms +step:903/1680 train_time:79238ms step_avg:87.75ms +step:904/1680 train_time:79326ms step_avg:87.75ms +step:905/1680 train_time:79414ms step_avg:87.75ms +step:906/1680 train_time:79501ms step_avg:87.75ms +step:907/1680 train_time:79590ms step_avg:87.75ms +step:908/1680 train_time:79678ms step_avg:87.75ms +step:909/1680 train_time:79766ms step_avg:87.75ms +step:910/1680 train_time:79854ms step_avg:87.75ms +step:911/1680 train_time:79943ms step_avg:87.75ms +step:912/1680 train_time:80032ms step_avg:87.75ms +step:913/1680 train_time:80121ms step_avg:87.76ms +step:914/1680 train_time:80209ms step_avg:87.76ms +step:915/1680 train_time:80299ms step_avg:87.76ms +step:916/1680 train_time:80387ms step_avg:87.76ms +step:917/1680 train_time:80475ms step_avg:87.76ms +step:918/1680 train_time:80563ms step_avg:87.76ms +step:919/1680 train_time:80652ms step_avg:87.76ms +step:920/1680 train_time:80740ms step_avg:87.76ms +step:921/1680 train_time:80828ms step_avg:87.76ms +step:922/1680 train_time:80917ms step_avg:87.76ms +step:923/1680 train_time:81005ms step_avg:87.76ms +step:924/1680 train_time:81094ms step_avg:87.76ms +step:925/1680 train_time:81182ms step_avg:87.76ms +step:926/1680 train_time:81272ms step_avg:87.77ms +step:927/1680 train_time:81361ms step_avg:87.77ms +step:928/1680 train_time:81449ms step_avg:87.77ms +step:929/1680 train_time:81537ms step_avg:87.77ms +step:930/1680 train_time:81626ms step_avg:87.77ms +step:931/1680 train_time:81715ms step_avg:87.77ms +step:932/1680 train_time:81803ms step_avg:87.77ms +step:933/1680 train_time:81893ms step_avg:87.77ms +step:934/1680 train_time:81982ms step_avg:87.77ms +step:935/1680 train_time:82070ms step_avg:87.78ms +step:936/1680 train_time:82160ms step_avg:87.78ms +step:937/1680 train_time:82248ms step_avg:87.78ms +step:938/1680 train_time:82336ms step_avg:87.78ms +step:939/1680 train_time:82425ms step_avg:87.78ms +step:940/1680 train_time:82513ms step_avg:87.78ms +step:941/1680 train_time:82601ms step_avg:87.78ms +step:942/1680 train_time:82689ms step_avg:87.78ms +step:943/1680 train_time:82778ms step_avg:87.78ms +step:944/1680 train_time:82866ms step_avg:87.78ms +step:945/1680 train_time:82956ms step_avg:87.78ms +step:946/1680 train_time:83044ms step_avg:87.78ms +step:947/1680 train_time:83133ms step_avg:87.79ms +step:948/1680 train_time:83221ms step_avg:87.79ms +step:949/1680 train_time:83311ms step_avg:87.79ms +step:950/1680 train_time:83399ms step_avg:87.79ms +step:951/1680 train_time:83488ms step_avg:87.79ms +step:952/1680 train_time:83577ms step_avg:87.79ms +step:953/1680 train_time:83665ms step_avg:87.79ms +step:954/1680 train_time:83754ms step_avg:87.79ms +step:955/1680 train_time:83842ms step_avg:87.79ms +step:956/1680 train_time:83931ms step_avg:87.79ms +step:957/1680 train_time:84020ms step_avg:87.79ms +step:958/1680 train_time:84108ms step_avg:87.80ms +step:959/1680 train_time:84197ms step_avg:87.80ms +step:960/1680 train_time:84286ms step_avg:87.80ms +step:961/1680 train_time:84374ms step_avg:87.80ms +step:962/1680 train_time:84462ms step_avg:87.80ms +step:963/1680 train_time:84551ms step_avg:87.80ms +step:964/1680 train_time:84639ms step_avg:87.80ms +step:965/1680 train_time:84728ms step_avg:87.80ms +step:966/1680 train_time:84816ms step_avg:87.80ms +step:967/1680 train_time:84905ms step_avg:87.80ms +step:968/1680 train_time:84994ms step_avg:87.80ms +step:969/1680 train_time:85083ms step_avg:87.80ms +step:970/1680 train_time:85172ms step_avg:87.81ms +step:971/1680 train_time:85261ms step_avg:87.81ms +step:972/1680 train_time:85351ms step_avg:87.81ms +step:973/1680 train_time:85439ms step_avg:87.81ms +step:974/1680 train_time:85528ms step_avg:87.81ms +step:975/1680 train_time:85616ms step_avg:87.81ms +step:976/1680 train_time:85704ms step_avg:87.81ms +step:977/1680 train_time:85792ms step_avg:87.81ms +step:978/1680 train_time:85880ms step_avg:87.81ms +step:979/1680 train_time:85969ms step_avg:87.81ms +step:980/1680 train_time:86058ms step_avg:87.81ms +step:981/1680 train_time:86147ms step_avg:87.82ms +step:982/1680 train_time:86236ms step_avg:87.82ms +step:983/1680 train_time:86324ms step_avg:87.82ms +step:984/1680 train_time:86413ms step_avg:87.82ms +step:985/1680 train_time:86501ms step_avg:87.82ms +step:986/1680 train_time:86591ms step_avg:87.82ms +step:987/1680 train_time:86679ms step_avg:87.82ms +step:988/1680 train_time:86768ms step_avg:87.82ms +step:989/1680 train_time:86858ms step_avg:87.82ms +step:990/1680 train_time:86946ms step_avg:87.82ms +step:991/1680 train_time:87034ms step_avg:87.82ms +step:992/1680 train_time:87123ms step_avg:87.83ms +step:993/1680 train_time:87212ms step_avg:87.83ms +step:994/1680 train_time:87301ms step_avg:87.83ms +step:995/1680 train_time:87389ms step_avg:87.83ms +step:996/1680 train_time:87477ms step_avg:87.83ms +step:997/1680 train_time:87566ms step_avg:87.83ms +step:998/1680 train_time:87653ms step_avg:87.83ms +step:999/1680 train_time:87742ms step_avg:87.83ms +step:1000/1680 train_time:87830ms step_avg:87.83ms +step:1000/1680 val_loss:3.4685 train_time:87920ms step_avg:87.92ms +step:1001/1680 train_time:87942ms step_avg:87.85ms +step:1002/1680 train_time:88013ms step_avg:87.84ms +step:1003/1680 train_time:88104ms step_avg:87.84ms +step:1004/1680 train_time:88192ms step_avg:87.84ms +step:1005/1680 train_time:88279ms step_avg:87.84ms +step:1006/1680 train_time:88367ms step_avg:87.84ms +step:1007/1680 train_time:88454ms step_avg:87.84ms +step:1008/1680 train_time:88542ms step_avg:87.84ms +step:1009/1680 train_time:88630ms step_avg:87.84ms +step:1010/1680 train_time:88717ms step_avg:87.84ms +step:1011/1680 train_time:88805ms step_avg:87.84ms +step:1012/1680 train_time:88894ms step_avg:87.84ms +step:1013/1680 train_time:88985ms step_avg:87.84ms +step:1014/1680 train_time:89074ms step_avg:87.84ms +step:1015/1680 train_time:89164ms step_avg:87.85ms +step:1016/1680 train_time:89252ms step_avg:87.85ms +step:1017/1680 train_time:89340ms step_avg:87.85ms +step:1018/1680 train_time:89428ms step_avg:87.85ms +step:1019/1680 train_time:89516ms step_avg:87.85ms +step:1020/1680 train_time:89604ms step_avg:87.85ms +step:1021/1680 train_time:89692ms step_avg:87.85ms +step:1022/1680 train_time:89781ms step_avg:87.85ms +step:1023/1680 train_time:89870ms step_avg:87.85ms +step:1024/1680 train_time:89959ms step_avg:87.85ms +step:1025/1680 train_time:90050ms step_avg:87.85ms +step:1026/1680 train_time:90139ms step_avg:87.85ms +step:1027/1680 train_time:90228ms step_avg:87.86ms +step:1028/1680 train_time:90317ms step_avg:87.86ms +step:1029/1680 train_time:90405ms step_avg:87.86ms +step:1030/1680 train_time:90493ms step_avg:87.86ms +step:1031/1680 train_time:90580ms step_avg:87.86ms +step:1032/1680 train_time:90668ms step_avg:87.86ms +step:1033/1680 train_time:90756ms step_avg:87.86ms +step:1034/1680 train_time:90845ms step_avg:87.86ms +step:1035/1680 train_time:90933ms step_avg:87.86ms +step:1036/1680 train_time:91023ms step_avg:87.86ms +step:1037/1680 train_time:91112ms step_avg:87.86ms +step:1038/1680 train_time:91200ms step_avg:87.86ms +step:1039/1680 train_time:91289ms step_avg:87.86ms +step:1040/1680 train_time:91377ms step_avg:87.86ms +step:1041/1680 train_time:91465ms step_avg:87.86ms +step:1042/1680 train_time:91554ms step_avg:87.86ms +step:1043/1680 train_time:91642ms step_avg:87.86ms +step:1044/1680 train_time:91732ms step_avg:87.87ms +step:1045/1680 train_time:91820ms step_avg:87.87ms +step:1046/1680 train_time:91909ms step_avg:87.87ms +step:1047/1680 train_time:91997ms step_avg:87.87ms +step:1048/1680 train_time:92087ms step_avg:87.87ms +step:1049/1680 train_time:92175ms step_avg:87.87ms +step:1050/1680 train_time:92264ms step_avg:87.87ms +step:1051/1680 train_time:92352ms step_avg:87.87ms +step:1052/1680 train_time:92442ms step_avg:87.87ms +step:1053/1680 train_time:92529ms step_avg:87.87ms +step:1054/1680 train_time:92617ms step_avg:87.87ms +step:1055/1680 train_time:92706ms step_avg:87.87ms +step:1056/1680 train_time:92794ms step_avg:87.87ms +step:1057/1680 train_time:92882ms step_avg:87.87ms +step:1058/1680 train_time:92971ms step_avg:87.87ms +step:1059/1680 train_time:93061ms step_avg:87.88ms +step:1060/1680 train_time:93150ms step_avg:87.88ms +step:1061/1680 train_time:93239ms step_avg:87.88ms +step:1062/1680 train_time:93328ms step_avg:87.88ms +step:1063/1680 train_time:93416ms step_avg:87.88ms +step:1064/1680 train_time:93504ms step_avg:87.88ms +step:1065/1680 train_time:93593ms step_avg:87.88ms +step:1066/1680 train_time:93682ms step_avg:87.88ms +step:1067/1680 train_time:93770ms step_avg:87.88ms +step:1068/1680 train_time:93858ms step_avg:87.88ms +step:1069/1680 train_time:93947ms step_avg:87.88ms +step:1070/1680 train_time:94036ms step_avg:87.88ms +step:1071/1680 train_time:94124ms step_avg:87.88ms +step:1072/1680 train_time:94214ms step_avg:87.89ms +step:1073/1680 train_time:94302ms step_avg:87.89ms +step:1074/1680 train_time:94391ms step_avg:87.89ms +step:1075/1680 train_time:94480ms step_avg:87.89ms +step:1076/1680 train_time:94568ms step_avg:87.89ms +step:1077/1680 train_time:94657ms step_avg:87.89ms +step:1078/1680 train_time:94744ms step_avg:87.89ms +step:1079/1680 train_time:94833ms step_avg:87.89ms +step:1080/1680 train_time:94922ms step_avg:87.89ms +step:1081/1680 train_time:95011ms step_avg:87.89ms +step:1082/1680 train_time:95100ms step_avg:87.89ms +step:1083/1680 train_time:95188ms step_avg:87.89ms +step:1084/1680 train_time:95277ms step_avg:87.89ms +step:1085/1680 train_time:95366ms step_avg:87.89ms +step:1086/1680 train_time:95454ms step_avg:87.89ms +step:1087/1680 train_time:95543ms step_avg:87.90ms +step:1088/1680 train_time:95631ms step_avg:87.90ms +step:1089/1680 train_time:95720ms step_avg:87.90ms +step:1090/1680 train_time:95808ms step_avg:87.90ms +step:1091/1680 train_time:95896ms step_avg:87.90ms +step:1092/1680 train_time:95985ms step_avg:87.90ms +step:1093/1680 train_time:96073ms step_avg:87.90ms +step:1094/1680 train_time:96163ms step_avg:87.90ms +step:1095/1680 train_time:96251ms step_avg:87.90ms +step:1096/1680 train_time:96341ms step_avg:87.90ms +step:1097/1680 train_time:96430ms step_avg:87.90ms +step:1098/1680 train_time:96520ms step_avg:87.91ms +step:1099/1680 train_time:96609ms step_avg:87.91ms +step:1100/1680 train_time:96697ms step_avg:87.91ms +step:1101/1680 train_time:96786ms step_avg:87.91ms +step:1102/1680 train_time:96875ms step_avg:87.91ms +step:1103/1680 train_time:96964ms step_avg:87.91ms +step:1104/1680 train_time:97053ms step_avg:87.91ms +step:1105/1680 train_time:97142ms step_avg:87.91ms +step:1106/1680 train_time:97232ms step_avg:87.91ms +step:1107/1680 train_time:97322ms step_avg:87.92ms +step:1108/1680 train_time:97413ms step_avg:87.92ms +step:1109/1680 train_time:97503ms step_avg:87.92ms +step:1110/1680 train_time:97592ms step_avg:87.92ms +step:1111/1680 train_time:97682ms step_avg:87.92ms +step:1112/1680 train_time:97770ms step_avg:87.92ms +step:1113/1680 train_time:97859ms step_avg:87.92ms +step:1114/1680 train_time:97948ms step_avg:87.92ms +step:1115/1680 train_time:98037ms step_avg:87.93ms +step:1116/1680 train_time:98126ms step_avg:87.93ms +step:1117/1680 train_time:98215ms step_avg:87.93ms +step:1118/1680 train_time:98304ms step_avg:87.93ms +step:1119/1680 train_time:98393ms step_avg:87.93ms +step:1120/1680 train_time:98482ms step_avg:87.93ms +step:1121/1680 train_time:98571ms step_avg:87.93ms +step:1122/1680 train_time:98661ms step_avg:87.93ms +step:1123/1680 train_time:98750ms step_avg:87.93ms +step:1124/1680 train_time:98839ms step_avg:87.93ms +step:1125/1680 train_time:98929ms step_avg:87.94ms +step:1125/1680 val_loss:3.4150 train_time:99019ms step_avg:88.02ms +step:1126/1680 train_time:99041ms step_avg:87.96ms +step:1127/1680 train_time:99111ms step_avg:87.94ms +step:1128/1680 train_time:99203ms step_avg:87.95ms +step:1129/1680 train_time:99295ms step_avg:87.95ms +step:1130/1680 train_time:99384ms step_avg:87.95ms +step:1131/1680 train_time:99472ms step_avg:87.95ms +step:1132/1680 train_time:99560ms step_avg:87.95ms +step:1133/1680 train_time:99648ms step_avg:87.95ms +step:1134/1680 train_time:99736ms step_avg:87.95ms +step:1135/1680 train_time:99825ms step_avg:87.95ms +step:1136/1680 train_time:99914ms step_avg:87.95ms +step:1137/1680 train_time:100005ms step_avg:87.95ms +step:1138/1680 train_time:100096ms step_avg:87.96ms +step:1139/1680 train_time:100187ms step_avg:87.96ms +step:1140/1680 train_time:100279ms step_avg:87.96ms +step:1141/1680 train_time:100368ms step_avg:87.97ms +step:1142/1680 train_time:100457ms step_avg:87.97ms +step:1143/1680 train_time:100545ms step_avg:87.97ms +step:1144/1680 train_time:100633ms step_avg:87.97ms +step:1145/1680 train_time:100722ms step_avg:87.97ms +step:1146/1680 train_time:100810ms step_avg:87.97ms +step:1147/1680 train_time:100898ms step_avg:87.97ms +step:1148/1680 train_time:100988ms step_avg:87.97ms +step:1149/1680 train_time:101078ms step_avg:87.97ms +step:1150/1680 train_time:101168ms step_avg:87.97ms +step:1151/1680 train_time:101258ms step_avg:87.97ms +step:1152/1680 train_time:101347ms step_avg:87.98ms +step:1153/1680 train_time:101437ms step_avg:87.98ms +step:1154/1680 train_time:101526ms step_avg:87.98ms +step:1155/1680 train_time:101615ms step_avg:87.98ms +step:1156/1680 train_time:101703ms step_avg:87.98ms +step:1157/1680 train_time:101792ms step_avg:87.98ms +step:1158/1680 train_time:101881ms step_avg:87.98ms +step:1159/1680 train_time:101969ms step_avg:87.98ms +step:1160/1680 train_time:102059ms step_avg:87.98ms +step:1161/1680 train_time:102149ms step_avg:87.98ms +step:1162/1680 train_time:102238ms step_avg:87.98ms +step:1163/1680 train_time:102328ms step_avg:87.99ms +step:1164/1680 train_time:102417ms step_avg:87.99ms +step:1165/1680 train_time:102506ms step_avg:87.99ms +step:1166/1680 train_time:102595ms step_avg:87.99ms +step:1167/1680 train_time:102684ms step_avg:87.99ms +step:1168/1680 train_time:102773ms step_avg:87.99ms +step:1169/1680 train_time:102862ms step_avg:87.99ms +step:1170/1680 train_time:102950ms step_avg:87.99ms +step:1171/1680 train_time:103040ms step_avg:87.99ms +step:1172/1680 train_time:103130ms step_avg:87.99ms +step:1173/1680 train_time:103219ms step_avg:88.00ms +step:1174/1680 train_time:103309ms step_avg:88.00ms +step:1175/1680 train_time:103399ms step_avg:88.00ms +step:1176/1680 train_time:103487ms step_avg:88.00ms +step:1177/1680 train_time:103577ms step_avg:88.00ms +step:1178/1680 train_time:103666ms step_avg:88.00ms +step:1179/1680 train_time:103754ms step_avg:88.00ms +step:1180/1680 train_time:103843ms step_avg:88.00ms +step:1181/1680 train_time:103932ms step_avg:88.00ms +step:1182/1680 train_time:104021ms step_avg:88.00ms +step:1183/1680 train_time:104110ms step_avg:88.00ms +step:1184/1680 train_time:104199ms step_avg:88.01ms +step:1185/1680 train_time:104288ms step_avg:88.01ms +step:1186/1680 train_time:104378ms step_avg:88.01ms +step:1187/1680 train_time:104468ms step_avg:88.01ms +step:1188/1680 train_time:104557ms step_avg:88.01ms +step:1189/1680 train_time:104646ms step_avg:88.01ms +step:1190/1680 train_time:104735ms step_avg:88.01ms +step:1191/1680 train_time:104824ms step_avg:88.01ms +step:1192/1680 train_time:104913ms step_avg:88.01ms +step:1193/1680 train_time:105002ms step_avg:88.02ms +step:1194/1680 train_time:105091ms step_avg:88.02ms +step:1195/1680 train_time:105180ms step_avg:88.02ms +step:1196/1680 train_time:105269ms step_avg:88.02ms +step:1197/1680 train_time:105359ms step_avg:88.02ms +step:1198/1680 train_time:105448ms step_avg:88.02ms +step:1199/1680 train_time:105537ms step_avg:88.02ms +step:1200/1680 train_time:105625ms step_avg:88.02ms +step:1201/1680 train_time:105715ms step_avg:88.02ms +step:1202/1680 train_time:105804ms step_avg:88.02ms +step:1203/1680 train_time:105894ms step_avg:88.02ms +step:1204/1680 train_time:105983ms step_avg:88.03ms +step:1205/1680 train_time:106072ms step_avg:88.03ms +step:1206/1680 train_time:106161ms step_avg:88.03ms +step:1207/1680 train_time:106250ms step_avg:88.03ms +step:1208/1680 train_time:106340ms step_avg:88.03ms +step:1209/1680 train_time:106428ms step_avg:88.03ms +step:1210/1680 train_time:106517ms step_avg:88.03ms +step:1211/1680 train_time:106607ms step_avg:88.03ms +step:1212/1680 train_time:106696ms step_avg:88.03ms +step:1213/1680 train_time:106786ms step_avg:88.03ms +step:1214/1680 train_time:106875ms step_avg:88.04ms +step:1215/1680 train_time:106965ms step_avg:88.04ms +step:1216/1680 train_time:107053ms step_avg:88.04ms +step:1217/1680 train_time:107142ms step_avg:88.04ms +step:1218/1680 train_time:107231ms step_avg:88.04ms +step:1219/1680 train_time:107321ms step_avg:88.04ms +step:1220/1680 train_time:107411ms step_avg:88.04ms +step:1221/1680 train_time:107500ms step_avg:88.04ms +step:1222/1680 train_time:107590ms step_avg:88.04ms +step:1223/1680 train_time:107678ms step_avg:88.04ms +step:1224/1680 train_time:107767ms step_avg:88.05ms +step:1225/1680 train_time:107857ms step_avg:88.05ms +step:1226/1680 train_time:107946ms step_avg:88.05ms +step:1227/1680 train_time:108035ms step_avg:88.05ms +step:1228/1680 train_time:108124ms step_avg:88.05ms +step:1229/1680 train_time:108213ms step_avg:88.05ms +step:1230/1680 train_time:108302ms step_avg:88.05ms +step:1231/1680 train_time:108392ms step_avg:88.05ms +step:1232/1680 train_time:108481ms step_avg:88.05ms +step:1233/1680 train_time:108570ms step_avg:88.05ms +step:1234/1680 train_time:108659ms step_avg:88.05ms +step:1235/1680 train_time:108748ms step_avg:88.06ms +step:1236/1680 train_time:108838ms step_avg:88.06ms +step:1237/1680 train_time:108927ms step_avg:88.06ms +step:1238/1680 train_time:109016ms step_avg:88.06ms +step:1239/1680 train_time:109105ms step_avg:88.06ms +step:1240/1680 train_time:109194ms step_avg:88.06ms +step:1241/1680 train_time:109283ms step_avg:88.06ms +step:1242/1680 train_time:109372ms step_avg:88.06ms +step:1243/1680 train_time:109463ms step_avg:88.06ms +step:1244/1680 train_time:109552ms step_avg:88.06ms +step:1245/1680 train_time:109642ms step_avg:88.07ms +step:1246/1680 train_time:109731ms step_avg:88.07ms +step:1247/1680 train_time:109820ms step_avg:88.07ms +step:1248/1680 train_time:109909ms step_avg:88.07ms +step:1249/1680 train_time:109999ms step_avg:88.07ms +step:1250/1680 train_time:110089ms step_avg:88.07ms +step:1250/1680 val_loss:3.3770 train_time:110179ms step_avg:88.14ms +step:1251/1680 train_time:110200ms step_avg:88.09ms +step:1252/1680 train_time:110271ms step_avg:88.08ms +step:1253/1680 train_time:110361ms step_avg:88.08ms +step:1254/1680 train_time:110452ms step_avg:88.08ms +step:1255/1680 train_time:110540ms step_avg:88.08ms +step:1256/1680 train_time:110628ms step_avg:88.08ms +step:1257/1680 train_time:110716ms step_avg:88.08ms +step:1258/1680 train_time:110805ms step_avg:88.08ms +step:1259/1680 train_time:110893ms step_avg:88.08ms +step:1260/1680 train_time:110983ms step_avg:88.08ms +step:1261/1680 train_time:111072ms step_avg:88.08ms +step:1262/1680 train_time:111163ms step_avg:88.08ms +step:1263/1680 train_time:111253ms step_avg:88.09ms +step:1264/1680 train_time:111343ms step_avg:88.09ms +step:1265/1680 train_time:111433ms step_avg:88.09ms +step:1266/1680 train_time:111522ms step_avg:88.09ms +step:1267/1680 train_time:111611ms step_avg:88.09ms +step:1268/1680 train_time:111700ms step_avg:88.09ms +step:1269/1680 train_time:111789ms step_avg:88.09ms +step:1270/1680 train_time:111878ms step_avg:88.09ms +step:1271/1680 train_time:111967ms step_avg:88.09ms +step:1272/1680 train_time:112056ms step_avg:88.09ms +step:1273/1680 train_time:112146ms step_avg:88.10ms +step:1274/1680 train_time:112237ms step_avg:88.10ms +step:1275/1680 train_time:112326ms step_avg:88.10ms +step:1276/1680 train_time:112416ms step_avg:88.10ms +step:1277/1680 train_time:112505ms step_avg:88.10ms +step:1278/1680 train_time:112594ms step_avg:88.10ms +step:1279/1680 train_time:112683ms step_avg:88.10ms +step:1280/1680 train_time:112771ms step_avg:88.10ms +step:1281/1680 train_time:112860ms step_avg:88.10ms +step:1282/1680 train_time:112949ms step_avg:88.10ms +step:1283/1680 train_time:113038ms step_avg:88.10ms +step:1284/1680 train_time:113128ms step_avg:88.11ms +step:1285/1680 train_time:113218ms step_avg:88.11ms +step:1286/1680 train_time:113308ms step_avg:88.11ms +step:1287/1680 train_time:113398ms step_avg:88.11ms +step:1288/1680 train_time:113487ms step_avg:88.11ms +step:1289/1680 train_time:113576ms step_avg:88.11ms +step:1290/1680 train_time:113665ms step_avg:88.11ms +step:1291/1680 train_time:113754ms step_avg:88.11ms +step:1292/1680 train_time:113844ms step_avg:88.11ms +step:1293/1680 train_time:113933ms step_avg:88.12ms +step:1294/1680 train_time:114023ms step_avg:88.12ms +step:1295/1680 train_time:114112ms step_avg:88.12ms +step:1296/1680 train_time:114203ms step_avg:88.12ms +step:1297/1680 train_time:114291ms step_avg:88.12ms +step:1298/1680 train_time:114381ms step_avg:88.12ms +step:1299/1680 train_time:114471ms step_avg:88.12ms +step:1300/1680 train_time:114560ms step_avg:88.12ms +step:1301/1680 train_time:114650ms step_avg:88.12ms +step:1302/1680 train_time:114738ms step_avg:88.12ms +step:1303/1680 train_time:114827ms step_avg:88.13ms +step:1304/1680 train_time:114918ms step_avg:88.13ms +step:1305/1680 train_time:115007ms step_avg:88.13ms +step:1306/1680 train_time:115096ms step_avg:88.13ms +step:1307/1680 train_time:115185ms step_avg:88.13ms +step:1308/1680 train_time:115274ms step_avg:88.13ms +step:1309/1680 train_time:115364ms step_avg:88.13ms +step:1310/1680 train_time:115453ms step_avg:88.13ms +step:1311/1680 train_time:115543ms step_avg:88.13ms +step:1312/1680 train_time:115632ms step_avg:88.13ms +step:1313/1680 train_time:115722ms step_avg:88.14ms +step:1314/1680 train_time:115811ms step_avg:88.14ms +step:1315/1680 train_time:115900ms step_avg:88.14ms +step:1316/1680 train_time:115989ms step_avg:88.14ms +step:1317/1680 train_time:116078ms step_avg:88.14ms +step:1318/1680 train_time:116168ms step_avg:88.14ms +step:1319/1680 train_time:116258ms step_avg:88.14ms +step:1320/1680 train_time:116347ms step_avg:88.14ms +step:1321/1680 train_time:116437ms step_avg:88.14ms +step:1322/1680 train_time:116526ms step_avg:88.14ms +step:1323/1680 train_time:116616ms step_avg:88.15ms +step:1324/1680 train_time:116705ms step_avg:88.15ms +step:1325/1680 train_time:116795ms step_avg:88.15ms +step:1326/1680 train_time:116884ms step_avg:88.15ms +step:1327/1680 train_time:116973ms step_avg:88.15ms +step:1328/1680 train_time:117061ms step_avg:88.15ms +step:1329/1680 train_time:117150ms step_avg:88.15ms +step:1330/1680 train_time:117240ms step_avg:88.15ms +step:1331/1680 train_time:117330ms step_avg:88.15ms +step:1332/1680 train_time:117420ms step_avg:88.15ms +step:1333/1680 train_time:117509ms step_avg:88.15ms +step:1334/1680 train_time:117598ms step_avg:88.15ms +step:1335/1680 train_time:117687ms step_avg:88.16ms +step:1336/1680 train_time:117777ms step_avg:88.16ms +step:1337/1680 train_time:117867ms step_avg:88.16ms +step:1338/1680 train_time:117956ms step_avg:88.16ms +step:1339/1680 train_time:118045ms step_avg:88.16ms +step:1340/1680 train_time:118134ms step_avg:88.16ms +step:1341/1680 train_time:118223ms step_avg:88.16ms +step:1342/1680 train_time:118313ms step_avg:88.16ms +step:1343/1680 train_time:118402ms step_avg:88.16ms +step:1344/1680 train_time:118492ms step_avg:88.16ms +step:1345/1680 train_time:118581ms step_avg:88.16ms +step:1346/1680 train_time:118669ms step_avg:88.16ms +step:1347/1680 train_time:118759ms step_avg:88.17ms +step:1348/1680 train_time:118848ms step_avg:88.17ms +step:1349/1680 train_time:118938ms step_avg:88.17ms +step:1350/1680 train_time:119027ms step_avg:88.17ms +step:1351/1680 train_time:119116ms step_avg:88.17ms +step:1352/1680 train_time:119206ms step_avg:88.17ms +step:1353/1680 train_time:119296ms step_avg:88.17ms +step:1354/1680 train_time:119385ms step_avg:88.17ms +step:1355/1680 train_time:119474ms step_avg:88.17ms +step:1356/1680 train_time:119563ms step_avg:88.17ms +step:1357/1680 train_time:119652ms step_avg:88.17ms +step:1358/1680 train_time:119741ms step_avg:88.17ms +step:1359/1680 train_time:119830ms step_avg:88.18ms +step:1360/1680 train_time:119919ms step_avg:88.18ms +step:1361/1680 train_time:120008ms step_avg:88.18ms +step:1362/1680 train_time:120097ms step_avg:88.18ms +step:1363/1680 train_time:120186ms step_avg:88.18ms +step:1364/1680 train_time:120276ms step_avg:88.18ms +step:1365/1680 train_time:120365ms step_avg:88.18ms +step:1366/1680 train_time:120454ms step_avg:88.18ms +step:1367/1680 train_time:120543ms step_avg:88.18ms +step:1368/1680 train_time:120633ms step_avg:88.18ms +step:1369/1680 train_time:120723ms step_avg:88.18ms +step:1370/1680 train_time:120813ms step_avg:88.18ms +step:1371/1680 train_time:120902ms step_avg:88.19ms +step:1372/1680 train_time:120992ms step_avg:88.19ms +step:1373/1680 train_time:121080ms step_avg:88.19ms +step:1374/1680 train_time:121170ms step_avg:88.19ms +step:1375/1680 train_time:121258ms step_avg:88.19ms +step:1375/1680 val_loss:3.3421 train_time:121349ms step_avg:88.25ms +step:1376/1680 train_time:121369ms step_avg:88.20ms +step:1377/1680 train_time:121444ms step_avg:88.19ms +step:1378/1680 train_time:121534ms step_avg:88.20ms +step:1379/1680 train_time:121623ms step_avg:88.20ms +step:1380/1680 train_time:121711ms step_avg:88.20ms +step:1381/1680 train_time:121800ms step_avg:88.20ms +step:1382/1680 train_time:121888ms step_avg:88.20ms +step:1383/1680 train_time:121976ms step_avg:88.20ms +step:1384/1680 train_time:122065ms step_avg:88.20ms +step:1385/1680 train_time:122154ms step_avg:88.20ms +step:1386/1680 train_time:122242ms step_avg:88.20ms +step:1387/1680 train_time:122332ms step_avg:88.20ms +step:1388/1680 train_time:122424ms step_avg:88.20ms +step:1389/1680 train_time:122514ms step_avg:88.20ms +step:1390/1680 train_time:122604ms step_avg:88.20ms +step:1391/1680 train_time:122693ms step_avg:88.20ms +step:1392/1680 train_time:122782ms step_avg:88.21ms +step:1393/1680 train_time:122871ms step_avg:88.21ms +step:1394/1680 train_time:122960ms step_avg:88.21ms +step:1395/1680 train_time:123048ms step_avg:88.21ms +step:1396/1680 train_time:123137ms step_avg:88.21ms +step:1397/1680 train_time:123226ms step_avg:88.21ms +step:1398/1680 train_time:123315ms step_avg:88.21ms +step:1399/1680 train_time:123405ms step_avg:88.21ms +step:1400/1680 train_time:123495ms step_avg:88.21ms +step:1401/1680 train_time:123585ms step_avg:88.21ms +step:1402/1680 train_time:123674ms step_avg:88.21ms +step:1403/1680 train_time:123765ms step_avg:88.21ms +step:1404/1680 train_time:123854ms step_avg:88.21ms +step:1405/1680 train_time:123942ms step_avg:88.22ms +step:1406/1680 train_time:124031ms step_avg:88.22ms +step:1407/1680 train_time:124119ms step_avg:88.22ms +step:1408/1680 train_time:124208ms step_avg:88.22ms +step:1409/1680 train_time:124298ms step_avg:88.22ms +step:1410/1680 train_time:124387ms step_avg:88.22ms +step:1411/1680 train_time:124477ms step_avg:88.22ms +step:1412/1680 train_time:124567ms step_avg:88.22ms +step:1413/1680 train_time:124657ms step_avg:88.22ms +step:1414/1680 train_time:124747ms step_avg:88.22ms +step:1415/1680 train_time:124836ms step_avg:88.22ms +step:1416/1680 train_time:124925ms step_avg:88.22ms +step:1417/1680 train_time:125014ms step_avg:88.22ms +step:1418/1680 train_time:125103ms step_avg:88.22ms +step:1419/1680 train_time:125191ms step_avg:88.22ms +step:1420/1680 train_time:125280ms step_avg:88.23ms +step:1421/1680 train_time:125370ms step_avg:88.23ms +step:1422/1680 train_time:125460ms step_avg:88.23ms +step:1423/1680 train_time:125549ms step_avg:88.23ms +step:1424/1680 train_time:125638ms step_avg:88.23ms +step:1425/1680 train_time:125728ms step_avg:88.23ms +step:1426/1680 train_time:125817ms step_avg:88.23ms +step:1427/1680 train_time:125906ms step_avg:88.23ms +step:1428/1680 train_time:125995ms step_avg:88.23ms +step:1429/1680 train_time:126085ms step_avg:88.23ms +step:1430/1680 train_time:126174ms step_avg:88.23ms +step:1431/1680 train_time:126262ms step_avg:88.23ms +step:1432/1680 train_time:126351ms step_avg:88.23ms +step:1433/1680 train_time:126441ms step_avg:88.24ms +step:1434/1680 train_time:126531ms step_avg:88.24ms +step:1435/1680 train_time:126620ms step_avg:88.24ms +step:1436/1680 train_time:126709ms step_avg:88.24ms +step:1437/1680 train_time:126799ms step_avg:88.24ms +step:1438/1680 train_time:126887ms step_avg:88.24ms +step:1439/1680 train_time:126976ms step_avg:88.24ms +step:1440/1680 train_time:127065ms step_avg:88.24ms +step:1441/1680 train_time:127155ms step_avg:88.24ms +step:1442/1680 train_time:127245ms step_avg:88.24ms +step:1443/1680 train_time:127333ms step_avg:88.24ms +step:1444/1680 train_time:127423ms step_avg:88.24ms +step:1445/1680 train_time:127513ms step_avg:88.24ms +step:1446/1680 train_time:127602ms step_avg:88.24ms +step:1447/1680 train_time:127692ms step_avg:88.25ms +step:1448/1680 train_time:127781ms step_avg:88.25ms +step:1449/1680 train_time:127871ms step_avg:88.25ms +step:1450/1680 train_time:127960ms step_avg:88.25ms +step:1451/1680 train_time:128049ms step_avg:88.25ms +step:1452/1680 train_time:128138ms step_avg:88.25ms +step:1453/1680 train_time:128227ms step_avg:88.25ms +step:1454/1680 train_time:128317ms step_avg:88.25ms +step:1455/1680 train_time:128406ms step_avg:88.25ms +step:1456/1680 train_time:128497ms step_avg:88.25ms +step:1457/1680 train_time:128586ms step_avg:88.25ms +step:1458/1680 train_time:128674ms step_avg:88.25ms +step:1459/1680 train_time:128763ms step_avg:88.25ms +step:1460/1680 train_time:128853ms step_avg:88.26ms +step:1461/1680 train_time:128944ms step_avg:88.26ms +step:1462/1680 train_time:129033ms step_avg:88.26ms +step:1463/1680 train_time:129121ms step_avg:88.26ms +step:1464/1680 train_time:129211ms step_avg:88.26ms +step:1465/1680 train_time:129300ms step_avg:88.26ms +step:1466/1680 train_time:129389ms step_avg:88.26ms +step:1467/1680 train_time:129479ms step_avg:88.26ms +step:1468/1680 train_time:129568ms step_avg:88.26ms +step:1469/1680 train_time:129658ms step_avg:88.26ms +step:1470/1680 train_time:129746ms step_avg:88.26ms +step:1471/1680 train_time:129835ms step_avg:88.26ms +step:1472/1680 train_time:129924ms step_avg:88.26ms +step:1473/1680 train_time:130013ms step_avg:88.26ms +step:1474/1680 train_time:130102ms step_avg:88.26ms +step:1475/1680 train_time:130191ms step_avg:88.27ms +step:1476/1680 train_time:130280ms step_avg:88.27ms +step:1477/1680 train_time:130370ms step_avg:88.27ms +step:1478/1680 train_time:130460ms step_avg:88.27ms +step:1479/1680 train_time:130549ms step_avg:88.27ms +step:1480/1680 train_time:130638ms step_avg:88.27ms +step:1481/1680 train_time:130728ms step_avg:88.27ms +step:1482/1680 train_time:130818ms step_avg:88.27ms +step:1483/1680 train_time:130906ms step_avg:88.27ms +step:1484/1680 train_time:130996ms step_avg:88.27ms +step:1485/1680 train_time:131085ms step_avg:88.27ms +step:1486/1680 train_time:131174ms step_avg:88.27ms +step:1487/1680 train_time:131262ms step_avg:88.27ms +step:1488/1680 train_time:131352ms step_avg:88.27ms +step:1489/1680 train_time:131441ms step_avg:88.27ms +step:1490/1680 train_time:131531ms step_avg:88.28ms +step:1491/1680 train_time:131620ms step_avg:88.28ms +step:1492/1680 train_time:131709ms step_avg:88.28ms +step:1493/1680 train_time:131800ms step_avg:88.28ms +step:1494/1680 train_time:131889ms step_avg:88.28ms +step:1495/1680 train_time:131978ms step_avg:88.28ms +step:1496/1680 train_time:132067ms step_avg:88.28ms +step:1497/1680 train_time:132157ms step_avg:88.28ms +step:1498/1680 train_time:132246ms step_avg:88.28ms +step:1499/1680 train_time:132335ms step_avg:88.28ms +step:1500/1680 train_time:132424ms step_avg:88.28ms +step:1500/1680 val_loss:3.3125 train_time:132514ms step_avg:88.34ms +step:1501/1680 train_time:132534ms step_avg:88.30ms +step:1502/1680 train_time:132606ms step_avg:88.29ms +step:1503/1680 train_time:132699ms step_avg:88.29ms +step:1504/1680 train_time:132788ms step_avg:88.29ms +step:1505/1680 train_time:132877ms step_avg:88.29ms +step:1506/1680 train_time:132965ms step_avg:88.29ms +step:1507/1680 train_time:133053ms step_avg:88.29ms +step:1508/1680 train_time:133141ms step_avg:88.29ms +step:1509/1680 train_time:133229ms step_avg:88.29ms +step:1510/1680 train_time:133318ms step_avg:88.29ms +step:1511/1680 train_time:133406ms step_avg:88.29ms +step:1512/1680 train_time:133497ms step_avg:88.29ms +step:1513/1680 train_time:133589ms step_avg:88.29ms +step:1514/1680 train_time:133681ms step_avg:88.30ms +step:1515/1680 train_time:133771ms step_avg:88.30ms +step:1516/1680 train_time:133860ms step_avg:88.30ms +step:1517/1680 train_time:133949ms step_avg:88.30ms +step:1518/1680 train_time:134037ms step_avg:88.30ms +step:1519/1680 train_time:134125ms step_avg:88.30ms +step:1520/1680 train_time:134213ms step_avg:88.30ms +step:1521/1680 train_time:134302ms step_avg:88.30ms +step:1522/1680 train_time:134391ms step_avg:88.30ms +step:1523/1680 train_time:134480ms step_avg:88.30ms +step:1524/1680 train_time:134571ms step_avg:88.30ms +step:1525/1680 train_time:134661ms step_avg:88.30ms +step:1526/1680 train_time:134751ms step_avg:88.30ms +step:1527/1680 train_time:134841ms step_avg:88.30ms +step:1528/1680 train_time:134931ms step_avg:88.31ms +step:1529/1680 train_time:135020ms step_avg:88.31ms +step:1530/1680 train_time:135108ms step_avg:88.31ms +step:1531/1680 train_time:135196ms step_avg:88.31ms +step:1532/1680 train_time:135284ms step_avg:88.31ms +step:1533/1680 train_time:135373ms step_avg:88.31ms +step:1534/1680 train_time:135462ms step_avg:88.31ms +step:1535/1680 train_time:135552ms step_avg:88.31ms +step:1536/1680 train_time:135642ms step_avg:88.31ms +step:1537/1680 train_time:135732ms step_avg:88.31ms +step:1538/1680 train_time:135823ms step_avg:88.31ms +step:1539/1680 train_time:135912ms step_avg:88.31ms +step:1540/1680 train_time:136002ms step_avg:88.31ms +step:1541/1680 train_time:136090ms step_avg:88.31ms +step:1542/1680 train_time:136180ms step_avg:88.31ms +step:1543/1680 train_time:136269ms step_avg:88.31ms +step:1544/1680 train_time:136357ms step_avg:88.31ms +step:1545/1680 train_time:136446ms step_avg:88.31ms +step:1546/1680 train_time:136535ms step_avg:88.32ms +step:1547/1680 train_time:136625ms step_avg:88.32ms +step:1548/1680 train_time:136714ms step_avg:88.32ms +step:1549/1680 train_time:136804ms step_avg:88.32ms +step:1550/1680 train_time:136893ms step_avg:88.32ms +step:1551/1680 train_time:136982ms step_avg:88.32ms +step:1552/1680 train_time:137071ms step_avg:88.32ms +step:1553/1680 train_time:137160ms step_avg:88.32ms +step:1554/1680 train_time:137249ms step_avg:88.32ms +step:1555/1680 train_time:137338ms step_avg:88.32ms +step:1556/1680 train_time:137427ms step_avg:88.32ms +step:1557/1680 train_time:137517ms step_avg:88.32ms +step:1558/1680 train_time:137606ms step_avg:88.32ms +step:1559/1680 train_time:137696ms step_avg:88.32ms +step:1560/1680 train_time:137786ms step_avg:88.32ms +step:1561/1680 train_time:137875ms step_avg:88.32ms +step:1562/1680 train_time:137964ms step_avg:88.33ms +step:1563/1680 train_time:138053ms step_avg:88.33ms +step:1564/1680 train_time:138142ms step_avg:88.33ms +step:1565/1680 train_time:138231ms step_avg:88.33ms +step:1566/1680 train_time:138321ms step_avg:88.33ms +step:1567/1680 train_time:138410ms step_avg:88.33ms +step:1568/1680 train_time:138499ms step_avg:88.33ms +step:1569/1680 train_time:138588ms step_avg:88.33ms +step:1570/1680 train_time:138678ms step_avg:88.33ms +step:1571/1680 train_time:138767ms step_avg:88.33ms +step:1572/1680 train_time:138856ms step_avg:88.33ms +step:1573/1680 train_time:138945ms step_avg:88.33ms +step:1574/1680 train_time:139034ms step_avg:88.33ms +step:1575/1680 train_time:139122ms step_avg:88.33ms +step:1576/1680 train_time:139211ms step_avg:88.33ms +step:1577/1680 train_time:139299ms step_avg:88.33ms +step:1578/1680 train_time:139389ms step_avg:88.33ms +step:1579/1680 train_time:139479ms step_avg:88.33ms +step:1580/1680 train_time:139569ms step_avg:88.33ms +step:1581/1680 train_time:139658ms step_avg:88.34ms +step:1582/1680 train_time:139747ms step_avg:88.34ms +step:1583/1680 train_time:139836ms step_avg:88.34ms +step:1584/1680 train_time:139926ms step_avg:88.34ms +step:1585/1680 train_time:140016ms step_avg:88.34ms +step:1586/1680 train_time:140105ms step_avg:88.34ms +step:1587/1680 train_time:140195ms step_avg:88.34ms +step:1588/1680 train_time:140283ms step_avg:88.34ms +step:1589/1680 train_time:140372ms step_avg:88.34ms +step:1590/1680 train_time:140461ms step_avg:88.34ms +step:1591/1680 train_time:140550ms step_avg:88.34ms +step:1592/1680 train_time:140639ms step_avg:88.34ms +step:1593/1680 train_time:140728ms step_avg:88.34ms +step:1594/1680 train_time:140817ms step_avg:88.34ms +step:1595/1680 train_time:140906ms step_avg:88.34ms +step:1596/1680 train_time:140995ms step_avg:88.34ms +step:1597/1680 train_time:141085ms step_avg:88.34ms +step:1598/1680 train_time:141174ms step_avg:88.34ms +step:1599/1680 train_time:141263ms step_avg:88.34ms +step:1600/1680 train_time:141351ms step_avg:88.34ms +step:1601/1680 train_time:141440ms step_avg:88.34ms +step:1602/1680 train_time:141530ms step_avg:88.35ms +step:1603/1680 train_time:141619ms step_avg:88.35ms +step:1604/1680 train_time:141708ms step_avg:88.35ms +step:1605/1680 train_time:141798ms step_avg:88.35ms +step:1606/1680 train_time:141887ms step_avg:88.35ms +step:1607/1680 train_time:141975ms step_avg:88.35ms +step:1608/1680 train_time:142064ms step_avg:88.35ms +step:1609/1680 train_time:142153ms step_avg:88.35ms +step:1610/1680 train_time:142243ms step_avg:88.35ms +step:1611/1680 train_time:142332ms step_avg:88.35ms +step:1612/1680 train_time:142421ms step_avg:88.35ms +step:1613/1680 train_time:142510ms step_avg:88.35ms +step:1614/1680 train_time:142600ms step_avg:88.35ms +step:1615/1680 train_time:142689ms step_avg:88.35ms +step:1616/1680 train_time:142778ms step_avg:88.35ms +step:1617/1680 train_time:142868ms step_avg:88.35ms +step:1618/1680 train_time:142957ms step_avg:88.35ms +step:1619/1680 train_time:143046ms step_avg:88.35ms +step:1620/1680 train_time:143136ms step_avg:88.36ms +step:1621/1680 train_time:143226ms step_avg:88.36ms +step:1622/1680 train_time:143315ms step_avg:88.36ms +step:1623/1680 train_time:143405ms step_avg:88.36ms +step:1624/1680 train_time:143494ms step_avg:88.36ms +step:1625/1680 train_time:143583ms step_avg:88.36ms +step:1625/1680 val_loss:3.2891 train_time:143674ms step_avg:88.41ms +step:1626/1680 train_time:143692ms step_avg:88.37ms +step:1627/1680 train_time:143768ms step_avg:88.36ms +step:1628/1680 train_time:143862ms step_avg:88.37ms +step:1629/1680 train_time:143951ms step_avg:88.37ms +step:1630/1680 train_time:144039ms step_avg:88.37ms +step:1631/1680 train_time:144127ms step_avg:88.37ms +step:1632/1680 train_time:144215ms step_avg:88.37ms +step:1633/1680 train_time:144304ms step_avg:88.37ms +step:1634/1680 train_time:144392ms step_avg:88.37ms +step:1635/1680 train_time:144480ms step_avg:88.37ms +step:1636/1680 train_time:144571ms step_avg:88.37ms +step:1637/1680 train_time:144661ms step_avg:88.37ms +step:1638/1680 train_time:144752ms step_avg:88.37ms +step:1639/1680 train_time:144845ms step_avg:88.37ms +step:1640/1680 train_time:144934ms step_avg:88.37ms +step:1641/1680 train_time:145023ms step_avg:88.37ms +step:1642/1680 train_time:145111ms step_avg:88.37ms +step:1643/1680 train_time:145200ms step_avg:88.37ms +step:1644/1680 train_time:145288ms step_avg:88.37ms +step:1645/1680 train_time:145377ms step_avg:88.37ms +step:1646/1680 train_time:145466ms step_avg:88.38ms +step:1647/1680 train_time:145554ms step_avg:88.38ms +step:1648/1680 train_time:145645ms step_avg:88.38ms +step:1649/1680 train_time:145736ms step_avg:88.38ms +step:1650/1680 train_time:145827ms step_avg:88.38ms +step:1651/1680 train_time:145916ms step_avg:88.38ms +step:1652/1680 train_time:146006ms step_avg:88.38ms +step:1653/1680 train_time:146095ms step_avg:88.38ms +step:1654/1680 train_time:146184ms step_avg:88.38ms +step:1655/1680 train_time:146273ms step_avg:88.38ms +step:1656/1680 train_time:146361ms step_avg:88.38ms +step:1657/1680 train_time:146451ms step_avg:88.38ms +step:1658/1680 train_time:146539ms step_avg:88.38ms +step:1659/1680 train_time:146629ms step_avg:88.38ms +step:1660/1680 train_time:146719ms step_avg:88.38ms +step:1661/1680 train_time:146809ms step_avg:88.39ms +step:1662/1680 train_time:146899ms step_avg:88.39ms +step:1663/1680 train_time:146988ms step_avg:88.39ms +step:1664/1680 train_time:147077ms step_avg:88.39ms +step:1665/1680 train_time:147166ms step_avg:88.39ms +step:1666/1680 train_time:147255ms step_avg:88.39ms +step:1667/1680 train_time:147344ms step_avg:88.39ms +step:1668/1680 train_time:147432ms step_avg:88.39ms +step:1669/1680 train_time:147521ms step_avg:88.39ms +step:1670/1680 train_time:147610ms step_avg:88.39ms +step:1671/1680 train_time:147699ms step_avg:88.39ms +step:1672/1680 train_time:147789ms step_avg:88.39ms +step:1673/1680 train_time:147878ms step_avg:88.39ms +step:1674/1680 train_time:147968ms step_avg:88.39ms +step:1675/1680 train_time:148056ms step_avg:88.39ms +step:1676/1680 train_time:148146ms step_avg:88.39ms +step:1677/1680 train_time:148236ms step_avg:88.39ms +step:1678/1680 train_time:148325ms step_avg:88.39ms +step:1679/1680 train_time:148413ms step_avg:88.39ms +step:1680/1680 train_time:148502ms step_avg:88.39ms +step:1680/1680 val_loss:3.2782 train_time:148593ms step_avg:88.45ms +peak memory allocated: 30760 MiB reserved: 46254 MiB diff --git a/records/092725_BF16CE/f26e4a90-074c-4ed4-b3e3-ce69223863c4.txt b/records/092725_BF16CE/f26e4a90-074c-4ed4-b3e3-ce69223863c4.txt new file mode 100644 index 000000000..0d905b366 --- /dev/null +++ b/records/092725_BF16CE/f26e4a90-074c-4ed4-b3e3-ce69223863c4.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:20:00 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 5856MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 26C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 172549 C /usr/bin/python 0MiB | +| 0 N/A N/A 172550 C /usr/bin/python 0MiB | +| 0 N/A N/A 172551 C /usr/bin/python 0MiB | +| 0 N/A N/A 172552 C /usr/bin/python 0MiB | +| 0 N/A N/A 172553 C /usr/bin/python 0MiB | +| 0 N/A N/A 172554 C /usr/bin/python 0MiB | +| 0 N/A N/A 172555 C /usr/bin/python 0MiB | +| 0 N/A N/A 172556 C /usr/bin/python 0MiB | +| 1 N/A N/A 172550 C /usr/bin/python 0MiB | +| 2 N/A N/A 172551 C /usr/bin/python 0MiB | +| 3 N/A N/A 172552 C /usr/bin/python 0MiB | +| 4 N/A N/A 172553 C /usr/bin/python 0MiB | +| 5 N/A N/A 172554 C /usr/bin/python 0MiB | +| 6 N/A N/A 172555 C /usr/bin/python 0MiB | +| 7 N/A N/A 172556 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:141ms step_avg:140.53ms +step:2/1680 train_time:161ms step_avg:80.26ms +step:3/1680 train_time:226ms step_avg:75.22ms +step:4/1680 train_time:311ms step_avg:77.70ms +step:5/1680 train_time:397ms step_avg:79.31ms +step:6/1680 train_time:483ms step_avg:80.44ms +step:7/1680 train_time:569ms step_avg:81.28ms +step:8/1680 train_time:655ms step_avg:81.86ms +step:9/1680 train_time:741ms step_avg:82.36ms +step:10/1680 train_time:828ms step_avg:82.79ms +step:11/1680 train_time:914ms step_avg:83.14ms +step:12/1680 train_time:1003ms step_avg:83.55ms +step:13/1680 train_time:1094ms step_avg:84.19ms +step:14/1680 train_time:1184ms step_avg:84.60ms +step:15/1680 train_time:1272ms step_avg:84.80ms +step:16/1680 train_time:1359ms step_avg:84.93ms +step:17/1680 train_time:1446ms step_avg:85.06ms +step:18/1680 train_time:1533ms step_avg:85.15ms +step:19/1680 train_time:1619ms step_avg:85.22ms +step:20/1680 train_time:1706ms step_avg:85.28ms +step:21/1680 train_time:1792ms step_avg:85.34ms +step:22/1680 train_time:1879ms step_avg:85.39ms +step:23/1680 train_time:1966ms step_avg:85.49ms +step:24/1680 train_time:2054ms step_avg:85.60ms +step:25/1680 train_time:2143ms step_avg:85.71ms +step:26/1680 train_time:2231ms step_avg:85.83ms +step:27/1680 train_time:2319ms step_avg:85.90ms +step:28/1680 train_time:2407ms step_avg:85.96ms +step:29/1680 train_time:2494ms step_avg:86.01ms +step:30/1680 train_time:2582ms step_avg:86.06ms +step:31/1680 train_time:2668ms step_avg:86.08ms +step:32/1680 train_time:2755ms step_avg:86.09ms +step:33/1680 train_time:2841ms step_avg:86.10ms +step:34/1680 train_time:2928ms step_avg:86.13ms +step:35/1680 train_time:3016ms step_avg:86.16ms +step:36/1680 train_time:3104ms step_avg:86.21ms +step:37/1680 train_time:3193ms step_avg:86.30ms +step:38/1680 train_time:3281ms step_avg:86.35ms +step:39/1680 train_time:3370ms step_avg:86.40ms +step:40/1680 train_time:3457ms step_avg:86.41ms +step:41/1680 train_time:3544ms step_avg:86.43ms +step:42/1680 train_time:3631ms step_avg:86.45ms +step:43/1680 train_time:3718ms step_avg:86.46ms +step:44/1680 train_time:3805ms step_avg:86.47ms +step:45/1680 train_time:3893ms step_avg:86.50ms +step:46/1680 train_time:3980ms step_avg:86.51ms +step:47/1680 train_time:4067ms step_avg:86.54ms +step:48/1680 train_time:4155ms step_avg:86.56ms +step:49/1680 train_time:4242ms step_avg:86.58ms +step:50/1680 train_time:4330ms step_avg:86.60ms +step:51/1680 train_time:4417ms step_avg:86.60ms +step:52/1680 train_time:4504ms step_avg:86.61ms +step:53/1680 train_time:4592ms step_avg:86.64ms +step:54/1680 train_time:4679ms step_avg:86.64ms +step:55/1680 train_time:4766ms step_avg:86.66ms +step:56/1680 train_time:4853ms step_avg:86.66ms +step:57/1680 train_time:4940ms step_avg:86.67ms +step:58/1680 train_time:5028ms step_avg:86.68ms +step:59/1680 train_time:5115ms step_avg:86.69ms +step:60/1680 train_time:5202ms step_avg:86.71ms +step:61/1680 train_time:5291ms step_avg:86.73ms +step:62/1680 train_time:5378ms step_avg:86.74ms +step:63/1680 train_time:5465ms step_avg:86.75ms +step:64/1680 train_time:5552ms step_avg:86.76ms +step:65/1680 train_time:5640ms step_avg:86.77ms +step:66/1680 train_time:5728ms step_avg:86.78ms +step:67/1680 train_time:5814ms step_avg:86.78ms +step:68/1680 train_time:5902ms step_avg:86.80ms +step:69/1680 train_time:5991ms step_avg:86.82ms +step:70/1680 train_time:6077ms step_avg:86.82ms +step:71/1680 train_time:6165ms step_avg:86.82ms +step:72/1680 train_time:6251ms step_avg:86.82ms +step:73/1680 train_time:6339ms step_avg:86.83ms +step:74/1680 train_time:6426ms step_avg:86.84ms +step:75/1680 train_time:6513ms step_avg:86.85ms +step:76/1680 train_time:6602ms step_avg:86.86ms +step:77/1680 train_time:6689ms step_avg:86.87ms +step:78/1680 train_time:6776ms step_avg:86.87ms +step:79/1680 train_time:6863ms step_avg:86.87ms +step:80/1680 train_time:6950ms step_avg:86.88ms +step:81/1680 train_time:7037ms step_avg:86.87ms +step:82/1680 train_time:7125ms step_avg:86.89ms +step:83/1680 train_time:7213ms step_avg:86.90ms +step:84/1680 train_time:7300ms step_avg:86.90ms +step:85/1680 train_time:7387ms step_avg:86.90ms +step:86/1680 train_time:7474ms step_avg:86.91ms +step:87/1680 train_time:7561ms step_avg:86.91ms +step:88/1680 train_time:7648ms step_avg:86.91ms +step:89/1680 train_time:7735ms step_avg:86.92ms +step:90/1680 train_time:7823ms step_avg:86.92ms +step:91/1680 train_time:7911ms step_avg:86.93ms +step:92/1680 train_time:7997ms step_avg:86.93ms +step:93/1680 train_time:8085ms step_avg:86.93ms +step:94/1680 train_time:8172ms step_avg:86.94ms +step:95/1680 train_time:8259ms step_avg:86.94ms +step:96/1680 train_time:8345ms step_avg:86.93ms +step:97/1680 train_time:8433ms step_avg:86.94ms +step:98/1680 train_time:8520ms step_avg:86.93ms +step:99/1680 train_time:8607ms step_avg:86.94ms +step:100/1680 train_time:8694ms step_avg:86.94ms +step:101/1680 train_time:8780ms step_avg:86.93ms +step:102/1680 train_time:8868ms step_avg:86.94ms +step:103/1680 train_time:8955ms step_avg:86.94ms +step:104/1680 train_time:9042ms step_avg:86.94ms +step:105/1680 train_time:9130ms step_avg:86.95ms +step:106/1680 train_time:9216ms step_avg:86.95ms +step:107/1680 train_time:9304ms step_avg:86.95ms +step:108/1680 train_time:9392ms step_avg:86.96ms +step:109/1680 train_time:9478ms step_avg:86.96ms +step:110/1680 train_time:9566ms step_avg:86.96ms +step:111/1680 train_time:9653ms step_avg:86.96ms +step:112/1680 train_time:9740ms step_avg:86.96ms +step:113/1680 train_time:9828ms step_avg:86.97ms +step:114/1680 train_time:9915ms step_avg:86.97ms +step:115/1680 train_time:10002ms step_avg:86.98ms +step:116/1680 train_time:10089ms step_avg:86.98ms +step:117/1680 train_time:10176ms step_avg:86.98ms +step:118/1680 train_time:10264ms step_avg:86.98ms +step:119/1680 train_time:10351ms step_avg:86.98ms +step:120/1680 train_time:10438ms step_avg:86.98ms +step:121/1680 train_time:10525ms step_avg:86.99ms +step:122/1680 train_time:10612ms step_avg:86.99ms +step:123/1680 train_time:10701ms step_avg:87.00ms +step:124/1680 train_time:10788ms step_avg:87.00ms +step:125/1680 train_time:10875ms step_avg:87.00ms +step:125/1680 val_loss:4.3142 train_time:10964ms step_avg:87.71ms +step:126/1680 train_time:10983ms step_avg:87.17ms +step:127/1680 train_time:11055ms step_avg:87.05ms +step:128/1680 train_time:11152ms step_avg:87.12ms +step:129/1680 train_time:11243ms step_avg:87.16ms +step:130/1680 train_time:11330ms step_avg:87.16ms +step:131/1680 train_time:11416ms step_avg:87.15ms +step:132/1680 train_time:11503ms step_avg:87.14ms +step:133/1680 train_time:11588ms step_avg:87.13ms +step:134/1680 train_time:11674ms step_avg:87.12ms +step:135/1680 train_time:11760ms step_avg:87.11ms +step:136/1680 train_time:11846ms step_avg:87.10ms +step:137/1680 train_time:11932ms step_avg:87.09ms +step:138/1680 train_time:12019ms step_avg:87.10ms +step:139/1680 train_time:12110ms step_avg:87.12ms +step:140/1680 train_time:12198ms step_avg:87.13ms +step:141/1680 train_time:12287ms step_avg:87.14ms +step:142/1680 train_time:12374ms step_avg:87.14ms +step:143/1680 train_time:12461ms step_avg:87.14ms +step:144/1680 train_time:12547ms step_avg:87.13ms +step:145/1680 train_time:12633ms step_avg:87.12ms +step:146/1680 train_time:12719ms step_avg:87.12ms +step:147/1680 train_time:12805ms step_avg:87.11ms +step:148/1680 train_time:12892ms step_avg:87.11ms +step:149/1680 train_time:12979ms step_avg:87.10ms +step:150/1680 train_time:13067ms step_avg:87.11ms +step:151/1680 train_time:13156ms step_avg:87.13ms +step:152/1680 train_time:13245ms step_avg:87.14ms +step:153/1680 train_time:13332ms step_avg:87.14ms +step:154/1680 train_time:13419ms step_avg:87.13ms +step:155/1680 train_time:13505ms step_avg:87.13ms +step:156/1680 train_time:13592ms step_avg:87.13ms +step:157/1680 train_time:13678ms step_avg:87.12ms +step:158/1680 train_time:13765ms step_avg:87.12ms +step:159/1680 train_time:13851ms step_avg:87.12ms +step:160/1680 train_time:13938ms step_avg:87.11ms +step:161/1680 train_time:14025ms step_avg:87.11ms +step:162/1680 train_time:14112ms step_avg:87.11ms +step:163/1680 train_time:14200ms step_avg:87.12ms +step:164/1680 train_time:14289ms step_avg:87.13ms +step:165/1680 train_time:14378ms step_avg:87.14ms +step:166/1680 train_time:14465ms step_avg:87.14ms +step:167/1680 train_time:14551ms step_avg:87.13ms +step:168/1680 train_time:14638ms step_avg:87.13ms +step:169/1680 train_time:14725ms step_avg:87.13ms +step:170/1680 train_time:14811ms step_avg:87.12ms +step:171/1680 train_time:14897ms step_avg:87.12ms +step:172/1680 train_time:14985ms step_avg:87.12ms +step:173/1680 train_time:15072ms step_avg:87.12ms +step:174/1680 train_time:15160ms step_avg:87.12ms +step:175/1680 train_time:15248ms step_avg:87.13ms +step:176/1680 train_time:15336ms step_avg:87.14ms +step:177/1680 train_time:15423ms step_avg:87.14ms +step:178/1680 train_time:15510ms step_avg:87.14ms +step:179/1680 train_time:15597ms step_avg:87.14ms +step:180/1680 train_time:15684ms step_avg:87.13ms +step:181/1680 train_time:15771ms step_avg:87.13ms +step:182/1680 train_time:15858ms step_avg:87.13ms +step:183/1680 train_time:15945ms step_avg:87.13ms +step:184/1680 train_time:16031ms step_avg:87.13ms +step:185/1680 train_time:16118ms step_avg:87.13ms +step:186/1680 train_time:16206ms step_avg:87.13ms +step:187/1680 train_time:16293ms step_avg:87.13ms +step:188/1680 train_time:16380ms step_avg:87.13ms +step:189/1680 train_time:16468ms step_avg:87.13ms +step:190/1680 train_time:16554ms step_avg:87.13ms +step:191/1680 train_time:16642ms step_avg:87.13ms +step:192/1680 train_time:16729ms step_avg:87.13ms +step:193/1680 train_time:16815ms step_avg:87.13ms +step:194/1680 train_time:16903ms step_avg:87.13ms +step:195/1680 train_time:16990ms step_avg:87.13ms +step:196/1680 train_time:17076ms step_avg:87.12ms +step:197/1680 train_time:17164ms step_avg:87.13ms +step:198/1680 train_time:17252ms step_avg:87.13ms +step:199/1680 train_time:17338ms step_avg:87.13ms +step:200/1680 train_time:17426ms step_avg:87.13ms +step:201/1680 train_time:17512ms step_avg:87.13ms +step:202/1680 train_time:17600ms step_avg:87.13ms +step:203/1680 train_time:17687ms step_avg:87.13ms +step:204/1680 train_time:17774ms step_avg:87.13ms +step:205/1680 train_time:17861ms step_avg:87.13ms +step:206/1680 train_time:17949ms step_avg:87.13ms +step:207/1680 train_time:18036ms step_avg:87.13ms +step:208/1680 train_time:18123ms step_avg:87.13ms +step:209/1680 train_time:18211ms step_avg:87.13ms +step:210/1680 train_time:18299ms step_avg:87.14ms +step:211/1680 train_time:18385ms step_avg:87.13ms +step:212/1680 train_time:18473ms step_avg:87.14ms +step:213/1680 train_time:18561ms step_avg:87.14ms +step:214/1680 train_time:18647ms step_avg:87.14ms +step:215/1680 train_time:18734ms step_avg:87.13ms +step:216/1680 train_time:18821ms step_avg:87.13ms +step:217/1680 train_time:18908ms step_avg:87.13ms +step:218/1680 train_time:18995ms step_avg:87.13ms +step:219/1680 train_time:19082ms step_avg:87.13ms +step:220/1680 train_time:19169ms step_avg:87.13ms +step:221/1680 train_time:19256ms step_avg:87.13ms +step:222/1680 train_time:19343ms step_avg:87.13ms +step:223/1680 train_time:19430ms step_avg:87.13ms +step:224/1680 train_time:19517ms step_avg:87.13ms +step:225/1680 train_time:19604ms step_avg:87.13ms +step:226/1680 train_time:19691ms step_avg:87.13ms +step:227/1680 train_time:19778ms step_avg:87.13ms +step:228/1680 train_time:19865ms step_avg:87.13ms +step:229/1680 train_time:19952ms step_avg:87.13ms +step:230/1680 train_time:20038ms step_avg:87.12ms +step:231/1680 train_time:20125ms step_avg:87.12ms +step:232/1680 train_time:20212ms step_avg:87.12ms +step:233/1680 train_time:20299ms step_avg:87.12ms +step:234/1680 train_time:20387ms step_avg:87.12ms +step:235/1680 train_time:20474ms step_avg:87.12ms +step:236/1680 train_time:20561ms step_avg:87.12ms +step:237/1680 train_time:20648ms step_avg:87.12ms +step:238/1680 train_time:20735ms step_avg:87.12ms +step:239/1680 train_time:20823ms step_avg:87.12ms +step:240/1680 train_time:20910ms step_avg:87.12ms +step:241/1680 train_time:20997ms step_avg:87.12ms +step:242/1680 train_time:21084ms step_avg:87.12ms +step:243/1680 train_time:21171ms step_avg:87.12ms +step:244/1680 train_time:21258ms step_avg:87.12ms +step:245/1680 train_time:21346ms step_avg:87.12ms +step:246/1680 train_time:21432ms step_avg:87.12ms +step:247/1680 train_time:21519ms step_avg:87.12ms +step:248/1680 train_time:21606ms step_avg:87.12ms +step:249/1680 train_time:21694ms step_avg:87.12ms +step:250/1680 train_time:21780ms step_avg:87.12ms +step:250/1680 val_loss:3.9678 train_time:21869ms step_avg:87.48ms +step:251/1680 train_time:21888ms step_avg:87.20ms +step:252/1680 train_time:21958ms step_avg:87.13ms +step:253/1680 train_time:22048ms step_avg:87.15ms +step:254/1680 train_time:22136ms step_avg:87.15ms +step:255/1680 train_time:22224ms step_avg:87.15ms +step:256/1680 train_time:22311ms step_avg:87.15ms +step:257/1680 train_time:22397ms step_avg:87.15ms +step:258/1680 train_time:22483ms step_avg:87.14ms +step:259/1680 train_time:22569ms step_avg:87.14ms +step:260/1680 train_time:22655ms step_avg:87.14ms +step:261/1680 train_time:22741ms step_avg:87.13ms +step:262/1680 train_time:22828ms step_avg:87.13ms +step:263/1680 train_time:22917ms step_avg:87.14ms +step:264/1680 train_time:23007ms step_avg:87.15ms +step:265/1680 train_time:23095ms step_avg:87.15ms +step:266/1680 train_time:23183ms step_avg:87.15ms +step:267/1680 train_time:23270ms step_avg:87.15ms +step:268/1680 train_time:23357ms step_avg:87.15ms +step:269/1680 train_time:23443ms step_avg:87.15ms +step:270/1680 train_time:23530ms step_avg:87.15ms +step:271/1680 train_time:23617ms step_avg:87.15ms +step:272/1680 train_time:23703ms step_avg:87.14ms +step:273/1680 train_time:23790ms step_avg:87.14ms +step:274/1680 train_time:23877ms step_avg:87.14ms +step:275/1680 train_time:23965ms step_avg:87.14ms +step:276/1680 train_time:24053ms step_avg:87.15ms +step:277/1680 train_time:24141ms step_avg:87.15ms +step:278/1680 train_time:24228ms step_avg:87.15ms +step:279/1680 train_time:24315ms step_avg:87.15ms +step:280/1680 train_time:24402ms step_avg:87.15ms +step:281/1680 train_time:24489ms step_avg:87.15ms +step:282/1680 train_time:24575ms step_avg:87.14ms +step:283/1680 train_time:24661ms step_avg:87.14ms +step:284/1680 train_time:24748ms step_avg:87.14ms +step:285/1680 train_time:24835ms step_avg:87.14ms +step:286/1680 train_time:24922ms step_avg:87.14ms +step:287/1680 train_time:25010ms step_avg:87.14ms +step:288/1680 train_time:25097ms step_avg:87.14ms +step:289/1680 train_time:25184ms step_avg:87.14ms +step:290/1680 train_time:25272ms step_avg:87.14ms +step:291/1680 train_time:25359ms step_avg:87.14ms +step:292/1680 train_time:25446ms step_avg:87.14ms +step:293/1680 train_time:25533ms step_avg:87.15ms +step:294/1680 train_time:25621ms step_avg:87.14ms +step:295/1680 train_time:25707ms step_avg:87.14ms +step:296/1680 train_time:25795ms step_avg:87.14ms +step:297/1680 train_time:25881ms step_avg:87.14ms +step:298/1680 train_time:25968ms step_avg:87.14ms +step:299/1680 train_time:26056ms step_avg:87.15ms +step:300/1680 train_time:26144ms step_avg:87.15ms +step:301/1680 train_time:26232ms step_avg:87.15ms +step:302/1680 train_time:26319ms step_avg:87.15ms +step:303/1680 train_time:26406ms step_avg:87.15ms +step:304/1680 train_time:26493ms step_avg:87.15ms +step:305/1680 train_time:26579ms step_avg:87.14ms +step:306/1680 train_time:26666ms step_avg:87.14ms +step:307/1680 train_time:26753ms step_avg:87.14ms +step:308/1680 train_time:26840ms step_avg:87.14ms +step:309/1680 train_time:26927ms step_avg:87.14ms +step:310/1680 train_time:27015ms step_avg:87.14ms +step:311/1680 train_time:27101ms step_avg:87.14ms +step:312/1680 train_time:27188ms step_avg:87.14ms +step:313/1680 train_time:27277ms step_avg:87.15ms +step:314/1680 train_time:27364ms step_avg:87.15ms +step:315/1680 train_time:27451ms step_avg:87.15ms +step:316/1680 train_time:27538ms step_avg:87.15ms +step:317/1680 train_time:27625ms step_avg:87.15ms +step:318/1680 train_time:27712ms step_avg:87.14ms +step:319/1680 train_time:27799ms step_avg:87.14ms +step:320/1680 train_time:27886ms step_avg:87.14ms +step:321/1680 train_time:27974ms step_avg:87.15ms +step:322/1680 train_time:28060ms step_avg:87.14ms +step:323/1680 train_time:28147ms step_avg:87.14ms +step:324/1680 train_time:28234ms step_avg:87.14ms +step:325/1680 train_time:28322ms step_avg:87.14ms +step:326/1680 train_time:28409ms step_avg:87.14ms +step:327/1680 train_time:28496ms step_avg:87.14ms +step:328/1680 train_time:28582ms step_avg:87.14ms +step:329/1680 train_time:28670ms step_avg:87.14ms +step:330/1680 train_time:28757ms step_avg:87.14ms +step:331/1680 train_time:28844ms step_avg:87.14ms +step:332/1680 train_time:28930ms step_avg:87.14ms +step:333/1680 train_time:29017ms step_avg:87.14ms +step:334/1680 train_time:29104ms step_avg:87.14ms +step:335/1680 train_time:29191ms step_avg:87.14ms +step:336/1680 train_time:29278ms step_avg:87.14ms +step:337/1680 train_time:29365ms step_avg:87.14ms +step:338/1680 train_time:29452ms step_avg:87.14ms +step:339/1680 train_time:29539ms step_avg:87.14ms +step:340/1680 train_time:29626ms step_avg:87.14ms +step:341/1680 train_time:29714ms step_avg:87.14ms +step:342/1680 train_time:29800ms step_avg:87.14ms +step:343/1680 train_time:29888ms step_avg:87.14ms +step:344/1680 train_time:29975ms step_avg:87.14ms +step:345/1680 train_time:30062ms step_avg:87.14ms +step:346/1680 train_time:30149ms step_avg:87.14ms +step:347/1680 train_time:30237ms step_avg:87.14ms +step:348/1680 train_time:30324ms step_avg:87.14ms +step:349/1680 train_time:30412ms step_avg:87.14ms +step:350/1680 train_time:30499ms step_avg:87.14ms +step:351/1680 train_time:30585ms step_avg:87.14ms +step:352/1680 train_time:30672ms step_avg:87.14ms +step:353/1680 train_time:30759ms step_avg:87.14ms +step:354/1680 train_time:30846ms step_avg:87.13ms +step:355/1680 train_time:30932ms step_avg:87.13ms +step:356/1680 train_time:31019ms step_avg:87.13ms +step:357/1680 train_time:31106ms step_avg:87.13ms +step:358/1680 train_time:31193ms step_avg:87.13ms +step:359/1680 train_time:31280ms step_avg:87.13ms +step:360/1680 train_time:31367ms step_avg:87.13ms +step:361/1680 train_time:31455ms step_avg:87.13ms +step:362/1680 train_time:31542ms step_avg:87.13ms +step:363/1680 train_time:31630ms step_avg:87.13ms +step:364/1680 train_time:31718ms step_avg:87.14ms +step:365/1680 train_time:31805ms step_avg:87.14ms +step:366/1680 train_time:31891ms step_avg:87.13ms +step:367/1680 train_time:31978ms step_avg:87.13ms +step:368/1680 train_time:32065ms step_avg:87.13ms +step:369/1680 train_time:32152ms step_avg:87.13ms +step:370/1680 train_time:32240ms step_avg:87.14ms +step:371/1680 train_time:32327ms step_avg:87.13ms +step:372/1680 train_time:32414ms step_avg:87.13ms +step:373/1680 train_time:32500ms step_avg:87.13ms +step:374/1680 train_time:32587ms step_avg:87.13ms +step:375/1680 train_time:32675ms step_avg:87.13ms +step:375/1680 val_loss:3.8154 train_time:32764ms step_avg:87.37ms +step:376/1680 train_time:32783ms step_avg:87.19ms +step:377/1680 train_time:32852ms step_avg:87.14ms +step:378/1680 train_time:32944ms step_avg:87.15ms +step:379/1680 train_time:33033ms step_avg:87.16ms +step:380/1680 train_time:33120ms step_avg:87.16ms +step:381/1680 train_time:33206ms step_avg:87.15ms +step:382/1680 train_time:33292ms step_avg:87.15ms +step:383/1680 train_time:33378ms step_avg:87.15ms +step:384/1680 train_time:33464ms step_avg:87.15ms +step:385/1680 train_time:33551ms step_avg:87.14ms +step:386/1680 train_time:33636ms step_avg:87.14ms +step:387/1680 train_time:33724ms step_avg:87.14ms +step:388/1680 train_time:33812ms step_avg:87.14ms +step:389/1680 train_time:33901ms step_avg:87.15ms +step:390/1680 train_time:33989ms step_avg:87.15ms +step:391/1680 train_time:34077ms step_avg:87.15ms +step:392/1680 train_time:34164ms step_avg:87.15ms +step:393/1680 train_time:34251ms step_avg:87.15ms +step:394/1680 train_time:34337ms step_avg:87.15ms +step:395/1680 train_time:34423ms step_avg:87.15ms +step:396/1680 train_time:34509ms step_avg:87.14ms +step:397/1680 train_time:34595ms step_avg:87.14ms +step:398/1680 train_time:34682ms step_avg:87.14ms +step:399/1680 train_time:34769ms step_avg:87.14ms +step:400/1680 train_time:34857ms step_avg:87.14ms +step:401/1680 train_time:34945ms step_avg:87.15ms +step:402/1680 train_time:35033ms step_avg:87.15ms +step:403/1680 train_time:35121ms step_avg:87.15ms +step:404/1680 train_time:35208ms step_avg:87.15ms +step:405/1680 train_time:35294ms step_avg:87.15ms +step:406/1680 train_time:35381ms step_avg:87.15ms +step:407/1680 train_time:35467ms step_avg:87.14ms +step:408/1680 train_time:35554ms step_avg:87.14ms +step:409/1680 train_time:35640ms step_avg:87.14ms +step:410/1680 train_time:35727ms step_avg:87.14ms +step:411/1680 train_time:35815ms step_avg:87.14ms +step:412/1680 train_time:35903ms step_avg:87.14ms +step:413/1680 train_time:35990ms step_avg:87.14ms +step:414/1680 train_time:36079ms step_avg:87.15ms +step:415/1680 train_time:36166ms step_avg:87.15ms +step:416/1680 train_time:36253ms step_avg:87.15ms +step:417/1680 train_time:36340ms step_avg:87.15ms +step:418/1680 train_time:36427ms step_avg:87.15ms +step:419/1680 train_time:36514ms step_avg:87.14ms +step:420/1680 train_time:36600ms step_avg:87.14ms +step:421/1680 train_time:36687ms step_avg:87.14ms +step:422/1680 train_time:36774ms step_avg:87.14ms +step:423/1680 train_time:36861ms step_avg:87.14ms +step:424/1680 train_time:36949ms step_avg:87.14ms +step:425/1680 train_time:37036ms step_avg:87.14ms +step:426/1680 train_time:37123ms step_avg:87.14ms +step:427/1680 train_time:37211ms step_avg:87.15ms +step:428/1680 train_time:37298ms step_avg:87.15ms +step:429/1680 train_time:37385ms step_avg:87.14ms +step:430/1680 train_time:37472ms step_avg:87.14ms +step:431/1680 train_time:37559ms step_avg:87.14ms +step:432/1680 train_time:37645ms step_avg:87.14ms +step:433/1680 train_time:37732ms step_avg:87.14ms +step:434/1680 train_time:37819ms step_avg:87.14ms +step:435/1680 train_time:37906ms step_avg:87.14ms +step:436/1680 train_time:37994ms step_avg:87.14ms +step:437/1680 train_time:38081ms step_avg:87.14ms +step:438/1680 train_time:38168ms step_avg:87.14ms +step:439/1680 train_time:38255ms step_avg:87.14ms +step:440/1680 train_time:38342ms step_avg:87.14ms +step:441/1680 train_time:38429ms step_avg:87.14ms +step:442/1680 train_time:38517ms step_avg:87.14ms +step:443/1680 train_time:38603ms step_avg:87.14ms +step:444/1680 train_time:38690ms step_avg:87.14ms +step:445/1680 train_time:38778ms step_avg:87.14ms +step:446/1680 train_time:38865ms step_avg:87.14ms +step:447/1680 train_time:38953ms step_avg:87.14ms +step:448/1680 train_time:39040ms step_avg:87.14ms +step:449/1680 train_time:39127ms step_avg:87.14ms +step:450/1680 train_time:39215ms step_avg:87.14ms +step:451/1680 train_time:39302ms step_avg:87.14ms +step:452/1680 train_time:39389ms step_avg:87.14ms +step:453/1680 train_time:39477ms step_avg:87.15ms +step:454/1680 train_time:39564ms step_avg:87.14ms +step:455/1680 train_time:39651ms step_avg:87.14ms +step:456/1680 train_time:39737ms step_avg:87.14ms +step:457/1680 train_time:39824ms step_avg:87.14ms +step:458/1680 train_time:39911ms step_avg:87.14ms +step:459/1680 train_time:39998ms step_avg:87.14ms +step:460/1680 train_time:40085ms step_avg:87.14ms +step:461/1680 train_time:40172ms step_avg:87.14ms +step:462/1680 train_time:40259ms step_avg:87.14ms +step:463/1680 train_time:40346ms step_avg:87.14ms +step:464/1680 train_time:40433ms step_avg:87.14ms +step:465/1680 train_time:40520ms step_avg:87.14ms +step:466/1680 train_time:40607ms step_avg:87.14ms +step:467/1680 train_time:40694ms step_avg:87.14ms +step:468/1680 train_time:40782ms step_avg:87.14ms +step:469/1680 train_time:40868ms step_avg:87.14ms +step:470/1680 train_time:40955ms step_avg:87.14ms +step:471/1680 train_time:41042ms step_avg:87.14ms +step:472/1680 train_time:41129ms step_avg:87.14ms +step:473/1680 train_time:41217ms step_avg:87.14ms +step:474/1680 train_time:41304ms step_avg:87.14ms +step:475/1680 train_time:41391ms step_avg:87.14ms +step:476/1680 train_time:41478ms step_avg:87.14ms +step:477/1680 train_time:41565ms step_avg:87.14ms +step:478/1680 train_time:41652ms step_avg:87.14ms +step:479/1680 train_time:41739ms step_avg:87.14ms +step:480/1680 train_time:41826ms step_avg:87.14ms +step:481/1680 train_time:41913ms step_avg:87.14ms +step:482/1680 train_time:42000ms step_avg:87.14ms +step:483/1680 train_time:42087ms step_avg:87.14ms +step:484/1680 train_time:42174ms step_avg:87.14ms +step:485/1680 train_time:42261ms step_avg:87.14ms +step:486/1680 train_time:42348ms step_avg:87.14ms +step:487/1680 train_time:42435ms step_avg:87.13ms +step:488/1680 train_time:42522ms step_avg:87.13ms +step:489/1680 train_time:42608ms step_avg:87.13ms +step:490/1680 train_time:42696ms step_avg:87.13ms +step:491/1680 train_time:42783ms step_avg:87.14ms +step:492/1680 train_time:42870ms step_avg:87.13ms +step:493/1680 train_time:42957ms step_avg:87.13ms +step:494/1680 train_time:43043ms step_avg:87.13ms +step:495/1680 train_time:43130ms step_avg:87.13ms +step:496/1680 train_time:43218ms step_avg:87.13ms +step:497/1680 train_time:43305ms step_avg:87.13ms +step:498/1680 train_time:43392ms step_avg:87.13ms +step:499/1680 train_time:43479ms step_avg:87.13ms +step:500/1680 train_time:43566ms step_avg:87.13ms +step:500/1680 val_loss:3.7161 train_time:43655ms step_avg:87.31ms +step:501/1680 train_time:43674ms step_avg:87.17ms +step:502/1680 train_time:43746ms step_avg:87.14ms +step:503/1680 train_time:43836ms step_avg:87.15ms +step:504/1680 train_time:43924ms step_avg:87.15ms +step:505/1680 train_time:44010ms step_avg:87.15ms +step:506/1680 train_time:44097ms step_avg:87.15ms +step:507/1680 train_time:44183ms step_avg:87.15ms +step:508/1680 train_time:44269ms step_avg:87.14ms +step:509/1680 train_time:44355ms step_avg:87.14ms +step:510/1680 train_time:44441ms step_avg:87.14ms +step:511/1680 train_time:44528ms step_avg:87.14ms +step:512/1680 train_time:44614ms step_avg:87.14ms +step:513/1680 train_time:44703ms step_avg:87.14ms +step:514/1680 train_time:44791ms step_avg:87.14ms +step:515/1680 train_time:44881ms step_avg:87.15ms +step:516/1680 train_time:44968ms step_avg:87.15ms +step:517/1680 train_time:45055ms step_avg:87.15ms +step:518/1680 train_time:45142ms step_avg:87.15ms +step:519/1680 train_time:45228ms step_avg:87.14ms +step:520/1680 train_time:45314ms step_avg:87.14ms +step:521/1680 train_time:45401ms step_avg:87.14ms +step:522/1680 train_time:45488ms step_avg:87.14ms +step:523/1680 train_time:45574ms step_avg:87.14ms +step:524/1680 train_time:45661ms step_avg:87.14ms +step:525/1680 train_time:45749ms step_avg:87.14ms +step:526/1680 train_time:45837ms step_avg:87.14ms +step:527/1680 train_time:45924ms step_avg:87.14ms +step:528/1680 train_time:46012ms step_avg:87.14ms +step:529/1680 train_time:46099ms step_avg:87.14ms +step:530/1680 train_time:46186ms step_avg:87.14ms +step:531/1680 train_time:46273ms step_avg:87.14ms +step:532/1680 train_time:46360ms step_avg:87.14ms +step:533/1680 train_time:46446ms step_avg:87.14ms +step:534/1680 train_time:46532ms step_avg:87.14ms +step:535/1680 train_time:46619ms step_avg:87.14ms +step:536/1680 train_time:46706ms step_avg:87.14ms +step:537/1680 train_time:46793ms step_avg:87.14ms +step:538/1680 train_time:46882ms step_avg:87.14ms +step:539/1680 train_time:46969ms step_avg:87.14ms +step:540/1680 train_time:47057ms step_avg:87.14ms +step:541/1680 train_time:47145ms step_avg:87.14ms +step:542/1680 train_time:47232ms step_avg:87.14ms +step:543/1680 train_time:47319ms step_avg:87.14ms +step:544/1680 train_time:47405ms step_avg:87.14ms +step:545/1680 train_time:47492ms step_avg:87.14ms +step:546/1680 train_time:47579ms step_avg:87.14ms +step:547/1680 train_time:47667ms step_avg:87.14ms +step:548/1680 train_time:47754ms step_avg:87.14ms +step:549/1680 train_time:47842ms step_avg:87.14ms +step:550/1680 train_time:47930ms step_avg:87.15ms +step:551/1680 train_time:48019ms step_avg:87.15ms +step:552/1680 train_time:48107ms step_avg:87.15ms +step:553/1680 train_time:48196ms step_avg:87.15ms +step:554/1680 train_time:48284ms step_avg:87.15ms +step:555/1680 train_time:48373ms step_avg:87.16ms +step:556/1680 train_time:48461ms step_avg:87.16ms +step:557/1680 train_time:48549ms step_avg:87.16ms +step:558/1680 train_time:48637ms step_avg:87.16ms +step:559/1680 train_time:48725ms step_avg:87.16ms +step:560/1680 train_time:48813ms step_avg:87.17ms +step:561/1680 train_time:48901ms step_avg:87.17ms +step:562/1680 train_time:48990ms step_avg:87.17ms +step:563/1680 train_time:49078ms step_avg:87.17ms +step:564/1680 train_time:49167ms step_avg:87.18ms +step:565/1680 train_time:49256ms step_avg:87.18ms +step:566/1680 train_time:49345ms step_avg:87.18ms +step:567/1680 train_time:49433ms step_avg:87.18ms +step:568/1680 train_time:49521ms step_avg:87.18ms +step:569/1680 train_time:49608ms step_avg:87.18ms +step:570/1680 train_time:49696ms step_avg:87.19ms +step:571/1680 train_time:49784ms step_avg:87.19ms +step:572/1680 train_time:49872ms step_avg:87.19ms +step:573/1680 train_time:49960ms step_avg:87.19ms +step:574/1680 train_time:50048ms step_avg:87.19ms +step:575/1680 train_time:50136ms step_avg:87.19ms +step:576/1680 train_time:50225ms step_avg:87.20ms +step:577/1680 train_time:50314ms step_avg:87.20ms +step:578/1680 train_time:50402ms step_avg:87.20ms +step:579/1680 train_time:50490ms step_avg:87.20ms +step:580/1680 train_time:50579ms step_avg:87.20ms +step:581/1680 train_time:50666ms step_avg:87.21ms +step:582/1680 train_time:50755ms step_avg:87.21ms +step:583/1680 train_time:50845ms step_avg:87.21ms +step:584/1680 train_time:50933ms step_avg:87.21ms +step:585/1680 train_time:51022ms step_avg:87.22ms +step:586/1680 train_time:51110ms step_avg:87.22ms +step:587/1680 train_time:51197ms step_avg:87.22ms +step:588/1680 train_time:51286ms step_avg:87.22ms +step:589/1680 train_time:51374ms step_avg:87.22ms +step:590/1680 train_time:51462ms step_avg:87.22ms +step:591/1680 train_time:51549ms step_avg:87.22ms +step:592/1680 train_time:51638ms step_avg:87.23ms +step:593/1680 train_time:51726ms step_avg:87.23ms +step:594/1680 train_time:51815ms step_avg:87.23ms +step:595/1680 train_time:51903ms step_avg:87.23ms +step:596/1680 train_time:51991ms step_avg:87.23ms +step:597/1680 train_time:52080ms step_avg:87.24ms +step:598/1680 train_time:52168ms step_avg:87.24ms +step:599/1680 train_time:52256ms step_avg:87.24ms +step:600/1680 train_time:52345ms step_avg:87.24ms +step:601/1680 train_time:52432ms step_avg:87.24ms +step:602/1680 train_time:52520ms step_avg:87.24ms +step:603/1680 train_time:52609ms step_avg:87.25ms +step:604/1680 train_time:52697ms step_avg:87.25ms +step:605/1680 train_time:52785ms step_avg:87.25ms +step:606/1680 train_time:52874ms step_avg:87.25ms +step:607/1680 train_time:52962ms step_avg:87.25ms +step:608/1680 train_time:53050ms step_avg:87.25ms +step:609/1680 train_time:53139ms step_avg:87.26ms +step:610/1680 train_time:53227ms step_avg:87.26ms +step:611/1680 train_time:53315ms step_avg:87.26ms +step:612/1680 train_time:53404ms step_avg:87.26ms +step:613/1680 train_time:53493ms step_avg:87.26ms +step:614/1680 train_time:53581ms step_avg:87.27ms +step:615/1680 train_time:53670ms step_avg:87.27ms +step:616/1680 train_time:53758ms step_avg:87.27ms +step:617/1680 train_time:53847ms step_avg:87.27ms +step:618/1680 train_time:53934ms step_avg:87.27ms +step:619/1680 train_time:54022ms step_avg:87.27ms +step:620/1680 train_time:54111ms step_avg:87.28ms +step:621/1680 train_time:54199ms step_avg:87.28ms +step:622/1680 train_time:54287ms step_avg:87.28ms +step:623/1680 train_time:54375ms step_avg:87.28ms +step:624/1680 train_time:54463ms step_avg:87.28ms +step:625/1680 train_time:54551ms step_avg:87.28ms +step:625/1680 val_loss:3.6159 train_time:54641ms step_avg:87.43ms +step:626/1680 train_time:54660ms step_avg:87.32ms +step:627/1680 train_time:54731ms step_avg:87.29ms +step:628/1680 train_time:54821ms step_avg:87.30ms +step:629/1680 train_time:54911ms step_avg:87.30ms +step:630/1680 train_time:55000ms step_avg:87.30ms +step:631/1680 train_time:55087ms step_avg:87.30ms +step:632/1680 train_time:55174ms step_avg:87.30ms +step:633/1680 train_time:55262ms step_avg:87.30ms +step:634/1680 train_time:55348ms step_avg:87.30ms +step:635/1680 train_time:55435ms step_avg:87.30ms +step:636/1680 train_time:55523ms step_avg:87.30ms +step:637/1680 train_time:55614ms step_avg:87.31ms +step:638/1680 train_time:55704ms step_avg:87.31ms +step:639/1680 train_time:55793ms step_avg:87.31ms +step:640/1680 train_time:55884ms step_avg:87.32ms +step:641/1680 train_time:55973ms step_avg:87.32ms +step:642/1680 train_time:56062ms step_avg:87.32ms +step:643/1680 train_time:56148ms step_avg:87.32ms +step:644/1680 train_time:56235ms step_avg:87.32ms +step:645/1680 train_time:56323ms step_avg:87.32ms +step:646/1680 train_time:56410ms step_avg:87.32ms +step:647/1680 train_time:56498ms step_avg:87.32ms +step:648/1680 train_time:56587ms step_avg:87.33ms +step:649/1680 train_time:56675ms step_avg:87.33ms +step:650/1680 train_time:56765ms step_avg:87.33ms +step:651/1680 train_time:56854ms step_avg:87.33ms +step:652/1680 train_time:56942ms step_avg:87.33ms +step:653/1680 train_time:57030ms step_avg:87.34ms +step:654/1680 train_time:57119ms step_avg:87.34ms +step:655/1680 train_time:57206ms step_avg:87.34ms +step:656/1680 train_time:57294ms step_avg:87.34ms +step:657/1680 train_time:57382ms step_avg:87.34ms +step:658/1680 train_time:57469ms step_avg:87.34ms +step:659/1680 train_time:57558ms step_avg:87.34ms +step:660/1680 train_time:57647ms step_avg:87.34ms +step:661/1680 train_time:57736ms step_avg:87.35ms +step:662/1680 train_time:57824ms step_avg:87.35ms +step:663/1680 train_time:57913ms step_avg:87.35ms +step:664/1680 train_time:58002ms step_avg:87.35ms +step:665/1680 train_time:58090ms step_avg:87.35ms +step:666/1680 train_time:58177ms step_avg:87.35ms +step:667/1680 train_time:58265ms step_avg:87.35ms +step:668/1680 train_time:58353ms step_avg:87.35ms +step:669/1680 train_time:58440ms step_avg:87.35ms +step:670/1680 train_time:58528ms step_avg:87.36ms +step:671/1680 train_time:58617ms step_avg:87.36ms +step:672/1680 train_time:58706ms step_avg:87.36ms +step:673/1680 train_time:58795ms step_avg:87.36ms +step:674/1680 train_time:58884ms step_avg:87.36ms +step:675/1680 train_time:58972ms step_avg:87.37ms +step:676/1680 train_time:59062ms step_avg:87.37ms +step:677/1680 train_time:59150ms step_avg:87.37ms +step:678/1680 train_time:59237ms step_avg:87.37ms +step:679/1680 train_time:59325ms step_avg:87.37ms +step:680/1680 train_time:59413ms step_avg:87.37ms +step:681/1680 train_time:59501ms step_avg:87.37ms +step:682/1680 train_time:59589ms step_avg:87.37ms +step:683/1680 train_time:59679ms step_avg:87.38ms +step:684/1680 train_time:59767ms step_avg:87.38ms +step:685/1680 train_time:59856ms step_avg:87.38ms +step:686/1680 train_time:59944ms step_avg:87.38ms +step:687/1680 train_time:60033ms step_avg:87.38ms +step:688/1680 train_time:60121ms step_avg:87.39ms +step:689/1680 train_time:60208ms step_avg:87.39ms +step:690/1680 train_time:60296ms step_avg:87.39ms +step:691/1680 train_time:60384ms step_avg:87.39ms +step:692/1680 train_time:60472ms step_avg:87.39ms +step:693/1680 train_time:60561ms step_avg:87.39ms +step:694/1680 train_time:60649ms step_avg:87.39ms +step:695/1680 train_time:60737ms step_avg:87.39ms +step:696/1680 train_time:60826ms step_avg:87.39ms +step:697/1680 train_time:60914ms step_avg:87.40ms +step:698/1680 train_time:61004ms step_avg:87.40ms +step:699/1680 train_time:61091ms step_avg:87.40ms +step:700/1680 train_time:61179ms step_avg:87.40ms +step:701/1680 train_time:61267ms step_avg:87.40ms +step:702/1680 train_time:61355ms step_avg:87.40ms +step:703/1680 train_time:61443ms step_avg:87.40ms +step:704/1680 train_time:61532ms step_avg:87.40ms +step:705/1680 train_time:61620ms step_avg:87.40ms +step:706/1680 train_time:61708ms step_avg:87.41ms +step:707/1680 train_time:61796ms step_avg:87.41ms +step:708/1680 train_time:61885ms step_avg:87.41ms +step:709/1680 train_time:61974ms step_avg:87.41ms +step:710/1680 train_time:62063ms step_avg:87.41ms +step:711/1680 train_time:62151ms step_avg:87.41ms +step:712/1680 train_time:62239ms step_avg:87.41ms +step:713/1680 train_time:62327ms step_avg:87.41ms +step:714/1680 train_time:62415ms step_avg:87.42ms +step:715/1680 train_time:62503ms step_avg:87.42ms +step:716/1680 train_time:62591ms step_avg:87.42ms +step:717/1680 train_time:62679ms step_avg:87.42ms +step:718/1680 train_time:62767ms step_avg:87.42ms +step:719/1680 train_time:62855ms step_avg:87.42ms +step:720/1680 train_time:62943ms step_avg:87.42ms +step:721/1680 train_time:63032ms step_avg:87.42ms +step:722/1680 train_time:63120ms step_avg:87.42ms +step:723/1680 train_time:63208ms step_avg:87.42ms +step:724/1680 train_time:63296ms step_avg:87.43ms +step:725/1680 train_time:63384ms step_avg:87.43ms +step:726/1680 train_time:63473ms step_avg:87.43ms +step:727/1680 train_time:63562ms step_avg:87.43ms +step:728/1680 train_time:63649ms step_avg:87.43ms +step:729/1680 train_time:63737ms step_avg:87.43ms +step:730/1680 train_time:63825ms step_avg:87.43ms +step:731/1680 train_time:63914ms step_avg:87.43ms +step:732/1680 train_time:64003ms step_avg:87.44ms +step:733/1680 train_time:64090ms step_avg:87.44ms +step:734/1680 train_time:64178ms step_avg:87.44ms +step:735/1680 train_time:64266ms step_avg:87.44ms +step:736/1680 train_time:64354ms step_avg:87.44ms +step:737/1680 train_time:64443ms step_avg:87.44ms +step:738/1680 train_time:64531ms step_avg:87.44ms +step:739/1680 train_time:64619ms step_avg:87.44ms +step:740/1680 train_time:64707ms step_avg:87.44ms +step:741/1680 train_time:64795ms step_avg:87.44ms +step:742/1680 train_time:64883ms step_avg:87.44ms +step:743/1680 train_time:64972ms step_avg:87.44ms +step:744/1680 train_time:65060ms step_avg:87.45ms +step:745/1680 train_time:65147ms step_avg:87.45ms +step:746/1680 train_time:65236ms step_avg:87.45ms +step:747/1680 train_time:65323ms step_avg:87.45ms +step:748/1680 train_time:65412ms step_avg:87.45ms +step:749/1680 train_time:65499ms step_avg:87.45ms +step:750/1680 train_time:65588ms step_avg:87.45ms +step:750/1680 val_loss:3.5655 train_time:65678ms step_avg:87.57ms +step:751/1680 train_time:65697ms step_avg:87.48ms +step:752/1680 train_time:65769ms step_avg:87.46ms +step:753/1680 train_time:65862ms step_avg:87.47ms +step:754/1680 train_time:65951ms step_avg:87.47ms +step:755/1680 train_time:66038ms step_avg:87.47ms +step:756/1680 train_time:66125ms step_avg:87.47ms +step:757/1680 train_time:66212ms step_avg:87.47ms +step:758/1680 train_time:66300ms step_avg:87.47ms +step:759/1680 train_time:66387ms step_avg:87.47ms +step:760/1680 train_time:66476ms step_avg:87.47ms +step:761/1680 train_time:66563ms step_avg:87.47ms +step:762/1680 train_time:66652ms step_avg:87.47ms +step:763/1680 train_time:66741ms step_avg:87.47ms +step:764/1680 train_time:66832ms step_avg:87.48ms +step:765/1680 train_time:66921ms step_avg:87.48ms +step:766/1680 train_time:67009ms step_avg:87.48ms +step:767/1680 train_time:67096ms step_avg:87.48ms +step:768/1680 train_time:67184ms step_avg:87.48ms +step:769/1680 train_time:67271ms step_avg:87.48ms +step:770/1680 train_time:67359ms step_avg:87.48ms +step:771/1680 train_time:67447ms step_avg:87.48ms +step:772/1680 train_time:67535ms step_avg:87.48ms +step:773/1680 train_time:67623ms step_avg:87.48ms +step:774/1680 train_time:67713ms step_avg:87.48ms +step:775/1680 train_time:67802ms step_avg:87.49ms +step:776/1680 train_time:67891ms step_avg:87.49ms +step:777/1680 train_time:67980ms step_avg:87.49ms +step:778/1680 train_time:68068ms step_avg:87.49ms +step:779/1680 train_time:68156ms step_avg:87.49ms +step:780/1680 train_time:68243ms step_avg:87.49ms +step:781/1680 train_time:68331ms step_avg:87.49ms +step:782/1680 train_time:68419ms step_avg:87.49ms +step:783/1680 train_time:68507ms step_avg:87.49ms +step:784/1680 train_time:68594ms step_avg:87.49ms +step:785/1680 train_time:68683ms step_avg:87.49ms +step:786/1680 train_time:68771ms step_avg:87.50ms +step:787/1680 train_time:68860ms step_avg:87.50ms +step:788/1680 train_time:68948ms step_avg:87.50ms +step:789/1680 train_time:69037ms step_avg:87.50ms +step:790/1680 train_time:69125ms step_avg:87.50ms +step:791/1680 train_time:69214ms step_avg:87.50ms +step:792/1680 train_time:69301ms step_avg:87.50ms +step:793/1680 train_time:69389ms step_avg:87.50ms +step:794/1680 train_time:69478ms step_avg:87.50ms +step:795/1680 train_time:69566ms step_avg:87.50ms +step:796/1680 train_time:69654ms step_avg:87.51ms +step:797/1680 train_time:69743ms step_avg:87.51ms +step:798/1680 train_time:69831ms step_avg:87.51ms +step:799/1680 train_time:69919ms step_avg:87.51ms +step:800/1680 train_time:70007ms step_avg:87.51ms +step:801/1680 train_time:70095ms step_avg:87.51ms +step:802/1680 train_time:70183ms step_avg:87.51ms +step:803/1680 train_time:70271ms step_avg:87.51ms +step:804/1680 train_time:70359ms step_avg:87.51ms +step:805/1680 train_time:70446ms step_avg:87.51ms +step:806/1680 train_time:70534ms step_avg:87.51ms +step:807/1680 train_time:70623ms step_avg:87.51ms +step:808/1680 train_time:70711ms step_avg:87.51ms +step:809/1680 train_time:70800ms step_avg:87.52ms +step:810/1680 train_time:70888ms step_avg:87.52ms +step:811/1680 train_time:70976ms step_avg:87.52ms +step:812/1680 train_time:71064ms step_avg:87.52ms +step:813/1680 train_time:71154ms step_avg:87.52ms +step:814/1680 train_time:71242ms step_avg:87.52ms +step:815/1680 train_time:71330ms step_avg:87.52ms +step:816/1680 train_time:71418ms step_avg:87.52ms +step:817/1680 train_time:71506ms step_avg:87.52ms +step:818/1680 train_time:71594ms step_avg:87.52ms +step:819/1680 train_time:71682ms step_avg:87.52ms +step:820/1680 train_time:71770ms step_avg:87.52ms +step:821/1680 train_time:71859ms step_avg:87.53ms +step:822/1680 train_time:71947ms step_avg:87.53ms +step:823/1680 train_time:72035ms step_avg:87.53ms +step:824/1680 train_time:72124ms step_avg:87.53ms +step:825/1680 train_time:72213ms step_avg:87.53ms +step:826/1680 train_time:72302ms step_avg:87.53ms +step:827/1680 train_time:72389ms step_avg:87.53ms +step:828/1680 train_time:72477ms step_avg:87.53ms +step:829/1680 train_time:72565ms step_avg:87.53ms +step:830/1680 train_time:72653ms step_avg:87.53ms +step:831/1680 train_time:72742ms step_avg:87.54ms +step:832/1680 train_time:72830ms step_avg:87.54ms +step:833/1680 train_time:72918ms step_avg:87.54ms +step:834/1680 train_time:73006ms step_avg:87.54ms +step:835/1680 train_time:73094ms step_avg:87.54ms +step:836/1680 train_time:73182ms step_avg:87.54ms +step:837/1680 train_time:73271ms step_avg:87.54ms +step:838/1680 train_time:73360ms step_avg:87.54ms +step:839/1680 train_time:73448ms step_avg:87.54ms +step:840/1680 train_time:73536ms step_avg:87.54ms +step:841/1680 train_time:73623ms step_avg:87.54ms +step:842/1680 train_time:73712ms step_avg:87.54ms +step:843/1680 train_time:73800ms step_avg:87.54ms +step:844/1680 train_time:73888ms step_avg:87.55ms +step:845/1680 train_time:73977ms step_avg:87.55ms +step:846/1680 train_time:74065ms step_avg:87.55ms +step:847/1680 train_time:74153ms step_avg:87.55ms +step:848/1680 train_time:74242ms step_avg:87.55ms +step:849/1680 train_time:74330ms step_avg:87.55ms +step:850/1680 train_time:74419ms step_avg:87.55ms +step:851/1680 train_time:74507ms step_avg:87.55ms +step:852/1680 train_time:74595ms step_avg:87.55ms +step:853/1680 train_time:74683ms step_avg:87.55ms +step:854/1680 train_time:74771ms step_avg:87.55ms +step:855/1680 train_time:74860ms step_avg:87.56ms +step:856/1680 train_time:74948ms step_avg:87.56ms +step:857/1680 train_time:75037ms step_avg:87.56ms +step:858/1680 train_time:75125ms step_avg:87.56ms +step:859/1680 train_time:75213ms step_avg:87.56ms +step:860/1680 train_time:75302ms step_avg:87.56ms +step:861/1680 train_time:75390ms step_avg:87.56ms +step:862/1680 train_time:75478ms step_avg:87.56ms +step:863/1680 train_time:75567ms step_avg:87.56ms +step:864/1680 train_time:75655ms step_avg:87.56ms +step:865/1680 train_time:75744ms step_avg:87.56ms +step:866/1680 train_time:75832ms step_avg:87.57ms +step:867/1680 train_time:75920ms step_avg:87.57ms +step:868/1680 train_time:76008ms step_avg:87.57ms +step:869/1680 train_time:76096ms step_avg:87.57ms +step:870/1680 train_time:76184ms step_avg:87.57ms +step:871/1680 train_time:76272ms step_avg:87.57ms +step:872/1680 train_time:76360ms step_avg:87.57ms +step:873/1680 train_time:76449ms step_avg:87.57ms +step:874/1680 train_time:76537ms step_avg:87.57ms +step:875/1680 train_time:76625ms step_avg:87.57ms +step:875/1680 val_loss:3.5193 train_time:76714ms step_avg:87.67ms +step:876/1680 train_time:76733ms step_avg:87.59ms +step:877/1680 train_time:76805ms step_avg:87.58ms +step:878/1680 train_time:76896ms step_avg:87.58ms +step:879/1680 train_time:76987ms step_avg:87.58ms +step:880/1680 train_time:77075ms step_avg:87.59ms +step:881/1680 train_time:77162ms step_avg:87.58ms +step:882/1680 train_time:77249ms step_avg:87.58ms +step:883/1680 train_time:77336ms step_avg:87.58ms +step:884/1680 train_time:77423ms step_avg:87.58ms +step:885/1680 train_time:77510ms step_avg:87.58ms +step:886/1680 train_time:77598ms step_avg:87.58ms +step:887/1680 train_time:77687ms step_avg:87.58ms +step:888/1680 train_time:77778ms step_avg:87.59ms +step:889/1680 train_time:77868ms step_avg:87.59ms +step:890/1680 train_time:77957ms step_avg:87.59ms +step:891/1680 train_time:78046ms step_avg:87.59ms +step:892/1680 train_time:78134ms step_avg:87.59ms +step:893/1680 train_time:78221ms step_avg:87.59ms +step:894/1680 train_time:78309ms step_avg:87.59ms +step:895/1680 train_time:78396ms step_avg:87.59ms +step:896/1680 train_time:78483ms step_avg:87.59ms +step:897/1680 train_time:78571ms step_avg:87.59ms +step:898/1680 train_time:78659ms step_avg:87.59ms +step:899/1680 train_time:78748ms step_avg:87.60ms +step:900/1680 train_time:78837ms step_avg:87.60ms +step:901/1680 train_time:78927ms step_avg:87.60ms +step:902/1680 train_time:79016ms step_avg:87.60ms +step:903/1680 train_time:79104ms step_avg:87.60ms +step:904/1680 train_time:79192ms step_avg:87.60ms +step:905/1680 train_time:79280ms step_avg:87.60ms +step:906/1680 train_time:79368ms step_avg:87.60ms +step:907/1680 train_time:79455ms step_avg:87.60ms +step:908/1680 train_time:79543ms step_avg:87.60ms +step:909/1680 train_time:79630ms step_avg:87.60ms +step:910/1680 train_time:79719ms step_avg:87.60ms +step:911/1680 train_time:79809ms step_avg:87.61ms +step:912/1680 train_time:79897ms step_avg:87.61ms +step:913/1680 train_time:79986ms step_avg:87.61ms +step:914/1680 train_time:80074ms step_avg:87.61ms +step:915/1680 train_time:80162ms step_avg:87.61ms +step:916/1680 train_time:80249ms step_avg:87.61ms +step:917/1680 train_time:80337ms step_avg:87.61ms +step:918/1680 train_time:80425ms step_avg:87.61ms +step:919/1680 train_time:80514ms step_avg:87.61ms +step:920/1680 train_time:80601ms step_avg:87.61ms +step:921/1680 train_time:80689ms step_avg:87.61ms +step:922/1680 train_time:80778ms step_avg:87.61ms +step:923/1680 train_time:80867ms step_avg:87.61ms +step:924/1680 train_time:80956ms step_avg:87.61ms +step:925/1680 train_time:81044ms step_avg:87.62ms +step:926/1680 train_time:81133ms step_avg:87.62ms +step:927/1680 train_time:81221ms step_avg:87.62ms +step:928/1680 train_time:81309ms step_avg:87.62ms +step:929/1680 train_time:81397ms step_avg:87.62ms +step:930/1680 train_time:81485ms step_avg:87.62ms +step:931/1680 train_time:81574ms step_avg:87.62ms +step:932/1680 train_time:81663ms step_avg:87.62ms +step:933/1680 train_time:81751ms step_avg:87.62ms +step:934/1680 train_time:81839ms step_avg:87.62ms +step:935/1680 train_time:81929ms step_avg:87.62ms +step:936/1680 train_time:82018ms step_avg:87.63ms +step:937/1680 train_time:82107ms step_avg:87.63ms +step:938/1680 train_time:82195ms step_avg:87.63ms +step:939/1680 train_time:82283ms step_avg:87.63ms +step:940/1680 train_time:82371ms step_avg:87.63ms +step:941/1680 train_time:82458ms step_avg:87.63ms +step:942/1680 train_time:82547ms step_avg:87.63ms +step:943/1680 train_time:82635ms step_avg:87.63ms +step:944/1680 train_time:82723ms step_avg:87.63ms +step:945/1680 train_time:82812ms step_avg:87.63ms +step:946/1680 train_time:82900ms step_avg:87.63ms +step:947/1680 train_time:82990ms step_avg:87.63ms +step:948/1680 train_time:83078ms step_avg:87.63ms +step:949/1680 train_time:83166ms step_avg:87.63ms +step:950/1680 train_time:83253ms step_avg:87.63ms +step:951/1680 train_time:83341ms step_avg:87.64ms +step:952/1680 train_time:83429ms step_avg:87.64ms +step:953/1680 train_time:83517ms step_avg:87.64ms +step:954/1680 train_time:83605ms step_avg:87.64ms +step:955/1680 train_time:83693ms step_avg:87.64ms +step:956/1680 train_time:83782ms step_avg:87.64ms +step:957/1680 train_time:83870ms step_avg:87.64ms +step:958/1680 train_time:83958ms step_avg:87.64ms +step:959/1680 train_time:84047ms step_avg:87.64ms +step:960/1680 train_time:84135ms step_avg:87.64ms +step:961/1680 train_time:84223ms step_avg:87.64ms +step:962/1680 train_time:84311ms step_avg:87.64ms +step:963/1680 train_time:84399ms step_avg:87.64ms +step:964/1680 train_time:84487ms step_avg:87.64ms +step:965/1680 train_time:84576ms step_avg:87.64ms +step:966/1680 train_time:84664ms step_avg:87.64ms +step:967/1680 train_time:84751ms step_avg:87.64ms +step:968/1680 train_time:84839ms step_avg:87.64ms +step:969/1680 train_time:84927ms step_avg:87.64ms +step:970/1680 train_time:85016ms step_avg:87.65ms +step:971/1680 train_time:85104ms step_avg:87.65ms +step:972/1680 train_time:85193ms step_avg:87.65ms +step:973/1680 train_time:85282ms step_avg:87.65ms +step:974/1680 train_time:85370ms step_avg:87.65ms +step:975/1680 train_time:85458ms step_avg:87.65ms +step:976/1680 train_time:85546ms step_avg:87.65ms +step:977/1680 train_time:85634ms step_avg:87.65ms +step:978/1680 train_time:85722ms step_avg:87.65ms +step:979/1680 train_time:85811ms step_avg:87.65ms +step:980/1680 train_time:85898ms step_avg:87.65ms +step:981/1680 train_time:85986ms step_avg:87.65ms +step:982/1680 train_time:86075ms step_avg:87.65ms +step:983/1680 train_time:86164ms step_avg:87.65ms +step:984/1680 train_time:86253ms step_avg:87.66ms +step:985/1680 train_time:86340ms step_avg:87.66ms +step:986/1680 train_time:86428ms step_avg:87.66ms +step:987/1680 train_time:86516ms step_avg:87.66ms +step:988/1680 train_time:86604ms step_avg:87.66ms +step:989/1680 train_time:86692ms step_avg:87.66ms +step:990/1680 train_time:86780ms step_avg:87.66ms +step:991/1680 train_time:86869ms step_avg:87.66ms +step:992/1680 train_time:86956ms step_avg:87.66ms +step:993/1680 train_time:87045ms step_avg:87.66ms +step:994/1680 train_time:87133ms step_avg:87.66ms +step:995/1680 train_time:87221ms step_avg:87.66ms +step:996/1680 train_time:87309ms step_avg:87.66ms +step:997/1680 train_time:87397ms step_avg:87.66ms +step:998/1680 train_time:87485ms step_avg:87.66ms +step:999/1680 train_time:87573ms step_avg:87.66ms +step:1000/1680 train_time:87660ms step_avg:87.66ms +step:1000/1680 val_loss:3.4679 train_time:87749ms step_avg:87.75ms +step:1001/1680 train_time:87768ms step_avg:87.68ms +step:1002/1680 train_time:87841ms step_avg:87.67ms +step:1003/1680 train_time:87932ms step_avg:87.67ms +step:1004/1680 train_time:88022ms step_avg:87.67ms +step:1005/1680 train_time:88110ms step_avg:87.67ms +step:1006/1680 train_time:88197ms step_avg:87.67ms +step:1007/1680 train_time:88284ms step_avg:87.67ms +step:1008/1680 train_time:88371ms step_avg:87.67ms +step:1009/1680 train_time:88458ms step_avg:87.67ms +step:1010/1680 train_time:88546ms step_avg:87.67ms +step:1011/1680 train_time:88633ms step_avg:87.67ms +step:1012/1680 train_time:88722ms step_avg:87.67ms +step:1013/1680 train_time:88812ms step_avg:87.67ms +step:1014/1680 train_time:88902ms step_avg:87.67ms +step:1015/1680 train_time:88991ms step_avg:87.68ms +step:1016/1680 train_time:89080ms step_avg:87.68ms +step:1017/1680 train_time:89168ms step_avg:87.68ms +step:1018/1680 train_time:89255ms step_avg:87.68ms +step:1019/1680 train_time:89343ms step_avg:87.68ms +step:1020/1680 train_time:89431ms step_avg:87.68ms +step:1021/1680 train_time:89518ms step_avg:87.68ms +step:1022/1680 train_time:89605ms step_avg:87.68ms +step:1023/1680 train_time:89693ms step_avg:87.68ms +step:1024/1680 train_time:89781ms step_avg:87.68ms +step:1025/1680 train_time:89871ms step_avg:87.68ms +step:1026/1680 train_time:89960ms step_avg:87.68ms +step:1027/1680 train_time:90050ms step_avg:87.68ms +step:1028/1680 train_time:90139ms step_avg:87.68ms +step:1029/1680 train_time:90227ms step_avg:87.68ms +step:1030/1680 train_time:90314ms step_avg:87.68ms +step:1031/1680 train_time:90402ms step_avg:87.68ms +step:1032/1680 train_time:90490ms step_avg:87.68ms +step:1033/1680 train_time:90577ms step_avg:87.68ms +step:1034/1680 train_time:90666ms step_avg:87.68ms +step:1035/1680 train_time:90754ms step_avg:87.68ms +step:1036/1680 train_time:90842ms step_avg:87.69ms +step:1037/1680 train_time:90931ms step_avg:87.69ms +step:1038/1680 train_time:91020ms step_avg:87.69ms +step:1039/1680 train_time:91109ms step_avg:87.69ms +step:1040/1680 train_time:91197ms step_avg:87.69ms +step:1041/1680 train_time:91285ms step_avg:87.69ms +step:1042/1680 train_time:91373ms step_avg:87.69ms +step:1043/1680 train_time:91461ms step_avg:87.69ms +step:1044/1680 train_time:91549ms step_avg:87.69ms +step:1045/1680 train_time:91638ms step_avg:87.69ms +step:1046/1680 train_time:91726ms step_avg:87.69ms +step:1047/1680 train_time:91814ms step_avg:87.69ms +step:1048/1680 train_time:91902ms step_avg:87.69ms +step:1049/1680 train_time:91990ms step_avg:87.69ms +step:1050/1680 train_time:92079ms step_avg:87.69ms +step:1051/1680 train_time:92168ms step_avg:87.70ms +step:1052/1680 train_time:92257ms step_avg:87.70ms +step:1053/1680 train_time:92344ms step_avg:87.70ms +step:1054/1680 train_time:92432ms step_avg:87.70ms +step:1055/1680 train_time:92520ms step_avg:87.70ms +step:1056/1680 train_time:92608ms step_avg:87.70ms +step:1057/1680 train_time:92697ms step_avg:87.70ms +step:1058/1680 train_time:92785ms step_avg:87.70ms +step:1059/1680 train_time:92873ms step_avg:87.70ms +step:1060/1680 train_time:92961ms step_avg:87.70ms +step:1061/1680 train_time:93051ms step_avg:87.70ms +step:1062/1680 train_time:93140ms step_avg:87.70ms +step:1063/1680 train_time:93229ms step_avg:87.70ms +step:1064/1680 train_time:93316ms step_avg:87.70ms +step:1065/1680 train_time:93404ms step_avg:87.70ms +step:1066/1680 train_time:93491ms step_avg:87.70ms +step:1067/1680 train_time:93579ms step_avg:87.70ms +step:1068/1680 train_time:93668ms step_avg:87.70ms +step:1069/1680 train_time:93756ms step_avg:87.70ms +step:1070/1680 train_time:93845ms step_avg:87.71ms +step:1071/1680 train_time:93933ms step_avg:87.71ms +step:1072/1680 train_time:94021ms step_avg:87.71ms +step:1073/1680 train_time:94110ms step_avg:87.71ms +step:1074/1680 train_time:94198ms step_avg:87.71ms +step:1075/1680 train_time:94287ms step_avg:87.71ms +step:1076/1680 train_time:94374ms step_avg:87.71ms +step:1077/1680 train_time:94463ms step_avg:87.71ms +step:1078/1680 train_time:94550ms step_avg:87.71ms +step:1079/1680 train_time:94639ms step_avg:87.71ms +step:1080/1680 train_time:94728ms step_avg:87.71ms +step:1081/1680 train_time:94816ms step_avg:87.71ms +step:1082/1680 train_time:94904ms step_avg:87.71ms +step:1083/1680 train_time:94992ms step_avg:87.71ms +step:1084/1680 train_time:95081ms step_avg:87.71ms +step:1085/1680 train_time:95169ms step_avg:87.71ms +step:1086/1680 train_time:95257ms step_avg:87.71ms +step:1087/1680 train_time:95345ms step_avg:87.71ms +step:1088/1680 train_time:95433ms step_avg:87.71ms +step:1089/1680 train_time:95521ms step_avg:87.71ms +step:1090/1680 train_time:95609ms step_avg:87.71ms +step:1091/1680 train_time:95698ms step_avg:87.72ms +step:1092/1680 train_time:95786ms step_avg:87.72ms +step:1093/1680 train_time:95873ms step_avg:87.72ms +step:1094/1680 train_time:95962ms step_avg:87.72ms +step:1095/1680 train_time:96051ms step_avg:87.72ms +step:1096/1680 train_time:96140ms step_avg:87.72ms +step:1097/1680 train_time:96230ms step_avg:87.72ms +step:1098/1680 train_time:96318ms step_avg:87.72ms +step:1099/1680 train_time:96407ms step_avg:87.72ms +step:1100/1680 train_time:96495ms step_avg:87.72ms +step:1101/1680 train_time:96585ms step_avg:87.72ms +step:1102/1680 train_time:96674ms step_avg:87.73ms +step:1103/1680 train_time:96763ms step_avg:87.73ms +step:1104/1680 train_time:96852ms step_avg:87.73ms +step:1105/1680 train_time:96941ms step_avg:87.73ms +step:1106/1680 train_time:97030ms step_avg:87.73ms +step:1107/1680 train_time:97119ms step_avg:87.73ms +step:1108/1680 train_time:97209ms step_avg:87.73ms +step:1109/1680 train_time:97298ms step_avg:87.73ms +step:1110/1680 train_time:97386ms step_avg:87.74ms +step:1111/1680 train_time:97476ms step_avg:87.74ms +step:1112/1680 train_time:97564ms step_avg:87.74ms +step:1113/1680 train_time:97652ms step_avg:87.74ms +step:1114/1680 train_time:97741ms step_avg:87.74ms +step:1115/1680 train_time:97831ms step_avg:87.74ms +step:1116/1680 train_time:97920ms step_avg:87.74ms +step:1117/1680 train_time:98009ms step_avg:87.74ms +step:1118/1680 train_time:98098ms step_avg:87.74ms +step:1119/1680 train_time:98188ms step_avg:87.75ms +step:1120/1680 train_time:98277ms step_avg:87.75ms +step:1121/1680 train_time:98366ms step_avg:87.75ms +step:1122/1680 train_time:98454ms step_avg:87.75ms +step:1123/1680 train_time:98543ms step_avg:87.75ms +step:1124/1680 train_time:98632ms step_avg:87.75ms +step:1125/1680 train_time:98720ms step_avg:87.75ms +step:1125/1680 val_loss:3.4153 train_time:98811ms step_avg:87.83ms +step:1126/1680 train_time:98831ms step_avg:87.77ms +step:1127/1680 train_time:98900ms step_avg:87.75ms +step:1128/1680 train_time:98993ms step_avg:87.76ms +step:1129/1680 train_time:99084ms step_avg:87.76ms +step:1130/1680 train_time:99173ms step_avg:87.76ms +step:1131/1680 train_time:99262ms step_avg:87.76ms +step:1132/1680 train_time:99349ms step_avg:87.76ms +step:1133/1680 train_time:99437ms step_avg:87.76ms +step:1134/1680 train_time:99524ms step_avg:87.76ms +step:1135/1680 train_time:99612ms step_avg:87.76ms +step:1136/1680 train_time:99700ms step_avg:87.76ms +step:1137/1680 train_time:99790ms step_avg:87.77ms +step:1138/1680 train_time:99882ms step_avg:87.77ms +step:1139/1680 train_time:99973ms step_avg:87.77ms +step:1140/1680 train_time:100064ms step_avg:87.78ms +step:1141/1680 train_time:100153ms step_avg:87.78ms +step:1142/1680 train_time:100242ms step_avg:87.78ms +step:1143/1680 train_time:100330ms step_avg:87.78ms +step:1144/1680 train_time:100418ms step_avg:87.78ms +step:1145/1680 train_time:100507ms step_avg:87.78ms +step:1146/1680 train_time:100594ms step_avg:87.78ms +step:1147/1680 train_time:100683ms step_avg:87.78ms +step:1148/1680 train_time:100772ms step_avg:87.78ms +step:1149/1680 train_time:100862ms step_avg:87.78ms +step:1150/1680 train_time:100951ms step_avg:87.78ms +step:1151/1680 train_time:101042ms step_avg:87.79ms +step:1152/1680 train_time:101131ms step_avg:87.79ms +step:1153/1680 train_time:101220ms step_avg:87.79ms +step:1154/1680 train_time:101309ms step_avg:87.79ms +step:1155/1680 train_time:101399ms step_avg:87.79ms +step:1156/1680 train_time:101488ms step_avg:87.79ms +step:1157/1680 train_time:101575ms step_avg:87.79ms +step:1158/1680 train_time:101664ms step_avg:87.79ms +step:1159/1680 train_time:101752ms step_avg:87.79ms +step:1160/1680 train_time:101842ms step_avg:87.79ms +step:1161/1680 train_time:101931ms step_avg:87.80ms +step:1162/1680 train_time:102021ms step_avg:87.80ms +step:1163/1680 train_time:102110ms step_avg:87.80ms +step:1164/1680 train_time:102200ms step_avg:87.80ms +step:1165/1680 train_time:102290ms step_avg:87.80ms +step:1166/1680 train_time:102379ms step_avg:87.80ms +step:1167/1680 train_time:102467ms step_avg:87.80ms +step:1168/1680 train_time:102556ms step_avg:87.80ms +step:1169/1680 train_time:102645ms step_avg:87.81ms +step:1170/1680 train_time:102734ms step_avg:87.81ms +step:1171/1680 train_time:102823ms step_avg:87.81ms +step:1172/1680 train_time:102912ms step_avg:87.81ms +step:1173/1680 train_time:103001ms step_avg:87.81ms +step:1174/1680 train_time:103091ms step_avg:87.81ms +step:1175/1680 train_time:103180ms step_avg:87.81ms +step:1176/1680 train_time:103269ms step_avg:87.81ms +step:1177/1680 train_time:103359ms step_avg:87.82ms +step:1178/1680 train_time:103449ms step_avg:87.82ms +step:1179/1680 train_time:103537ms step_avg:87.82ms +step:1180/1680 train_time:103625ms step_avg:87.82ms +step:1181/1680 train_time:103715ms step_avg:87.82ms +step:1182/1680 train_time:103804ms step_avg:87.82ms +step:1183/1680 train_time:103893ms step_avg:87.82ms +step:1184/1680 train_time:103981ms step_avg:87.82ms +step:1185/1680 train_time:104070ms step_avg:87.82ms +step:1186/1680 train_time:104160ms step_avg:87.82ms +step:1187/1680 train_time:104249ms step_avg:87.83ms +step:1188/1680 train_time:104337ms step_avg:87.83ms +step:1189/1680 train_time:104427ms step_avg:87.83ms +step:1190/1680 train_time:104516ms step_avg:87.83ms +step:1191/1680 train_time:104604ms step_avg:87.83ms +step:1192/1680 train_time:104693ms step_avg:87.83ms +step:1193/1680 train_time:104782ms step_avg:87.83ms +step:1194/1680 train_time:104871ms step_avg:87.83ms +step:1195/1680 train_time:104960ms step_avg:87.83ms +step:1196/1680 train_time:105049ms step_avg:87.83ms +step:1197/1680 train_time:105138ms step_avg:87.83ms +step:1198/1680 train_time:105227ms step_avg:87.84ms +step:1199/1680 train_time:105315ms step_avg:87.84ms +step:1200/1680 train_time:105405ms step_avg:87.84ms +step:1201/1680 train_time:105493ms step_avg:87.84ms +step:1202/1680 train_time:105582ms step_avg:87.84ms +step:1203/1680 train_time:105670ms step_avg:87.84ms +step:1204/1680 train_time:105759ms step_avg:87.84ms +step:1205/1680 train_time:105849ms step_avg:87.84ms +step:1206/1680 train_time:105938ms step_avg:87.84ms +step:1207/1680 train_time:106027ms step_avg:87.84ms +step:1208/1680 train_time:106116ms step_avg:87.84ms +step:1209/1680 train_time:106204ms step_avg:87.84ms +step:1210/1680 train_time:106293ms step_avg:87.85ms +step:1211/1680 train_time:106382ms step_avg:87.85ms +step:1212/1680 train_time:106472ms step_avg:87.85ms +step:1213/1680 train_time:106561ms step_avg:87.85ms +step:1214/1680 train_time:106649ms step_avg:87.85ms +step:1215/1680 train_time:106738ms step_avg:87.85ms +step:1216/1680 train_time:106827ms step_avg:87.85ms +step:1217/1680 train_time:106916ms step_avg:87.85ms +step:1218/1680 train_time:107006ms step_avg:87.85ms +step:1219/1680 train_time:107095ms step_avg:87.85ms +step:1220/1680 train_time:107185ms step_avg:87.86ms +step:1221/1680 train_time:107273ms step_avg:87.86ms +step:1222/1680 train_time:107362ms step_avg:87.86ms +step:1223/1680 train_time:107451ms step_avg:87.86ms +step:1224/1680 train_time:107539ms step_avg:87.86ms +step:1225/1680 train_time:107629ms step_avg:87.86ms +step:1226/1680 train_time:107717ms step_avg:87.86ms +step:1227/1680 train_time:107806ms step_avg:87.86ms +step:1228/1680 train_time:107896ms step_avg:87.86ms +step:1229/1680 train_time:107985ms step_avg:87.86ms +step:1230/1680 train_time:108074ms step_avg:87.87ms +step:1231/1680 train_time:108164ms step_avg:87.87ms +step:1232/1680 train_time:108253ms step_avg:87.87ms +step:1233/1680 train_time:108342ms step_avg:87.87ms +step:1234/1680 train_time:108431ms step_avg:87.87ms +step:1235/1680 train_time:108520ms step_avg:87.87ms +step:1236/1680 train_time:108610ms step_avg:87.87ms +step:1237/1680 train_time:108699ms step_avg:87.87ms +step:1238/1680 train_time:108789ms step_avg:87.87ms +step:1239/1680 train_time:108878ms step_avg:87.88ms +step:1240/1680 train_time:108966ms step_avg:87.88ms +step:1241/1680 train_time:109055ms step_avg:87.88ms +step:1242/1680 train_time:109144ms step_avg:87.88ms +step:1243/1680 train_time:109232ms step_avg:87.88ms +step:1244/1680 train_time:109322ms step_avg:87.88ms +step:1245/1680 train_time:109411ms step_avg:87.88ms +step:1246/1680 train_time:109500ms step_avg:87.88ms +step:1247/1680 train_time:109589ms step_avg:87.88ms +step:1248/1680 train_time:109677ms step_avg:87.88ms +step:1249/1680 train_time:109767ms step_avg:87.88ms +step:1250/1680 train_time:109856ms step_avg:87.88ms +step:1250/1680 val_loss:3.3770 train_time:109946ms step_avg:87.96ms +step:1251/1680 train_time:109964ms step_avg:87.90ms +step:1252/1680 train_time:110038ms step_avg:87.89ms +step:1253/1680 train_time:110130ms step_avg:87.89ms +step:1254/1680 train_time:110220ms step_avg:87.89ms +step:1255/1680 train_time:110308ms step_avg:87.89ms +step:1256/1680 train_time:110396ms step_avg:87.89ms +step:1257/1680 train_time:110484ms step_avg:87.89ms +step:1258/1680 train_time:110572ms step_avg:87.90ms +step:1259/1680 train_time:110661ms step_avg:87.90ms +step:1260/1680 train_time:110749ms step_avg:87.90ms +step:1261/1680 train_time:110837ms step_avg:87.90ms +step:1262/1680 train_time:110927ms step_avg:87.90ms +step:1263/1680 train_time:111018ms step_avg:87.90ms +step:1264/1680 train_time:111109ms step_avg:87.90ms +step:1265/1680 train_time:111198ms step_avg:87.90ms +step:1266/1680 train_time:111286ms step_avg:87.90ms +step:1267/1680 train_time:111375ms step_avg:87.90ms +step:1268/1680 train_time:111463ms step_avg:87.90ms +step:1269/1680 train_time:111552ms step_avg:87.91ms +step:1270/1680 train_time:111640ms step_avg:87.91ms +step:1271/1680 train_time:111729ms step_avg:87.91ms +step:1272/1680 train_time:111817ms step_avg:87.91ms +step:1273/1680 train_time:111906ms step_avg:87.91ms +step:1274/1680 train_time:111997ms step_avg:87.91ms +step:1275/1680 train_time:112087ms step_avg:87.91ms +step:1276/1680 train_time:112177ms step_avg:87.91ms +step:1277/1680 train_time:112265ms step_avg:87.91ms +step:1278/1680 train_time:112355ms step_avg:87.91ms +step:1279/1680 train_time:112443ms step_avg:87.92ms +step:1280/1680 train_time:112532ms step_avg:87.92ms +step:1281/1680 train_time:112620ms step_avg:87.92ms +step:1282/1680 train_time:112709ms step_avg:87.92ms +step:1283/1680 train_time:112798ms step_avg:87.92ms +step:1284/1680 train_time:112887ms step_avg:87.92ms +step:1285/1680 train_time:112977ms step_avg:87.92ms +step:1286/1680 train_time:113066ms step_avg:87.92ms +step:1287/1680 train_time:113155ms step_avg:87.92ms +step:1288/1680 train_time:113245ms step_avg:87.92ms +step:1289/1680 train_time:113335ms step_avg:87.93ms +step:1290/1680 train_time:113425ms step_avg:87.93ms +step:1291/1680 train_time:113514ms step_avg:87.93ms +step:1292/1680 train_time:113602ms step_avg:87.93ms +step:1293/1680 train_time:113690ms step_avg:87.93ms +step:1294/1680 train_time:113778ms step_avg:87.93ms +step:1295/1680 train_time:113867ms step_avg:87.93ms +step:1296/1680 train_time:113956ms step_avg:87.93ms +step:1297/1680 train_time:114046ms step_avg:87.93ms +step:1298/1680 train_time:114135ms step_avg:87.93ms +step:1299/1680 train_time:114224ms step_avg:87.93ms +step:1300/1680 train_time:114313ms step_avg:87.93ms +step:1301/1680 train_time:114401ms step_avg:87.93ms +step:1302/1680 train_time:114490ms step_avg:87.93ms +step:1303/1680 train_time:114579ms step_avg:87.93ms +step:1304/1680 train_time:114669ms step_avg:87.94ms +step:1305/1680 train_time:114758ms step_avg:87.94ms +step:1306/1680 train_time:114846ms step_avg:87.94ms +step:1307/1680 train_time:114936ms step_avg:87.94ms +step:1308/1680 train_time:115024ms step_avg:87.94ms +step:1309/1680 train_time:115113ms step_avg:87.94ms +step:1310/1680 train_time:115203ms step_avg:87.94ms +step:1311/1680 train_time:115293ms step_avg:87.94ms +step:1312/1680 train_time:115381ms step_avg:87.94ms +step:1313/1680 train_time:115470ms step_avg:87.94ms +step:1314/1680 train_time:115559ms step_avg:87.94ms +step:1315/1680 train_time:115648ms step_avg:87.95ms +step:1316/1680 train_time:115737ms step_avg:87.95ms +step:1317/1680 train_time:115825ms step_avg:87.95ms +step:1318/1680 train_time:115914ms step_avg:87.95ms +step:1319/1680 train_time:116003ms step_avg:87.95ms +step:1320/1680 train_time:116092ms step_avg:87.95ms +step:1321/1680 train_time:116181ms step_avg:87.95ms +step:1322/1680 train_time:116271ms step_avg:87.95ms +step:1323/1680 train_time:116361ms step_avg:87.95ms +step:1324/1680 train_time:116451ms step_avg:87.95ms +step:1325/1680 train_time:116541ms step_avg:87.96ms +step:1326/1680 train_time:116630ms step_avg:87.96ms +step:1327/1680 train_time:116718ms step_avg:87.96ms +step:1328/1680 train_time:116807ms step_avg:87.96ms +step:1329/1680 train_time:116896ms step_avg:87.96ms +step:1330/1680 train_time:116985ms step_avg:87.96ms +step:1331/1680 train_time:117075ms step_avg:87.96ms +step:1332/1680 train_time:117165ms step_avg:87.96ms +step:1333/1680 train_time:117253ms step_avg:87.96ms +step:1334/1680 train_time:117343ms step_avg:87.96ms +step:1335/1680 train_time:117432ms step_avg:87.96ms +step:1336/1680 train_time:117521ms step_avg:87.96ms +step:1337/1680 train_time:117611ms step_avg:87.97ms +step:1338/1680 train_time:117700ms step_avg:87.97ms +step:1339/1680 train_time:117789ms step_avg:87.97ms +step:1340/1680 train_time:117877ms step_avg:87.97ms +step:1341/1680 train_time:117966ms step_avg:87.97ms +step:1342/1680 train_time:118055ms step_avg:87.97ms +step:1343/1680 train_time:118144ms step_avg:87.97ms +step:1344/1680 train_time:118234ms step_avg:87.97ms +step:1345/1680 train_time:118322ms step_avg:87.97ms +step:1346/1680 train_time:118411ms step_avg:87.97ms +step:1347/1680 train_time:118500ms step_avg:87.97ms +step:1348/1680 train_time:118589ms step_avg:87.97ms +step:1349/1680 train_time:118678ms step_avg:87.97ms +step:1350/1680 train_time:118767ms step_avg:87.98ms +step:1351/1680 train_time:118855ms step_avg:87.98ms +step:1352/1680 train_time:118944ms step_avg:87.98ms +step:1353/1680 train_time:119034ms step_avg:87.98ms +step:1354/1680 train_time:119123ms step_avg:87.98ms +step:1355/1680 train_time:119212ms step_avg:87.98ms +step:1356/1680 train_time:119301ms step_avg:87.98ms +step:1357/1680 train_time:119391ms step_avg:87.98ms +step:1358/1680 train_time:119479ms step_avg:87.98ms +step:1359/1680 train_time:119568ms step_avg:87.98ms +step:1360/1680 train_time:119658ms step_avg:87.98ms +step:1361/1680 train_time:119747ms step_avg:87.98ms +step:1362/1680 train_time:119835ms step_avg:87.98ms +step:1363/1680 train_time:119924ms step_avg:87.99ms +step:1364/1680 train_time:120012ms step_avg:87.99ms +step:1365/1680 train_time:120101ms step_avg:87.99ms +step:1366/1680 train_time:120190ms step_avg:87.99ms +step:1367/1680 train_time:120279ms step_avg:87.99ms +step:1368/1680 train_time:120369ms step_avg:87.99ms +step:1369/1680 train_time:120458ms step_avg:87.99ms +step:1370/1680 train_time:120548ms step_avg:87.99ms +step:1371/1680 train_time:120637ms step_avg:87.99ms +step:1372/1680 train_time:120727ms step_avg:87.99ms +step:1373/1680 train_time:120816ms step_avg:87.99ms +step:1374/1680 train_time:120904ms step_avg:87.99ms +step:1375/1680 train_time:120993ms step_avg:88.00ms +step:1375/1680 val_loss:3.3421 train_time:121084ms step_avg:88.06ms +step:1376/1680 train_time:121103ms step_avg:88.01ms +step:1377/1680 train_time:121175ms step_avg:88.00ms +step:1378/1680 train_time:121268ms step_avg:88.00ms +step:1379/1680 train_time:121356ms step_avg:88.00ms +step:1380/1680 train_time:121444ms step_avg:88.00ms +step:1381/1680 train_time:121533ms step_avg:88.00ms +step:1382/1680 train_time:121622ms step_avg:88.00ms +step:1383/1680 train_time:121709ms step_avg:88.00ms +step:1384/1680 train_time:121798ms step_avg:88.00ms +step:1385/1680 train_time:121887ms step_avg:88.00ms +step:1386/1680 train_time:121975ms step_avg:88.01ms +step:1387/1680 train_time:122064ms step_avg:88.01ms +step:1388/1680 train_time:122155ms step_avg:88.01ms +step:1389/1680 train_time:122246ms step_avg:88.01ms +step:1390/1680 train_time:122335ms step_avg:88.01ms +step:1391/1680 train_time:122424ms step_avg:88.01ms +step:1392/1680 train_time:122513ms step_avg:88.01ms +step:1393/1680 train_time:122601ms step_avg:88.01ms +step:1394/1680 train_time:122689ms step_avg:88.01ms +step:1395/1680 train_time:122778ms step_avg:88.01ms +step:1396/1680 train_time:122867ms step_avg:88.01ms +step:1397/1680 train_time:122956ms step_avg:88.01ms +step:1398/1680 train_time:123045ms step_avg:88.02ms +step:1399/1680 train_time:123135ms step_avg:88.02ms +step:1400/1680 train_time:123224ms step_avg:88.02ms +step:1401/1680 train_time:123314ms step_avg:88.02ms +step:1402/1680 train_time:123404ms step_avg:88.02ms +step:1403/1680 train_time:123492ms step_avg:88.02ms +step:1404/1680 train_time:123581ms step_avg:88.02ms +step:1405/1680 train_time:123669ms step_avg:88.02ms +step:1406/1680 train_time:123758ms step_avg:88.02ms +step:1407/1680 train_time:123846ms step_avg:88.02ms +step:1408/1680 train_time:123935ms step_avg:88.02ms +step:1409/1680 train_time:124025ms step_avg:88.02ms +step:1410/1680 train_time:124114ms step_avg:88.02ms +step:1411/1680 train_time:124204ms step_avg:88.03ms +step:1412/1680 train_time:124294ms step_avg:88.03ms +step:1413/1680 train_time:124384ms step_avg:88.03ms +step:1414/1680 train_time:124473ms step_avg:88.03ms +step:1415/1680 train_time:124561ms step_avg:88.03ms +step:1416/1680 train_time:124649ms step_avg:88.03ms +step:1417/1680 train_time:124738ms step_avg:88.03ms +step:1418/1680 train_time:124826ms step_avg:88.03ms +step:1419/1680 train_time:124915ms step_avg:88.03ms +step:1420/1680 train_time:125004ms step_avg:88.03ms +step:1421/1680 train_time:125093ms step_avg:88.03ms +step:1422/1680 train_time:125183ms step_avg:88.03ms +step:1423/1680 train_time:125272ms step_avg:88.03ms +step:1424/1680 train_time:125361ms step_avg:88.03ms +step:1425/1680 train_time:125450ms step_avg:88.04ms +step:1426/1680 train_time:125540ms step_avg:88.04ms +step:1427/1680 train_time:125629ms step_avg:88.04ms +step:1428/1680 train_time:125718ms step_avg:88.04ms +step:1429/1680 train_time:125808ms step_avg:88.04ms +step:1430/1680 train_time:125898ms step_avg:88.04ms +step:1431/1680 train_time:125986ms step_avg:88.04ms +step:1432/1680 train_time:126076ms step_avg:88.04ms +step:1433/1680 train_time:126165ms step_avg:88.04ms +step:1434/1680 train_time:126255ms step_avg:88.04ms +step:1435/1680 train_time:126345ms step_avg:88.04ms +step:1436/1680 train_time:126434ms step_avg:88.05ms +step:1437/1680 train_time:126522ms step_avg:88.05ms +step:1438/1680 train_time:126611ms step_avg:88.05ms +step:1439/1680 train_time:126700ms step_avg:88.05ms +step:1440/1680 train_time:126790ms step_avg:88.05ms +step:1441/1680 train_time:126880ms step_avg:88.05ms +step:1442/1680 train_time:126968ms step_avg:88.05ms +step:1443/1680 train_time:127057ms step_avg:88.05ms +step:1444/1680 train_time:127147ms step_avg:88.05ms +step:1445/1680 train_time:127236ms step_avg:88.05ms +step:1446/1680 train_time:127325ms step_avg:88.05ms +step:1447/1680 train_time:127415ms step_avg:88.05ms +step:1448/1680 train_time:127504ms step_avg:88.06ms +step:1449/1680 train_time:127594ms step_avg:88.06ms +step:1450/1680 train_time:127683ms step_avg:88.06ms +step:1451/1680 train_time:127772ms step_avg:88.06ms +step:1452/1680 train_time:127862ms step_avg:88.06ms +step:1453/1680 train_time:127950ms step_avg:88.06ms +step:1454/1680 train_time:128038ms step_avg:88.06ms +step:1455/1680 train_time:128127ms step_avg:88.06ms +step:1456/1680 train_time:128217ms step_avg:88.06ms +step:1457/1680 train_time:128305ms step_avg:88.06ms +step:1458/1680 train_time:128395ms step_avg:88.06ms +step:1459/1680 train_time:128484ms step_avg:88.06ms +step:1460/1680 train_time:128573ms step_avg:88.06ms +step:1461/1680 train_time:128661ms step_avg:88.06ms +step:1462/1680 train_time:128750ms step_avg:88.06ms +step:1463/1680 train_time:128839ms step_avg:88.07ms +step:1464/1680 train_time:128928ms step_avg:88.07ms +step:1465/1680 train_time:129017ms step_avg:88.07ms +step:1466/1680 train_time:129106ms step_avg:88.07ms +step:1467/1680 train_time:129195ms step_avg:88.07ms +step:1468/1680 train_time:129284ms step_avg:88.07ms +step:1469/1680 train_time:129374ms step_avg:88.07ms +step:1470/1680 train_time:129462ms step_avg:88.07ms +step:1471/1680 train_time:129551ms step_avg:88.07ms +step:1472/1680 train_time:129640ms step_avg:88.07ms +step:1473/1680 train_time:129729ms step_avg:88.07ms +step:1474/1680 train_time:129819ms step_avg:88.07ms +step:1475/1680 train_time:129909ms step_avg:88.07ms +step:1476/1680 train_time:129998ms step_avg:88.07ms +step:1477/1680 train_time:130088ms step_avg:88.08ms +step:1478/1680 train_time:130176ms step_avg:88.08ms +step:1479/1680 train_time:130266ms step_avg:88.08ms +step:1480/1680 train_time:130356ms step_avg:88.08ms +step:1481/1680 train_time:130445ms step_avg:88.08ms +step:1482/1680 train_time:130534ms step_avg:88.08ms +step:1483/1680 train_time:130622ms step_avg:88.08ms +step:1484/1680 train_time:130712ms step_avg:88.08ms +step:1485/1680 train_time:130801ms step_avg:88.08ms +step:1486/1680 train_time:130889ms step_avg:88.08ms +step:1487/1680 train_time:130979ms step_avg:88.08ms +step:1488/1680 train_time:131068ms step_avg:88.08ms +step:1489/1680 train_time:131156ms step_avg:88.08ms +step:1490/1680 train_time:131245ms step_avg:88.08ms +step:1491/1680 train_time:131334ms step_avg:88.08ms +step:1492/1680 train_time:131423ms step_avg:88.08ms +step:1493/1680 train_time:131513ms step_avg:88.09ms +step:1494/1680 train_time:131602ms step_avg:88.09ms +step:1495/1680 train_time:131691ms step_avg:88.09ms +step:1496/1680 train_time:131780ms step_avg:88.09ms +step:1497/1680 train_time:131869ms step_avg:88.09ms +step:1498/1680 train_time:131958ms step_avg:88.09ms +step:1499/1680 train_time:132047ms step_avg:88.09ms +step:1500/1680 train_time:132137ms step_avg:88.09ms +step:1500/1680 val_loss:3.3122 train_time:132227ms step_avg:88.15ms +step:1501/1680 train_time:132247ms step_avg:88.11ms +step:1502/1680 train_time:132318ms step_avg:88.09ms +step:1503/1680 train_time:132409ms step_avg:88.10ms +step:1504/1680 train_time:132498ms step_avg:88.10ms +step:1505/1680 train_time:132586ms step_avg:88.10ms +step:1506/1680 train_time:132674ms step_avg:88.10ms +step:1507/1680 train_time:132762ms step_avg:88.10ms +step:1508/1680 train_time:132850ms step_avg:88.10ms +step:1509/1680 train_time:132937ms step_avg:88.10ms +step:1510/1680 train_time:133026ms step_avg:88.10ms +step:1511/1680 train_time:133114ms step_avg:88.10ms +step:1512/1680 train_time:133204ms step_avg:88.10ms +step:1513/1680 train_time:133294ms step_avg:88.10ms +step:1514/1680 train_time:133385ms step_avg:88.10ms +step:1515/1680 train_time:133475ms step_avg:88.10ms +step:1516/1680 train_time:133565ms step_avg:88.10ms +step:1517/1680 train_time:133654ms step_avg:88.10ms +step:1518/1680 train_time:133742ms step_avg:88.10ms +step:1519/1680 train_time:133831ms step_avg:88.10ms +step:1520/1680 train_time:133920ms step_avg:88.11ms +step:1521/1680 train_time:134009ms step_avg:88.11ms +step:1522/1680 train_time:134097ms step_avg:88.11ms +step:1523/1680 train_time:134186ms step_avg:88.11ms +step:1524/1680 train_time:134275ms step_avg:88.11ms +step:1525/1680 train_time:134366ms step_avg:88.11ms +step:1526/1680 train_time:134457ms step_avg:88.11ms +step:1527/1680 train_time:134546ms step_avg:88.11ms +step:1528/1680 train_time:134636ms step_avg:88.11ms +step:1529/1680 train_time:134724ms step_avg:88.11ms +step:1530/1680 train_time:134813ms step_avg:88.11ms +step:1531/1680 train_time:134901ms step_avg:88.11ms +step:1532/1680 train_time:134989ms step_avg:88.11ms +step:1533/1680 train_time:135078ms step_avg:88.11ms +step:1534/1680 train_time:135167ms step_avg:88.11ms +step:1535/1680 train_time:135257ms step_avg:88.12ms +step:1536/1680 train_time:135347ms step_avg:88.12ms +step:1537/1680 train_time:135436ms step_avg:88.12ms +step:1538/1680 train_time:135526ms step_avg:88.12ms +step:1539/1680 train_time:135614ms step_avg:88.12ms +step:1540/1680 train_time:135704ms step_avg:88.12ms +step:1541/1680 train_time:135793ms step_avg:88.12ms +step:1542/1680 train_time:135881ms step_avg:88.12ms +step:1543/1680 train_time:135970ms step_avg:88.12ms +step:1544/1680 train_time:136058ms step_avg:88.12ms +step:1545/1680 train_time:136148ms step_avg:88.12ms +step:1546/1680 train_time:136237ms step_avg:88.12ms +step:1547/1680 train_time:136326ms step_avg:88.12ms +step:1548/1680 train_time:136416ms step_avg:88.12ms +step:1549/1680 train_time:136504ms step_avg:88.12ms +step:1550/1680 train_time:136593ms step_avg:88.12ms +step:1551/1680 train_time:136682ms step_avg:88.13ms +step:1552/1680 train_time:136770ms step_avg:88.13ms +step:1553/1680 train_time:136860ms step_avg:88.13ms +step:1554/1680 train_time:136948ms step_avg:88.13ms +step:1555/1680 train_time:137037ms step_avg:88.13ms +step:1556/1680 train_time:137126ms step_avg:88.13ms +step:1557/1680 train_time:137216ms step_avg:88.13ms +step:1558/1680 train_time:137305ms step_avg:88.13ms +step:1559/1680 train_time:137395ms step_avg:88.13ms +step:1560/1680 train_time:137484ms step_avg:88.13ms +step:1561/1680 train_time:137573ms step_avg:88.13ms +step:1562/1680 train_time:137662ms step_avg:88.13ms +step:1563/1680 train_time:137751ms step_avg:88.13ms +step:1564/1680 train_time:137840ms step_avg:88.13ms +step:1565/1680 train_time:137928ms step_avg:88.13ms +step:1566/1680 train_time:138018ms step_avg:88.13ms +step:1567/1680 train_time:138106ms step_avg:88.13ms +step:1568/1680 train_time:138196ms step_avg:88.14ms +step:1569/1680 train_time:138285ms step_avg:88.14ms +step:1570/1680 train_time:138376ms step_avg:88.14ms +step:1571/1680 train_time:138464ms step_avg:88.14ms +step:1572/1680 train_time:138553ms step_avg:88.14ms +step:1573/1680 train_time:138643ms step_avg:88.14ms +step:1574/1680 train_time:138732ms step_avg:88.14ms +step:1575/1680 train_time:138821ms step_avg:88.14ms +step:1576/1680 train_time:138910ms step_avg:88.14ms +step:1577/1680 train_time:139000ms step_avg:88.14ms +step:1578/1680 train_time:139089ms step_avg:88.14ms +step:1579/1680 train_time:139178ms step_avg:88.14ms +step:1580/1680 train_time:139267ms step_avg:88.14ms +step:1581/1680 train_time:139357ms step_avg:88.14ms +step:1582/1680 train_time:139446ms step_avg:88.15ms +step:1583/1680 train_time:139536ms step_avg:88.15ms +step:1584/1680 train_time:139625ms step_avg:88.15ms +step:1585/1680 train_time:139714ms step_avg:88.15ms +step:1586/1680 train_time:139803ms step_avg:88.15ms +step:1587/1680 train_time:139892ms step_avg:88.15ms +step:1588/1680 train_time:139981ms step_avg:88.15ms +step:1589/1680 train_time:140070ms step_avg:88.15ms +step:1590/1680 train_time:140159ms step_avg:88.15ms +step:1591/1680 train_time:140249ms step_avg:88.15ms +step:1592/1680 train_time:140338ms step_avg:88.15ms +step:1593/1680 train_time:140428ms step_avg:88.15ms +step:1594/1680 train_time:140517ms step_avg:88.15ms +step:1595/1680 train_time:140606ms step_avg:88.15ms +step:1596/1680 train_time:140695ms step_avg:88.15ms +step:1597/1680 train_time:140784ms step_avg:88.16ms +step:1598/1680 train_time:140873ms step_avg:88.16ms +step:1599/1680 train_time:140962ms step_avg:88.16ms +step:1600/1680 train_time:141052ms step_avg:88.16ms +step:1601/1680 train_time:141141ms step_avg:88.16ms +step:1602/1680 train_time:141230ms step_avg:88.16ms +step:1603/1680 train_time:141319ms step_avg:88.16ms +step:1604/1680 train_time:141408ms step_avg:88.16ms +step:1605/1680 train_time:141497ms step_avg:88.16ms +step:1606/1680 train_time:141586ms step_avg:88.16ms +step:1607/1680 train_time:141675ms step_avg:88.16ms +step:1608/1680 train_time:141764ms step_avg:88.16ms +step:1609/1680 train_time:141853ms step_avg:88.16ms +step:1610/1680 train_time:141943ms step_avg:88.16ms +step:1611/1680 train_time:142032ms step_avg:88.16ms +step:1612/1680 train_time:142121ms step_avg:88.16ms +step:1613/1680 train_time:142210ms step_avg:88.17ms +step:1614/1680 train_time:142300ms step_avg:88.17ms +step:1615/1680 train_time:142389ms step_avg:88.17ms +step:1616/1680 train_time:142478ms step_avg:88.17ms +step:1617/1680 train_time:142567ms step_avg:88.17ms +step:1618/1680 train_time:142656ms step_avg:88.17ms +step:1619/1680 train_time:142745ms step_avg:88.17ms +step:1620/1680 train_time:142834ms step_avg:88.17ms +step:1621/1680 train_time:142923ms step_avg:88.17ms +step:1622/1680 train_time:143011ms step_avg:88.17ms +step:1623/1680 train_time:143101ms step_avg:88.17ms +step:1624/1680 train_time:143189ms step_avg:88.17ms +step:1625/1680 train_time:143278ms step_avg:88.17ms +step:1625/1680 val_loss:3.2883 train_time:143368ms step_avg:88.23ms +step:1626/1680 train_time:143387ms step_avg:88.18ms +step:1627/1680 train_time:143461ms step_avg:88.17ms +step:1628/1680 train_time:143551ms step_avg:88.18ms +step:1629/1680 train_time:143641ms step_avg:88.18ms +step:1630/1680 train_time:143729ms step_avg:88.18ms +step:1631/1680 train_time:143818ms step_avg:88.18ms +step:1632/1680 train_time:143905ms step_avg:88.18ms +step:1633/1680 train_time:143993ms step_avg:88.18ms +step:1634/1680 train_time:144082ms step_avg:88.18ms +step:1635/1680 train_time:144170ms step_avg:88.18ms +step:1636/1680 train_time:144259ms step_avg:88.18ms +step:1637/1680 train_time:144350ms step_avg:88.18ms +step:1638/1680 train_time:144442ms step_avg:88.18ms +step:1639/1680 train_time:144532ms step_avg:88.18ms +step:1640/1680 train_time:144621ms step_avg:88.18ms +step:1641/1680 train_time:144711ms step_avg:88.18ms +step:1642/1680 train_time:144800ms step_avg:88.19ms +step:1643/1680 train_time:144889ms step_avg:88.19ms +step:1644/1680 train_time:144977ms step_avg:88.19ms +step:1645/1680 train_time:145066ms step_avg:88.19ms +step:1646/1680 train_time:145154ms step_avg:88.19ms +step:1647/1680 train_time:145242ms step_avg:88.19ms +step:1648/1680 train_time:145332ms step_avg:88.19ms +step:1649/1680 train_time:145422ms step_avg:88.19ms +step:1650/1680 train_time:145512ms step_avg:88.19ms +step:1651/1680 train_time:145602ms step_avg:88.19ms +step:1652/1680 train_time:145691ms step_avg:88.19ms +step:1653/1680 train_time:145780ms step_avg:88.19ms +step:1654/1680 train_time:145869ms step_avg:88.19ms +step:1655/1680 train_time:145957ms step_avg:88.19ms +step:1656/1680 train_time:146047ms step_avg:88.19ms +step:1657/1680 train_time:146135ms step_avg:88.19ms +step:1658/1680 train_time:146223ms step_avg:88.19ms +step:1659/1680 train_time:146312ms step_avg:88.19ms +step:1660/1680 train_time:146401ms step_avg:88.19ms +step:1661/1680 train_time:146491ms step_avg:88.19ms +step:1662/1680 train_time:146580ms step_avg:88.20ms +step:1663/1680 train_time:146670ms step_avg:88.20ms +step:1664/1680 train_time:146758ms step_avg:88.20ms +step:1665/1680 train_time:146847ms step_avg:88.20ms +step:1666/1680 train_time:146936ms step_avg:88.20ms +step:1667/1680 train_time:147024ms step_avg:88.20ms +step:1668/1680 train_time:147113ms step_avg:88.20ms +step:1669/1680 train_time:147202ms step_avg:88.20ms +step:1670/1680 train_time:147291ms step_avg:88.20ms +step:1671/1680 train_time:147381ms step_avg:88.20ms +step:1672/1680 train_time:147470ms step_avg:88.20ms +step:1673/1680 train_time:147559ms step_avg:88.20ms +step:1674/1680 train_time:147648ms step_avg:88.20ms +step:1675/1680 train_time:147738ms step_avg:88.20ms +step:1676/1680 train_time:147827ms step_avg:88.20ms +step:1677/1680 train_time:147916ms step_avg:88.20ms +step:1678/1680 train_time:148005ms step_avg:88.20ms +step:1679/1680 train_time:148094ms step_avg:88.20ms +step:1680/1680 train_time:148182ms step_avg:88.20ms +step:1680/1680 val_loss:3.2774 train_time:148273ms step_avg:88.26ms +peak memory allocated: 30760 MiB reserved: 45774 MiB diff --git a/records/092725_BF16CE/f713f5c8-a6e3-446a-9ec4-5014917cb254.txt b/records/092725_BF16CE/f713f5c8-a6e3-446a-9ec4-5014917cb254.txt new file mode 100644 index 000000000..8909a5d8e --- /dev/null +++ b/records/092725_BF16CE/f713f5c8-a6e3-446a-9ec4-5014917cb254.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 12:33:47 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 30C P0 123W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 24C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 28C P0 115W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 30C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 159205 C /usr/bin/python 0MiB | +| 0 N/A N/A 159206 C /usr/bin/python 0MiB | +| 0 N/A N/A 159207 C /usr/bin/python 0MiB | +| 0 N/A N/A 159208 C /usr/bin/python 0MiB | +| 0 N/A N/A 159209 C /usr/bin/python 0MiB | +| 0 N/A N/A 159210 C /usr/bin/python 0MiB | +| 0 N/A N/A 159211 C /usr/bin/python 0MiB | +| 0 N/A N/A 159212 C /usr/bin/python 0MiB | +| 1 N/A N/A 159206 C /usr/bin/python 0MiB | +| 2 N/A N/A 159207 C /usr/bin/python 0MiB | +| 3 N/A N/A 159208 C /usr/bin/python 0MiB | +| 4 N/A N/A 159209 C /usr/bin/python 0MiB | +| 5 N/A N/A 159210 C /usr/bin/python 0MiB | +| 6 N/A N/A 159211 C /usr/bin/python 0MiB | +| 7 N/A N/A 159212 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1680 train_time:146ms step_avg:146.05ms +step:2/1680 train_time:166ms step_avg:82.84ms +step:3/1680 train_time:229ms step_avg:76.50ms +step:4/1680 train_time:315ms step_avg:78.70ms +step:5/1680 train_time:401ms step_avg:80.27ms +step:6/1680 train_time:487ms step_avg:81.20ms +step:7/1680 train_time:574ms step_avg:81.93ms +step:8/1680 train_time:660ms step_avg:82.47ms +step:9/1680 train_time:746ms step_avg:82.90ms +step:10/1680 train_time:833ms step_avg:83.26ms +step:11/1680 train_time:919ms step_avg:83.52ms +step:12/1680 train_time:1006ms step_avg:83.86ms +step:13/1680 train_time:1097ms step_avg:84.35ms +step:14/1680 train_time:1187ms step_avg:84.80ms +step:15/1680 train_time:1275ms step_avg:85.00ms +step:16/1680 train_time:1363ms step_avg:85.18ms +step:17/1680 train_time:1449ms step_avg:85.24ms +step:18/1680 train_time:1536ms step_avg:85.31ms +step:19/1680 train_time:1622ms step_avg:85.36ms +step:20/1680 train_time:1708ms step_avg:85.41ms +step:21/1680 train_time:1795ms step_avg:85.48ms +step:22/1680 train_time:1882ms step_avg:85.52ms +step:23/1680 train_time:1969ms step_avg:85.61ms +step:24/1680 train_time:2058ms step_avg:85.74ms +step:25/1680 train_time:2146ms step_avg:85.85ms +step:26/1680 train_time:2235ms step_avg:85.95ms +step:27/1680 train_time:2323ms step_avg:86.04ms +step:28/1680 train_time:2411ms step_avg:86.11ms +step:29/1680 train_time:2499ms step_avg:86.16ms +step:30/1680 train_time:2585ms step_avg:86.18ms +step:31/1680 train_time:2672ms step_avg:86.19ms +step:32/1680 train_time:2760ms step_avg:86.24ms +step:33/1680 train_time:2846ms step_avg:86.25ms +step:34/1680 train_time:2933ms step_avg:86.27ms +step:35/1680 train_time:3021ms step_avg:86.31ms +step:36/1680 train_time:3109ms step_avg:86.36ms +step:37/1680 train_time:3198ms step_avg:86.43ms +step:38/1680 train_time:3285ms step_avg:86.46ms +step:39/1680 train_time:3373ms step_avg:86.49ms +step:40/1680 train_time:3461ms step_avg:86.52ms +step:41/1680 train_time:3548ms step_avg:86.54ms +step:42/1680 train_time:3635ms step_avg:86.55ms +step:43/1680 train_time:3722ms step_avg:86.57ms +step:44/1680 train_time:3810ms step_avg:86.59ms +step:45/1680 train_time:3898ms step_avg:86.62ms +step:46/1680 train_time:3985ms step_avg:86.63ms +step:47/1680 train_time:4072ms step_avg:86.64ms +step:48/1680 train_time:4160ms step_avg:86.68ms +step:49/1680 train_time:4249ms step_avg:86.70ms +step:50/1680 train_time:4337ms step_avg:86.73ms +step:51/1680 train_time:4424ms step_avg:86.74ms +step:52/1680 train_time:4511ms step_avg:86.75ms +step:53/1680 train_time:4598ms step_avg:86.76ms +step:54/1680 train_time:4686ms step_avg:86.79ms +step:55/1680 train_time:4773ms step_avg:86.79ms +step:56/1680 train_time:4861ms step_avg:86.80ms +step:57/1680 train_time:4948ms step_avg:86.81ms +step:58/1680 train_time:5036ms step_avg:86.83ms +step:59/1680 train_time:5124ms step_avg:86.85ms +step:60/1680 train_time:5211ms step_avg:86.85ms +step:61/1680 train_time:5299ms step_avg:86.87ms +step:62/1680 train_time:5387ms step_avg:86.88ms +step:63/1680 train_time:5474ms step_avg:86.88ms +step:64/1680 train_time:5561ms step_avg:86.88ms +step:65/1680 train_time:5647ms step_avg:86.88ms +step:66/1680 train_time:5734ms step_avg:86.88ms +step:67/1680 train_time:5821ms step_avg:86.88ms +step:68/1680 train_time:5908ms step_avg:86.88ms +step:69/1680 train_time:5995ms step_avg:86.89ms +step:70/1680 train_time:6082ms step_avg:86.89ms +step:71/1680 train_time:6170ms step_avg:86.90ms +step:72/1680 train_time:6258ms step_avg:86.92ms +step:73/1680 train_time:6346ms step_avg:86.93ms +step:74/1680 train_time:6433ms step_avg:86.93ms +step:75/1680 train_time:6520ms step_avg:86.93ms +step:76/1680 train_time:6608ms step_avg:86.94ms +step:77/1680 train_time:6695ms step_avg:86.95ms +step:78/1680 train_time:6782ms step_avg:86.95ms +step:79/1680 train_time:6870ms step_avg:86.96ms +step:80/1680 train_time:6958ms step_avg:86.97ms +step:81/1680 train_time:7045ms step_avg:86.97ms +step:82/1680 train_time:7132ms step_avg:86.98ms +step:83/1680 train_time:7220ms step_avg:86.99ms +step:84/1680 train_time:7307ms step_avg:86.99ms +step:85/1680 train_time:7395ms step_avg:87.00ms +step:86/1680 train_time:7482ms step_avg:87.00ms +step:87/1680 train_time:7569ms step_avg:87.00ms +step:88/1680 train_time:7657ms step_avg:87.01ms +step:89/1680 train_time:7744ms step_avg:87.01ms +step:90/1680 train_time:7831ms step_avg:87.01ms +step:91/1680 train_time:7919ms step_avg:87.02ms +step:92/1680 train_time:8006ms step_avg:87.02ms +step:93/1680 train_time:8094ms step_avg:87.03ms +step:94/1680 train_time:8181ms step_avg:87.03ms +step:95/1680 train_time:8268ms step_avg:87.03ms +step:96/1680 train_time:8356ms step_avg:87.04ms +step:97/1680 train_time:8443ms step_avg:87.04ms +step:98/1680 train_time:8530ms step_avg:87.04ms +step:99/1680 train_time:8617ms step_avg:87.04ms +step:100/1680 train_time:8705ms step_avg:87.05ms +step:101/1680 train_time:8792ms step_avg:87.05ms +step:102/1680 train_time:8880ms step_avg:87.06ms +step:103/1680 train_time:8967ms step_avg:87.06ms +step:104/1680 train_time:9054ms step_avg:87.06ms +step:105/1680 train_time:9142ms step_avg:87.06ms +step:106/1680 train_time:9229ms step_avg:87.06ms +step:107/1680 train_time:9316ms step_avg:87.06ms +step:108/1680 train_time:9404ms step_avg:87.07ms +step:109/1680 train_time:9491ms step_avg:87.07ms +step:110/1680 train_time:9579ms step_avg:87.08ms +step:111/1680 train_time:9666ms step_avg:87.08ms +step:112/1680 train_time:9753ms step_avg:87.08ms +step:113/1680 train_time:9841ms step_avg:87.09ms +step:114/1680 train_time:9928ms step_avg:87.09ms +step:115/1680 train_time:10015ms step_avg:87.09ms +step:116/1680 train_time:10102ms step_avg:87.09ms +step:117/1680 train_time:10189ms step_avg:87.09ms +step:118/1680 train_time:10276ms step_avg:87.09ms +step:119/1680 train_time:10364ms step_avg:87.09ms +step:120/1680 train_time:10451ms step_avg:87.09ms +step:121/1680 train_time:10538ms step_avg:87.09ms +step:122/1680 train_time:10625ms step_avg:87.09ms +step:123/1680 train_time:10712ms step_avg:87.09ms +step:124/1680 train_time:10799ms step_avg:87.09ms +step:125/1680 train_time:10886ms step_avg:87.09ms +step:125/1680 val_loss:4.3356 train_time:10974ms step_avg:87.79ms +step:126/1680 train_time:10995ms step_avg:87.26ms +step:127/1680 train_time:11065ms step_avg:87.12ms +step:128/1680 train_time:11160ms step_avg:87.19ms +step:129/1680 train_time:11253ms step_avg:87.23ms +step:130/1680 train_time:11340ms step_avg:87.23ms +step:131/1680 train_time:11426ms step_avg:87.22ms +step:132/1680 train_time:11512ms step_avg:87.22ms +step:133/1680 train_time:11598ms step_avg:87.20ms +step:134/1680 train_time:11684ms step_avg:87.20ms +step:135/1680 train_time:11770ms step_avg:87.19ms +step:136/1680 train_time:11856ms step_avg:87.18ms +step:137/1680 train_time:11943ms step_avg:87.17ms +step:138/1680 train_time:12030ms step_avg:87.17ms +step:139/1680 train_time:12119ms step_avg:87.19ms +step:140/1680 train_time:12209ms step_avg:87.21ms +step:141/1680 train_time:12297ms step_avg:87.21ms +step:142/1680 train_time:12384ms step_avg:87.21ms +step:143/1680 train_time:12471ms step_avg:87.21ms +step:144/1680 train_time:12558ms step_avg:87.21ms +step:145/1680 train_time:12644ms step_avg:87.20ms +step:146/1680 train_time:12731ms step_avg:87.20ms +step:147/1680 train_time:12817ms step_avg:87.19ms +step:148/1680 train_time:12903ms step_avg:87.18ms +step:149/1680 train_time:12990ms step_avg:87.18ms +step:150/1680 train_time:13078ms step_avg:87.19ms +step:151/1680 train_time:13166ms step_avg:87.19ms +step:152/1680 train_time:13255ms step_avg:87.20ms +step:153/1680 train_time:13342ms step_avg:87.20ms +step:154/1680 train_time:13430ms step_avg:87.21ms +step:155/1680 train_time:13518ms step_avg:87.21ms +step:156/1680 train_time:13604ms step_avg:87.20ms +step:157/1680 train_time:13690ms step_avg:87.20ms +step:158/1680 train_time:13777ms step_avg:87.20ms +step:159/1680 train_time:13863ms step_avg:87.19ms +step:160/1680 train_time:13951ms step_avg:87.19ms +step:161/1680 train_time:14038ms step_avg:87.19ms +step:162/1680 train_time:14126ms step_avg:87.19ms +step:163/1680 train_time:14214ms step_avg:87.20ms +step:164/1680 train_time:14301ms step_avg:87.20ms +step:165/1680 train_time:14389ms step_avg:87.21ms +step:166/1680 train_time:14476ms step_avg:87.20ms +step:167/1680 train_time:14563ms step_avg:87.20ms +step:168/1680 train_time:14650ms step_avg:87.20ms +step:169/1680 train_time:14737ms step_avg:87.20ms +step:170/1680 train_time:14824ms step_avg:87.20ms +step:171/1680 train_time:14910ms step_avg:87.19ms +step:172/1680 train_time:14998ms step_avg:87.20ms +step:173/1680 train_time:15085ms step_avg:87.20ms +step:174/1680 train_time:15173ms step_avg:87.20ms +step:175/1680 train_time:15260ms step_avg:87.20ms +step:176/1680 train_time:15348ms step_avg:87.20ms +step:177/1680 train_time:15435ms step_avg:87.20ms +step:178/1680 train_time:15522ms step_avg:87.20ms +step:179/1680 train_time:15609ms step_avg:87.20ms +step:180/1680 train_time:15696ms step_avg:87.20ms +step:181/1680 train_time:15782ms step_avg:87.19ms +step:182/1680 train_time:15869ms step_avg:87.19ms +step:183/1680 train_time:15957ms step_avg:87.19ms +step:184/1680 train_time:16043ms step_avg:87.19ms +step:185/1680 train_time:16131ms step_avg:87.19ms +step:186/1680 train_time:16219ms step_avg:87.20ms +step:187/1680 train_time:16306ms step_avg:87.20ms +step:188/1680 train_time:16393ms step_avg:87.20ms +step:189/1680 train_time:16480ms step_avg:87.20ms +step:190/1680 train_time:16567ms step_avg:87.20ms +step:191/1680 train_time:16654ms step_avg:87.20ms +step:192/1680 train_time:16741ms step_avg:87.19ms +step:193/1680 train_time:16828ms step_avg:87.19ms +step:194/1680 train_time:16915ms step_avg:87.19ms +step:195/1680 train_time:17003ms step_avg:87.19ms +step:196/1680 train_time:17090ms step_avg:87.19ms +step:197/1680 train_time:17178ms step_avg:87.20ms +step:198/1680 train_time:17265ms step_avg:87.20ms +step:199/1680 train_time:17352ms step_avg:87.20ms +step:200/1680 train_time:17439ms step_avg:87.20ms +step:201/1680 train_time:17527ms step_avg:87.20ms +step:202/1680 train_time:17614ms step_avg:87.20ms +step:203/1680 train_time:17701ms step_avg:87.20ms +step:204/1680 train_time:17788ms step_avg:87.19ms +step:205/1680 train_time:17875ms step_avg:87.19ms +step:206/1680 train_time:17962ms step_avg:87.20ms +step:207/1680 train_time:18049ms step_avg:87.19ms +step:208/1680 train_time:18137ms step_avg:87.20ms +step:209/1680 train_time:18224ms step_avg:87.20ms +step:210/1680 train_time:18311ms step_avg:87.20ms +step:211/1680 train_time:18399ms step_avg:87.20ms +step:212/1680 train_time:18486ms step_avg:87.20ms +step:213/1680 train_time:18574ms step_avg:87.20ms +step:214/1680 train_time:18660ms step_avg:87.20ms +step:215/1680 train_time:18747ms step_avg:87.20ms +step:216/1680 train_time:18834ms step_avg:87.20ms +step:217/1680 train_time:18921ms step_avg:87.20ms +step:218/1680 train_time:19008ms step_avg:87.19ms +step:219/1680 train_time:19096ms step_avg:87.19ms +step:220/1680 train_time:19183ms step_avg:87.19ms +step:221/1680 train_time:19270ms step_avg:87.19ms +step:222/1680 train_time:19357ms step_avg:87.19ms +step:223/1680 train_time:19444ms step_avg:87.19ms +step:224/1680 train_time:19531ms step_avg:87.19ms +step:225/1680 train_time:19619ms step_avg:87.20ms +step:226/1680 train_time:19706ms step_avg:87.19ms +step:227/1680 train_time:19793ms step_avg:87.19ms +step:228/1680 train_time:19880ms step_avg:87.19ms +step:229/1680 train_time:19967ms step_avg:87.19ms +step:230/1680 train_time:20054ms step_avg:87.19ms +step:231/1680 train_time:20140ms step_avg:87.19ms +step:232/1680 train_time:20228ms step_avg:87.19ms +step:233/1680 train_time:20315ms step_avg:87.19ms +step:234/1680 train_time:20402ms step_avg:87.19ms +step:235/1680 train_time:20490ms step_avg:87.19ms +step:236/1680 train_time:20577ms step_avg:87.19ms +step:237/1680 train_time:20664ms step_avg:87.19ms +step:238/1680 train_time:20752ms step_avg:87.19ms +step:239/1680 train_time:20839ms step_avg:87.19ms +step:240/1680 train_time:20927ms step_avg:87.20ms +step:241/1680 train_time:21014ms step_avg:87.19ms +step:242/1680 train_time:21101ms step_avg:87.19ms +step:243/1680 train_time:21188ms step_avg:87.19ms +step:244/1680 train_time:21275ms step_avg:87.19ms +step:245/1680 train_time:21362ms step_avg:87.19ms +step:246/1680 train_time:21450ms step_avg:87.19ms +step:247/1680 train_time:21537ms step_avg:87.19ms +step:248/1680 train_time:21623ms step_avg:87.19ms +step:249/1680 train_time:21710ms step_avg:87.19ms +step:250/1680 train_time:21798ms step_avg:87.19ms +step:250/1680 val_loss:3.9780 train_time:21887ms step_avg:87.55ms +step:251/1680 train_time:21907ms step_avg:87.28ms +step:252/1680 train_time:21978ms step_avg:87.21ms +step:253/1680 train_time:22068ms step_avg:87.23ms +step:254/1680 train_time:22155ms step_avg:87.23ms +step:255/1680 train_time:22241ms step_avg:87.22ms +step:256/1680 train_time:22328ms step_avg:87.22ms +step:257/1680 train_time:22415ms step_avg:87.22ms +step:258/1680 train_time:22501ms step_avg:87.21ms +step:259/1680 train_time:22587ms step_avg:87.21ms +step:260/1680 train_time:22674ms step_avg:87.21ms +step:261/1680 train_time:22760ms step_avg:87.20ms +step:262/1680 train_time:22848ms step_avg:87.21ms +step:263/1680 train_time:22937ms step_avg:87.21ms +step:264/1680 train_time:23026ms step_avg:87.22ms +step:265/1680 train_time:23114ms step_avg:87.22ms +step:266/1680 train_time:23201ms step_avg:87.22ms +step:267/1680 train_time:23288ms step_avg:87.22ms +step:268/1680 train_time:23376ms step_avg:87.22ms +step:269/1680 train_time:23462ms step_avg:87.22ms +step:270/1680 train_time:23549ms step_avg:87.22ms +step:271/1680 train_time:23635ms step_avg:87.22ms +step:272/1680 train_time:23722ms step_avg:87.21ms +step:273/1680 train_time:23809ms step_avg:87.21ms +step:274/1680 train_time:23897ms step_avg:87.22ms +step:275/1680 train_time:23985ms step_avg:87.22ms +step:276/1680 train_time:24074ms step_avg:87.22ms +step:277/1680 train_time:24161ms step_avg:87.22ms +step:278/1680 train_time:24249ms step_avg:87.23ms +step:279/1680 train_time:24335ms step_avg:87.22ms +step:280/1680 train_time:24422ms step_avg:87.22ms +step:281/1680 train_time:24509ms step_avg:87.22ms +step:282/1680 train_time:24596ms step_avg:87.22ms +step:283/1680 train_time:24682ms step_avg:87.22ms +step:284/1680 train_time:24770ms step_avg:87.22ms +step:285/1680 train_time:24857ms step_avg:87.22ms +step:286/1680 train_time:24945ms step_avg:87.22ms +step:287/1680 train_time:25034ms step_avg:87.23ms +step:288/1680 train_time:25121ms step_avg:87.23ms +step:289/1680 train_time:25208ms step_avg:87.23ms +step:290/1680 train_time:25296ms step_avg:87.23ms +step:291/1680 train_time:25383ms step_avg:87.23ms +step:292/1680 train_time:25470ms step_avg:87.23ms +step:293/1680 train_time:25557ms step_avg:87.23ms +step:294/1680 train_time:25644ms step_avg:87.23ms +step:295/1680 train_time:25732ms step_avg:87.23ms +step:296/1680 train_time:25819ms step_avg:87.23ms +step:297/1680 train_time:25907ms step_avg:87.23ms +step:298/1680 train_time:25995ms step_avg:87.23ms +step:299/1680 train_time:26082ms step_avg:87.23ms +step:300/1680 train_time:26170ms step_avg:87.23ms +step:301/1680 train_time:26256ms step_avg:87.23ms +step:302/1680 train_time:26344ms step_avg:87.23ms +step:303/1680 train_time:26431ms step_avg:87.23ms +step:304/1680 train_time:26519ms step_avg:87.23ms +step:305/1680 train_time:26605ms step_avg:87.23ms +step:306/1680 train_time:26692ms step_avg:87.23ms +step:307/1680 train_time:26779ms step_avg:87.23ms +step:308/1680 train_time:26867ms step_avg:87.23ms +step:309/1680 train_time:26953ms step_avg:87.23ms +step:310/1680 train_time:27040ms step_avg:87.23ms +step:311/1680 train_time:27128ms step_avg:87.23ms +step:312/1680 train_time:27215ms step_avg:87.23ms +step:313/1680 train_time:27302ms step_avg:87.23ms +step:314/1680 train_time:27390ms step_avg:87.23ms +step:315/1680 train_time:27477ms step_avg:87.23ms +step:316/1680 train_time:27564ms step_avg:87.23ms +step:317/1680 train_time:27651ms step_avg:87.23ms +step:318/1680 train_time:27738ms step_avg:87.23ms +step:319/1680 train_time:27825ms step_avg:87.22ms +step:320/1680 train_time:27912ms step_avg:87.22ms +step:321/1680 train_time:27999ms step_avg:87.23ms +step:322/1680 train_time:28086ms step_avg:87.22ms +step:323/1680 train_time:28173ms step_avg:87.22ms +step:324/1680 train_time:28260ms step_avg:87.22ms +step:325/1680 train_time:28347ms step_avg:87.22ms +step:326/1680 train_time:28434ms step_avg:87.22ms +step:327/1680 train_time:28522ms step_avg:87.22ms +step:328/1680 train_time:28609ms step_avg:87.22ms +step:329/1680 train_time:28696ms step_avg:87.22ms +step:330/1680 train_time:28783ms step_avg:87.22ms +step:331/1680 train_time:28871ms step_avg:87.22ms +step:332/1680 train_time:28958ms step_avg:87.22ms +step:333/1680 train_time:29045ms step_avg:87.22ms +step:334/1680 train_time:29133ms step_avg:87.22ms +step:335/1680 train_time:29219ms step_avg:87.22ms +step:336/1680 train_time:29307ms step_avg:87.22ms +step:337/1680 train_time:29394ms step_avg:87.22ms +step:338/1680 train_time:29481ms step_avg:87.22ms +step:339/1680 train_time:29568ms step_avg:87.22ms +step:340/1680 train_time:29655ms step_avg:87.22ms +step:341/1680 train_time:29743ms step_avg:87.22ms +step:342/1680 train_time:29830ms step_avg:87.22ms +step:343/1680 train_time:29917ms step_avg:87.22ms +step:344/1680 train_time:30004ms step_avg:87.22ms +step:345/1680 train_time:30091ms step_avg:87.22ms +step:346/1680 train_time:30178ms step_avg:87.22ms +step:347/1680 train_time:30266ms step_avg:87.22ms +step:348/1680 train_time:30353ms step_avg:87.22ms +step:349/1680 train_time:30440ms step_avg:87.22ms +step:350/1680 train_time:30527ms step_avg:87.22ms +step:351/1680 train_time:30615ms step_avg:87.22ms +step:352/1680 train_time:30702ms step_avg:87.22ms +step:353/1680 train_time:30790ms step_avg:87.22ms +step:354/1680 train_time:30878ms step_avg:87.22ms +step:355/1680 train_time:30965ms step_avg:87.22ms +step:356/1680 train_time:31052ms step_avg:87.23ms +step:357/1680 train_time:31140ms step_avg:87.23ms +step:358/1680 train_time:31227ms step_avg:87.23ms +step:359/1680 train_time:31315ms step_avg:87.23ms +step:360/1680 train_time:31402ms step_avg:87.23ms +step:361/1680 train_time:31489ms step_avg:87.23ms +step:362/1680 train_time:31577ms step_avg:87.23ms +step:363/1680 train_time:31664ms step_avg:87.23ms +step:364/1680 train_time:31751ms step_avg:87.23ms +step:365/1680 train_time:31838ms step_avg:87.23ms +step:366/1680 train_time:31926ms step_avg:87.23ms +step:367/1680 train_time:32013ms step_avg:87.23ms +step:368/1680 train_time:32100ms step_avg:87.23ms +step:369/1680 train_time:32188ms step_avg:87.23ms +step:370/1680 train_time:32276ms step_avg:87.23ms +step:371/1680 train_time:32363ms step_avg:87.23ms +step:372/1680 train_time:32450ms step_avg:87.23ms +step:373/1680 train_time:32537ms step_avg:87.23ms +step:374/1680 train_time:32625ms step_avg:87.23ms +step:375/1680 train_time:32712ms step_avg:87.23ms +step:375/1680 val_loss:3.8205 train_time:32801ms step_avg:87.47ms +step:376/1680 train_time:32819ms step_avg:87.29ms +step:377/1680 train_time:32892ms step_avg:87.25ms +step:378/1680 train_time:32985ms step_avg:87.26ms +step:379/1680 train_time:33074ms step_avg:87.27ms +step:380/1680 train_time:33162ms step_avg:87.27ms +step:381/1680 train_time:33250ms step_avg:87.27ms +step:382/1680 train_time:33336ms step_avg:87.27ms +step:383/1680 train_time:33422ms step_avg:87.26ms +step:384/1680 train_time:33508ms step_avg:87.26ms +step:385/1680 train_time:33594ms step_avg:87.26ms +step:386/1680 train_time:33681ms step_avg:87.26ms +step:387/1680 train_time:33767ms step_avg:87.25ms +step:388/1680 train_time:33856ms step_avg:87.26ms +step:389/1680 train_time:33945ms step_avg:87.26ms +step:390/1680 train_time:34034ms step_avg:87.27ms +step:391/1680 train_time:34122ms step_avg:87.27ms +step:392/1680 train_time:34210ms step_avg:87.27ms +step:393/1680 train_time:34297ms step_avg:87.27ms +step:394/1680 train_time:34384ms step_avg:87.27ms +step:395/1680 train_time:34470ms step_avg:87.27ms +step:396/1680 train_time:34556ms step_avg:87.26ms +step:397/1680 train_time:34642ms step_avg:87.26ms +step:398/1680 train_time:34729ms step_avg:87.26ms +step:399/1680 train_time:34816ms step_avg:87.26ms +step:400/1680 train_time:34904ms step_avg:87.26ms +step:401/1680 train_time:34992ms step_avg:87.26ms +step:402/1680 train_time:35082ms step_avg:87.27ms +step:403/1680 train_time:35169ms step_avg:87.27ms +step:404/1680 train_time:35256ms step_avg:87.27ms +step:405/1680 train_time:35343ms step_avg:87.27ms +step:406/1680 train_time:35430ms step_avg:87.27ms +step:407/1680 train_time:35517ms step_avg:87.26ms +step:408/1680 train_time:35603ms step_avg:87.26ms +step:409/1680 train_time:35690ms step_avg:87.26ms +step:410/1680 train_time:35776ms step_avg:87.26ms +step:411/1680 train_time:35863ms step_avg:87.26ms +step:412/1680 train_time:35951ms step_avg:87.26ms +step:413/1680 train_time:36040ms step_avg:87.26ms +step:414/1680 train_time:36128ms step_avg:87.26ms +step:415/1680 train_time:36215ms step_avg:87.27ms +step:416/1680 train_time:36302ms step_avg:87.27ms +step:417/1680 train_time:36390ms step_avg:87.27ms +step:418/1680 train_time:36477ms step_avg:87.26ms +step:419/1680 train_time:36563ms step_avg:87.26ms +step:420/1680 train_time:36650ms step_avg:87.26ms +step:421/1680 train_time:36738ms step_avg:87.26ms +step:422/1680 train_time:36824ms step_avg:87.26ms +step:423/1680 train_time:36912ms step_avg:87.26ms +step:424/1680 train_time:37001ms step_avg:87.27ms +step:425/1680 train_time:37089ms step_avg:87.27ms +step:426/1680 train_time:37178ms step_avg:87.27ms +step:427/1680 train_time:37265ms step_avg:87.27ms +step:428/1680 train_time:37352ms step_avg:87.27ms +step:429/1680 train_time:37439ms step_avg:87.27ms +step:430/1680 train_time:37526ms step_avg:87.27ms +step:431/1680 train_time:37613ms step_avg:87.27ms +step:432/1680 train_time:37700ms step_avg:87.27ms +step:433/1680 train_time:37787ms step_avg:87.27ms +step:434/1680 train_time:37874ms step_avg:87.27ms +step:435/1680 train_time:37962ms step_avg:87.27ms +step:436/1680 train_time:38050ms step_avg:87.27ms +step:437/1680 train_time:38138ms step_avg:87.27ms +step:438/1680 train_time:38224ms step_avg:87.27ms +step:439/1680 train_time:38311ms step_avg:87.27ms +step:440/1680 train_time:38398ms step_avg:87.27ms +step:441/1680 train_time:38485ms step_avg:87.27ms +step:442/1680 train_time:38572ms step_avg:87.27ms +step:443/1680 train_time:38659ms step_avg:87.27ms +step:444/1680 train_time:38747ms step_avg:87.27ms +step:445/1680 train_time:38834ms step_avg:87.27ms +step:446/1680 train_time:38921ms step_avg:87.27ms +step:447/1680 train_time:39008ms step_avg:87.27ms +step:448/1680 train_time:39096ms step_avg:87.27ms +step:449/1680 train_time:39183ms step_avg:87.27ms +step:450/1680 train_time:39271ms step_avg:87.27ms +step:451/1680 train_time:39358ms step_avg:87.27ms +step:452/1680 train_time:39445ms step_avg:87.27ms +step:453/1680 train_time:39533ms step_avg:87.27ms +step:454/1680 train_time:39619ms step_avg:87.27ms +step:455/1680 train_time:39706ms step_avg:87.27ms +step:456/1680 train_time:39793ms step_avg:87.27ms +step:457/1680 train_time:39882ms step_avg:87.27ms +step:458/1680 train_time:39968ms step_avg:87.27ms +step:459/1680 train_time:40056ms step_avg:87.27ms +step:460/1680 train_time:40143ms step_avg:87.27ms +step:461/1680 train_time:40230ms step_avg:87.27ms +step:462/1680 train_time:40319ms step_avg:87.27ms +step:463/1680 train_time:40406ms step_avg:87.27ms +step:464/1680 train_time:40493ms step_avg:87.27ms +step:465/1680 train_time:40581ms step_avg:87.27ms +step:466/1680 train_time:40668ms step_avg:87.27ms +step:467/1680 train_time:40755ms step_avg:87.27ms +step:468/1680 train_time:40842ms step_avg:87.27ms +step:469/1680 train_time:40929ms step_avg:87.27ms +step:470/1680 train_time:41017ms step_avg:87.27ms +step:471/1680 train_time:41105ms step_avg:87.27ms +step:472/1680 train_time:41191ms step_avg:87.27ms +step:473/1680 train_time:41280ms step_avg:87.27ms +step:474/1680 train_time:41366ms step_avg:87.27ms +step:475/1680 train_time:41453ms step_avg:87.27ms +step:476/1680 train_time:41541ms step_avg:87.27ms +step:477/1680 train_time:41628ms step_avg:87.27ms +step:478/1680 train_time:41715ms step_avg:87.27ms +step:479/1680 train_time:41801ms step_avg:87.27ms +step:480/1680 train_time:41889ms step_avg:87.27ms +step:481/1680 train_time:41976ms step_avg:87.27ms +step:482/1680 train_time:42063ms step_avg:87.27ms +step:483/1680 train_time:42151ms step_avg:87.27ms +step:484/1680 train_time:42239ms step_avg:87.27ms +step:485/1680 train_time:42325ms step_avg:87.27ms +step:486/1680 train_time:42413ms step_avg:87.27ms +step:487/1680 train_time:42500ms step_avg:87.27ms +step:488/1680 train_time:42587ms step_avg:87.27ms +step:489/1680 train_time:42675ms step_avg:87.27ms +step:490/1680 train_time:42761ms step_avg:87.27ms +step:491/1680 train_time:42849ms step_avg:87.27ms +step:492/1680 train_time:42937ms step_avg:87.27ms +step:493/1680 train_time:43024ms step_avg:87.27ms +step:494/1680 train_time:43112ms step_avg:87.27ms +step:495/1680 train_time:43200ms step_avg:87.27ms +step:496/1680 train_time:43287ms step_avg:87.27ms +step:497/1680 train_time:43375ms step_avg:87.27ms +step:498/1680 train_time:43462ms step_avg:87.27ms +step:499/1680 train_time:43549ms step_avg:87.27ms +step:500/1680 train_time:43636ms step_avg:87.27ms +step:500/1680 val_loss:3.7191 train_time:43724ms step_avg:87.45ms +step:501/1680 train_time:43743ms step_avg:87.31ms +step:502/1680 train_time:43814ms step_avg:87.28ms +step:503/1680 train_time:43907ms step_avg:87.29ms +step:504/1680 train_time:43995ms step_avg:87.29ms +step:505/1680 train_time:44082ms step_avg:87.29ms +step:506/1680 train_time:44168ms step_avg:87.29ms +step:507/1680 train_time:44254ms step_avg:87.29ms +step:508/1680 train_time:44340ms step_avg:87.28ms +step:509/1680 train_time:44426ms step_avg:87.28ms +step:510/1680 train_time:44512ms step_avg:87.28ms +step:511/1680 train_time:44598ms step_avg:87.28ms +step:512/1680 train_time:44686ms step_avg:87.28ms +step:513/1680 train_time:44775ms step_avg:87.28ms +step:514/1680 train_time:44865ms step_avg:87.29ms +step:515/1680 train_time:44955ms step_avg:87.29ms +step:516/1680 train_time:45042ms step_avg:87.29ms +step:517/1680 train_time:45128ms step_avg:87.29ms +step:518/1680 train_time:45215ms step_avg:87.29ms +step:519/1680 train_time:45302ms step_avg:87.29ms +step:520/1680 train_time:45388ms step_avg:87.28ms +step:521/1680 train_time:45475ms step_avg:87.28ms +step:522/1680 train_time:45561ms step_avg:87.28ms +step:523/1680 train_time:45648ms step_avg:87.28ms +step:524/1680 train_time:45736ms step_avg:87.28ms +step:525/1680 train_time:45825ms step_avg:87.28ms +step:526/1680 train_time:45914ms step_avg:87.29ms +step:527/1680 train_time:46002ms step_avg:87.29ms +step:528/1680 train_time:46089ms step_avg:87.29ms +step:529/1680 train_time:46176ms step_avg:87.29ms +step:530/1680 train_time:46263ms step_avg:87.29ms +step:531/1680 train_time:46349ms step_avg:87.29ms +step:532/1680 train_time:46436ms step_avg:87.28ms +step:533/1680 train_time:46522ms step_avg:87.28ms +step:534/1680 train_time:46609ms step_avg:87.28ms +step:535/1680 train_time:46697ms step_avg:87.28ms +step:536/1680 train_time:46785ms step_avg:87.29ms +step:537/1680 train_time:46873ms step_avg:87.29ms +step:538/1680 train_time:46961ms step_avg:87.29ms +step:539/1680 train_time:47049ms step_avg:87.29ms +step:540/1680 train_time:47136ms step_avg:87.29ms +step:541/1680 train_time:47223ms step_avg:87.29ms +step:542/1680 train_time:47310ms step_avg:87.29ms +step:543/1680 train_time:47397ms step_avg:87.29ms +step:544/1680 train_time:47484ms step_avg:87.29ms +step:545/1680 train_time:47571ms step_avg:87.29ms +step:546/1680 train_time:47657ms step_avg:87.28ms +step:547/1680 train_time:47745ms step_avg:87.28ms +step:548/1680 train_time:47833ms step_avg:87.29ms +step:549/1680 train_time:47921ms step_avg:87.29ms +step:550/1680 train_time:48011ms step_avg:87.29ms +step:551/1680 train_time:48099ms step_avg:87.29ms +step:552/1680 train_time:48187ms step_avg:87.30ms +step:553/1680 train_time:48276ms step_avg:87.30ms +step:554/1680 train_time:48364ms step_avg:87.30ms +step:555/1680 train_time:48452ms step_avg:87.30ms +step:556/1680 train_time:48540ms step_avg:87.30ms +step:557/1680 train_time:48629ms step_avg:87.30ms +step:558/1680 train_time:48717ms step_avg:87.31ms +step:559/1680 train_time:48805ms step_avg:87.31ms +step:560/1680 train_time:48894ms step_avg:87.31ms +step:561/1680 train_time:48983ms step_avg:87.31ms +step:562/1680 train_time:49071ms step_avg:87.32ms +step:563/1680 train_time:49159ms step_avg:87.32ms +step:564/1680 train_time:49248ms step_avg:87.32ms +step:565/1680 train_time:49336ms step_avg:87.32ms +step:566/1680 train_time:49424ms step_avg:87.32ms +step:567/1680 train_time:49513ms step_avg:87.32ms +step:568/1680 train_time:49600ms step_avg:87.32ms +step:569/1680 train_time:49689ms step_avg:87.33ms +step:570/1680 train_time:49776ms step_avg:87.33ms +step:571/1680 train_time:49865ms step_avg:87.33ms +step:572/1680 train_time:49954ms step_avg:87.33ms +step:573/1680 train_time:50042ms step_avg:87.33ms +step:574/1680 train_time:50131ms step_avg:87.34ms +step:575/1680 train_time:50219ms step_avg:87.34ms +step:576/1680 train_time:50307ms step_avg:87.34ms +step:577/1680 train_time:50396ms step_avg:87.34ms +step:578/1680 train_time:50484ms step_avg:87.34ms +step:579/1680 train_time:50572ms step_avg:87.34ms +step:580/1680 train_time:50660ms step_avg:87.34ms +step:581/1680 train_time:50748ms step_avg:87.35ms +step:582/1680 train_time:50837ms step_avg:87.35ms +step:583/1680 train_time:50925ms step_avg:87.35ms +step:584/1680 train_time:51014ms step_avg:87.35ms +step:585/1680 train_time:51103ms step_avg:87.36ms +step:586/1680 train_time:51192ms step_avg:87.36ms +step:587/1680 train_time:51280ms step_avg:87.36ms +step:588/1680 train_time:51369ms step_avg:87.36ms +step:589/1680 train_time:51456ms step_avg:87.36ms +step:590/1680 train_time:51544ms step_avg:87.36ms +step:591/1680 train_time:51633ms step_avg:87.37ms +step:592/1680 train_time:51722ms step_avg:87.37ms +step:593/1680 train_time:51810ms step_avg:87.37ms +step:594/1680 train_time:51898ms step_avg:87.37ms +step:595/1680 train_time:51987ms step_avg:87.37ms +step:596/1680 train_time:52075ms step_avg:87.37ms +step:597/1680 train_time:52163ms step_avg:87.37ms +step:598/1680 train_time:52252ms step_avg:87.38ms +step:599/1680 train_time:52340ms step_avg:87.38ms +step:600/1680 train_time:52429ms step_avg:87.38ms +step:601/1680 train_time:52516ms step_avg:87.38ms +step:602/1680 train_time:52605ms step_avg:87.38ms +step:603/1680 train_time:52693ms step_avg:87.38ms +step:604/1680 train_time:52781ms step_avg:87.39ms +step:605/1680 train_time:52870ms step_avg:87.39ms +step:606/1680 train_time:52958ms step_avg:87.39ms +step:607/1680 train_time:53046ms step_avg:87.39ms +step:608/1680 train_time:53135ms step_avg:87.39ms +step:609/1680 train_time:53223ms step_avg:87.39ms +step:610/1680 train_time:53312ms step_avg:87.40ms +step:611/1680 train_time:53401ms step_avg:87.40ms +step:612/1680 train_time:53489ms step_avg:87.40ms +step:613/1680 train_time:53577ms step_avg:87.40ms +step:614/1680 train_time:53665ms step_avg:87.40ms +step:615/1680 train_time:53753ms step_avg:87.40ms +step:616/1680 train_time:53841ms step_avg:87.40ms +step:617/1680 train_time:53930ms step_avg:87.41ms +step:618/1680 train_time:54018ms step_avg:87.41ms +step:619/1680 train_time:54106ms step_avg:87.41ms +step:620/1680 train_time:54195ms step_avg:87.41ms +step:621/1680 train_time:54283ms step_avg:87.41ms +step:622/1680 train_time:54371ms step_avg:87.41ms +step:623/1680 train_time:54459ms step_avg:87.41ms +step:624/1680 train_time:54548ms step_avg:87.42ms +step:625/1680 train_time:54637ms step_avg:87.42ms +step:625/1680 val_loss:3.6179 train_time:54727ms step_avg:87.56ms +step:626/1680 train_time:54747ms step_avg:87.45ms +step:627/1680 train_time:54818ms step_avg:87.43ms +step:628/1680 train_time:54909ms step_avg:87.43ms +step:629/1680 train_time:54999ms step_avg:87.44ms +step:630/1680 train_time:55088ms step_avg:87.44ms +step:631/1680 train_time:55175ms step_avg:87.44ms +step:632/1680 train_time:55262ms step_avg:87.44ms +step:633/1680 train_time:55349ms step_avg:87.44ms +step:634/1680 train_time:55436ms step_avg:87.44ms +step:635/1680 train_time:55523ms step_avg:87.44ms +step:636/1680 train_time:55611ms step_avg:87.44ms +step:637/1680 train_time:55709ms step_avg:87.46ms +step:638/1680 train_time:55801ms step_avg:87.46ms +step:639/1680 train_time:55890ms step_avg:87.46ms +step:640/1680 train_time:55978ms step_avg:87.47ms +step:641/1680 train_time:56067ms step_avg:87.47ms +step:642/1680 train_time:56155ms step_avg:87.47ms +step:643/1680 train_time:56243ms step_avg:87.47ms +step:644/1680 train_time:56330ms step_avg:87.47ms +step:645/1680 train_time:56417ms step_avg:87.47ms +step:646/1680 train_time:56505ms step_avg:87.47ms +step:647/1680 train_time:56593ms step_avg:87.47ms +step:648/1680 train_time:56682ms step_avg:87.47ms +step:649/1680 train_time:56771ms step_avg:87.47ms +step:650/1680 train_time:56860ms step_avg:87.48ms +step:651/1680 train_time:56949ms step_avg:87.48ms +step:652/1680 train_time:57037ms step_avg:87.48ms +step:653/1680 train_time:57127ms step_avg:87.48ms +step:654/1680 train_time:57215ms step_avg:87.48ms +step:655/1680 train_time:57302ms step_avg:87.48ms +step:656/1680 train_time:57391ms step_avg:87.49ms +step:657/1680 train_time:57478ms step_avg:87.49ms +step:658/1680 train_time:57566ms step_avg:87.49ms +step:659/1680 train_time:57655ms step_avg:87.49ms +step:660/1680 train_time:57744ms step_avg:87.49ms +step:661/1680 train_time:57832ms step_avg:87.49ms +step:662/1680 train_time:57921ms step_avg:87.49ms +step:663/1680 train_time:58010ms step_avg:87.50ms +step:664/1680 train_time:58098ms step_avg:87.50ms +step:665/1680 train_time:58187ms step_avg:87.50ms +step:666/1680 train_time:58275ms step_avg:87.50ms +step:667/1680 train_time:58362ms step_avg:87.50ms +step:668/1680 train_time:58450ms step_avg:87.50ms +step:669/1680 train_time:58538ms step_avg:87.50ms +step:670/1680 train_time:58627ms step_avg:87.50ms +step:671/1680 train_time:58715ms step_avg:87.50ms +step:672/1680 train_time:58804ms step_avg:87.51ms +step:673/1680 train_time:58893ms step_avg:87.51ms +step:674/1680 train_time:58981ms step_avg:87.51ms +step:675/1680 train_time:59070ms step_avg:87.51ms +step:676/1680 train_time:59159ms step_avg:87.51ms +step:677/1680 train_time:59248ms step_avg:87.52ms +step:678/1680 train_time:59336ms step_avg:87.52ms +step:679/1680 train_time:59425ms step_avg:87.52ms +step:680/1680 train_time:59513ms step_avg:87.52ms +step:681/1680 train_time:59601ms step_avg:87.52ms +step:682/1680 train_time:59689ms step_avg:87.52ms +step:683/1680 train_time:59777ms step_avg:87.52ms +step:684/1680 train_time:59865ms step_avg:87.52ms +step:685/1680 train_time:59953ms step_avg:87.52ms +step:686/1680 train_time:60041ms step_avg:87.52ms +step:687/1680 train_time:60130ms step_avg:87.53ms +step:688/1680 train_time:60218ms step_avg:87.53ms +step:689/1680 train_time:60307ms step_avg:87.53ms +step:690/1680 train_time:60395ms step_avg:87.53ms +step:691/1680 train_time:60483ms step_avg:87.53ms +step:692/1680 train_time:60571ms step_avg:87.53ms +step:693/1680 train_time:60659ms step_avg:87.53ms +step:694/1680 train_time:60747ms step_avg:87.53ms +step:695/1680 train_time:60835ms step_avg:87.53ms +step:696/1680 train_time:60924ms step_avg:87.53ms +step:697/1680 train_time:61013ms step_avg:87.54ms +step:698/1680 train_time:61100ms step_avg:87.54ms +step:699/1680 train_time:61189ms step_avg:87.54ms +step:700/1680 train_time:61277ms step_avg:87.54ms +step:701/1680 train_time:61365ms step_avg:87.54ms +step:702/1680 train_time:61452ms step_avg:87.54ms +step:703/1680 train_time:61541ms step_avg:87.54ms +step:704/1680 train_time:61629ms step_avg:87.54ms +step:705/1680 train_time:61718ms step_avg:87.54ms +step:706/1680 train_time:61806ms step_avg:87.54ms +step:707/1680 train_time:61894ms step_avg:87.54ms +step:708/1680 train_time:61982ms step_avg:87.54ms +step:709/1680 train_time:62070ms step_avg:87.55ms +step:710/1680 train_time:62158ms step_avg:87.55ms +step:711/1680 train_time:62246ms step_avg:87.55ms +step:712/1680 train_time:62334ms step_avg:87.55ms +step:713/1680 train_time:62422ms step_avg:87.55ms +step:714/1680 train_time:62511ms step_avg:87.55ms +step:715/1680 train_time:62599ms step_avg:87.55ms +step:716/1680 train_time:62687ms step_avg:87.55ms +step:717/1680 train_time:62775ms step_avg:87.55ms +step:718/1680 train_time:62864ms step_avg:87.55ms +step:719/1680 train_time:62952ms step_avg:87.55ms +step:720/1680 train_time:63040ms step_avg:87.56ms +step:721/1680 train_time:63129ms step_avg:87.56ms +step:722/1680 train_time:63218ms step_avg:87.56ms +step:723/1680 train_time:63306ms step_avg:87.56ms +step:724/1680 train_time:63394ms step_avg:87.56ms +step:725/1680 train_time:63483ms step_avg:87.56ms +step:726/1680 train_time:63570ms step_avg:87.56ms +step:727/1680 train_time:63659ms step_avg:87.56ms +step:728/1680 train_time:63748ms step_avg:87.57ms +step:729/1680 train_time:63836ms step_avg:87.57ms +step:730/1680 train_time:63925ms step_avg:87.57ms +step:731/1680 train_time:64013ms step_avg:87.57ms +step:732/1680 train_time:64101ms step_avg:87.57ms +step:733/1680 train_time:64189ms step_avg:87.57ms +step:734/1680 train_time:64277ms step_avg:87.57ms +step:735/1680 train_time:64366ms step_avg:87.57ms +step:736/1680 train_time:64456ms step_avg:87.58ms +step:737/1680 train_time:64544ms step_avg:87.58ms +step:738/1680 train_time:64632ms step_avg:87.58ms +step:739/1680 train_time:64721ms step_avg:87.58ms +step:740/1680 train_time:64809ms step_avg:87.58ms +step:741/1680 train_time:64897ms step_avg:87.58ms +step:742/1680 train_time:64986ms step_avg:87.58ms +step:743/1680 train_time:65074ms step_avg:87.58ms +step:744/1680 train_time:65161ms step_avg:87.58ms +step:745/1680 train_time:65250ms step_avg:87.58ms +step:746/1680 train_time:65338ms step_avg:87.58ms +step:747/1680 train_time:65427ms step_avg:87.59ms +step:748/1680 train_time:65517ms step_avg:87.59ms +step:749/1680 train_time:65605ms step_avg:87.59ms +step:750/1680 train_time:65693ms step_avg:87.59ms +step:750/1680 val_loss:3.5674 train_time:65782ms step_avg:87.71ms +step:751/1680 train_time:65802ms step_avg:87.62ms +step:752/1680 train_time:65874ms step_avg:87.60ms +step:753/1680 train_time:65969ms step_avg:87.61ms +step:754/1680 train_time:66059ms step_avg:87.61ms +step:755/1680 train_time:66146ms step_avg:87.61ms +step:756/1680 train_time:66234ms step_avg:87.61ms +step:757/1680 train_time:66321ms step_avg:87.61ms +step:758/1680 train_time:66409ms step_avg:87.61ms +step:759/1680 train_time:66496ms step_avg:87.61ms +step:760/1680 train_time:66584ms step_avg:87.61ms +step:761/1680 train_time:66671ms step_avg:87.61ms +step:762/1680 train_time:66760ms step_avg:87.61ms +step:763/1680 train_time:66851ms step_avg:87.62ms +step:764/1680 train_time:66942ms step_avg:87.62ms +step:765/1680 train_time:67031ms step_avg:87.62ms +step:766/1680 train_time:67119ms step_avg:87.62ms +step:767/1680 train_time:67208ms step_avg:87.62ms +step:768/1680 train_time:67296ms step_avg:87.63ms +step:769/1680 train_time:67384ms step_avg:87.63ms +step:770/1680 train_time:67471ms step_avg:87.63ms +step:771/1680 train_time:67559ms step_avg:87.63ms +step:772/1680 train_time:67646ms step_avg:87.62ms +step:773/1680 train_time:67734ms step_avg:87.63ms +step:774/1680 train_time:67823ms step_avg:87.63ms +step:775/1680 train_time:67912ms step_avg:87.63ms +step:776/1680 train_time:68001ms step_avg:87.63ms +step:777/1680 train_time:68090ms step_avg:87.63ms +step:778/1680 train_time:68179ms step_avg:87.63ms +step:779/1680 train_time:68268ms step_avg:87.64ms +step:780/1680 train_time:68356ms step_avg:87.64ms +step:781/1680 train_time:68444ms step_avg:87.64ms +step:782/1680 train_time:68532ms step_avg:87.64ms +step:783/1680 train_time:68619ms step_avg:87.64ms +step:784/1680 train_time:68707ms step_avg:87.64ms +step:785/1680 train_time:68796ms step_avg:87.64ms +step:786/1680 train_time:68885ms step_avg:87.64ms +step:787/1680 train_time:68973ms step_avg:87.64ms +step:788/1680 train_time:69063ms step_avg:87.64ms +step:789/1680 train_time:69151ms step_avg:87.64ms +step:790/1680 train_time:69241ms step_avg:87.65ms +step:791/1680 train_time:69328ms step_avg:87.65ms +step:792/1680 train_time:69416ms step_avg:87.65ms +step:793/1680 train_time:69505ms step_avg:87.65ms +step:794/1680 train_time:69593ms step_avg:87.65ms +step:795/1680 train_time:69680ms step_avg:87.65ms +step:796/1680 train_time:69769ms step_avg:87.65ms +step:797/1680 train_time:69858ms step_avg:87.65ms +step:798/1680 train_time:69946ms step_avg:87.65ms +step:799/1680 train_time:70035ms step_avg:87.65ms +step:800/1680 train_time:70124ms step_avg:87.65ms +step:801/1680 train_time:70212ms step_avg:87.66ms +step:802/1680 train_time:70301ms step_avg:87.66ms +step:803/1680 train_time:70389ms step_avg:87.66ms +step:804/1680 train_time:70477ms step_avg:87.66ms +step:805/1680 train_time:70565ms step_avg:87.66ms +step:806/1680 train_time:70653ms step_avg:87.66ms +step:807/1680 train_time:70742ms step_avg:87.66ms +step:808/1680 train_time:70830ms step_avg:87.66ms +step:809/1680 train_time:70918ms step_avg:87.66ms +step:810/1680 train_time:71007ms step_avg:87.66ms +step:811/1680 train_time:71095ms step_avg:87.66ms +step:812/1680 train_time:71184ms step_avg:87.66ms +step:813/1680 train_time:71272ms step_avg:87.67ms +step:814/1680 train_time:71361ms step_avg:87.67ms +step:815/1680 train_time:71449ms step_avg:87.67ms +step:816/1680 train_time:71537ms step_avg:87.67ms +step:817/1680 train_time:71625ms step_avg:87.67ms +step:818/1680 train_time:71714ms step_avg:87.67ms +step:819/1680 train_time:71802ms step_avg:87.67ms +step:820/1680 train_time:71890ms step_avg:87.67ms +step:821/1680 train_time:71979ms step_avg:87.67ms +step:822/1680 train_time:72067ms step_avg:87.67ms +step:823/1680 train_time:72156ms step_avg:87.67ms +step:824/1680 train_time:72245ms step_avg:87.68ms +step:825/1680 train_time:72332ms step_avg:87.68ms +step:826/1680 train_time:72421ms step_avg:87.68ms +step:827/1680 train_time:72509ms step_avg:87.68ms +step:828/1680 train_time:72597ms step_avg:87.68ms +step:829/1680 train_time:72686ms step_avg:87.68ms +step:830/1680 train_time:72774ms step_avg:87.68ms +step:831/1680 train_time:72862ms step_avg:87.68ms +step:832/1680 train_time:72951ms step_avg:87.68ms +step:833/1680 train_time:73039ms step_avg:87.68ms +step:834/1680 train_time:73128ms step_avg:87.68ms +step:835/1680 train_time:73216ms step_avg:87.68ms +step:836/1680 train_time:73305ms step_avg:87.68ms +step:837/1680 train_time:73393ms step_avg:87.69ms +step:838/1680 train_time:73481ms step_avg:87.69ms +step:839/1680 train_time:73569ms step_avg:87.69ms +step:840/1680 train_time:73658ms step_avg:87.69ms +step:841/1680 train_time:73746ms step_avg:87.69ms +step:842/1680 train_time:73835ms step_avg:87.69ms +step:843/1680 train_time:73923ms step_avg:87.69ms +step:844/1680 train_time:74011ms step_avg:87.69ms +step:845/1680 train_time:74100ms step_avg:87.69ms +step:846/1680 train_time:74189ms step_avg:87.69ms +step:847/1680 train_time:74278ms step_avg:87.69ms +step:848/1680 train_time:74366ms step_avg:87.70ms +step:849/1680 train_time:74454ms step_avg:87.70ms +step:850/1680 train_time:74542ms step_avg:87.70ms +step:851/1680 train_time:74630ms step_avg:87.70ms +step:852/1680 train_time:74718ms step_avg:87.70ms +step:853/1680 train_time:74807ms step_avg:87.70ms +step:854/1680 train_time:74895ms step_avg:87.70ms +step:855/1680 train_time:74984ms step_avg:87.70ms +step:856/1680 train_time:75072ms step_avg:87.70ms +step:857/1680 train_time:75160ms step_avg:87.70ms +step:858/1680 train_time:75249ms step_avg:87.70ms +step:859/1680 train_time:75337ms step_avg:87.70ms +step:860/1680 train_time:75426ms step_avg:87.70ms +step:861/1680 train_time:75513ms step_avg:87.70ms +step:862/1680 train_time:75601ms step_avg:87.70ms +step:863/1680 train_time:75689ms step_avg:87.70ms +step:864/1680 train_time:75777ms step_avg:87.71ms +step:865/1680 train_time:75866ms step_avg:87.71ms +step:866/1680 train_time:75954ms step_avg:87.71ms +step:867/1680 train_time:76043ms step_avg:87.71ms +step:868/1680 train_time:76131ms step_avg:87.71ms +step:869/1680 train_time:76220ms step_avg:87.71ms +step:870/1680 train_time:76308ms step_avg:87.71ms +step:871/1680 train_time:76396ms step_avg:87.71ms +step:872/1680 train_time:76485ms step_avg:87.71ms +step:873/1680 train_time:76573ms step_avg:87.71ms +step:874/1680 train_time:76662ms step_avg:87.71ms +step:875/1680 train_time:76750ms step_avg:87.71ms +step:875/1680 val_loss:3.5193 train_time:76840ms step_avg:87.82ms +step:876/1680 train_time:76859ms step_avg:87.74ms +step:877/1680 train_time:76933ms step_avg:87.72ms +step:878/1680 train_time:77026ms step_avg:87.73ms +step:879/1680 train_time:77117ms step_avg:87.73ms +step:880/1680 train_time:77205ms step_avg:87.73ms +step:881/1680 train_time:77292ms step_avg:87.73ms +step:882/1680 train_time:77380ms step_avg:87.73ms +step:883/1680 train_time:77467ms step_avg:87.73ms +step:884/1680 train_time:77554ms step_avg:87.73ms +step:885/1680 train_time:77641ms step_avg:87.73ms +step:886/1680 train_time:77728ms step_avg:87.73ms +step:887/1680 train_time:77816ms step_avg:87.73ms +step:888/1680 train_time:77907ms step_avg:87.73ms +step:889/1680 train_time:77998ms step_avg:87.74ms +step:890/1680 train_time:78087ms step_avg:87.74ms +step:891/1680 train_time:78177ms step_avg:87.74ms +step:892/1680 train_time:78267ms step_avg:87.74ms +step:893/1680 train_time:78354ms step_avg:87.74ms +step:894/1680 train_time:78441ms step_avg:87.74ms +step:895/1680 train_time:78529ms step_avg:87.74ms +step:896/1680 train_time:78617ms step_avg:87.74ms +step:897/1680 train_time:78704ms step_avg:87.74ms +step:898/1680 train_time:78792ms step_avg:87.74ms +step:899/1680 train_time:78881ms step_avg:87.74ms +step:900/1680 train_time:78972ms step_avg:87.75ms +step:901/1680 train_time:79062ms step_avg:87.75ms +step:902/1680 train_time:79151ms step_avg:87.75ms +step:903/1680 train_time:79241ms step_avg:87.75ms +step:904/1680 train_time:79329ms step_avg:87.75ms +step:905/1680 train_time:79416ms step_avg:87.75ms +step:906/1680 train_time:79504ms step_avg:87.75ms +step:907/1680 train_time:79592ms step_avg:87.75ms +step:908/1680 train_time:79680ms step_avg:87.75ms +step:909/1680 train_time:79768ms step_avg:87.75ms +step:910/1680 train_time:79857ms step_avg:87.75ms +step:911/1680 train_time:79945ms step_avg:87.76ms +step:912/1680 train_time:80034ms step_avg:87.76ms +step:913/1680 train_time:80123ms step_avg:87.76ms +step:914/1680 train_time:80212ms step_avg:87.76ms +step:915/1680 train_time:80301ms step_avg:87.76ms +step:916/1680 train_time:80388ms step_avg:87.76ms +step:917/1680 train_time:80476ms step_avg:87.76ms +step:918/1680 train_time:80564ms step_avg:87.76ms +step:919/1680 train_time:80652ms step_avg:87.76ms +step:920/1680 train_time:80740ms step_avg:87.76ms +step:921/1680 train_time:80828ms step_avg:87.76ms +step:922/1680 train_time:80917ms step_avg:87.76ms +step:923/1680 train_time:81006ms step_avg:87.76ms +step:924/1680 train_time:81094ms step_avg:87.76ms +step:925/1680 train_time:81183ms step_avg:87.77ms +step:926/1680 train_time:81272ms step_avg:87.77ms +step:927/1680 train_time:81361ms step_avg:87.77ms +step:928/1680 train_time:81449ms step_avg:87.77ms +step:929/1680 train_time:81537ms step_avg:87.77ms +step:930/1680 train_time:81625ms step_avg:87.77ms +step:931/1680 train_time:81712ms step_avg:87.77ms +step:932/1680 train_time:81800ms step_avg:87.77ms +step:933/1680 train_time:81889ms step_avg:87.77ms +step:934/1680 train_time:81977ms step_avg:87.77ms +step:935/1680 train_time:82066ms step_avg:87.77ms +step:936/1680 train_time:82155ms step_avg:87.77ms +step:937/1680 train_time:82243ms step_avg:87.77ms +step:938/1680 train_time:82331ms step_avg:87.77ms +step:939/1680 train_time:82420ms step_avg:87.77ms +step:940/1680 train_time:82508ms step_avg:87.77ms +step:941/1680 train_time:82596ms step_avg:87.77ms +step:942/1680 train_time:82684ms step_avg:87.77ms +step:943/1680 train_time:82772ms step_avg:87.78ms +step:944/1680 train_time:82860ms step_avg:87.78ms +step:945/1680 train_time:82949ms step_avg:87.78ms +step:946/1680 train_time:83037ms step_avg:87.78ms +step:947/1680 train_time:83126ms step_avg:87.78ms +step:948/1680 train_time:83215ms step_avg:87.78ms +step:949/1680 train_time:83304ms step_avg:87.78ms +step:950/1680 train_time:83391ms step_avg:87.78ms +step:951/1680 train_time:83480ms step_avg:87.78ms +step:952/1680 train_time:83568ms step_avg:87.78ms +step:953/1680 train_time:83657ms step_avg:87.78ms +step:954/1680 train_time:83745ms step_avg:87.78ms +step:955/1680 train_time:83833ms step_avg:87.78ms +step:956/1680 train_time:83921ms step_avg:87.78ms +step:957/1680 train_time:84010ms step_avg:87.79ms +step:958/1680 train_time:84099ms step_avg:87.79ms +step:959/1680 train_time:84187ms step_avg:87.79ms +step:960/1680 train_time:84276ms step_avg:87.79ms +step:961/1680 train_time:84364ms step_avg:87.79ms +step:962/1680 train_time:84452ms step_avg:87.79ms +step:963/1680 train_time:84540ms step_avg:87.79ms +step:964/1680 train_time:84629ms step_avg:87.79ms +step:965/1680 train_time:84718ms step_avg:87.79ms +step:966/1680 train_time:84806ms step_avg:87.79ms +step:967/1680 train_time:84894ms step_avg:87.79ms +step:968/1680 train_time:84982ms step_avg:87.79ms +step:969/1680 train_time:85070ms step_avg:87.79ms +step:970/1680 train_time:85159ms step_avg:87.79ms +step:971/1680 train_time:85248ms step_avg:87.79ms +step:972/1680 train_time:85336ms step_avg:87.79ms +step:973/1680 train_time:85424ms step_avg:87.79ms +step:974/1680 train_time:85512ms step_avg:87.80ms +step:975/1680 train_time:85601ms step_avg:87.80ms +step:976/1680 train_time:85689ms step_avg:87.80ms +step:977/1680 train_time:85777ms step_avg:87.80ms +step:978/1680 train_time:85864ms step_avg:87.80ms +step:979/1680 train_time:85953ms step_avg:87.80ms +step:980/1680 train_time:86041ms step_avg:87.80ms +step:981/1680 train_time:86131ms step_avg:87.80ms +step:982/1680 train_time:86219ms step_avg:87.80ms +step:983/1680 train_time:86308ms step_avg:87.80ms +step:984/1680 train_time:86396ms step_avg:87.80ms +step:985/1680 train_time:86484ms step_avg:87.80ms +step:986/1680 train_time:86572ms step_avg:87.80ms +step:987/1680 train_time:86661ms step_avg:87.80ms +step:988/1680 train_time:86750ms step_avg:87.80ms +step:989/1680 train_time:86838ms step_avg:87.80ms +step:990/1680 train_time:86926ms step_avg:87.80ms +step:991/1680 train_time:87015ms step_avg:87.80ms +step:992/1680 train_time:87103ms step_avg:87.81ms +step:993/1680 train_time:87192ms step_avg:87.81ms +step:994/1680 train_time:87280ms step_avg:87.81ms +step:995/1680 train_time:87368ms step_avg:87.81ms +step:996/1680 train_time:87457ms step_avg:87.81ms +step:997/1680 train_time:87546ms step_avg:87.81ms +step:998/1680 train_time:87634ms step_avg:87.81ms +step:999/1680 train_time:87722ms step_avg:87.81ms +step:1000/1680 train_time:87811ms step_avg:87.81ms +step:1000/1680 val_loss:3.4694 train_time:87900ms step_avg:87.90ms +step:1001/1680 train_time:87919ms step_avg:87.83ms +step:1002/1680 train_time:87992ms step_avg:87.82ms +step:1003/1680 train_time:88086ms step_avg:87.82ms +step:1004/1680 train_time:88175ms step_avg:87.82ms +step:1005/1680 train_time:88263ms step_avg:87.82ms +step:1006/1680 train_time:88350ms step_avg:87.82ms +step:1007/1680 train_time:88437ms step_avg:87.82ms +step:1008/1680 train_time:88524ms step_avg:87.82ms +step:1009/1680 train_time:88611ms step_avg:87.82ms +step:1010/1680 train_time:88699ms step_avg:87.82ms +step:1011/1680 train_time:88785ms step_avg:87.82ms +step:1012/1680 train_time:88874ms step_avg:87.82ms +step:1013/1680 train_time:88965ms step_avg:87.82ms +step:1014/1680 train_time:89056ms step_avg:87.83ms +step:1015/1680 train_time:89146ms step_avg:87.83ms +step:1016/1680 train_time:89234ms step_avg:87.83ms +step:1017/1680 train_time:89323ms step_avg:87.83ms +step:1018/1680 train_time:89410ms step_avg:87.83ms +step:1019/1680 train_time:89498ms step_avg:87.83ms +step:1020/1680 train_time:89585ms step_avg:87.83ms +step:1021/1680 train_time:89672ms step_avg:87.83ms +step:1022/1680 train_time:89760ms step_avg:87.83ms +step:1023/1680 train_time:89848ms step_avg:87.83ms +step:1024/1680 train_time:89937ms step_avg:87.83ms +step:1025/1680 train_time:90026ms step_avg:87.83ms +step:1026/1680 train_time:90115ms step_avg:87.83ms +step:1027/1680 train_time:90203ms step_avg:87.83ms +step:1028/1680 train_time:90292ms step_avg:87.83ms +step:1029/1680 train_time:90380ms step_avg:87.83ms +step:1030/1680 train_time:90468ms step_avg:87.83ms +step:1031/1680 train_time:90555ms step_avg:87.83ms +step:1032/1680 train_time:90643ms step_avg:87.83ms +step:1033/1680 train_time:90730ms step_avg:87.83ms +step:1034/1680 train_time:90818ms step_avg:87.83ms +step:1035/1680 train_time:90907ms step_avg:87.83ms +step:1036/1680 train_time:90996ms step_avg:87.83ms +step:1037/1680 train_time:91086ms step_avg:87.84ms +step:1038/1680 train_time:91175ms step_avg:87.84ms +step:1039/1680 train_time:91263ms step_avg:87.84ms +step:1040/1680 train_time:91351ms step_avg:87.84ms +step:1041/1680 train_time:91440ms step_avg:87.84ms +step:1042/1680 train_time:91528ms step_avg:87.84ms +step:1043/1680 train_time:91616ms step_avg:87.84ms +step:1044/1680 train_time:91703ms step_avg:87.84ms +step:1045/1680 train_time:91791ms step_avg:87.84ms +step:1046/1680 train_time:91880ms step_avg:87.84ms +step:1047/1680 train_time:91969ms step_avg:87.84ms +step:1048/1680 train_time:92058ms step_avg:87.84ms +step:1049/1680 train_time:92148ms step_avg:87.84ms +step:1050/1680 train_time:92236ms step_avg:87.84ms +step:1051/1680 train_time:92325ms step_avg:87.84ms +step:1052/1680 train_time:92413ms step_avg:87.85ms +step:1053/1680 train_time:92501ms step_avg:87.84ms +step:1054/1680 train_time:92589ms step_avg:87.85ms +step:1055/1680 train_time:92677ms step_avg:87.85ms +step:1056/1680 train_time:92765ms step_avg:87.85ms +step:1057/1680 train_time:92854ms step_avg:87.85ms +step:1058/1680 train_time:92942ms step_avg:87.85ms +step:1059/1680 train_time:93031ms step_avg:87.85ms +step:1060/1680 train_time:93120ms step_avg:87.85ms +step:1061/1680 train_time:93208ms step_avg:87.85ms +step:1062/1680 train_time:93297ms step_avg:87.85ms +step:1063/1680 train_time:93385ms step_avg:87.85ms +step:1064/1680 train_time:93474ms step_avg:87.85ms +step:1065/1680 train_time:93562ms step_avg:87.85ms +step:1066/1680 train_time:93649ms step_avg:87.85ms +step:1067/1680 train_time:93738ms step_avg:87.85ms +step:1068/1680 train_time:93826ms step_avg:87.85ms +step:1069/1680 train_time:93914ms step_avg:87.85ms +step:1070/1680 train_time:94002ms step_avg:87.85ms +step:1071/1680 train_time:94091ms step_avg:87.85ms +step:1072/1680 train_time:94180ms step_avg:87.85ms +step:1073/1680 train_time:94269ms step_avg:87.86ms +step:1074/1680 train_time:94357ms step_avg:87.86ms +step:1075/1680 train_time:94445ms step_avg:87.86ms +step:1076/1680 train_time:94533ms step_avg:87.86ms +step:1077/1680 train_time:94621ms step_avg:87.86ms +step:1078/1680 train_time:94710ms step_avg:87.86ms +step:1079/1680 train_time:94798ms step_avg:87.86ms +step:1080/1680 train_time:94886ms step_avg:87.86ms +step:1081/1680 train_time:94974ms step_avg:87.86ms +step:1082/1680 train_time:95062ms step_avg:87.86ms +step:1083/1680 train_time:95151ms step_avg:87.86ms +step:1084/1680 train_time:95239ms step_avg:87.86ms +step:1085/1680 train_time:95327ms step_avg:87.86ms +step:1086/1680 train_time:95416ms step_avg:87.86ms +step:1087/1680 train_time:95504ms step_avg:87.86ms +step:1088/1680 train_time:95592ms step_avg:87.86ms +step:1089/1680 train_time:95680ms step_avg:87.86ms +step:1090/1680 train_time:95769ms step_avg:87.86ms +step:1091/1680 train_time:95857ms step_avg:87.86ms +step:1092/1680 train_time:95945ms step_avg:87.86ms +step:1093/1680 train_time:96033ms step_avg:87.86ms +step:1094/1680 train_time:96121ms step_avg:87.86ms +step:1095/1680 train_time:96210ms step_avg:87.86ms +step:1096/1680 train_time:96299ms step_avg:87.86ms +step:1097/1680 train_time:96389ms step_avg:87.87ms +step:1098/1680 train_time:96478ms step_avg:87.87ms +step:1099/1680 train_time:96568ms step_avg:87.87ms +step:1100/1680 train_time:96657ms step_avg:87.87ms +step:1101/1680 train_time:96746ms step_avg:87.87ms +step:1102/1680 train_time:96834ms step_avg:87.87ms +step:1103/1680 train_time:96923ms step_avg:87.87ms +step:1104/1680 train_time:97012ms step_avg:87.87ms +step:1105/1680 train_time:97101ms step_avg:87.87ms +step:1106/1680 train_time:97191ms step_avg:87.88ms +step:1107/1680 train_time:97281ms step_avg:87.88ms +step:1108/1680 train_time:97370ms step_avg:87.88ms +step:1109/1680 train_time:97459ms step_avg:87.88ms +step:1110/1680 train_time:97548ms step_avg:87.88ms +step:1111/1680 train_time:97637ms step_avg:87.88ms +step:1112/1680 train_time:97728ms step_avg:87.88ms +step:1113/1680 train_time:97817ms step_avg:87.89ms +step:1114/1680 train_time:97906ms step_avg:87.89ms +step:1115/1680 train_time:97995ms step_avg:87.89ms +step:1116/1680 train_time:98086ms step_avg:87.89ms +step:1117/1680 train_time:98175ms step_avg:87.89ms +step:1118/1680 train_time:98264ms step_avg:87.89ms +step:1119/1680 train_time:98354ms step_avg:87.89ms +step:1120/1680 train_time:98442ms step_avg:87.89ms +step:1121/1680 train_time:98531ms step_avg:87.90ms +step:1122/1680 train_time:98620ms step_avg:87.90ms +step:1123/1680 train_time:98710ms step_avg:87.90ms +step:1124/1680 train_time:98799ms step_avg:87.90ms +step:1125/1680 train_time:98888ms step_avg:87.90ms +step:1125/1680 val_loss:3.4160 train_time:98979ms step_avg:87.98ms +step:1126/1680 train_time:98998ms step_avg:87.92ms +step:1127/1680 train_time:99070ms step_avg:87.91ms +step:1128/1680 train_time:99160ms step_avg:87.91ms +step:1129/1680 train_time:99250ms step_avg:87.91ms +step:1130/1680 train_time:99339ms step_avg:87.91ms +step:1131/1680 train_time:99428ms step_avg:87.91ms +step:1132/1680 train_time:99516ms step_avg:87.91ms +step:1133/1680 train_time:99605ms step_avg:87.91ms +step:1134/1680 train_time:99692ms step_avg:87.91ms +step:1135/1680 train_time:99780ms step_avg:87.91ms +step:1136/1680 train_time:99870ms step_avg:87.91ms +step:1137/1680 train_time:99961ms step_avg:87.92ms +step:1138/1680 train_time:100052ms step_avg:87.92ms +step:1139/1680 train_time:100143ms step_avg:87.92ms +step:1140/1680 train_time:100234ms step_avg:87.92ms +step:1141/1680 train_time:100323ms step_avg:87.93ms +step:1142/1680 train_time:100411ms step_avg:87.93ms +step:1143/1680 train_time:100500ms step_avg:87.93ms +step:1144/1680 train_time:100589ms step_avg:87.93ms +step:1145/1680 train_time:100677ms step_avg:87.93ms +step:1146/1680 train_time:100765ms step_avg:87.93ms +step:1147/1680 train_time:100854ms step_avg:87.93ms +step:1148/1680 train_time:100943ms step_avg:87.93ms +step:1149/1680 train_time:101033ms step_avg:87.93ms +step:1150/1680 train_time:101122ms step_avg:87.93ms +step:1151/1680 train_time:101212ms step_avg:87.93ms +step:1152/1680 train_time:101302ms step_avg:87.94ms +step:1153/1680 train_time:101390ms step_avg:87.94ms +step:1154/1680 train_time:101479ms step_avg:87.94ms +step:1155/1680 train_time:101568ms step_avg:87.94ms +step:1156/1680 train_time:101656ms step_avg:87.94ms +step:1157/1680 train_time:101746ms step_avg:87.94ms +step:1158/1680 train_time:101834ms step_avg:87.94ms +step:1159/1680 train_time:101924ms step_avg:87.94ms +step:1160/1680 train_time:102014ms step_avg:87.94ms +step:1161/1680 train_time:102104ms step_avg:87.94ms +step:1162/1680 train_time:102193ms step_avg:87.95ms +step:1163/1680 train_time:102282ms step_avg:87.95ms +step:1164/1680 train_time:102372ms step_avg:87.95ms +step:1165/1680 train_time:102461ms step_avg:87.95ms +step:1166/1680 train_time:102550ms step_avg:87.95ms +step:1167/1680 train_time:102639ms step_avg:87.95ms +step:1168/1680 train_time:102729ms step_avg:87.95ms +step:1169/1680 train_time:102817ms step_avg:87.95ms +step:1170/1680 train_time:102906ms step_avg:87.95ms +step:1171/1680 train_time:102994ms step_avg:87.95ms +step:1172/1680 train_time:103084ms step_avg:87.96ms +step:1173/1680 train_time:103173ms step_avg:87.96ms +step:1174/1680 train_time:103262ms step_avg:87.96ms +step:1175/1680 train_time:103352ms step_avg:87.96ms +step:1176/1680 train_time:103441ms step_avg:87.96ms +step:1177/1680 train_time:103530ms step_avg:87.96ms +step:1178/1680 train_time:103618ms step_avg:87.96ms +step:1179/1680 train_time:103706ms step_avg:87.96ms +step:1180/1680 train_time:103795ms step_avg:87.96ms +step:1181/1680 train_time:103884ms step_avg:87.96ms +step:1182/1680 train_time:103973ms step_avg:87.96ms +step:1183/1680 train_time:104062ms step_avg:87.96ms +step:1184/1680 train_time:104151ms step_avg:87.97ms +step:1185/1680 train_time:104240ms step_avg:87.97ms +step:1186/1680 train_time:104329ms step_avg:87.97ms +step:1187/1680 train_time:104418ms step_avg:87.97ms +step:1188/1680 train_time:104507ms step_avg:87.97ms +step:1189/1680 train_time:104596ms step_avg:87.97ms +step:1190/1680 train_time:104685ms step_avg:87.97ms +step:1191/1680 train_time:104773ms step_avg:87.97ms +step:1192/1680 train_time:104863ms step_avg:87.97ms +step:1193/1680 train_time:104952ms step_avg:87.97ms +step:1194/1680 train_time:105042ms step_avg:87.98ms +step:1195/1680 train_time:105132ms step_avg:87.98ms +step:1196/1680 train_time:105221ms step_avg:87.98ms +step:1197/1680 train_time:105309ms step_avg:87.98ms +step:1198/1680 train_time:105398ms step_avg:87.98ms +step:1199/1680 train_time:105487ms step_avg:87.98ms +step:1200/1680 train_time:105576ms step_avg:87.98ms +step:1201/1680 train_time:105665ms step_avg:87.98ms +step:1202/1680 train_time:105754ms step_avg:87.98ms +step:1203/1680 train_time:105844ms step_avg:87.98ms +step:1204/1680 train_time:105934ms step_avg:87.99ms +step:1205/1680 train_time:106023ms step_avg:87.99ms +step:1206/1680 train_time:106112ms step_avg:87.99ms +step:1207/1680 train_time:106202ms step_avg:87.99ms +step:1208/1680 train_time:106290ms step_avg:87.99ms +step:1209/1680 train_time:106379ms step_avg:87.99ms +step:1210/1680 train_time:106468ms step_avg:87.99ms +step:1211/1680 train_time:106558ms step_avg:87.99ms +step:1212/1680 train_time:106646ms step_avg:87.99ms +step:1213/1680 train_time:106735ms step_avg:87.99ms +step:1214/1680 train_time:106825ms step_avg:87.99ms +step:1215/1680 train_time:106914ms step_avg:87.99ms +step:1216/1680 train_time:107003ms step_avg:88.00ms +step:1217/1680 train_time:107093ms step_avg:88.00ms +step:1218/1680 train_time:107182ms step_avg:88.00ms +step:1219/1680 train_time:107271ms step_avg:88.00ms +step:1220/1680 train_time:107361ms step_avg:88.00ms +step:1221/1680 train_time:107450ms step_avg:88.00ms +step:1222/1680 train_time:107539ms step_avg:88.00ms +step:1223/1680 train_time:107628ms step_avg:88.00ms +step:1224/1680 train_time:107717ms step_avg:88.00ms +step:1225/1680 train_time:107806ms step_avg:88.00ms +step:1226/1680 train_time:107895ms step_avg:88.01ms +step:1227/1680 train_time:107984ms step_avg:88.01ms +step:1228/1680 train_time:108074ms step_avg:88.01ms +step:1229/1680 train_time:108162ms step_avg:88.01ms +step:1230/1680 train_time:108252ms step_avg:88.01ms +step:1231/1680 train_time:108341ms step_avg:88.01ms +step:1232/1680 train_time:108431ms step_avg:88.01ms +step:1233/1680 train_time:108521ms step_avg:88.01ms +step:1234/1680 train_time:108610ms step_avg:88.01ms +step:1235/1680 train_time:108699ms step_avg:88.02ms +step:1236/1680 train_time:108789ms step_avg:88.02ms +step:1237/1680 train_time:108878ms step_avg:88.02ms +step:1238/1680 train_time:108967ms step_avg:88.02ms +step:1239/1680 train_time:109056ms step_avg:88.02ms +step:1240/1680 train_time:109145ms step_avg:88.02ms +step:1241/1680 train_time:109234ms step_avg:88.02ms +step:1242/1680 train_time:109323ms step_avg:88.02ms +step:1243/1680 train_time:109413ms step_avg:88.02ms +step:1244/1680 train_time:109504ms step_avg:88.03ms +step:1245/1680 train_time:109592ms step_avg:88.03ms +step:1246/1680 train_time:109681ms step_avg:88.03ms +step:1247/1680 train_time:109770ms step_avg:88.03ms +step:1248/1680 train_time:109859ms step_avg:88.03ms +step:1249/1680 train_time:109948ms step_avg:88.03ms +step:1250/1680 train_time:110037ms step_avg:88.03ms +step:1250/1680 val_loss:3.3779 train_time:110128ms step_avg:88.10ms +step:1251/1680 train_time:110147ms step_avg:88.05ms +step:1252/1680 train_time:110219ms step_avg:88.03ms +step:1253/1680 train_time:110314ms step_avg:88.04ms +step:1254/1680 train_time:110404ms step_avg:88.04ms +step:1255/1680 train_time:110492ms step_avg:88.04ms +step:1256/1680 train_time:110580ms step_avg:88.04ms +step:1257/1680 train_time:110669ms step_avg:88.04ms +step:1258/1680 train_time:110756ms step_avg:88.04ms +step:1259/1680 train_time:110844ms step_avg:88.04ms +step:1260/1680 train_time:110932ms step_avg:88.04ms +step:1261/1680 train_time:111020ms step_avg:88.04ms +step:1262/1680 train_time:111111ms step_avg:88.04ms +step:1263/1680 train_time:111202ms step_avg:88.05ms +step:1264/1680 train_time:111294ms step_avg:88.05ms +step:1265/1680 train_time:111385ms step_avg:88.05ms +step:1266/1680 train_time:111474ms step_avg:88.05ms +step:1267/1680 train_time:111562ms step_avg:88.05ms +step:1268/1680 train_time:111650ms step_avg:88.05ms +step:1269/1680 train_time:111739ms step_avg:88.05ms +step:1270/1680 train_time:111827ms step_avg:88.05ms +step:1271/1680 train_time:111915ms step_avg:88.05ms +step:1272/1680 train_time:112004ms step_avg:88.05ms +step:1273/1680 train_time:112093ms step_avg:88.05ms +step:1274/1680 train_time:112182ms step_avg:88.06ms +step:1275/1680 train_time:112273ms step_avg:88.06ms +step:1276/1680 train_time:112364ms step_avg:88.06ms +step:1277/1680 train_time:112453ms step_avg:88.06ms +step:1278/1680 train_time:112543ms step_avg:88.06ms +step:1279/1680 train_time:112632ms step_avg:88.06ms +step:1280/1680 train_time:112720ms step_avg:88.06ms +step:1281/1680 train_time:112808ms step_avg:88.06ms +step:1282/1680 train_time:112897ms step_avg:88.06ms +step:1283/1680 train_time:112986ms step_avg:88.06ms +step:1284/1680 train_time:113075ms step_avg:88.06ms +step:1285/1680 train_time:113165ms step_avg:88.07ms +step:1286/1680 train_time:113254ms step_avg:88.07ms +step:1287/1680 train_time:113344ms step_avg:88.07ms +step:1288/1680 train_time:113434ms step_avg:88.07ms +step:1289/1680 train_time:113524ms step_avg:88.07ms +step:1290/1680 train_time:113612ms step_avg:88.07ms +step:1291/1680 train_time:113702ms step_avg:88.07ms +step:1292/1680 train_time:113791ms step_avg:88.07ms +step:1293/1680 train_time:113879ms step_avg:88.07ms +step:1294/1680 train_time:113968ms step_avg:88.07ms +step:1295/1680 train_time:114057ms step_avg:88.07ms +step:1296/1680 train_time:114146ms step_avg:88.08ms +step:1297/1680 train_time:114235ms step_avg:88.08ms +step:1298/1680 train_time:114325ms step_avg:88.08ms +step:1299/1680 train_time:114415ms step_avg:88.08ms +step:1300/1680 train_time:114505ms step_avg:88.08ms +step:1301/1680 train_time:114594ms step_avg:88.08ms +step:1302/1680 train_time:114684ms step_avg:88.08ms +step:1303/1680 train_time:114773ms step_avg:88.08ms +step:1304/1680 train_time:114862ms step_avg:88.08ms +step:1305/1680 train_time:114951ms step_avg:88.09ms +step:1306/1680 train_time:115039ms step_avg:88.09ms +step:1307/1680 train_time:115128ms step_avg:88.09ms +step:1308/1680 train_time:115218ms step_avg:88.09ms +step:1309/1680 train_time:115307ms step_avg:88.09ms +step:1310/1680 train_time:115397ms step_avg:88.09ms +step:1311/1680 train_time:115486ms step_avg:88.09ms +step:1312/1680 train_time:115575ms step_avg:88.09ms +step:1313/1680 train_time:115663ms step_avg:88.09ms +step:1314/1680 train_time:115752ms step_avg:88.09ms +step:1315/1680 train_time:115841ms step_avg:88.09ms +step:1316/1680 train_time:115930ms step_avg:88.09ms +step:1317/1680 train_time:116018ms step_avg:88.09ms +step:1318/1680 train_time:116108ms step_avg:88.09ms +step:1319/1680 train_time:116196ms step_avg:88.09ms +step:1320/1680 train_time:116286ms step_avg:88.10ms +step:1321/1680 train_time:116376ms step_avg:88.10ms +step:1322/1680 train_time:116466ms step_avg:88.10ms +step:1323/1680 train_time:116556ms step_avg:88.10ms +step:1324/1680 train_time:116644ms step_avg:88.10ms +step:1325/1680 train_time:116734ms step_avg:88.10ms +step:1326/1680 train_time:116823ms step_avg:88.10ms +step:1327/1680 train_time:116911ms step_avg:88.10ms +step:1328/1680 train_time:117000ms step_avg:88.10ms +step:1329/1680 train_time:117091ms step_avg:88.10ms +step:1330/1680 train_time:117180ms step_avg:88.11ms +step:1331/1680 train_time:117270ms step_avg:88.11ms +step:1332/1680 train_time:117359ms step_avg:88.11ms +step:1333/1680 train_time:117448ms step_avg:88.11ms +step:1334/1680 train_time:117538ms step_avg:88.11ms +step:1335/1680 train_time:117627ms step_avg:88.11ms +step:1336/1680 train_time:117716ms step_avg:88.11ms +step:1337/1680 train_time:117805ms step_avg:88.11ms +step:1338/1680 train_time:117894ms step_avg:88.11ms +step:1339/1680 train_time:117984ms step_avg:88.11ms +step:1340/1680 train_time:118073ms step_avg:88.11ms +step:1341/1680 train_time:118162ms step_avg:88.11ms +step:1342/1680 train_time:118250ms step_avg:88.12ms +step:1343/1680 train_time:118339ms step_avg:88.12ms +step:1344/1680 train_time:118428ms step_avg:88.12ms +step:1345/1680 train_time:118517ms step_avg:88.12ms +step:1346/1680 train_time:118607ms step_avg:88.12ms +step:1347/1680 train_time:118696ms step_avg:88.12ms +step:1348/1680 train_time:118785ms step_avg:88.12ms +step:1349/1680 train_time:118875ms step_avg:88.12ms +step:1350/1680 train_time:118964ms step_avg:88.12ms +step:1351/1680 train_time:119053ms step_avg:88.12ms +step:1352/1680 train_time:119142ms step_avg:88.12ms +step:1353/1680 train_time:119231ms step_avg:88.12ms +step:1354/1680 train_time:119320ms step_avg:88.12ms +step:1355/1680 train_time:119410ms step_avg:88.13ms +step:1356/1680 train_time:119499ms step_avg:88.13ms +step:1357/1680 train_time:119589ms step_avg:88.13ms +step:1358/1680 train_time:119677ms step_avg:88.13ms +step:1359/1680 train_time:119767ms step_avg:88.13ms +step:1360/1680 train_time:119855ms step_avg:88.13ms +step:1361/1680 train_time:119944ms step_avg:88.13ms +step:1362/1680 train_time:120033ms step_avg:88.13ms +step:1363/1680 train_time:120122ms step_avg:88.13ms +step:1364/1680 train_time:120212ms step_avg:88.13ms +step:1365/1680 train_time:120301ms step_avg:88.13ms +step:1366/1680 train_time:120391ms step_avg:88.13ms +step:1367/1680 train_time:120480ms step_avg:88.13ms +step:1368/1680 train_time:120569ms step_avg:88.14ms +step:1369/1680 train_time:120659ms step_avg:88.14ms +step:1370/1680 train_time:120748ms step_avg:88.14ms +step:1371/1680 train_time:120837ms step_avg:88.14ms +step:1372/1680 train_time:120926ms step_avg:88.14ms +step:1373/1680 train_time:121016ms step_avg:88.14ms +step:1374/1680 train_time:121104ms step_avg:88.14ms +step:1375/1680 train_time:121194ms step_avg:88.14ms +step:1375/1680 val_loss:3.3434 train_time:121285ms step_avg:88.21ms +step:1376/1680 train_time:121303ms step_avg:88.16ms +step:1377/1680 train_time:121378ms step_avg:88.15ms +step:1378/1680 train_time:121473ms step_avg:88.15ms +step:1379/1680 train_time:121562ms step_avg:88.15ms +step:1380/1680 train_time:121650ms step_avg:88.15ms +step:1381/1680 train_time:121737ms step_avg:88.15ms +step:1382/1680 train_time:121825ms step_avg:88.15ms +step:1383/1680 train_time:121914ms step_avg:88.15ms +step:1384/1680 train_time:122001ms step_avg:88.15ms +step:1385/1680 train_time:122089ms step_avg:88.15ms +step:1386/1680 train_time:122177ms step_avg:88.15ms +step:1387/1680 train_time:122267ms step_avg:88.15ms +step:1388/1680 train_time:122358ms step_avg:88.15ms +step:1389/1680 train_time:122450ms step_avg:88.16ms +step:1390/1680 train_time:122542ms step_avg:88.16ms +step:1391/1680 train_time:122632ms step_avg:88.16ms +step:1392/1680 train_time:122720ms step_avg:88.16ms +step:1393/1680 train_time:122809ms step_avg:88.16ms +step:1394/1680 train_time:122897ms step_avg:88.16ms +step:1395/1680 train_time:122985ms step_avg:88.16ms +step:1396/1680 train_time:123073ms step_avg:88.16ms +step:1397/1680 train_time:123161ms step_avg:88.16ms +step:1398/1680 train_time:123250ms step_avg:88.16ms +step:1399/1680 train_time:123339ms step_avg:88.16ms +step:1400/1680 train_time:123430ms step_avg:88.16ms +step:1401/1680 train_time:123520ms step_avg:88.17ms +step:1402/1680 train_time:123610ms step_avg:88.17ms +step:1403/1680 train_time:123699ms step_avg:88.17ms +step:1404/1680 train_time:123788ms step_avg:88.17ms +step:1405/1680 train_time:123876ms step_avg:88.17ms +step:1406/1680 train_time:123965ms step_avg:88.17ms +step:1407/1680 train_time:124053ms step_avg:88.17ms +step:1408/1680 train_time:124141ms step_avg:88.17ms +step:1409/1680 train_time:124231ms step_avg:88.17ms +step:1410/1680 train_time:124320ms step_avg:88.17ms +step:1411/1680 train_time:124409ms step_avg:88.17ms +step:1412/1680 train_time:124499ms step_avg:88.17ms +step:1413/1680 train_time:124589ms step_avg:88.17ms +step:1414/1680 train_time:124678ms step_avg:88.17ms +step:1415/1680 train_time:124768ms step_avg:88.18ms +step:1416/1680 train_time:124857ms step_avg:88.18ms +step:1417/1680 train_time:124945ms step_avg:88.18ms +step:1418/1680 train_time:125034ms step_avg:88.18ms +step:1419/1680 train_time:125123ms step_avg:88.18ms +step:1420/1680 train_time:125213ms step_avg:88.18ms +step:1421/1680 train_time:125301ms step_avg:88.18ms +step:1422/1680 train_time:125390ms step_avg:88.18ms +step:1423/1680 train_time:125480ms step_avg:88.18ms +step:1424/1680 train_time:125570ms step_avg:88.18ms +step:1425/1680 train_time:125659ms step_avg:88.18ms +step:1426/1680 train_time:125750ms step_avg:88.18ms +step:1427/1680 train_time:125839ms step_avg:88.18ms +step:1428/1680 train_time:125929ms step_avg:88.19ms +step:1429/1680 train_time:126018ms step_avg:88.19ms +step:1430/1680 train_time:126106ms step_avg:88.19ms +step:1431/1680 train_time:126195ms step_avg:88.19ms +step:1432/1680 train_time:126283ms step_avg:88.19ms +step:1433/1680 train_time:126373ms step_avg:88.19ms +step:1434/1680 train_time:126462ms step_avg:88.19ms +step:1435/1680 train_time:126553ms step_avg:88.19ms +step:1436/1680 train_time:126642ms step_avg:88.19ms +step:1437/1680 train_time:126731ms step_avg:88.19ms +step:1438/1680 train_time:126820ms step_avg:88.19ms +step:1439/1680 train_time:126910ms step_avg:88.19ms +step:1440/1680 train_time:126999ms step_avg:88.19ms +step:1441/1680 train_time:127088ms step_avg:88.19ms +step:1442/1680 train_time:127177ms step_avg:88.19ms +step:1443/1680 train_time:127266ms step_avg:88.20ms +step:1444/1680 train_time:127355ms step_avg:88.20ms +step:1445/1680 train_time:127445ms step_avg:88.20ms +step:1446/1680 train_time:127536ms step_avg:88.20ms +step:1447/1680 train_time:127626ms step_avg:88.20ms +step:1448/1680 train_time:127716ms step_avg:88.20ms +step:1449/1680 train_time:127805ms step_avg:88.20ms +step:1450/1680 train_time:127895ms step_avg:88.20ms +step:1451/1680 train_time:127984ms step_avg:88.20ms +step:1452/1680 train_time:128072ms step_avg:88.20ms +step:1453/1680 train_time:128161ms step_avg:88.20ms +step:1454/1680 train_time:128251ms step_avg:88.21ms +step:1455/1680 train_time:128340ms step_avg:88.21ms +step:1456/1680 train_time:128429ms step_avg:88.21ms +step:1457/1680 train_time:128518ms step_avg:88.21ms +step:1458/1680 train_time:128607ms step_avg:88.21ms +step:1459/1680 train_time:128697ms step_avg:88.21ms +step:1460/1680 train_time:128787ms step_avg:88.21ms +step:1461/1680 train_time:128876ms step_avg:88.21ms +step:1462/1680 train_time:128965ms step_avg:88.21ms +step:1463/1680 train_time:129053ms step_avg:88.21ms +step:1464/1680 train_time:129143ms step_avg:88.21ms +step:1465/1680 train_time:129232ms step_avg:88.21ms +step:1466/1680 train_time:129321ms step_avg:88.21ms +step:1467/1680 train_time:129411ms step_avg:88.21ms +step:1468/1680 train_time:129499ms step_avg:88.21ms +step:1469/1680 train_time:129589ms step_avg:88.22ms +step:1470/1680 train_time:129678ms step_avg:88.22ms +step:1471/1680 train_time:129767ms step_avg:88.22ms +step:1472/1680 train_time:129856ms step_avg:88.22ms +step:1473/1680 train_time:129946ms step_avg:88.22ms +step:1474/1680 train_time:130035ms step_avg:88.22ms +step:1475/1680 train_time:130125ms step_avg:88.22ms +step:1476/1680 train_time:130215ms step_avg:88.22ms +step:1477/1680 train_time:130305ms step_avg:88.22ms +step:1478/1680 train_time:130395ms step_avg:88.22ms +step:1479/1680 train_time:130485ms step_avg:88.22ms +step:1480/1680 train_time:130575ms step_avg:88.23ms +step:1481/1680 train_time:130666ms step_avg:88.23ms +step:1482/1680 train_time:130755ms step_avg:88.23ms +step:1483/1680 train_time:130845ms step_avg:88.23ms +step:1484/1680 train_time:130934ms step_avg:88.23ms +step:1485/1680 train_time:131023ms step_avg:88.23ms +step:1486/1680 train_time:131113ms step_avg:88.23ms +step:1487/1680 train_time:131201ms step_avg:88.23ms +step:1488/1680 train_time:131290ms step_avg:88.23ms +step:1489/1680 train_time:131379ms step_avg:88.23ms +step:1490/1680 train_time:131468ms step_avg:88.23ms +step:1491/1680 train_time:131558ms step_avg:88.23ms +step:1492/1680 train_time:131647ms step_avg:88.24ms +step:1493/1680 train_time:131737ms step_avg:88.24ms +step:1494/1680 train_time:131827ms step_avg:88.24ms +step:1495/1680 train_time:131916ms step_avg:88.24ms +step:1496/1680 train_time:132006ms step_avg:88.24ms +step:1497/1680 train_time:132095ms step_avg:88.24ms +step:1498/1680 train_time:132184ms step_avg:88.24ms +step:1499/1680 train_time:132273ms step_avg:88.24ms +step:1500/1680 train_time:132361ms step_avg:88.24ms +step:1500/1680 val_loss:3.3132 train_time:132452ms step_avg:88.30ms +step:1501/1680 train_time:132472ms step_avg:88.26ms +step:1502/1680 train_time:132545ms step_avg:88.25ms +step:1503/1680 train_time:132637ms step_avg:88.25ms +step:1504/1680 train_time:132727ms step_avg:88.25ms +step:1505/1680 train_time:132815ms step_avg:88.25ms +step:1506/1680 train_time:132904ms step_avg:88.25ms +step:1507/1680 train_time:132991ms step_avg:88.25ms +step:1508/1680 train_time:133079ms step_avg:88.25ms +step:1509/1680 train_time:133167ms step_avg:88.25ms +step:1510/1680 train_time:133256ms step_avg:88.25ms +step:1511/1680 train_time:133345ms step_avg:88.25ms +step:1512/1680 train_time:133435ms step_avg:88.25ms +step:1513/1680 train_time:133526ms step_avg:88.25ms +step:1514/1680 train_time:133617ms step_avg:88.25ms +step:1515/1680 train_time:133707ms step_avg:88.26ms +step:1516/1680 train_time:133797ms step_avg:88.26ms +step:1517/1680 train_time:133885ms step_avg:88.26ms +step:1518/1680 train_time:133973ms step_avg:88.26ms +step:1519/1680 train_time:134062ms step_avg:88.26ms +step:1520/1680 train_time:134151ms step_avg:88.26ms +step:1521/1680 train_time:134239ms step_avg:88.26ms +step:1522/1680 train_time:134327ms step_avg:88.26ms +step:1523/1680 train_time:134417ms step_avg:88.26ms +step:1524/1680 train_time:134506ms step_avg:88.26ms +step:1525/1680 train_time:134597ms step_avg:88.26ms +step:1526/1680 train_time:134688ms step_avg:88.26ms +step:1527/1680 train_time:134777ms step_avg:88.26ms +step:1528/1680 train_time:134867ms step_avg:88.26ms +step:1529/1680 train_time:134957ms step_avg:88.26ms +step:1530/1680 train_time:135045ms step_avg:88.26ms +step:1531/1680 train_time:135133ms step_avg:88.26ms +step:1532/1680 train_time:135222ms step_avg:88.26ms +step:1533/1680 train_time:135310ms step_avg:88.26ms +step:1534/1680 train_time:135399ms step_avg:88.27ms +step:1535/1680 train_time:135489ms step_avg:88.27ms +step:1536/1680 train_time:135579ms step_avg:88.27ms +step:1537/1680 train_time:135669ms step_avg:88.27ms +step:1538/1680 train_time:135759ms step_avg:88.27ms +step:1539/1680 train_time:135848ms step_avg:88.27ms +step:1540/1680 train_time:135938ms step_avg:88.27ms +step:1541/1680 train_time:136027ms step_avg:88.27ms +step:1542/1680 train_time:136115ms step_avg:88.27ms +step:1543/1680 train_time:136204ms step_avg:88.27ms +step:1544/1680 train_time:136293ms step_avg:88.27ms +step:1545/1680 train_time:136382ms step_avg:88.27ms +step:1546/1680 train_time:136470ms step_avg:88.27ms +step:1547/1680 train_time:136560ms step_avg:88.27ms +step:1548/1680 train_time:136650ms step_avg:88.28ms +step:1549/1680 train_time:136739ms step_avg:88.28ms +step:1550/1680 train_time:136829ms step_avg:88.28ms +step:1551/1680 train_time:136917ms step_avg:88.28ms +step:1552/1680 train_time:137006ms step_avg:88.28ms +step:1553/1680 train_time:137095ms step_avg:88.28ms +step:1554/1680 train_time:137184ms step_avg:88.28ms +step:1555/1680 train_time:137274ms step_avg:88.28ms +step:1556/1680 train_time:137363ms step_avg:88.28ms +step:1557/1680 train_time:137452ms step_avg:88.28ms +step:1558/1680 train_time:137541ms step_avg:88.28ms +step:1559/1680 train_time:137631ms step_avg:88.28ms +step:1560/1680 train_time:137721ms step_avg:88.28ms +step:1561/1680 train_time:137811ms step_avg:88.28ms +step:1562/1680 train_time:137900ms step_avg:88.28ms +step:1563/1680 train_time:137989ms step_avg:88.28ms +step:1564/1680 train_time:138077ms step_avg:88.28ms +step:1565/1680 train_time:138166ms step_avg:88.28ms +step:1566/1680 train_time:138256ms step_avg:88.29ms +step:1567/1680 train_time:138345ms step_avg:88.29ms +step:1568/1680 train_time:138435ms step_avg:88.29ms +step:1569/1680 train_time:138525ms step_avg:88.29ms +step:1570/1680 train_time:138614ms step_avg:88.29ms +step:1571/1680 train_time:138703ms step_avg:88.29ms +step:1572/1680 train_time:138793ms step_avg:88.29ms +step:1573/1680 train_time:138882ms step_avg:88.29ms +step:1574/1680 train_time:138971ms step_avg:88.29ms +step:1575/1680 train_time:139060ms step_avg:88.29ms +step:1576/1680 train_time:139150ms step_avg:88.29ms +step:1577/1680 train_time:139238ms step_avg:88.29ms +step:1578/1680 train_time:139327ms step_avg:88.29ms +step:1579/1680 train_time:139416ms step_avg:88.29ms +step:1580/1680 train_time:139505ms step_avg:88.29ms +step:1581/1680 train_time:139595ms step_avg:88.30ms +step:1582/1680 train_time:139685ms step_avg:88.30ms +step:1583/1680 train_time:139774ms step_avg:88.30ms +step:1584/1680 train_time:139864ms step_avg:88.30ms +step:1585/1680 train_time:139953ms step_avg:88.30ms +step:1586/1680 train_time:140043ms step_avg:88.30ms +step:1587/1680 train_time:140131ms step_avg:88.30ms +step:1588/1680 train_time:140220ms step_avg:88.30ms +step:1589/1680 train_time:140310ms step_avg:88.30ms +step:1590/1680 train_time:140399ms step_avg:88.30ms +step:1591/1680 train_time:140488ms step_avg:88.30ms +step:1592/1680 train_time:140577ms step_avg:88.30ms +step:1593/1680 train_time:140666ms step_avg:88.30ms +step:1594/1680 train_time:140756ms step_avg:88.30ms +step:1595/1680 train_time:140845ms step_avg:88.30ms +step:1596/1680 train_time:140933ms step_avg:88.30ms +step:1597/1680 train_time:141022ms step_avg:88.30ms +step:1598/1680 train_time:141112ms step_avg:88.31ms +step:1599/1680 train_time:141200ms step_avg:88.31ms +step:1600/1680 train_time:141290ms step_avg:88.31ms +step:1601/1680 train_time:141379ms step_avg:88.31ms +step:1602/1680 train_time:141468ms step_avg:88.31ms +step:1603/1680 train_time:141557ms step_avg:88.31ms +step:1604/1680 train_time:141646ms step_avg:88.31ms +step:1605/1680 train_time:141735ms step_avg:88.31ms +step:1606/1680 train_time:141824ms step_avg:88.31ms +step:1607/1680 train_time:141914ms step_avg:88.31ms +step:1608/1680 train_time:142003ms step_avg:88.31ms +step:1609/1680 train_time:142092ms step_avg:88.31ms +step:1610/1680 train_time:142181ms step_avg:88.31ms +step:1611/1680 train_time:142271ms step_avg:88.31ms +step:1612/1680 train_time:142361ms step_avg:88.31ms +step:1613/1680 train_time:142451ms step_avg:88.31ms +step:1614/1680 train_time:142540ms step_avg:88.31ms +step:1615/1680 train_time:142629ms step_avg:88.32ms +step:1616/1680 train_time:142718ms step_avg:88.32ms +step:1617/1680 train_time:142808ms step_avg:88.32ms +step:1618/1680 train_time:142897ms step_avg:88.32ms +step:1619/1680 train_time:142986ms step_avg:88.32ms +step:1620/1680 train_time:143075ms step_avg:88.32ms +step:1621/1680 train_time:143166ms step_avg:88.32ms +step:1622/1680 train_time:143255ms step_avg:88.32ms +step:1623/1680 train_time:143344ms step_avg:88.32ms +step:1624/1680 train_time:143434ms step_avg:88.32ms +step:1625/1680 train_time:143522ms step_avg:88.32ms +step:1625/1680 val_loss:3.2899 train_time:143613ms step_avg:88.38ms +step:1626/1680 train_time:143632ms step_avg:88.33ms +step:1627/1680 train_time:143705ms step_avg:88.33ms +step:1628/1680 train_time:143799ms step_avg:88.33ms +step:1629/1680 train_time:143890ms step_avg:88.33ms +step:1630/1680 train_time:143978ms step_avg:88.33ms +step:1631/1680 train_time:144066ms step_avg:88.33ms +step:1632/1680 train_time:144154ms step_avg:88.33ms +step:1633/1680 train_time:144243ms step_avg:88.33ms +step:1634/1680 train_time:144332ms step_avg:88.33ms +step:1635/1680 train_time:144420ms step_avg:88.33ms +step:1636/1680 train_time:144509ms step_avg:88.33ms +step:1637/1680 train_time:144598ms step_avg:88.33ms +step:1638/1680 train_time:144689ms step_avg:88.33ms +step:1639/1680 train_time:144780ms step_avg:88.33ms +step:1640/1680 train_time:144870ms step_avg:88.34ms +step:1641/1680 train_time:144961ms step_avg:88.34ms +step:1642/1680 train_time:145050ms step_avg:88.34ms +step:1643/1680 train_time:145139ms step_avg:88.34ms +step:1644/1680 train_time:145227ms step_avg:88.34ms +step:1645/1680 train_time:145316ms step_avg:88.34ms +step:1646/1680 train_time:145405ms step_avg:88.34ms +step:1647/1680 train_time:145493ms step_avg:88.34ms +step:1648/1680 train_time:145582ms step_avg:88.34ms +step:1649/1680 train_time:145674ms step_avg:88.34ms +step:1650/1680 train_time:145764ms step_avg:88.34ms +step:1651/1680 train_time:145854ms step_avg:88.34ms +step:1652/1680 train_time:145943ms step_avg:88.34ms +step:1653/1680 train_time:146034ms step_avg:88.34ms +step:1654/1680 train_time:146122ms step_avg:88.34ms +step:1655/1680 train_time:146211ms step_avg:88.35ms +step:1656/1680 train_time:146300ms step_avg:88.35ms +step:1657/1680 train_time:146389ms step_avg:88.35ms +step:1658/1680 train_time:146477ms step_avg:88.35ms +step:1659/1680 train_time:146566ms step_avg:88.35ms +step:1660/1680 train_time:146656ms step_avg:88.35ms +step:1661/1680 train_time:146747ms step_avg:88.35ms +step:1662/1680 train_time:146837ms step_avg:88.35ms +step:1663/1680 train_time:146926ms step_avg:88.35ms +step:1664/1680 train_time:147016ms step_avg:88.35ms +step:1665/1680 train_time:147106ms step_avg:88.35ms +step:1666/1680 train_time:147197ms step_avg:88.35ms +step:1667/1680 train_time:147285ms step_avg:88.35ms +step:1668/1680 train_time:147374ms step_avg:88.35ms +step:1669/1680 train_time:147462ms step_avg:88.35ms +step:1670/1680 train_time:147551ms step_avg:88.35ms +step:1671/1680 train_time:147640ms step_avg:88.35ms +step:1672/1680 train_time:147730ms step_avg:88.36ms +step:1673/1680 train_time:147820ms step_avg:88.36ms +step:1674/1680 train_time:147910ms step_avg:88.36ms +step:1675/1680 train_time:148000ms step_avg:88.36ms +step:1676/1680 train_time:148090ms step_avg:88.36ms +step:1677/1680 train_time:148178ms step_avg:88.36ms +step:1678/1680 train_time:148267ms step_avg:88.36ms +step:1679/1680 train_time:148356ms step_avg:88.36ms +step:1680/1680 train_time:148445ms step_avg:88.36ms +step:1680/1680 val_loss:3.2789 train_time:148536ms step_avg:88.41ms +peak memory allocated: 30760 MiB reserved: 45914 MiB diff --git a/records/092725_BF16CE/f790419e-3027-441e-a5ab-11549e63fc1c.txt b/records/092725_BF16CE/f790419e-3027-441e-a5ab-11549e63fc1c.txt new file mode 100644 index 000000000..a76bbec7d --- /dev/null +++ b/records/092725_BF16CE/f790419e-3027-441e-a5ab-11549e63fc1c.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:40:51 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 121W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 25C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 178592 C /usr/bin/python 0MiB | +| 0 N/A N/A 178593 C /usr/bin/python 0MiB | +| 0 N/A N/A 178594 C /usr/bin/python 0MiB | +| 0 N/A N/A 178595 C /usr/bin/python 0MiB | +| 0 N/A N/A 178596 C /usr/bin/python 0MiB | +| 0 N/A N/A 178597 C /usr/bin/python 0MiB | +| 0 N/A N/A 178598 C /usr/bin/python 0MiB | +| 0 N/A N/A 178599 C /usr/bin/python 0MiB | +| 1 N/A N/A 178593 C /usr/bin/python 0MiB | +| 2 N/A N/A 178594 C /usr/bin/python 0MiB | +| 3 N/A N/A 178595 C /usr/bin/python 0MiB | +| 4 N/A N/A 178596 C /usr/bin/python 0MiB | +| 5 N/A N/A 178597 C /usr/bin/python 0MiB | +| 6 N/A N/A 178598 C /usr/bin/python 0MiB | +| 7 N/A N/A 178599 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:152ms step_avg:152.00ms +step:2/1680 train_time:172ms step_avg:86.24ms +step:3/1680 train_time:236ms step_avg:78.55ms +step:4/1680 train_time:321ms step_avg:80.15ms +step:5/1680 train_time:406ms step_avg:81.30ms +step:6/1680 train_time:493ms step_avg:82.16ms +step:7/1680 train_time:579ms step_avg:82.72ms +step:8/1680 train_time:665ms step_avg:83.13ms +step:9/1680 train_time:752ms step_avg:83.52ms +step:10/1680 train_time:838ms step_avg:83.76ms +step:11/1680 train_time:924ms step_avg:84.02ms +step:12/1680 train_time:1013ms step_avg:84.38ms +step:13/1680 train_time:1102ms step_avg:84.80ms +step:14/1680 train_time:1192ms step_avg:85.15ms +step:15/1680 train_time:1279ms step_avg:85.30ms +step:16/1680 train_time:1366ms step_avg:85.39ms +step:17/1680 train_time:1453ms step_avg:85.48ms +step:18/1680 train_time:1539ms step_avg:85.52ms +step:19/1680 train_time:1627ms step_avg:85.62ms +step:20/1680 train_time:1713ms step_avg:85.66ms +step:21/1680 train_time:1799ms step_avg:85.68ms +step:22/1680 train_time:1886ms step_avg:85.72ms +step:23/1680 train_time:1973ms step_avg:85.79ms +step:24/1680 train_time:2061ms step_avg:85.88ms +step:25/1680 train_time:2150ms step_avg:86.00ms +step:26/1680 train_time:2238ms step_avg:86.07ms +step:27/1680 train_time:2326ms step_avg:86.14ms +step:28/1680 train_time:2413ms step_avg:86.18ms +step:29/1680 train_time:2500ms step_avg:86.21ms +step:30/1680 train_time:2588ms step_avg:86.27ms +step:31/1680 train_time:2675ms step_avg:86.29ms +step:32/1680 train_time:2761ms step_avg:86.28ms +step:33/1680 train_time:2848ms step_avg:86.30ms +step:34/1680 train_time:2935ms step_avg:86.33ms +step:35/1680 train_time:3023ms step_avg:86.36ms +step:36/1680 train_time:3112ms step_avg:86.44ms +step:37/1680 train_time:3199ms step_avg:86.47ms +step:38/1680 train_time:3287ms step_avg:86.50ms +step:39/1680 train_time:3374ms step_avg:86.52ms +step:40/1680 train_time:3461ms step_avg:86.52ms +step:41/1680 train_time:3548ms step_avg:86.53ms +step:42/1680 train_time:3635ms step_avg:86.54ms +step:43/1680 train_time:3722ms step_avg:86.55ms +step:44/1680 train_time:3808ms step_avg:86.55ms +step:45/1680 train_time:3895ms step_avg:86.55ms +step:46/1680 train_time:3981ms step_avg:86.54ms +step:47/1680 train_time:4068ms step_avg:86.55ms +step:48/1680 train_time:4156ms step_avg:86.59ms +step:49/1680 train_time:4244ms step_avg:86.61ms +step:50/1680 train_time:4331ms step_avg:86.62ms +step:51/1680 train_time:4418ms step_avg:86.63ms +step:52/1680 train_time:4506ms step_avg:86.65ms +step:53/1680 train_time:4593ms step_avg:86.66ms +step:54/1680 train_time:4680ms step_avg:86.66ms +step:55/1680 train_time:4767ms step_avg:86.67ms +step:56/1680 train_time:4854ms step_avg:86.67ms +step:57/1680 train_time:4940ms step_avg:86.67ms +step:58/1680 train_time:5028ms step_avg:86.68ms +step:59/1680 train_time:5115ms step_avg:86.69ms +step:60/1680 train_time:5204ms step_avg:86.73ms +step:61/1680 train_time:5291ms step_avg:86.73ms +step:62/1680 train_time:5377ms step_avg:86.73ms +step:63/1680 train_time:5464ms step_avg:86.73ms +step:64/1680 train_time:5552ms step_avg:86.75ms +step:65/1680 train_time:5639ms step_avg:86.75ms +step:66/1680 train_time:5726ms step_avg:86.75ms +step:67/1680 train_time:5813ms step_avg:86.77ms +step:68/1680 train_time:5902ms step_avg:86.79ms +step:69/1680 train_time:5988ms step_avg:86.79ms +step:70/1680 train_time:6075ms step_avg:86.79ms +step:71/1680 train_time:6163ms step_avg:86.81ms +step:72/1680 train_time:6251ms step_avg:86.81ms +step:73/1680 train_time:6338ms step_avg:86.82ms +step:74/1680 train_time:6425ms step_avg:86.83ms +step:75/1680 train_time:6513ms step_avg:86.84ms +step:76/1680 train_time:6600ms step_avg:86.85ms +step:77/1680 train_time:6687ms step_avg:86.84ms +step:78/1680 train_time:6775ms step_avg:86.87ms +step:79/1680 train_time:6862ms step_avg:86.85ms +step:80/1680 train_time:6949ms step_avg:86.86ms +step:81/1680 train_time:7036ms step_avg:86.87ms +step:82/1680 train_time:7124ms step_avg:86.88ms +step:83/1680 train_time:7211ms step_avg:86.88ms +step:84/1680 train_time:7298ms step_avg:86.88ms +step:85/1680 train_time:7386ms step_avg:86.89ms +step:86/1680 train_time:7473ms step_avg:86.89ms +step:87/1680 train_time:7560ms step_avg:86.90ms +step:88/1680 train_time:7648ms step_avg:86.91ms +step:89/1680 train_time:7735ms step_avg:86.91ms +step:90/1680 train_time:7822ms step_avg:86.91ms +step:91/1680 train_time:7909ms step_avg:86.92ms +step:92/1680 train_time:7996ms step_avg:86.91ms +step:93/1680 train_time:8083ms step_avg:86.91ms +step:94/1680 train_time:8170ms step_avg:86.91ms +step:95/1680 train_time:8257ms step_avg:86.92ms +step:96/1680 train_time:8345ms step_avg:86.92ms +step:97/1680 train_time:8432ms step_avg:86.93ms +step:98/1680 train_time:8519ms step_avg:86.92ms +step:99/1680 train_time:8606ms step_avg:86.92ms +step:100/1680 train_time:8693ms step_avg:86.93ms +step:101/1680 train_time:8780ms step_avg:86.93ms +step:102/1680 train_time:8867ms step_avg:86.93ms +step:103/1680 train_time:8954ms step_avg:86.93ms +step:104/1680 train_time:9041ms step_avg:86.93ms +step:105/1680 train_time:9128ms step_avg:86.93ms +step:106/1680 train_time:9215ms step_avg:86.93ms +step:107/1680 train_time:9302ms step_avg:86.94ms +step:108/1680 train_time:9389ms step_avg:86.94ms +step:109/1680 train_time:9477ms step_avg:86.94ms +step:110/1680 train_time:9563ms step_avg:86.94ms +step:111/1680 train_time:9651ms step_avg:86.94ms +step:112/1680 train_time:9738ms step_avg:86.95ms +step:113/1680 train_time:9826ms step_avg:86.95ms +step:114/1680 train_time:9913ms step_avg:86.96ms +step:115/1680 train_time:10000ms step_avg:86.96ms +step:116/1680 train_time:10087ms step_avg:86.96ms +step:117/1680 train_time:10174ms step_avg:86.96ms +step:118/1680 train_time:10261ms step_avg:86.96ms +step:119/1680 train_time:10348ms step_avg:86.96ms +step:120/1680 train_time:10436ms step_avg:86.96ms +step:121/1680 train_time:10522ms step_avg:86.96ms +step:122/1680 train_time:10609ms step_avg:86.96ms +step:123/1680 train_time:10696ms step_avg:86.96ms +step:124/1680 train_time:10784ms step_avg:86.97ms +step:125/1680 train_time:10871ms step_avg:86.97ms +step:125/1680 val_loss:4.3006 train_time:10960ms step_avg:87.68ms +step:126/1680 train_time:10982ms step_avg:87.16ms +step:127/1680 train_time:11050ms step_avg:87.01ms +step:128/1680 train_time:11150ms step_avg:87.11ms +step:129/1680 train_time:11238ms step_avg:87.12ms +step:130/1680 train_time:11325ms step_avg:87.11ms +step:131/1680 train_time:11411ms step_avg:87.11ms +step:132/1680 train_time:11497ms step_avg:87.10ms +step:133/1680 train_time:11583ms step_avg:87.09ms +step:134/1680 train_time:11669ms step_avg:87.08ms +step:135/1680 train_time:11755ms step_avg:87.07ms +step:136/1680 train_time:11841ms step_avg:87.07ms +step:137/1680 train_time:11928ms step_avg:87.06ms +step:138/1680 train_time:12014ms step_avg:87.06ms +step:139/1680 train_time:12105ms step_avg:87.09ms +step:140/1680 train_time:12194ms step_avg:87.10ms +step:141/1680 train_time:12283ms step_avg:87.11ms +step:142/1680 train_time:12370ms step_avg:87.11ms +step:143/1680 train_time:12457ms step_avg:87.11ms +step:144/1680 train_time:12543ms step_avg:87.11ms +step:145/1680 train_time:12629ms step_avg:87.10ms +step:146/1680 train_time:12715ms step_avg:87.09ms +step:147/1680 train_time:12802ms step_avg:87.09ms +step:148/1680 train_time:12888ms step_avg:87.08ms +step:149/1680 train_time:12976ms step_avg:87.08ms +step:150/1680 train_time:13064ms step_avg:87.09ms +step:151/1680 train_time:13152ms step_avg:87.10ms +step:152/1680 train_time:13241ms step_avg:87.11ms +step:153/1680 train_time:13329ms step_avg:87.11ms +step:154/1680 train_time:13415ms step_avg:87.11ms +step:155/1680 train_time:13502ms step_avg:87.11ms +step:156/1680 train_time:13589ms step_avg:87.11ms +step:157/1680 train_time:13675ms step_avg:87.10ms +step:158/1680 train_time:13762ms step_avg:87.10ms +step:159/1680 train_time:13848ms step_avg:87.10ms +step:160/1680 train_time:13935ms step_avg:87.09ms +step:161/1680 train_time:14022ms step_avg:87.09ms +step:162/1680 train_time:14110ms step_avg:87.10ms +step:163/1680 train_time:14197ms step_avg:87.10ms +step:164/1680 train_time:14285ms step_avg:87.10ms +step:165/1680 train_time:14372ms step_avg:87.10ms +step:166/1680 train_time:14459ms step_avg:87.10ms +step:167/1680 train_time:14546ms step_avg:87.10ms +step:168/1680 train_time:14632ms step_avg:87.10ms +step:169/1680 train_time:14719ms step_avg:87.09ms +step:170/1680 train_time:14806ms step_avg:87.09ms +step:171/1680 train_time:14892ms step_avg:87.09ms +step:172/1680 train_time:14980ms step_avg:87.09ms +step:173/1680 train_time:15068ms step_avg:87.10ms +step:174/1680 train_time:15155ms step_avg:87.10ms +step:175/1680 train_time:15242ms step_avg:87.10ms +step:176/1680 train_time:15329ms step_avg:87.10ms +step:177/1680 train_time:15417ms step_avg:87.10ms +step:178/1680 train_time:15504ms step_avg:87.10ms +step:179/1680 train_time:15591ms step_avg:87.10ms +step:180/1680 train_time:15679ms step_avg:87.10ms +step:181/1680 train_time:15765ms step_avg:87.10ms +step:182/1680 train_time:15852ms step_avg:87.10ms +step:183/1680 train_time:15939ms step_avg:87.10ms +step:184/1680 train_time:16025ms step_avg:87.09ms +step:185/1680 train_time:16112ms step_avg:87.09ms +step:186/1680 train_time:16200ms step_avg:87.10ms +step:187/1680 train_time:16288ms step_avg:87.10ms +step:188/1680 train_time:16375ms step_avg:87.10ms +step:189/1680 train_time:16461ms step_avg:87.10ms +step:190/1680 train_time:16548ms step_avg:87.10ms +step:191/1680 train_time:16635ms step_avg:87.10ms +step:192/1680 train_time:16722ms step_avg:87.10ms +step:193/1680 train_time:16809ms step_avg:87.09ms +step:194/1680 train_time:16896ms step_avg:87.09ms +step:195/1680 train_time:16983ms step_avg:87.09ms +step:196/1680 train_time:17070ms step_avg:87.09ms +step:197/1680 train_time:17157ms step_avg:87.09ms +step:198/1680 train_time:17244ms step_avg:87.09ms +step:199/1680 train_time:17331ms step_avg:87.09ms +step:200/1680 train_time:17418ms step_avg:87.09ms +step:201/1680 train_time:17506ms step_avg:87.09ms +step:202/1680 train_time:17593ms step_avg:87.09ms +step:203/1680 train_time:17680ms step_avg:87.09ms +step:204/1680 train_time:17767ms step_avg:87.10ms +step:205/1680 train_time:17855ms step_avg:87.10ms +step:206/1680 train_time:17941ms step_avg:87.09ms +step:207/1680 train_time:18028ms step_avg:87.09ms +step:208/1680 train_time:18114ms step_avg:87.09ms +step:209/1680 train_time:18202ms step_avg:87.09ms +step:210/1680 train_time:18289ms step_avg:87.09ms +step:211/1680 train_time:18376ms step_avg:87.09ms +step:212/1680 train_time:18463ms step_avg:87.09ms +step:213/1680 train_time:18549ms step_avg:87.09ms +step:214/1680 train_time:18636ms step_avg:87.08ms +step:215/1680 train_time:18722ms step_avg:87.08ms +step:216/1680 train_time:18809ms step_avg:87.08ms +step:217/1680 train_time:18896ms step_avg:87.08ms +step:218/1680 train_time:18983ms step_avg:87.08ms +step:219/1680 train_time:19070ms step_avg:87.08ms +step:220/1680 train_time:19157ms step_avg:87.08ms +step:221/1680 train_time:19244ms step_avg:87.07ms +step:222/1680 train_time:19331ms step_avg:87.08ms +step:223/1680 train_time:19419ms step_avg:87.08ms +step:224/1680 train_time:19505ms step_avg:87.08ms +step:225/1680 train_time:19593ms step_avg:87.08ms +step:226/1680 train_time:19680ms step_avg:87.08ms +step:227/1680 train_time:19767ms step_avg:87.08ms +step:228/1680 train_time:19855ms step_avg:87.08ms +step:229/1680 train_time:19941ms step_avg:87.08ms +step:230/1680 train_time:20028ms step_avg:87.08ms +step:231/1680 train_time:20116ms step_avg:87.08ms +step:232/1680 train_time:20203ms step_avg:87.08ms +step:233/1680 train_time:20290ms step_avg:87.08ms +step:234/1680 train_time:20378ms step_avg:87.08ms +step:235/1680 train_time:20464ms step_avg:87.08ms +step:236/1680 train_time:20552ms step_avg:87.08ms +step:237/1680 train_time:20639ms step_avg:87.08ms +step:238/1680 train_time:20726ms step_avg:87.08ms +step:239/1680 train_time:20813ms step_avg:87.08ms +step:240/1680 train_time:20900ms step_avg:87.08ms +step:241/1680 train_time:20987ms step_avg:87.08ms +step:242/1680 train_time:21075ms step_avg:87.09ms +step:243/1680 train_time:21163ms step_avg:87.09ms +step:244/1680 train_time:21249ms step_avg:87.09ms +step:245/1680 train_time:21336ms step_avg:87.09ms +step:246/1680 train_time:21423ms step_avg:87.09ms +step:247/1680 train_time:21510ms step_avg:87.09ms +step:248/1680 train_time:21597ms step_avg:87.09ms +step:249/1680 train_time:21684ms step_avg:87.09ms +step:250/1680 train_time:21771ms step_avg:87.08ms +step:250/1680 val_loss:3.9658 train_time:21860ms step_avg:87.44ms +step:251/1680 train_time:21879ms step_avg:87.17ms +step:252/1680 train_time:21951ms step_avg:87.11ms +step:253/1680 train_time:22045ms step_avg:87.13ms +step:254/1680 train_time:22133ms step_avg:87.14ms +step:255/1680 train_time:22220ms step_avg:87.14ms +step:256/1680 train_time:22306ms step_avg:87.13ms +step:257/1680 train_time:22392ms step_avg:87.13ms +step:258/1680 train_time:22478ms step_avg:87.12ms +step:259/1680 train_time:22564ms step_avg:87.12ms +step:260/1680 train_time:22650ms step_avg:87.12ms +step:261/1680 train_time:22736ms step_avg:87.11ms +step:262/1680 train_time:22824ms step_avg:87.11ms +step:263/1680 train_time:22912ms step_avg:87.12ms +step:264/1680 train_time:23000ms step_avg:87.12ms +step:265/1680 train_time:23089ms step_avg:87.13ms +step:266/1680 train_time:23177ms step_avg:87.13ms +step:267/1680 train_time:23263ms step_avg:87.13ms +step:268/1680 train_time:23350ms step_avg:87.13ms +step:269/1680 train_time:23437ms step_avg:87.13ms +step:270/1680 train_time:23522ms step_avg:87.12ms +step:271/1680 train_time:23608ms step_avg:87.12ms +step:272/1680 train_time:23695ms step_avg:87.11ms +step:273/1680 train_time:23781ms step_avg:87.11ms +step:274/1680 train_time:23869ms step_avg:87.11ms +step:275/1680 train_time:23957ms step_avg:87.12ms +step:276/1680 train_time:24045ms step_avg:87.12ms +step:277/1680 train_time:24133ms step_avg:87.12ms +step:278/1680 train_time:24220ms step_avg:87.12ms +step:279/1680 train_time:24307ms step_avg:87.12ms +step:280/1680 train_time:24394ms step_avg:87.12ms +step:281/1680 train_time:24480ms step_avg:87.12ms +step:282/1680 train_time:24568ms step_avg:87.12ms +step:283/1680 train_time:24654ms step_avg:87.12ms +step:284/1680 train_time:24740ms step_avg:87.11ms +step:285/1680 train_time:24828ms step_avg:87.11ms +step:286/1680 train_time:24914ms step_avg:87.11ms +step:287/1680 train_time:25003ms step_avg:87.12ms +step:288/1680 train_time:25090ms step_avg:87.12ms +step:289/1680 train_time:25178ms step_avg:87.12ms +step:290/1680 train_time:25265ms step_avg:87.12ms +step:291/1680 train_time:25352ms step_avg:87.12ms +step:292/1680 train_time:25439ms step_avg:87.12ms +step:293/1680 train_time:25526ms step_avg:87.12ms +step:294/1680 train_time:25612ms step_avg:87.11ms +step:295/1680 train_time:25698ms step_avg:87.11ms +step:296/1680 train_time:25784ms step_avg:87.11ms +step:297/1680 train_time:25871ms step_avg:87.11ms +step:298/1680 train_time:25959ms step_avg:87.11ms +step:299/1680 train_time:26046ms step_avg:87.11ms +step:300/1680 train_time:26133ms step_avg:87.11ms +step:301/1680 train_time:26221ms step_avg:87.11ms +step:302/1680 train_time:26309ms step_avg:87.11ms +step:303/1680 train_time:26395ms step_avg:87.11ms +step:304/1680 train_time:26483ms step_avg:87.11ms +step:305/1680 train_time:26570ms step_avg:87.11ms +step:306/1680 train_time:26657ms step_avg:87.11ms +step:307/1680 train_time:26744ms step_avg:87.11ms +step:308/1680 train_time:26831ms step_avg:87.11ms +step:309/1680 train_time:26917ms step_avg:87.11ms +step:310/1680 train_time:27004ms step_avg:87.11ms +step:311/1680 train_time:27092ms step_avg:87.11ms +step:312/1680 train_time:27179ms step_avg:87.11ms +step:313/1680 train_time:27267ms step_avg:87.11ms +step:314/1680 train_time:27354ms step_avg:87.11ms +step:315/1680 train_time:27441ms step_avg:87.11ms +step:316/1680 train_time:27528ms step_avg:87.12ms +step:317/1680 train_time:27616ms step_avg:87.12ms +step:318/1680 train_time:27704ms step_avg:87.12ms +step:319/1680 train_time:27791ms step_avg:87.12ms +step:320/1680 train_time:27878ms step_avg:87.12ms +step:321/1680 train_time:27964ms step_avg:87.12ms +step:322/1680 train_time:28052ms step_avg:87.12ms +step:323/1680 train_time:28139ms step_avg:87.12ms +step:324/1680 train_time:28227ms step_avg:87.12ms +step:325/1680 train_time:28313ms step_avg:87.12ms +step:326/1680 train_time:28401ms step_avg:87.12ms +step:327/1680 train_time:28488ms step_avg:87.12ms +step:328/1680 train_time:28576ms step_avg:87.12ms +step:329/1680 train_time:28663ms step_avg:87.12ms +step:330/1680 train_time:28750ms step_avg:87.12ms +step:331/1680 train_time:28837ms step_avg:87.12ms +step:332/1680 train_time:28924ms step_avg:87.12ms +step:333/1680 train_time:29011ms step_avg:87.12ms +step:334/1680 train_time:29098ms step_avg:87.12ms +step:335/1680 train_time:29185ms step_avg:87.12ms +step:336/1680 train_time:29273ms step_avg:87.12ms +step:337/1680 train_time:29359ms step_avg:87.12ms +step:338/1680 train_time:29446ms step_avg:87.12ms +step:339/1680 train_time:29533ms step_avg:87.12ms +step:340/1680 train_time:29620ms step_avg:87.12ms +step:341/1680 train_time:29707ms step_avg:87.12ms +step:342/1680 train_time:29794ms step_avg:87.12ms +step:343/1680 train_time:29880ms step_avg:87.11ms +step:344/1680 train_time:29967ms step_avg:87.11ms +step:345/1680 train_time:30054ms step_avg:87.11ms +step:346/1680 train_time:30142ms step_avg:87.11ms +step:347/1680 train_time:30229ms step_avg:87.12ms +step:348/1680 train_time:30316ms step_avg:87.11ms +step:349/1680 train_time:30404ms step_avg:87.12ms +step:350/1680 train_time:30491ms step_avg:87.12ms +step:351/1680 train_time:30578ms step_avg:87.12ms +step:352/1680 train_time:30665ms step_avg:87.12ms +step:353/1680 train_time:30752ms step_avg:87.12ms +step:354/1680 train_time:30839ms step_avg:87.12ms +step:355/1680 train_time:30927ms step_avg:87.12ms +step:356/1680 train_time:31014ms step_avg:87.12ms +step:357/1680 train_time:31100ms step_avg:87.11ms +step:358/1680 train_time:31187ms step_avg:87.12ms +step:359/1680 train_time:31274ms step_avg:87.12ms +step:360/1680 train_time:31361ms step_avg:87.12ms +step:361/1680 train_time:31449ms step_avg:87.12ms +step:362/1680 train_time:31536ms step_avg:87.12ms +step:363/1680 train_time:31623ms step_avg:87.11ms +step:364/1680 train_time:31709ms step_avg:87.11ms +step:365/1680 train_time:31797ms step_avg:87.11ms +step:366/1680 train_time:31884ms step_avg:87.11ms +step:367/1680 train_time:31971ms step_avg:87.11ms +step:368/1680 train_time:32058ms step_avg:87.11ms +step:369/1680 train_time:32146ms step_avg:87.12ms +step:370/1680 train_time:32233ms step_avg:87.12ms +step:371/1680 train_time:32320ms step_avg:87.12ms +step:372/1680 train_time:32407ms step_avg:87.12ms +step:373/1680 train_time:32495ms step_avg:87.12ms +step:374/1680 train_time:32582ms step_avg:87.12ms +step:375/1680 train_time:32670ms step_avg:87.12ms +step:375/1680 val_loss:3.8196 train_time:32758ms step_avg:87.36ms +step:376/1680 train_time:32778ms step_avg:87.18ms +step:377/1680 train_time:32849ms step_avg:87.13ms +step:378/1680 train_time:32940ms step_avg:87.14ms +step:379/1680 train_time:33029ms step_avg:87.15ms +step:380/1680 train_time:33116ms step_avg:87.15ms +step:381/1680 train_time:33203ms step_avg:87.15ms +step:382/1680 train_time:33290ms step_avg:87.15ms +step:383/1680 train_time:33376ms step_avg:87.14ms +step:384/1680 train_time:33462ms step_avg:87.14ms +step:385/1680 train_time:33549ms step_avg:87.14ms +step:386/1680 train_time:33634ms step_avg:87.14ms +step:387/1680 train_time:33722ms step_avg:87.14ms +step:388/1680 train_time:33811ms step_avg:87.14ms +step:389/1680 train_time:33899ms step_avg:87.14ms +step:390/1680 train_time:33988ms step_avg:87.15ms +step:391/1680 train_time:34075ms step_avg:87.15ms +step:392/1680 train_time:34163ms step_avg:87.15ms +step:393/1680 train_time:34249ms step_avg:87.15ms +step:394/1680 train_time:34336ms step_avg:87.15ms +step:395/1680 train_time:34422ms step_avg:87.14ms +step:396/1680 train_time:34508ms step_avg:87.14ms +step:397/1680 train_time:34595ms step_avg:87.14ms +step:398/1680 train_time:34681ms step_avg:87.14ms +step:399/1680 train_time:34769ms step_avg:87.14ms +step:400/1680 train_time:34856ms step_avg:87.14ms +step:401/1680 train_time:34944ms step_avg:87.14ms +step:402/1680 train_time:35033ms step_avg:87.15ms +step:403/1680 train_time:35120ms step_avg:87.15ms +step:404/1680 train_time:35207ms step_avg:87.15ms +step:405/1680 train_time:35295ms step_avg:87.15ms +step:406/1680 train_time:35382ms step_avg:87.15ms +step:407/1680 train_time:35468ms step_avg:87.14ms +step:408/1680 train_time:35554ms step_avg:87.14ms +step:409/1680 train_time:35641ms step_avg:87.14ms +step:410/1680 train_time:35728ms step_avg:87.14ms +step:411/1680 train_time:35815ms step_avg:87.14ms +step:412/1680 train_time:35903ms step_avg:87.14ms +step:413/1680 train_time:35990ms step_avg:87.14ms +step:414/1680 train_time:36077ms step_avg:87.14ms +step:415/1680 train_time:36164ms step_avg:87.14ms +step:416/1680 train_time:36251ms step_avg:87.14ms +step:417/1680 train_time:36339ms step_avg:87.14ms +step:418/1680 train_time:36426ms step_avg:87.14ms +step:419/1680 train_time:36512ms step_avg:87.14ms +step:420/1680 train_time:36600ms step_avg:87.14ms +step:421/1680 train_time:36686ms step_avg:87.14ms +step:422/1680 train_time:36774ms step_avg:87.14ms +step:423/1680 train_time:36861ms step_avg:87.14ms +step:424/1680 train_time:36949ms step_avg:87.14ms +step:425/1680 train_time:37036ms step_avg:87.14ms +step:426/1680 train_time:37124ms step_avg:87.15ms +step:427/1680 train_time:37211ms step_avg:87.15ms +step:428/1680 train_time:37298ms step_avg:87.15ms +step:429/1680 train_time:37386ms step_avg:87.15ms +step:430/1680 train_time:37473ms step_avg:87.15ms +step:431/1680 train_time:37560ms step_avg:87.15ms +step:432/1680 train_time:37646ms step_avg:87.14ms +step:433/1680 train_time:37734ms step_avg:87.14ms +step:434/1680 train_time:37820ms step_avg:87.14ms +step:435/1680 train_time:37908ms step_avg:87.14ms +step:436/1680 train_time:37995ms step_avg:87.14ms +step:437/1680 train_time:38082ms step_avg:87.14ms +step:438/1680 train_time:38169ms step_avg:87.14ms +step:439/1680 train_time:38256ms step_avg:87.14ms +step:440/1680 train_time:38343ms step_avg:87.14ms +step:441/1680 train_time:38430ms step_avg:87.14ms +step:442/1680 train_time:38517ms step_avg:87.14ms +step:443/1680 train_time:38604ms step_avg:87.14ms +step:444/1680 train_time:38692ms step_avg:87.14ms +step:445/1680 train_time:38779ms step_avg:87.14ms +step:446/1680 train_time:38866ms step_avg:87.14ms +step:447/1680 train_time:38952ms step_avg:87.14ms +step:448/1680 train_time:39041ms step_avg:87.14ms +step:449/1680 train_time:39128ms step_avg:87.14ms +step:450/1680 train_time:39216ms step_avg:87.15ms +step:451/1680 train_time:39303ms step_avg:87.15ms +step:452/1680 train_time:39391ms step_avg:87.15ms +step:453/1680 train_time:39478ms step_avg:87.15ms +step:454/1680 train_time:39564ms step_avg:87.15ms +step:455/1680 train_time:39652ms step_avg:87.15ms +step:456/1680 train_time:39739ms step_avg:87.15ms +step:457/1680 train_time:39826ms step_avg:87.15ms +step:458/1680 train_time:39914ms step_avg:87.15ms +step:459/1680 train_time:40001ms step_avg:87.15ms +step:460/1680 train_time:40088ms step_avg:87.15ms +step:461/1680 train_time:40175ms step_avg:87.15ms +step:462/1680 train_time:40263ms step_avg:87.15ms +step:463/1680 train_time:40349ms step_avg:87.15ms +step:464/1680 train_time:40436ms step_avg:87.15ms +step:465/1680 train_time:40523ms step_avg:87.15ms +step:466/1680 train_time:40610ms step_avg:87.15ms +step:467/1680 train_time:40698ms step_avg:87.15ms +step:468/1680 train_time:40784ms step_avg:87.15ms +step:469/1680 train_time:40872ms step_avg:87.15ms +step:470/1680 train_time:40959ms step_avg:87.15ms +step:471/1680 train_time:41047ms step_avg:87.15ms +step:472/1680 train_time:41134ms step_avg:87.15ms +step:473/1680 train_time:41222ms step_avg:87.15ms +step:474/1680 train_time:41309ms step_avg:87.15ms +step:475/1680 train_time:41396ms step_avg:87.15ms +step:476/1680 train_time:41483ms step_avg:87.15ms +step:477/1680 train_time:41570ms step_avg:87.15ms +step:478/1680 train_time:41657ms step_avg:87.15ms +step:479/1680 train_time:41744ms step_avg:87.15ms +step:480/1680 train_time:41831ms step_avg:87.15ms +step:481/1680 train_time:41917ms step_avg:87.15ms +step:482/1680 train_time:42005ms step_avg:87.15ms +step:483/1680 train_time:42093ms step_avg:87.15ms +step:484/1680 train_time:42180ms step_avg:87.15ms +step:485/1680 train_time:42268ms step_avg:87.15ms +step:486/1680 train_time:42355ms step_avg:87.15ms +step:487/1680 train_time:42442ms step_avg:87.15ms +step:488/1680 train_time:42528ms step_avg:87.15ms +step:489/1680 train_time:42615ms step_avg:87.15ms +step:490/1680 train_time:42703ms step_avg:87.15ms +step:491/1680 train_time:42790ms step_avg:87.15ms +step:492/1680 train_time:42877ms step_avg:87.15ms +step:493/1680 train_time:42964ms step_avg:87.15ms +step:494/1680 train_time:43052ms step_avg:87.15ms +step:495/1680 train_time:43139ms step_avg:87.15ms +step:496/1680 train_time:43226ms step_avg:87.15ms +step:497/1680 train_time:43313ms step_avg:87.15ms +step:498/1680 train_time:43401ms step_avg:87.15ms +step:499/1680 train_time:43487ms step_avg:87.15ms +step:500/1680 train_time:43575ms step_avg:87.15ms +step:500/1680 val_loss:3.7183 train_time:43662ms step_avg:87.32ms +step:501/1680 train_time:43682ms step_avg:87.19ms +step:502/1680 train_time:43751ms step_avg:87.15ms +step:503/1680 train_time:43842ms step_avg:87.16ms +step:504/1680 train_time:43932ms step_avg:87.17ms +step:505/1680 train_time:44019ms step_avg:87.17ms +step:506/1680 train_time:44106ms step_avg:87.17ms +step:507/1680 train_time:44192ms step_avg:87.16ms +step:508/1680 train_time:44278ms step_avg:87.16ms +step:509/1680 train_time:44364ms step_avg:87.16ms +step:510/1680 train_time:44451ms step_avg:87.16ms +step:511/1680 train_time:44537ms step_avg:87.16ms +step:512/1680 train_time:44624ms step_avg:87.16ms +step:513/1680 train_time:44712ms step_avg:87.16ms +step:514/1680 train_time:44801ms step_avg:87.16ms +step:515/1680 train_time:44889ms step_avg:87.16ms +step:516/1680 train_time:44977ms step_avg:87.16ms +step:517/1680 train_time:45064ms step_avg:87.16ms +step:518/1680 train_time:45151ms step_avg:87.16ms +step:519/1680 train_time:45238ms step_avg:87.16ms +step:520/1680 train_time:45324ms step_avg:87.16ms +step:521/1680 train_time:45410ms step_avg:87.16ms +step:522/1680 train_time:45496ms step_avg:87.16ms +step:523/1680 train_time:45584ms step_avg:87.16ms +step:524/1680 train_time:45671ms step_avg:87.16ms +step:525/1680 train_time:45759ms step_avg:87.16ms +step:526/1680 train_time:45846ms step_avg:87.16ms +step:527/1680 train_time:45934ms step_avg:87.16ms +step:528/1680 train_time:46022ms step_avg:87.16ms +step:529/1680 train_time:46109ms step_avg:87.16ms +step:530/1680 train_time:46196ms step_avg:87.16ms +step:531/1680 train_time:46282ms step_avg:87.16ms +step:532/1680 train_time:46369ms step_avg:87.16ms +step:533/1680 train_time:46456ms step_avg:87.16ms +step:534/1680 train_time:46543ms step_avg:87.16ms +step:535/1680 train_time:46630ms step_avg:87.16ms +step:536/1680 train_time:46717ms step_avg:87.16ms +step:537/1680 train_time:46805ms step_avg:87.16ms +step:538/1680 train_time:46892ms step_avg:87.16ms +step:539/1680 train_time:46980ms step_avg:87.16ms +step:540/1680 train_time:47068ms step_avg:87.16ms +step:541/1680 train_time:47155ms step_avg:87.16ms +step:542/1680 train_time:47242ms step_avg:87.16ms +step:543/1680 train_time:47329ms step_avg:87.16ms +step:544/1680 train_time:47416ms step_avg:87.16ms +step:545/1680 train_time:47504ms step_avg:87.16ms +step:546/1680 train_time:47590ms step_avg:87.16ms +step:547/1680 train_time:47676ms step_avg:87.16ms +step:548/1680 train_time:47764ms step_avg:87.16ms +step:549/1680 train_time:47852ms step_avg:87.16ms +step:550/1680 train_time:47941ms step_avg:87.17ms +step:551/1680 train_time:48030ms step_avg:87.17ms +step:552/1680 train_time:48119ms step_avg:87.17ms +step:553/1680 train_time:48207ms step_avg:87.17ms +step:554/1680 train_time:48296ms step_avg:87.18ms +step:555/1680 train_time:48384ms step_avg:87.18ms +step:556/1680 train_time:48472ms step_avg:87.18ms +step:557/1680 train_time:48560ms step_avg:87.18ms +step:558/1680 train_time:48648ms step_avg:87.18ms +step:559/1680 train_time:48736ms step_avg:87.18ms +step:560/1680 train_time:48824ms step_avg:87.19ms +step:561/1680 train_time:48912ms step_avg:87.19ms +step:562/1680 train_time:49001ms step_avg:87.19ms +step:563/1680 train_time:49089ms step_avg:87.19ms +step:564/1680 train_time:49177ms step_avg:87.19ms +step:565/1680 train_time:49265ms step_avg:87.20ms +step:566/1680 train_time:49354ms step_avg:87.20ms +step:567/1680 train_time:49442ms step_avg:87.20ms +step:568/1680 train_time:49530ms step_avg:87.20ms +step:569/1680 train_time:49618ms step_avg:87.20ms +step:570/1680 train_time:49706ms step_avg:87.20ms +step:571/1680 train_time:49794ms step_avg:87.20ms +step:572/1680 train_time:49883ms step_avg:87.21ms +step:573/1680 train_time:49971ms step_avg:87.21ms +step:574/1680 train_time:50059ms step_avg:87.21ms +step:575/1680 train_time:50147ms step_avg:87.21ms +step:576/1680 train_time:50235ms step_avg:87.21ms +step:577/1680 train_time:50323ms step_avg:87.22ms +step:578/1680 train_time:50411ms step_avg:87.22ms +step:579/1680 train_time:50500ms step_avg:87.22ms +step:580/1680 train_time:50588ms step_avg:87.22ms +step:581/1680 train_time:50676ms step_avg:87.22ms +step:582/1680 train_time:50765ms step_avg:87.22ms +step:583/1680 train_time:50853ms step_avg:87.23ms +step:584/1680 train_time:50941ms step_avg:87.23ms +step:585/1680 train_time:51030ms step_avg:87.23ms +step:586/1680 train_time:51118ms step_avg:87.23ms +step:587/1680 train_time:51206ms step_avg:87.23ms +step:588/1680 train_time:51294ms step_avg:87.23ms +step:589/1680 train_time:51382ms step_avg:87.24ms +step:590/1680 train_time:51470ms step_avg:87.24ms +step:591/1680 train_time:51558ms step_avg:87.24ms +step:592/1680 train_time:51647ms step_avg:87.24ms +step:593/1680 train_time:51735ms step_avg:87.24ms +step:594/1680 train_time:51824ms step_avg:87.25ms +step:595/1680 train_time:51912ms step_avg:87.25ms +step:596/1680 train_time:52001ms step_avg:87.25ms +step:597/1680 train_time:52090ms step_avg:87.25ms +step:598/1680 train_time:52178ms step_avg:87.25ms +step:599/1680 train_time:52267ms step_avg:87.26ms +step:600/1680 train_time:52355ms step_avg:87.26ms +step:601/1680 train_time:52442ms step_avg:87.26ms +step:602/1680 train_time:52530ms step_avg:87.26ms +step:603/1680 train_time:52618ms step_avg:87.26ms +step:604/1680 train_time:52706ms step_avg:87.26ms +step:605/1680 train_time:52794ms step_avg:87.26ms +step:606/1680 train_time:52883ms step_avg:87.27ms +step:607/1680 train_time:52972ms step_avg:87.27ms +step:608/1680 train_time:53060ms step_avg:87.27ms +step:609/1680 train_time:53148ms step_avg:87.27ms +step:610/1680 train_time:53236ms step_avg:87.27ms +step:611/1680 train_time:53324ms step_avg:87.27ms +step:612/1680 train_time:53411ms step_avg:87.27ms +step:613/1680 train_time:53500ms step_avg:87.28ms +step:614/1680 train_time:53588ms step_avg:87.28ms +step:615/1680 train_time:53676ms step_avg:87.28ms +step:616/1680 train_time:53765ms step_avg:87.28ms +step:617/1680 train_time:53853ms step_avg:87.28ms +step:618/1680 train_time:53941ms step_avg:87.28ms +step:619/1680 train_time:54029ms step_avg:87.29ms +step:620/1680 train_time:54118ms step_avg:87.29ms +step:621/1680 train_time:54206ms step_avg:87.29ms +step:622/1680 train_time:54294ms step_avg:87.29ms +step:623/1680 train_time:54382ms step_avg:87.29ms +step:624/1680 train_time:54470ms step_avg:87.29ms +step:625/1680 train_time:54559ms step_avg:87.29ms +step:625/1680 val_loss:3.6198 train_time:54649ms step_avg:87.44ms +step:626/1680 train_time:54668ms step_avg:87.33ms +step:627/1680 train_time:54740ms step_avg:87.30ms +step:628/1680 train_time:54827ms step_avg:87.30ms +step:629/1680 train_time:54917ms step_avg:87.31ms +step:630/1680 train_time:55005ms step_avg:87.31ms +step:631/1680 train_time:55093ms step_avg:87.31ms +step:632/1680 train_time:55181ms step_avg:87.31ms +step:633/1680 train_time:55268ms step_avg:87.31ms +step:634/1680 train_time:55355ms step_avg:87.31ms +step:635/1680 train_time:55442ms step_avg:87.31ms +step:636/1680 train_time:55530ms step_avg:87.31ms +step:637/1680 train_time:55621ms step_avg:87.32ms +step:638/1680 train_time:55711ms step_avg:87.32ms +step:639/1680 train_time:55800ms step_avg:87.32ms +step:640/1680 train_time:55889ms step_avg:87.33ms +step:641/1680 train_time:55977ms step_avg:87.33ms +step:642/1680 train_time:56066ms step_avg:87.33ms +step:643/1680 train_time:56154ms step_avg:87.33ms +step:644/1680 train_time:56242ms step_avg:87.33ms +step:645/1680 train_time:56330ms step_avg:87.33ms +step:646/1680 train_time:56417ms step_avg:87.33ms +step:647/1680 train_time:56505ms step_avg:87.33ms +step:648/1680 train_time:56594ms step_avg:87.34ms +step:649/1680 train_time:56683ms step_avg:87.34ms +step:650/1680 train_time:56772ms step_avg:87.34ms +step:651/1680 train_time:56860ms step_avg:87.34ms +step:652/1680 train_time:56949ms step_avg:87.35ms +step:653/1680 train_time:57037ms step_avg:87.35ms +step:654/1680 train_time:57125ms step_avg:87.35ms +step:655/1680 train_time:57213ms step_avg:87.35ms +step:656/1680 train_time:57300ms step_avg:87.35ms +step:657/1680 train_time:57388ms step_avg:87.35ms +step:658/1680 train_time:57476ms step_avg:87.35ms +step:659/1680 train_time:57564ms step_avg:87.35ms +step:660/1680 train_time:57653ms step_avg:87.35ms +step:661/1680 train_time:57741ms step_avg:87.35ms +step:662/1680 train_time:57829ms step_avg:87.36ms +step:663/1680 train_time:57919ms step_avg:87.36ms +step:664/1680 train_time:58008ms step_avg:87.36ms +step:665/1680 train_time:58095ms step_avg:87.36ms +step:666/1680 train_time:58183ms step_avg:87.36ms +step:667/1680 train_time:58271ms step_avg:87.36ms +step:668/1680 train_time:58358ms step_avg:87.36ms +step:669/1680 train_time:58447ms step_avg:87.36ms +step:670/1680 train_time:58536ms step_avg:87.37ms +step:671/1680 train_time:58623ms step_avg:87.37ms +step:672/1680 train_time:58712ms step_avg:87.37ms +step:673/1680 train_time:58800ms step_avg:87.37ms +step:674/1680 train_time:58889ms step_avg:87.37ms +step:675/1680 train_time:58977ms step_avg:87.37ms +step:676/1680 train_time:59065ms step_avg:87.37ms +step:677/1680 train_time:59153ms step_avg:87.38ms +step:678/1680 train_time:59241ms step_avg:87.38ms +step:679/1680 train_time:59328ms step_avg:87.38ms +step:680/1680 train_time:59417ms step_avg:87.38ms +step:681/1680 train_time:59505ms step_avg:87.38ms +step:682/1680 train_time:59594ms step_avg:87.38ms +step:683/1680 train_time:59682ms step_avg:87.38ms +step:684/1680 train_time:59770ms step_avg:87.38ms +step:685/1680 train_time:59859ms step_avg:87.38ms +step:686/1680 train_time:59947ms step_avg:87.39ms +step:687/1680 train_time:60035ms step_avg:87.39ms +step:688/1680 train_time:60123ms step_avg:87.39ms +step:689/1680 train_time:60212ms step_avg:87.39ms +step:690/1680 train_time:60300ms step_avg:87.39ms +step:691/1680 train_time:60387ms step_avg:87.39ms +step:692/1680 train_time:60476ms step_avg:87.39ms +step:693/1680 train_time:60565ms step_avg:87.39ms +step:694/1680 train_time:60653ms step_avg:87.40ms +step:695/1680 train_time:60741ms step_avg:87.40ms +step:696/1680 train_time:60830ms step_avg:87.40ms +step:697/1680 train_time:60918ms step_avg:87.40ms +step:698/1680 train_time:61006ms step_avg:87.40ms +step:699/1680 train_time:61094ms step_avg:87.40ms +step:700/1680 train_time:61182ms step_avg:87.40ms +step:701/1680 train_time:61270ms step_avg:87.40ms +step:702/1680 train_time:61358ms step_avg:87.40ms +step:703/1680 train_time:61446ms step_avg:87.41ms +step:704/1680 train_time:61534ms step_avg:87.41ms +step:705/1680 train_time:61623ms step_avg:87.41ms +step:706/1680 train_time:61711ms step_avg:87.41ms +step:707/1680 train_time:61799ms step_avg:87.41ms +step:708/1680 train_time:61888ms step_avg:87.41ms +step:709/1680 train_time:61977ms step_avg:87.41ms +step:710/1680 train_time:62065ms step_avg:87.42ms +step:711/1680 train_time:62152ms step_avg:87.42ms +step:712/1680 train_time:62241ms step_avg:87.42ms +step:713/1680 train_time:62329ms step_avg:87.42ms +step:714/1680 train_time:62417ms step_avg:87.42ms +step:715/1680 train_time:62506ms step_avg:87.42ms +step:716/1680 train_time:62594ms step_avg:87.42ms +step:717/1680 train_time:62682ms step_avg:87.42ms +step:718/1680 train_time:62770ms step_avg:87.42ms +step:719/1680 train_time:62858ms step_avg:87.42ms +step:720/1680 train_time:62946ms step_avg:87.42ms +step:721/1680 train_time:63034ms step_avg:87.43ms +step:722/1680 train_time:63123ms step_avg:87.43ms +step:723/1680 train_time:63211ms step_avg:87.43ms +step:724/1680 train_time:63299ms step_avg:87.43ms +step:725/1680 train_time:63387ms step_avg:87.43ms +step:726/1680 train_time:63476ms step_avg:87.43ms +step:727/1680 train_time:63563ms step_avg:87.43ms +step:728/1680 train_time:63651ms step_avg:87.43ms +step:729/1680 train_time:63739ms step_avg:87.43ms +step:730/1680 train_time:63827ms step_avg:87.43ms +step:731/1680 train_time:63916ms step_avg:87.44ms +step:732/1680 train_time:64004ms step_avg:87.44ms +step:733/1680 train_time:64092ms step_avg:87.44ms +step:734/1680 train_time:64179ms step_avg:87.44ms +step:735/1680 train_time:64268ms step_avg:87.44ms +step:736/1680 train_time:64355ms step_avg:87.44ms +step:737/1680 train_time:64443ms step_avg:87.44ms +step:738/1680 train_time:64531ms step_avg:87.44ms +step:739/1680 train_time:64620ms step_avg:87.44ms +step:740/1680 train_time:64708ms step_avg:87.44ms +step:741/1680 train_time:64796ms step_avg:87.44ms +step:742/1680 train_time:64885ms step_avg:87.45ms +step:743/1680 train_time:64974ms step_avg:87.45ms +step:744/1680 train_time:65063ms step_avg:87.45ms +step:745/1680 train_time:65152ms step_avg:87.45ms +step:746/1680 train_time:65239ms step_avg:87.45ms +step:747/1680 train_time:65327ms step_avg:87.45ms +step:748/1680 train_time:65416ms step_avg:87.45ms +step:749/1680 train_time:65504ms step_avg:87.45ms +step:750/1680 train_time:65592ms step_avg:87.46ms +step:750/1680 val_loss:3.5683 train_time:65682ms step_avg:87.58ms +step:751/1680 train_time:65700ms step_avg:87.48ms +step:752/1680 train_time:65774ms step_avg:87.47ms +step:753/1680 train_time:65867ms step_avg:87.47ms +step:754/1680 train_time:65955ms step_avg:87.47ms +step:755/1680 train_time:66043ms step_avg:87.47ms +step:756/1680 train_time:66131ms step_avg:87.48ms +step:757/1680 train_time:66218ms step_avg:87.47ms +step:758/1680 train_time:66305ms step_avg:87.47ms +step:759/1680 train_time:66393ms step_avg:87.47ms +step:760/1680 train_time:66480ms step_avg:87.47ms +step:761/1680 train_time:66567ms step_avg:87.47ms +step:762/1680 train_time:66657ms step_avg:87.48ms +step:763/1680 train_time:66747ms step_avg:87.48ms +step:764/1680 train_time:66836ms step_avg:87.48ms +step:765/1680 train_time:66926ms step_avg:87.48ms +step:766/1680 train_time:67015ms step_avg:87.49ms +step:767/1680 train_time:67103ms step_avg:87.49ms +step:768/1680 train_time:67190ms step_avg:87.49ms +step:769/1680 train_time:67277ms step_avg:87.49ms +step:770/1680 train_time:67365ms step_avg:87.49ms +step:771/1680 train_time:67453ms step_avg:87.49ms +step:772/1680 train_time:67541ms step_avg:87.49ms +step:773/1680 train_time:67630ms step_avg:87.49ms +step:774/1680 train_time:67719ms step_avg:87.49ms +step:775/1680 train_time:67808ms step_avg:87.49ms +step:776/1680 train_time:67897ms step_avg:87.50ms +step:777/1680 train_time:67986ms step_avg:87.50ms +step:778/1680 train_time:68075ms step_avg:87.50ms +step:779/1680 train_time:68162ms step_avg:87.50ms +step:780/1680 train_time:68250ms step_avg:87.50ms +step:781/1680 train_time:68337ms step_avg:87.50ms +step:782/1680 train_time:68425ms step_avg:87.50ms +step:783/1680 train_time:68513ms step_avg:87.50ms +step:784/1680 train_time:68601ms step_avg:87.50ms +step:785/1680 train_time:68691ms step_avg:87.50ms +step:786/1680 train_time:68780ms step_avg:87.51ms +step:787/1680 train_time:68870ms step_avg:87.51ms +step:788/1680 train_time:68958ms step_avg:87.51ms +step:789/1680 train_time:69047ms step_avg:87.51ms +step:790/1680 train_time:69135ms step_avg:87.51ms +step:791/1680 train_time:69223ms step_avg:87.51ms +step:792/1680 train_time:69311ms step_avg:87.51ms +step:793/1680 train_time:69398ms step_avg:87.51ms +step:794/1680 train_time:69486ms step_avg:87.51ms +step:795/1680 train_time:69574ms step_avg:87.51ms +step:796/1680 train_time:69664ms step_avg:87.52ms +step:797/1680 train_time:69752ms step_avg:87.52ms +step:798/1680 train_time:69842ms step_avg:87.52ms +step:799/1680 train_time:69931ms step_avg:87.52ms +step:800/1680 train_time:70019ms step_avg:87.52ms +step:801/1680 train_time:70108ms step_avg:87.52ms +step:802/1680 train_time:70196ms step_avg:87.53ms +step:803/1680 train_time:70284ms step_avg:87.53ms +step:804/1680 train_time:70371ms step_avg:87.53ms +step:805/1680 train_time:70459ms step_avg:87.53ms +step:806/1680 train_time:70547ms step_avg:87.53ms +step:807/1680 train_time:70636ms step_avg:87.53ms +step:808/1680 train_time:70725ms step_avg:87.53ms +step:809/1680 train_time:70814ms step_avg:87.53ms +step:810/1680 train_time:70903ms step_avg:87.53ms +step:811/1680 train_time:70991ms step_avg:87.54ms +step:812/1680 train_time:71079ms step_avg:87.54ms +step:813/1680 train_time:71168ms step_avg:87.54ms +step:814/1680 train_time:71255ms step_avg:87.54ms +step:815/1680 train_time:71343ms step_avg:87.54ms +step:816/1680 train_time:71431ms step_avg:87.54ms +step:817/1680 train_time:71519ms step_avg:87.54ms +step:818/1680 train_time:71607ms step_avg:87.54ms +step:819/1680 train_time:71696ms step_avg:87.54ms +step:820/1680 train_time:71784ms step_avg:87.54ms +step:821/1680 train_time:71873ms step_avg:87.54ms +step:822/1680 train_time:71961ms step_avg:87.54ms +step:823/1680 train_time:72049ms step_avg:87.54ms +step:824/1680 train_time:72138ms step_avg:87.55ms +step:825/1680 train_time:72227ms step_avg:87.55ms +step:826/1680 train_time:72314ms step_avg:87.55ms +step:827/1680 train_time:72401ms step_avg:87.55ms +step:828/1680 train_time:72489ms step_avg:87.55ms +step:829/1680 train_time:72577ms step_avg:87.55ms +step:830/1680 train_time:72666ms step_avg:87.55ms +step:831/1680 train_time:72754ms step_avg:87.55ms +step:832/1680 train_time:72842ms step_avg:87.55ms +step:833/1680 train_time:72931ms step_avg:87.55ms +step:834/1680 train_time:73019ms step_avg:87.55ms +step:835/1680 train_time:73107ms step_avg:87.55ms +step:836/1680 train_time:73195ms step_avg:87.55ms +step:837/1680 train_time:73283ms step_avg:87.55ms +step:838/1680 train_time:73371ms step_avg:87.56ms +step:839/1680 train_time:73459ms step_avg:87.56ms +step:840/1680 train_time:73547ms step_avg:87.56ms +step:841/1680 train_time:73635ms step_avg:87.56ms +step:842/1680 train_time:73724ms step_avg:87.56ms +step:843/1680 train_time:73811ms step_avg:87.56ms +step:844/1680 train_time:73899ms step_avg:87.56ms +step:845/1680 train_time:73988ms step_avg:87.56ms +step:846/1680 train_time:74076ms step_avg:87.56ms +step:847/1680 train_time:74165ms step_avg:87.56ms +step:848/1680 train_time:74253ms step_avg:87.56ms +step:849/1680 train_time:74341ms step_avg:87.56ms +step:850/1680 train_time:74429ms step_avg:87.56ms +step:851/1680 train_time:74517ms step_avg:87.56ms +step:852/1680 train_time:74606ms step_avg:87.57ms +step:853/1680 train_time:74694ms step_avg:87.57ms +step:854/1680 train_time:74782ms step_avg:87.57ms +step:855/1680 train_time:74870ms step_avg:87.57ms +step:856/1680 train_time:74959ms step_avg:87.57ms +step:857/1680 train_time:75048ms step_avg:87.57ms +step:858/1680 train_time:75138ms step_avg:87.57ms +step:859/1680 train_time:75227ms step_avg:87.58ms +step:860/1680 train_time:75315ms step_avg:87.58ms +step:861/1680 train_time:75403ms step_avg:87.58ms +step:862/1680 train_time:75492ms step_avg:87.58ms +step:863/1680 train_time:75580ms step_avg:87.58ms +step:864/1680 train_time:75668ms step_avg:87.58ms +step:865/1680 train_time:75757ms step_avg:87.58ms +step:866/1680 train_time:75844ms step_avg:87.58ms +step:867/1680 train_time:75933ms step_avg:87.58ms +step:868/1680 train_time:76020ms step_avg:87.58ms +step:869/1680 train_time:76108ms step_avg:87.58ms +step:870/1680 train_time:76197ms step_avg:87.58ms +step:871/1680 train_time:76285ms step_avg:87.58ms +step:872/1680 train_time:76373ms step_avg:87.58ms +step:873/1680 train_time:76460ms step_avg:87.58ms +step:874/1680 train_time:76549ms step_avg:87.58ms +step:875/1680 train_time:76636ms step_avg:87.58ms +step:875/1680 val_loss:3.5233 train_time:76727ms step_avg:87.69ms +step:876/1680 train_time:76746ms step_avg:87.61ms +step:877/1680 train_time:76819ms step_avg:87.59ms +step:878/1680 train_time:76911ms step_avg:87.60ms +step:879/1680 train_time:76999ms step_avg:87.60ms +step:880/1680 train_time:77087ms step_avg:87.60ms +step:881/1680 train_time:77174ms step_avg:87.60ms +step:882/1680 train_time:77261ms step_avg:87.60ms +step:883/1680 train_time:77348ms step_avg:87.60ms +step:884/1680 train_time:77435ms step_avg:87.60ms +step:885/1680 train_time:77523ms step_avg:87.60ms +step:886/1680 train_time:77610ms step_avg:87.60ms +step:887/1680 train_time:77700ms step_avg:87.60ms +step:888/1680 train_time:77790ms step_avg:87.60ms +step:889/1680 train_time:77880ms step_avg:87.60ms +step:890/1680 train_time:77968ms step_avg:87.61ms +step:891/1680 train_time:78057ms step_avg:87.61ms +step:892/1680 train_time:78144ms step_avg:87.61ms +step:893/1680 train_time:78232ms step_avg:87.61ms +step:894/1680 train_time:78319ms step_avg:87.61ms +step:895/1680 train_time:78407ms step_avg:87.61ms +step:896/1680 train_time:78495ms step_avg:87.61ms +step:897/1680 train_time:78582ms step_avg:87.61ms +step:898/1680 train_time:78671ms step_avg:87.61ms +step:899/1680 train_time:78760ms step_avg:87.61ms +step:900/1680 train_time:78850ms step_avg:87.61ms +step:901/1680 train_time:78938ms step_avg:87.61ms +step:902/1680 train_time:79028ms step_avg:87.61ms +step:903/1680 train_time:79116ms step_avg:87.61ms +step:904/1680 train_time:79204ms step_avg:87.62ms +step:905/1680 train_time:79291ms step_avg:87.61ms +step:906/1680 train_time:79379ms step_avg:87.61ms +step:907/1680 train_time:79467ms step_avg:87.62ms +step:908/1680 train_time:79555ms step_avg:87.62ms +step:909/1680 train_time:79643ms step_avg:87.62ms +step:910/1680 train_time:79732ms step_avg:87.62ms +step:911/1680 train_time:79821ms step_avg:87.62ms +step:912/1680 train_time:79910ms step_avg:87.62ms +step:913/1680 train_time:79998ms step_avg:87.62ms +step:914/1680 train_time:80087ms step_avg:87.62ms +step:915/1680 train_time:80176ms step_avg:87.62ms +step:916/1680 train_time:80263ms step_avg:87.62ms +step:917/1680 train_time:80352ms step_avg:87.62ms +step:918/1680 train_time:80440ms step_avg:87.62ms +step:919/1680 train_time:80527ms step_avg:87.62ms +step:920/1680 train_time:80616ms step_avg:87.63ms +step:921/1680 train_time:80704ms step_avg:87.63ms +step:922/1680 train_time:80793ms step_avg:87.63ms +step:923/1680 train_time:80881ms step_avg:87.63ms +step:924/1680 train_time:80969ms step_avg:87.63ms +step:925/1680 train_time:81058ms step_avg:87.63ms +step:926/1680 train_time:81147ms step_avg:87.63ms +step:927/1680 train_time:81235ms step_avg:87.63ms +step:928/1680 train_time:81324ms step_avg:87.63ms +step:929/1680 train_time:81411ms step_avg:87.63ms +step:930/1680 train_time:81500ms step_avg:87.63ms +step:931/1680 train_time:81588ms step_avg:87.63ms +step:932/1680 train_time:81677ms step_avg:87.64ms +step:933/1680 train_time:81765ms step_avg:87.64ms +step:934/1680 train_time:81854ms step_avg:87.64ms +step:935/1680 train_time:81942ms step_avg:87.64ms +step:936/1680 train_time:82031ms step_avg:87.64ms +step:937/1680 train_time:82119ms step_avg:87.64ms +step:938/1680 train_time:82207ms step_avg:87.64ms +step:939/1680 train_time:82295ms step_avg:87.64ms +step:940/1680 train_time:82383ms step_avg:87.64ms +step:941/1680 train_time:82471ms step_avg:87.64ms +step:942/1680 train_time:82559ms step_avg:87.64ms +step:943/1680 train_time:82647ms step_avg:87.64ms +step:944/1680 train_time:82735ms step_avg:87.64ms +step:945/1680 train_time:82824ms step_avg:87.64ms +step:946/1680 train_time:82911ms step_avg:87.64ms +step:947/1680 train_time:83000ms step_avg:87.64ms +step:948/1680 train_time:83088ms step_avg:87.65ms +step:949/1680 train_time:83176ms step_avg:87.65ms +step:950/1680 train_time:83264ms step_avg:87.65ms +step:951/1680 train_time:83352ms step_avg:87.65ms +step:952/1680 train_time:83440ms step_avg:87.65ms +step:953/1680 train_time:83528ms step_avg:87.65ms +step:954/1680 train_time:83617ms step_avg:87.65ms +step:955/1680 train_time:83705ms step_avg:87.65ms +step:956/1680 train_time:83793ms step_avg:87.65ms +step:957/1680 train_time:83882ms step_avg:87.65ms +step:958/1680 train_time:83970ms step_avg:87.65ms +step:959/1680 train_time:84059ms step_avg:87.65ms +step:960/1680 train_time:84147ms step_avg:87.65ms +step:961/1680 train_time:84235ms step_avg:87.65ms +step:962/1680 train_time:84323ms step_avg:87.65ms +step:963/1680 train_time:84411ms step_avg:87.65ms +step:964/1680 train_time:84500ms step_avg:87.66ms +step:965/1680 train_time:84589ms step_avg:87.66ms +step:966/1680 train_time:84677ms step_avg:87.66ms +step:967/1680 train_time:84765ms step_avg:87.66ms +step:968/1680 train_time:84853ms step_avg:87.66ms +step:969/1680 train_time:84941ms step_avg:87.66ms +step:970/1680 train_time:85029ms step_avg:87.66ms +step:971/1680 train_time:85118ms step_avg:87.66ms +step:972/1680 train_time:85206ms step_avg:87.66ms +step:973/1680 train_time:85294ms step_avg:87.66ms +step:974/1680 train_time:85382ms step_avg:87.66ms +step:975/1680 train_time:85470ms step_avg:87.66ms +step:976/1680 train_time:85558ms step_avg:87.66ms +step:977/1680 train_time:85646ms step_avg:87.66ms +step:978/1680 train_time:85735ms step_avg:87.66ms +step:979/1680 train_time:85824ms step_avg:87.66ms +step:980/1680 train_time:85912ms step_avg:87.66ms +step:981/1680 train_time:86000ms step_avg:87.67ms +step:982/1680 train_time:86088ms step_avg:87.67ms +step:983/1680 train_time:86176ms step_avg:87.67ms +step:984/1680 train_time:86264ms step_avg:87.67ms +step:985/1680 train_time:86353ms step_avg:87.67ms +step:986/1680 train_time:86441ms step_avg:87.67ms +step:987/1680 train_time:86528ms step_avg:87.67ms +step:988/1680 train_time:86617ms step_avg:87.67ms +step:989/1680 train_time:86706ms step_avg:87.67ms +step:990/1680 train_time:86794ms step_avg:87.67ms +step:991/1680 train_time:86883ms step_avg:87.67ms +step:992/1680 train_time:86972ms step_avg:87.67ms +step:993/1680 train_time:87060ms step_avg:87.67ms +step:994/1680 train_time:87148ms step_avg:87.67ms +step:995/1680 train_time:87237ms step_avg:87.68ms +step:996/1680 train_time:87325ms step_avg:87.68ms +step:997/1680 train_time:87413ms step_avg:87.68ms +step:998/1680 train_time:87501ms step_avg:87.68ms +step:999/1680 train_time:87589ms step_avg:87.68ms +step:1000/1680 train_time:87678ms step_avg:87.68ms +step:1000/1680 val_loss:3.4726 train_time:87768ms step_avg:87.77ms +step:1001/1680 train_time:87787ms step_avg:87.70ms +step:1002/1680 train_time:87858ms step_avg:87.68ms +step:1003/1680 train_time:87949ms step_avg:87.69ms +step:1004/1680 train_time:88037ms step_avg:87.69ms +step:1005/1680 train_time:88125ms step_avg:87.69ms +step:1006/1680 train_time:88212ms step_avg:87.69ms +step:1007/1680 train_time:88299ms step_avg:87.68ms +step:1008/1680 train_time:88386ms step_avg:87.68ms +step:1009/1680 train_time:88473ms step_avg:87.68ms +step:1010/1680 train_time:88562ms step_avg:87.68ms +step:1011/1680 train_time:88649ms step_avg:87.68ms +step:1012/1680 train_time:88738ms step_avg:87.69ms +step:1013/1680 train_time:88828ms step_avg:87.69ms +step:1014/1680 train_time:88918ms step_avg:87.69ms +step:1015/1680 train_time:89006ms step_avg:87.69ms +step:1016/1680 train_time:89095ms step_avg:87.69ms +step:1017/1680 train_time:89183ms step_avg:87.69ms +step:1018/1680 train_time:89270ms step_avg:87.69ms +step:1019/1680 train_time:89358ms step_avg:87.69ms +step:1020/1680 train_time:89446ms step_avg:87.69ms +step:1021/1680 train_time:89533ms step_avg:87.69ms +step:1022/1680 train_time:89621ms step_avg:87.69ms +step:1023/1680 train_time:89709ms step_avg:87.69ms +step:1024/1680 train_time:89798ms step_avg:87.69ms +step:1025/1680 train_time:89888ms step_avg:87.70ms +step:1026/1680 train_time:89977ms step_avg:87.70ms +step:1027/1680 train_time:90065ms step_avg:87.70ms +step:1028/1680 train_time:90153ms step_avg:87.70ms +step:1029/1680 train_time:90241ms step_avg:87.70ms +step:1030/1680 train_time:90329ms step_avg:87.70ms +step:1031/1680 train_time:90417ms step_avg:87.70ms +step:1032/1680 train_time:90505ms step_avg:87.70ms +step:1033/1680 train_time:90593ms step_avg:87.70ms +step:1034/1680 train_time:90682ms step_avg:87.70ms +step:1035/1680 train_time:90770ms step_avg:87.70ms +step:1036/1680 train_time:90859ms step_avg:87.70ms +step:1037/1680 train_time:90948ms step_avg:87.70ms +step:1038/1680 train_time:91036ms step_avg:87.70ms +step:1039/1680 train_time:91125ms step_avg:87.70ms +step:1040/1680 train_time:91213ms step_avg:87.70ms +step:1041/1680 train_time:91301ms step_avg:87.70ms +step:1042/1680 train_time:91389ms step_avg:87.70ms +step:1043/1680 train_time:91477ms step_avg:87.71ms +step:1044/1680 train_time:91566ms step_avg:87.71ms +step:1045/1680 train_time:91654ms step_avg:87.71ms +step:1046/1680 train_time:91743ms step_avg:87.71ms +step:1047/1680 train_time:91831ms step_avg:87.71ms +step:1048/1680 train_time:91919ms step_avg:87.71ms +step:1049/1680 train_time:92007ms step_avg:87.71ms +step:1050/1680 train_time:92096ms step_avg:87.71ms +step:1051/1680 train_time:92184ms step_avg:87.71ms +step:1052/1680 train_time:92273ms step_avg:87.71ms +step:1053/1680 train_time:92361ms step_avg:87.71ms +step:1054/1680 train_time:92449ms step_avg:87.71ms +step:1055/1680 train_time:92537ms step_avg:87.71ms +step:1056/1680 train_time:92626ms step_avg:87.71ms +step:1057/1680 train_time:92715ms step_avg:87.71ms +step:1058/1680 train_time:92803ms step_avg:87.72ms +step:1059/1680 train_time:92892ms step_avg:87.72ms +step:1060/1680 train_time:92979ms step_avg:87.72ms +step:1061/1680 train_time:93068ms step_avg:87.72ms +step:1062/1680 train_time:93157ms step_avg:87.72ms +step:1063/1680 train_time:93245ms step_avg:87.72ms +step:1064/1680 train_time:93332ms step_avg:87.72ms +step:1065/1680 train_time:93420ms step_avg:87.72ms +step:1066/1680 train_time:93509ms step_avg:87.72ms +step:1067/1680 train_time:93597ms step_avg:87.72ms +step:1068/1680 train_time:93686ms step_avg:87.72ms +step:1069/1680 train_time:93776ms step_avg:87.72ms +step:1070/1680 train_time:93865ms step_avg:87.72ms +step:1071/1680 train_time:93953ms step_avg:87.72ms +step:1072/1680 train_time:94042ms step_avg:87.73ms +step:1073/1680 train_time:94130ms step_avg:87.73ms +step:1074/1680 train_time:94218ms step_avg:87.73ms +step:1075/1680 train_time:94306ms step_avg:87.73ms +step:1076/1680 train_time:94393ms step_avg:87.73ms +step:1077/1680 train_time:94482ms step_avg:87.73ms +step:1078/1680 train_time:94569ms step_avg:87.73ms +step:1079/1680 train_time:94658ms step_avg:87.73ms +step:1080/1680 train_time:94748ms step_avg:87.73ms +step:1081/1680 train_time:94836ms step_avg:87.73ms +step:1082/1680 train_time:94925ms step_avg:87.73ms +step:1083/1680 train_time:95013ms step_avg:87.73ms +step:1084/1680 train_time:95101ms step_avg:87.73ms +step:1085/1680 train_time:95189ms step_avg:87.73ms +step:1086/1680 train_time:95277ms step_avg:87.73ms +step:1087/1680 train_time:95365ms step_avg:87.73ms +step:1088/1680 train_time:95454ms step_avg:87.73ms +step:1089/1680 train_time:95542ms step_avg:87.73ms +step:1090/1680 train_time:95630ms step_avg:87.73ms +step:1091/1680 train_time:95719ms step_avg:87.73ms +step:1092/1680 train_time:95807ms step_avg:87.74ms +step:1093/1680 train_time:95896ms step_avg:87.74ms +step:1094/1680 train_time:95985ms step_avg:87.74ms +step:1095/1680 train_time:96074ms step_avg:87.74ms +step:1096/1680 train_time:96162ms step_avg:87.74ms +step:1097/1680 train_time:96252ms step_avg:87.74ms +step:1098/1680 train_time:96340ms step_avg:87.74ms +step:1099/1680 train_time:96429ms step_avg:87.74ms +step:1100/1680 train_time:96518ms step_avg:87.74ms +step:1101/1680 train_time:96607ms step_avg:87.75ms +step:1102/1680 train_time:96696ms step_avg:87.75ms +step:1103/1680 train_time:96785ms step_avg:87.75ms +step:1104/1680 train_time:96874ms step_avg:87.75ms +step:1105/1680 train_time:96964ms step_avg:87.75ms +step:1106/1680 train_time:97054ms step_avg:87.75ms +step:1107/1680 train_time:97142ms step_avg:87.75ms +step:1108/1680 train_time:97232ms step_avg:87.75ms +step:1109/1680 train_time:97321ms step_avg:87.76ms +step:1110/1680 train_time:97409ms step_avg:87.76ms +step:1111/1680 train_time:97498ms step_avg:87.76ms +step:1112/1680 train_time:97587ms step_avg:87.76ms +step:1113/1680 train_time:97676ms step_avg:87.76ms +step:1114/1680 train_time:97764ms step_avg:87.76ms +step:1115/1680 train_time:97853ms step_avg:87.76ms +step:1116/1680 train_time:97943ms step_avg:87.76ms +step:1117/1680 train_time:98032ms step_avg:87.76ms +step:1118/1680 train_time:98121ms step_avg:87.76ms +step:1119/1680 train_time:98210ms step_avg:87.77ms +step:1120/1680 train_time:98298ms step_avg:87.77ms +step:1121/1680 train_time:98387ms step_avg:87.77ms +step:1122/1680 train_time:98475ms step_avg:87.77ms +step:1123/1680 train_time:98565ms step_avg:87.77ms +step:1124/1680 train_time:98654ms step_avg:87.77ms +step:1125/1680 train_time:98744ms step_avg:87.77ms +step:1125/1680 val_loss:3.4192 train_time:98835ms step_avg:87.85ms +step:1126/1680 train_time:98854ms step_avg:87.79ms +step:1127/1680 train_time:98924ms step_avg:87.78ms +step:1128/1680 train_time:99015ms step_avg:87.78ms +step:1129/1680 train_time:99106ms step_avg:87.78ms +step:1130/1680 train_time:99195ms step_avg:87.78ms +step:1131/1680 train_time:99283ms step_avg:87.78ms +step:1132/1680 train_time:99371ms step_avg:87.78ms +step:1133/1680 train_time:99459ms step_avg:87.78ms +step:1134/1680 train_time:99547ms step_avg:87.78ms +step:1135/1680 train_time:99634ms step_avg:87.78ms +step:1136/1680 train_time:99724ms step_avg:87.79ms +step:1137/1680 train_time:99815ms step_avg:87.79ms +step:1138/1680 train_time:99905ms step_avg:87.79ms +step:1139/1680 train_time:99996ms step_avg:87.79ms +step:1140/1680 train_time:100086ms step_avg:87.79ms +step:1141/1680 train_time:100175ms step_avg:87.80ms +step:1142/1680 train_time:100263ms step_avg:87.80ms +step:1143/1680 train_time:100352ms step_avg:87.80ms +step:1144/1680 train_time:100440ms step_avg:87.80ms +step:1145/1680 train_time:100528ms step_avg:87.80ms +step:1146/1680 train_time:100615ms step_avg:87.80ms +step:1147/1680 train_time:100704ms step_avg:87.80ms +step:1148/1680 train_time:100794ms step_avg:87.80ms +step:1149/1680 train_time:100884ms step_avg:87.80ms +step:1150/1680 train_time:100975ms step_avg:87.80ms +step:1151/1680 train_time:101065ms step_avg:87.81ms +step:1152/1680 train_time:101154ms step_avg:87.81ms +step:1153/1680 train_time:101243ms step_avg:87.81ms +step:1154/1680 train_time:101331ms step_avg:87.81ms +step:1155/1680 train_time:101418ms step_avg:87.81ms +step:1156/1680 train_time:101506ms step_avg:87.81ms +step:1157/1680 train_time:101595ms step_avg:87.81ms +step:1158/1680 train_time:101683ms step_avg:87.81ms +step:1159/1680 train_time:101772ms step_avg:87.81ms +step:1160/1680 train_time:101861ms step_avg:87.81ms +step:1161/1680 train_time:101952ms step_avg:87.81ms +step:1162/1680 train_time:102041ms step_avg:87.82ms +step:1163/1680 train_time:102131ms step_avg:87.82ms +step:1164/1680 train_time:102219ms step_avg:87.82ms +step:1165/1680 train_time:102308ms step_avg:87.82ms +step:1166/1680 train_time:102396ms step_avg:87.82ms +step:1167/1680 train_time:102485ms step_avg:87.82ms +step:1168/1680 train_time:102573ms step_avg:87.82ms +step:1169/1680 train_time:102662ms step_avg:87.82ms +step:1170/1680 train_time:102751ms step_avg:87.82ms +step:1171/1680 train_time:102840ms step_avg:87.82ms +step:1172/1680 train_time:102928ms step_avg:87.82ms +step:1173/1680 train_time:103018ms step_avg:87.82ms +step:1174/1680 train_time:103106ms step_avg:87.82ms +step:1175/1680 train_time:103195ms step_avg:87.83ms +step:1176/1680 train_time:103284ms step_avg:87.83ms +step:1177/1680 train_time:103373ms step_avg:87.83ms +step:1178/1680 train_time:103463ms step_avg:87.83ms +step:1179/1680 train_time:103551ms step_avg:87.83ms +step:1180/1680 train_time:103639ms step_avg:87.83ms +step:1181/1680 train_time:103728ms step_avg:87.83ms +step:1182/1680 train_time:103817ms step_avg:87.83ms +step:1183/1680 train_time:103907ms step_avg:87.83ms +step:1184/1680 train_time:103996ms step_avg:87.83ms +step:1185/1680 train_time:104085ms step_avg:87.84ms +step:1186/1680 train_time:104173ms step_avg:87.84ms +step:1187/1680 train_time:104262ms step_avg:87.84ms +step:1188/1680 train_time:104351ms step_avg:87.84ms +step:1189/1680 train_time:104439ms step_avg:87.84ms +step:1190/1680 train_time:104528ms step_avg:87.84ms +step:1191/1680 train_time:104617ms step_avg:87.84ms +step:1192/1680 train_time:104706ms step_avg:87.84ms +step:1193/1680 train_time:104795ms step_avg:87.84ms +step:1194/1680 train_time:104885ms step_avg:87.84ms +step:1195/1680 train_time:104974ms step_avg:87.84ms +step:1196/1680 train_time:105062ms step_avg:87.84ms +step:1197/1680 train_time:105151ms step_avg:87.85ms +step:1198/1680 train_time:105240ms step_avg:87.85ms +step:1199/1680 train_time:105329ms step_avg:87.85ms +step:1200/1680 train_time:105418ms step_avg:87.85ms +step:1201/1680 train_time:105506ms step_avg:87.85ms +step:1202/1680 train_time:105595ms step_avg:87.85ms +step:1203/1680 train_time:105683ms step_avg:87.85ms +step:1204/1680 train_time:105772ms step_avg:87.85ms +step:1205/1680 train_time:105861ms step_avg:87.85ms +step:1206/1680 train_time:105950ms step_avg:87.85ms +step:1207/1680 train_time:106040ms step_avg:87.85ms +step:1208/1680 train_time:106128ms step_avg:87.85ms +step:1209/1680 train_time:106218ms step_avg:87.86ms +step:1210/1680 train_time:106307ms step_avg:87.86ms +step:1211/1680 train_time:106396ms step_avg:87.86ms +step:1212/1680 train_time:106486ms step_avg:87.86ms +step:1213/1680 train_time:106575ms step_avg:87.86ms +step:1214/1680 train_time:106663ms step_avg:87.86ms +step:1215/1680 train_time:106753ms step_avg:87.86ms +step:1216/1680 train_time:106842ms step_avg:87.86ms +step:1217/1680 train_time:106931ms step_avg:87.86ms +step:1218/1680 train_time:107020ms step_avg:87.87ms +step:1219/1680 train_time:107110ms step_avg:87.87ms +step:1220/1680 train_time:107200ms step_avg:87.87ms +step:1221/1680 train_time:107290ms step_avg:87.87ms +step:1222/1680 train_time:107379ms step_avg:87.87ms +step:1223/1680 train_time:107467ms step_avg:87.87ms +step:1224/1680 train_time:107556ms step_avg:87.87ms +step:1225/1680 train_time:107644ms step_avg:87.87ms +step:1226/1680 train_time:107733ms step_avg:87.87ms +step:1227/1680 train_time:107822ms step_avg:87.87ms +step:1228/1680 train_time:107911ms step_avg:87.88ms +step:1229/1680 train_time:108000ms step_avg:87.88ms +step:1230/1680 train_time:108089ms step_avg:87.88ms +step:1231/1680 train_time:108180ms step_avg:87.88ms +step:1232/1680 train_time:108270ms step_avg:87.88ms +step:1233/1680 train_time:108359ms step_avg:87.88ms +step:1234/1680 train_time:108448ms step_avg:87.88ms +step:1235/1680 train_time:108537ms step_avg:87.88ms +step:1236/1680 train_time:108626ms step_avg:87.88ms +step:1237/1680 train_time:108715ms step_avg:87.89ms +step:1238/1680 train_time:108805ms step_avg:87.89ms +step:1239/1680 train_time:108894ms step_avg:87.89ms +step:1240/1680 train_time:108983ms step_avg:87.89ms +step:1241/1680 train_time:109072ms step_avg:87.89ms +step:1242/1680 train_time:109161ms step_avg:87.89ms +step:1243/1680 train_time:109250ms step_avg:87.89ms +step:1244/1680 train_time:109340ms step_avg:87.89ms +step:1245/1680 train_time:109429ms step_avg:87.89ms +step:1246/1680 train_time:109518ms step_avg:87.90ms +step:1247/1680 train_time:109607ms step_avg:87.90ms +step:1248/1680 train_time:109695ms step_avg:87.90ms +step:1249/1680 train_time:109784ms step_avg:87.90ms +step:1250/1680 train_time:109873ms step_avg:87.90ms +step:1250/1680 val_loss:3.3811 train_time:109963ms step_avg:87.97ms +step:1251/1680 train_time:109982ms step_avg:87.91ms +step:1252/1680 train_time:110057ms step_avg:87.91ms +step:1253/1680 train_time:110149ms step_avg:87.91ms +step:1254/1680 train_time:110237ms step_avg:87.91ms +step:1255/1680 train_time:110325ms step_avg:87.91ms +step:1256/1680 train_time:110413ms step_avg:87.91ms +step:1257/1680 train_time:110500ms step_avg:87.91ms +step:1258/1680 train_time:110588ms step_avg:87.91ms +step:1259/1680 train_time:110677ms step_avg:87.91ms +step:1260/1680 train_time:110766ms step_avg:87.91ms +step:1261/1680 train_time:110854ms step_avg:87.91ms +step:1262/1680 train_time:110945ms step_avg:87.91ms +step:1263/1680 train_time:111038ms step_avg:87.92ms +step:1264/1680 train_time:111128ms step_avg:87.92ms +step:1265/1680 train_time:111218ms step_avg:87.92ms +step:1266/1680 train_time:111306ms step_avg:87.92ms +step:1267/1680 train_time:111394ms step_avg:87.92ms +step:1268/1680 train_time:111482ms step_avg:87.92ms +step:1269/1680 train_time:111570ms step_avg:87.92ms +step:1270/1680 train_time:111659ms step_avg:87.92ms +step:1271/1680 train_time:111747ms step_avg:87.92ms +step:1272/1680 train_time:111836ms step_avg:87.92ms +step:1273/1680 train_time:111927ms step_avg:87.92ms +step:1274/1680 train_time:112019ms step_avg:87.93ms +step:1275/1680 train_time:112108ms step_avg:87.93ms +step:1276/1680 train_time:112198ms step_avg:87.93ms +step:1277/1680 train_time:112288ms step_avg:87.93ms +step:1278/1680 train_time:112377ms step_avg:87.93ms +step:1279/1680 train_time:112466ms step_avg:87.93ms +step:1280/1680 train_time:112554ms step_avg:87.93ms +step:1281/1680 train_time:112642ms step_avg:87.93ms +step:1282/1680 train_time:112730ms step_avg:87.93ms +step:1283/1680 train_time:112820ms step_avg:87.93ms +step:1284/1680 train_time:112909ms step_avg:87.94ms +step:1285/1680 train_time:113000ms step_avg:87.94ms +step:1286/1680 train_time:113090ms step_avg:87.94ms +step:1287/1680 train_time:113180ms step_avg:87.94ms +step:1288/1680 train_time:113269ms step_avg:87.94ms +step:1289/1680 train_time:113358ms step_avg:87.94ms +step:1290/1680 train_time:113447ms step_avg:87.94ms +step:1291/1680 train_time:113537ms step_avg:87.94ms +step:1292/1680 train_time:113625ms step_avg:87.95ms +step:1293/1680 train_time:113713ms step_avg:87.95ms +step:1294/1680 train_time:113802ms step_avg:87.95ms +step:1295/1680 train_time:113891ms step_avg:87.95ms +step:1296/1680 train_time:113980ms step_avg:87.95ms +step:1297/1680 train_time:114069ms step_avg:87.95ms +step:1298/1680 train_time:114159ms step_avg:87.95ms +step:1299/1680 train_time:114248ms step_avg:87.95ms +step:1300/1680 train_time:114337ms step_avg:87.95ms +step:1301/1680 train_time:114427ms step_avg:87.95ms +step:1302/1680 train_time:114515ms step_avg:87.95ms +step:1303/1680 train_time:114603ms step_avg:87.95ms +step:1304/1680 train_time:114692ms step_avg:87.95ms +step:1305/1680 train_time:114781ms step_avg:87.95ms +step:1306/1680 train_time:114869ms step_avg:87.96ms +step:1307/1680 train_time:114958ms step_avg:87.96ms +step:1308/1680 train_time:115047ms step_avg:87.96ms +step:1309/1680 train_time:115137ms step_avg:87.96ms +step:1310/1680 train_time:115227ms step_avg:87.96ms +step:1311/1680 train_time:115316ms step_avg:87.96ms +step:1312/1680 train_time:115405ms step_avg:87.96ms +step:1313/1680 train_time:115494ms step_avg:87.96ms +step:1314/1680 train_time:115583ms step_avg:87.96ms +step:1315/1680 train_time:115672ms step_avg:87.96ms +step:1316/1680 train_time:115761ms step_avg:87.96ms +step:1317/1680 train_time:115850ms step_avg:87.97ms +step:1318/1680 train_time:115939ms step_avg:87.97ms +step:1319/1680 train_time:116028ms step_avg:87.97ms +step:1320/1680 train_time:116119ms step_avg:87.97ms +step:1321/1680 train_time:116209ms step_avg:87.97ms +step:1322/1680 train_time:116299ms step_avg:87.97ms +step:1323/1680 train_time:116388ms step_avg:87.97ms +step:1324/1680 train_time:116477ms step_avg:87.97ms +step:1325/1680 train_time:116566ms step_avg:87.97ms +step:1326/1680 train_time:116655ms step_avg:87.97ms +step:1327/1680 train_time:116744ms step_avg:87.98ms +step:1328/1680 train_time:116832ms step_avg:87.98ms +step:1329/1680 train_time:116922ms step_avg:87.98ms +step:1330/1680 train_time:117011ms step_avg:87.98ms +step:1331/1680 train_time:117100ms step_avg:87.98ms +step:1332/1680 train_time:117190ms step_avg:87.98ms +step:1333/1680 train_time:117279ms step_avg:87.98ms +step:1334/1680 train_time:117368ms step_avg:87.98ms +step:1335/1680 train_time:117457ms step_avg:87.98ms +step:1336/1680 train_time:117546ms step_avg:87.98ms +step:1337/1680 train_time:117636ms step_avg:87.99ms +step:1338/1680 train_time:117725ms step_avg:87.99ms +step:1339/1680 train_time:117813ms step_avg:87.99ms +step:1340/1680 train_time:117902ms step_avg:87.99ms +step:1341/1680 train_time:117991ms step_avg:87.99ms +step:1342/1680 train_time:118079ms step_avg:87.99ms +step:1343/1680 train_time:118168ms step_avg:87.99ms +step:1344/1680 train_time:118258ms step_avg:87.99ms +step:1345/1680 train_time:118347ms step_avg:87.99ms +step:1346/1680 train_time:118436ms step_avg:87.99ms +step:1347/1680 train_time:118526ms step_avg:87.99ms +step:1348/1680 train_time:118615ms step_avg:87.99ms +step:1349/1680 train_time:118704ms step_avg:87.99ms +step:1350/1680 train_time:118793ms step_avg:87.99ms +step:1351/1680 train_time:118882ms step_avg:88.00ms +step:1352/1680 train_time:118971ms step_avg:88.00ms +step:1353/1680 train_time:119060ms step_avg:88.00ms +step:1354/1680 train_time:119150ms step_avg:88.00ms +step:1355/1680 train_time:119239ms step_avg:88.00ms +step:1356/1680 train_time:119329ms step_avg:88.00ms +step:1357/1680 train_time:119419ms step_avg:88.00ms +step:1358/1680 train_time:119507ms step_avg:88.00ms +step:1359/1680 train_time:119595ms step_avg:88.00ms +step:1360/1680 train_time:119684ms step_avg:88.00ms +step:1361/1680 train_time:119772ms step_avg:88.00ms +step:1362/1680 train_time:119862ms step_avg:88.00ms +step:1363/1680 train_time:119951ms step_avg:88.00ms +step:1364/1680 train_time:120039ms step_avg:88.01ms +step:1365/1680 train_time:120128ms step_avg:88.01ms +step:1366/1680 train_time:120219ms step_avg:88.01ms +step:1367/1680 train_time:120308ms step_avg:88.01ms +step:1368/1680 train_time:120397ms step_avg:88.01ms +step:1369/1680 train_time:120487ms step_avg:88.01ms +step:1370/1680 train_time:120575ms step_avg:88.01ms +step:1371/1680 train_time:120664ms step_avg:88.01ms +step:1372/1680 train_time:120753ms step_avg:88.01ms +step:1373/1680 train_time:120842ms step_avg:88.01ms +step:1374/1680 train_time:120930ms step_avg:88.01ms +step:1375/1680 train_time:121019ms step_avg:88.01ms +step:1375/1680 val_loss:3.3468 train_time:121109ms step_avg:88.08ms +step:1376/1680 train_time:121127ms step_avg:88.03ms +step:1377/1680 train_time:121203ms step_avg:88.02ms +step:1378/1680 train_time:121295ms step_avg:88.02ms +step:1379/1680 train_time:121386ms step_avg:88.02ms +step:1380/1680 train_time:121474ms step_avg:88.02ms +step:1381/1680 train_time:121564ms step_avg:88.03ms +step:1382/1680 train_time:121652ms step_avg:88.03ms +step:1383/1680 train_time:121739ms step_avg:88.02ms +step:1384/1680 train_time:121827ms step_avg:88.03ms +step:1385/1680 train_time:121915ms step_avg:88.03ms +step:1386/1680 train_time:122003ms step_avg:88.03ms +step:1387/1680 train_time:122093ms step_avg:88.03ms +step:1388/1680 train_time:122185ms step_avg:88.03ms +step:1389/1680 train_time:122276ms step_avg:88.03ms +step:1390/1680 train_time:122367ms step_avg:88.03ms +step:1391/1680 train_time:122456ms step_avg:88.03ms +step:1392/1680 train_time:122545ms step_avg:88.04ms +step:1393/1680 train_time:122633ms step_avg:88.04ms +step:1394/1680 train_time:122721ms step_avg:88.04ms +step:1395/1680 train_time:122810ms step_avg:88.04ms +step:1396/1680 train_time:122898ms step_avg:88.04ms +step:1397/1680 train_time:122987ms step_avg:88.04ms +step:1398/1680 train_time:123076ms step_avg:88.04ms +step:1399/1680 train_time:123167ms step_avg:88.04ms +step:1400/1680 train_time:123257ms step_avg:88.04ms +step:1401/1680 train_time:123347ms step_avg:88.04ms +step:1402/1680 train_time:123437ms step_avg:88.04ms +step:1403/1680 train_time:123527ms step_avg:88.05ms +step:1404/1680 train_time:123617ms step_avg:88.05ms +step:1405/1680 train_time:123706ms step_avg:88.05ms +step:1406/1680 train_time:123794ms step_avg:88.05ms +step:1407/1680 train_time:123883ms step_avg:88.05ms +step:1408/1680 train_time:123972ms step_avg:88.05ms +step:1409/1680 train_time:124061ms step_avg:88.05ms +step:1410/1680 train_time:124150ms step_avg:88.05ms +step:1411/1680 train_time:124239ms step_avg:88.05ms +step:1412/1680 train_time:124328ms step_avg:88.05ms +step:1413/1680 train_time:124417ms step_avg:88.05ms +step:1414/1680 train_time:124508ms step_avg:88.05ms +step:1415/1680 train_time:124597ms step_avg:88.05ms +step:1416/1680 train_time:124686ms step_avg:88.06ms +step:1417/1680 train_time:124775ms step_avg:88.06ms +step:1418/1680 train_time:124864ms step_avg:88.06ms +step:1419/1680 train_time:124952ms step_avg:88.06ms +step:1420/1680 train_time:125041ms step_avg:88.06ms +step:1421/1680 train_time:125130ms step_avg:88.06ms +step:1422/1680 train_time:125220ms step_avg:88.06ms +step:1423/1680 train_time:125309ms step_avg:88.06ms +step:1424/1680 train_time:125398ms step_avg:88.06ms +step:1425/1680 train_time:125488ms step_avg:88.06ms +step:1426/1680 train_time:125577ms step_avg:88.06ms +step:1427/1680 train_time:125666ms step_avg:88.06ms +step:1428/1680 train_time:125755ms step_avg:88.06ms +step:1429/1680 train_time:125844ms step_avg:88.06ms +step:1430/1680 train_time:125933ms step_avg:88.06ms +step:1431/1680 train_time:126022ms step_avg:88.07ms +step:1432/1680 train_time:126110ms step_avg:88.07ms +step:1433/1680 train_time:126199ms step_avg:88.07ms +step:1434/1680 train_time:126288ms step_avg:88.07ms +step:1435/1680 train_time:126378ms step_avg:88.07ms +step:1436/1680 train_time:126468ms step_avg:88.07ms +step:1437/1680 train_time:126557ms step_avg:88.07ms +step:1438/1680 train_time:126646ms step_avg:88.07ms +step:1439/1680 train_time:126734ms step_avg:88.07ms +step:1440/1680 train_time:126823ms step_avg:88.07ms +step:1441/1680 train_time:126912ms step_avg:88.07ms +step:1442/1680 train_time:127001ms step_avg:88.07ms +step:1443/1680 train_time:127090ms step_avg:88.07ms +step:1444/1680 train_time:127180ms step_avg:88.07ms +step:1445/1680 train_time:127268ms step_avg:88.07ms +step:1446/1680 train_time:127357ms step_avg:88.08ms +step:1447/1680 train_time:127447ms step_avg:88.08ms +step:1448/1680 train_time:127536ms step_avg:88.08ms +step:1449/1680 train_time:127624ms step_avg:88.08ms +step:1450/1680 train_time:127714ms step_avg:88.08ms +step:1451/1680 train_time:127803ms step_avg:88.08ms +step:1452/1680 train_time:127891ms step_avg:88.08ms +step:1453/1680 train_time:127979ms step_avg:88.08ms +step:1454/1680 train_time:128069ms step_avg:88.08ms +step:1455/1680 train_time:128158ms step_avg:88.08ms +step:1456/1680 train_time:128247ms step_avg:88.08ms +step:1457/1680 train_time:128336ms step_avg:88.08ms +step:1458/1680 train_time:128426ms step_avg:88.08ms +step:1459/1680 train_time:128515ms step_avg:88.08ms +step:1460/1680 train_time:128605ms step_avg:88.09ms +step:1461/1680 train_time:128694ms step_avg:88.09ms +step:1462/1680 train_time:128783ms step_avg:88.09ms +step:1463/1680 train_time:128872ms step_avg:88.09ms +step:1464/1680 train_time:128961ms step_avg:88.09ms +step:1465/1680 train_time:129050ms step_avg:88.09ms +step:1466/1680 train_time:129138ms step_avg:88.09ms +step:1467/1680 train_time:129227ms step_avg:88.09ms +step:1468/1680 train_time:129316ms step_avg:88.09ms +step:1469/1680 train_time:129406ms step_avg:88.09ms +step:1470/1680 train_time:129495ms step_avg:88.09ms +step:1471/1680 train_time:129584ms step_avg:88.09ms +step:1472/1680 train_time:129674ms step_avg:88.09ms +step:1473/1680 train_time:129763ms step_avg:88.09ms +step:1474/1680 train_time:129852ms step_avg:88.10ms +step:1475/1680 train_time:129942ms step_avg:88.10ms +step:1476/1680 train_time:130030ms step_avg:88.10ms +step:1477/1680 train_time:130119ms step_avg:88.10ms +step:1478/1680 train_time:130208ms step_avg:88.10ms +step:1479/1680 train_time:130297ms step_avg:88.10ms +step:1480/1680 train_time:130387ms step_avg:88.10ms +step:1481/1680 train_time:130476ms step_avg:88.10ms +step:1482/1680 train_time:130566ms step_avg:88.10ms +step:1483/1680 train_time:130655ms step_avg:88.10ms +step:1484/1680 train_time:130744ms step_avg:88.10ms +step:1485/1680 train_time:130833ms step_avg:88.10ms +step:1486/1680 train_time:130922ms step_avg:88.10ms +step:1487/1680 train_time:131011ms step_avg:88.10ms +step:1488/1680 train_time:131100ms step_avg:88.11ms +step:1489/1680 train_time:131190ms step_avg:88.11ms +step:1490/1680 train_time:131278ms step_avg:88.11ms +step:1491/1680 train_time:131368ms step_avg:88.11ms +step:1492/1680 train_time:131457ms step_avg:88.11ms +step:1493/1680 train_time:131545ms step_avg:88.11ms +step:1494/1680 train_time:131635ms step_avg:88.11ms +step:1495/1680 train_time:131724ms step_avg:88.11ms +step:1496/1680 train_time:131812ms step_avg:88.11ms +step:1497/1680 train_time:131902ms step_avg:88.11ms +step:1498/1680 train_time:131992ms step_avg:88.11ms +step:1499/1680 train_time:132081ms step_avg:88.11ms +step:1500/1680 train_time:132170ms step_avg:88.11ms +step:1500/1680 val_loss:3.3167 train_time:132260ms step_avg:88.17ms +step:1501/1680 train_time:132279ms step_avg:88.13ms +step:1502/1680 train_time:132351ms step_avg:88.12ms +step:1503/1680 train_time:132443ms step_avg:88.12ms +step:1504/1680 train_time:132533ms step_avg:88.12ms +step:1505/1680 train_time:132621ms step_avg:88.12ms +step:1506/1680 train_time:132709ms step_avg:88.12ms +step:1507/1680 train_time:132796ms step_avg:88.12ms +step:1508/1680 train_time:132885ms step_avg:88.12ms +step:1509/1680 train_time:132972ms step_avg:88.12ms +step:1510/1680 train_time:133061ms step_avg:88.12ms +step:1511/1680 train_time:133150ms step_avg:88.12ms +step:1512/1680 train_time:133240ms step_avg:88.12ms +step:1513/1680 train_time:133330ms step_avg:88.12ms +step:1514/1680 train_time:133424ms step_avg:88.13ms +step:1515/1680 train_time:133514ms step_avg:88.13ms +step:1516/1680 train_time:133603ms step_avg:88.13ms +step:1517/1680 train_time:133691ms step_avg:88.13ms +step:1518/1680 train_time:133780ms step_avg:88.13ms +step:1519/1680 train_time:133868ms step_avg:88.13ms +step:1520/1680 train_time:133956ms step_avg:88.13ms +step:1521/1680 train_time:134044ms step_avg:88.13ms +step:1522/1680 train_time:134133ms step_avg:88.13ms +step:1523/1680 train_time:134222ms step_avg:88.13ms +step:1524/1680 train_time:134312ms step_avg:88.13ms +step:1525/1680 train_time:134402ms step_avg:88.13ms +step:1526/1680 train_time:134493ms step_avg:88.13ms +step:1527/1680 train_time:134583ms step_avg:88.14ms +step:1528/1680 train_time:134672ms step_avg:88.14ms +step:1529/1680 train_time:134761ms step_avg:88.14ms +step:1530/1680 train_time:134850ms step_avg:88.14ms +step:1531/1680 train_time:134938ms step_avg:88.14ms +step:1532/1680 train_time:135027ms step_avg:88.14ms +step:1533/1680 train_time:135115ms step_avg:88.14ms +step:1534/1680 train_time:135204ms step_avg:88.14ms +step:1535/1680 train_time:135293ms step_avg:88.14ms +step:1536/1680 train_time:135382ms step_avg:88.14ms +step:1537/1680 train_time:135472ms step_avg:88.14ms +step:1538/1680 train_time:135561ms step_avg:88.14ms +step:1539/1680 train_time:135650ms step_avg:88.14ms +step:1540/1680 train_time:135739ms step_avg:88.14ms +step:1541/1680 train_time:135828ms step_avg:88.14ms +step:1542/1680 train_time:135917ms step_avg:88.14ms +step:1543/1680 train_time:136006ms step_avg:88.14ms +step:1544/1680 train_time:136095ms step_avg:88.14ms +step:1545/1680 train_time:136184ms step_avg:88.15ms +step:1546/1680 train_time:136272ms step_avg:88.15ms +step:1547/1680 train_time:136362ms step_avg:88.15ms +step:1548/1680 train_time:136451ms step_avg:88.15ms +step:1549/1680 train_time:136539ms step_avg:88.15ms +step:1550/1680 train_time:136629ms step_avg:88.15ms +step:1551/1680 train_time:136718ms step_avg:88.15ms +step:1552/1680 train_time:136808ms step_avg:88.15ms +step:1553/1680 train_time:136896ms step_avg:88.15ms +step:1554/1680 train_time:136985ms step_avg:88.15ms +step:1555/1680 train_time:137073ms step_avg:88.15ms +step:1556/1680 train_time:137163ms step_avg:88.15ms +step:1557/1680 train_time:137252ms step_avg:88.15ms +step:1558/1680 train_time:137341ms step_avg:88.15ms +step:1559/1680 train_time:137430ms step_avg:88.15ms +step:1560/1680 train_time:137519ms step_avg:88.15ms +step:1561/1680 train_time:137609ms step_avg:88.15ms +step:1562/1680 train_time:137698ms step_avg:88.15ms +step:1563/1680 train_time:137788ms step_avg:88.16ms +step:1564/1680 train_time:137876ms step_avg:88.16ms +step:1565/1680 train_time:137965ms step_avg:88.16ms +step:1566/1680 train_time:138054ms step_avg:88.16ms +step:1567/1680 train_time:138143ms step_avg:88.16ms +step:1568/1680 train_time:138232ms step_avg:88.16ms +step:1569/1680 train_time:138321ms step_avg:88.16ms +step:1570/1680 train_time:138411ms step_avg:88.16ms +step:1571/1680 train_time:138501ms step_avg:88.16ms +step:1572/1680 train_time:138589ms step_avg:88.16ms +step:1573/1680 train_time:138678ms step_avg:88.16ms +step:1574/1680 train_time:138769ms step_avg:88.16ms +step:1575/1680 train_time:138857ms step_avg:88.16ms +step:1576/1680 train_time:138946ms step_avg:88.16ms +step:1577/1680 train_time:139035ms step_avg:88.16ms +step:1578/1680 train_time:139124ms step_avg:88.16ms +step:1579/1680 train_time:139213ms step_avg:88.16ms +step:1580/1680 train_time:139301ms step_avg:88.17ms +step:1581/1680 train_time:139391ms step_avg:88.17ms +step:1582/1680 train_time:139480ms step_avg:88.17ms +step:1583/1680 train_time:139569ms step_avg:88.17ms +step:1584/1680 train_time:139658ms step_avg:88.17ms +step:1585/1680 train_time:139748ms step_avg:88.17ms +step:1586/1680 train_time:139837ms step_avg:88.17ms +step:1587/1680 train_time:139926ms step_avg:88.17ms +step:1588/1680 train_time:140015ms step_avg:88.17ms +step:1589/1680 train_time:140105ms step_avg:88.17ms +step:1590/1680 train_time:140193ms step_avg:88.17ms +step:1591/1680 train_time:140282ms step_avg:88.17ms +step:1592/1680 train_time:140370ms step_avg:88.17ms +step:1593/1680 train_time:140459ms step_avg:88.17ms +step:1594/1680 train_time:140548ms step_avg:88.17ms +step:1595/1680 train_time:140637ms step_avg:88.17ms +step:1596/1680 train_time:140726ms step_avg:88.17ms +step:1597/1680 train_time:140816ms step_avg:88.18ms +step:1598/1680 train_time:140905ms step_avg:88.18ms +step:1599/1680 train_time:140994ms step_avg:88.18ms +step:1600/1680 train_time:141083ms step_avg:88.18ms +step:1601/1680 train_time:141172ms step_avg:88.18ms +step:1602/1680 train_time:141262ms step_avg:88.18ms +step:1603/1680 train_time:141351ms step_avg:88.18ms +step:1604/1680 train_time:141439ms step_avg:88.18ms +step:1605/1680 train_time:141528ms step_avg:88.18ms +step:1606/1680 train_time:141617ms step_avg:88.18ms +step:1607/1680 train_time:141706ms step_avg:88.18ms +step:1608/1680 train_time:141795ms step_avg:88.18ms +step:1609/1680 train_time:141884ms step_avg:88.18ms +step:1610/1680 train_time:141974ms step_avg:88.18ms +step:1611/1680 train_time:142062ms step_avg:88.18ms +step:1612/1680 train_time:142151ms step_avg:88.18ms +step:1613/1680 train_time:142240ms step_avg:88.18ms +step:1614/1680 train_time:142329ms step_avg:88.18ms +step:1615/1680 train_time:142418ms step_avg:88.18ms +step:1616/1680 train_time:142509ms step_avg:88.19ms +step:1617/1680 train_time:142598ms step_avg:88.19ms +step:1618/1680 train_time:142687ms step_avg:88.19ms +step:1619/1680 train_time:142776ms step_avg:88.19ms +step:1620/1680 train_time:142865ms step_avg:88.19ms +step:1621/1680 train_time:142955ms step_avg:88.19ms +step:1622/1680 train_time:143044ms step_avg:88.19ms +step:1623/1680 train_time:143133ms step_avg:88.19ms +step:1624/1680 train_time:143222ms step_avg:88.19ms +step:1625/1680 train_time:143311ms step_avg:88.19ms +step:1625/1680 val_loss:3.2929 train_time:143401ms step_avg:88.25ms +step:1626/1680 train_time:143420ms step_avg:88.20ms +step:1627/1680 train_time:143493ms step_avg:88.19ms +step:1628/1680 train_time:143584ms step_avg:88.20ms +step:1629/1680 train_time:143673ms step_avg:88.20ms +step:1630/1680 train_time:143762ms step_avg:88.20ms +step:1631/1680 train_time:143850ms step_avg:88.20ms +step:1632/1680 train_time:143938ms step_avg:88.20ms +step:1633/1680 train_time:144026ms step_avg:88.20ms +step:1634/1680 train_time:144114ms step_avg:88.20ms +step:1635/1680 train_time:144203ms step_avg:88.20ms +step:1636/1680 train_time:144293ms step_avg:88.20ms +step:1637/1680 train_time:144383ms step_avg:88.20ms +step:1638/1680 train_time:144474ms step_avg:88.20ms +step:1639/1680 train_time:144564ms step_avg:88.20ms +step:1640/1680 train_time:144654ms step_avg:88.20ms +step:1641/1680 train_time:144743ms step_avg:88.20ms +step:1642/1680 train_time:144832ms step_avg:88.20ms +step:1643/1680 train_time:144921ms step_avg:88.21ms +step:1644/1680 train_time:145009ms step_avg:88.20ms +step:1645/1680 train_time:145097ms step_avg:88.20ms +step:1646/1680 train_time:145186ms step_avg:88.21ms +step:1647/1680 train_time:145274ms step_avg:88.21ms +step:1648/1680 train_time:145366ms step_avg:88.21ms +step:1649/1680 train_time:145455ms step_avg:88.21ms +step:1650/1680 train_time:145544ms step_avg:88.21ms +step:1651/1680 train_time:145634ms step_avg:88.21ms +step:1652/1680 train_time:145723ms step_avg:88.21ms +step:1653/1680 train_time:145812ms step_avg:88.21ms +step:1654/1680 train_time:145901ms step_avg:88.21ms +step:1655/1680 train_time:145991ms step_avg:88.21ms +step:1656/1680 train_time:146079ms step_avg:88.21ms +step:1657/1680 train_time:146168ms step_avg:88.21ms +step:1658/1680 train_time:146257ms step_avg:88.21ms +step:1659/1680 train_time:146346ms step_avg:88.21ms +step:1660/1680 train_time:146436ms step_avg:88.21ms +step:1661/1680 train_time:146526ms step_avg:88.22ms +step:1662/1680 train_time:146615ms step_avg:88.22ms +step:1663/1680 train_time:146704ms step_avg:88.22ms +step:1664/1680 train_time:146793ms step_avg:88.22ms +step:1665/1680 train_time:146883ms step_avg:88.22ms +step:1666/1680 train_time:146973ms step_avg:88.22ms +step:1667/1680 train_time:147061ms step_avg:88.22ms +step:1668/1680 train_time:147150ms step_avg:88.22ms +step:1669/1680 train_time:147239ms step_avg:88.22ms +step:1670/1680 train_time:147328ms step_avg:88.22ms +step:1671/1680 train_time:147418ms step_avg:88.22ms +step:1672/1680 train_time:147507ms step_avg:88.22ms +step:1673/1680 train_time:147596ms step_avg:88.22ms +step:1674/1680 train_time:147686ms step_avg:88.22ms +step:1675/1680 train_time:147775ms step_avg:88.22ms +step:1676/1680 train_time:147864ms step_avg:88.22ms +step:1677/1680 train_time:147953ms step_avg:88.22ms +step:1678/1680 train_time:148043ms step_avg:88.23ms +step:1679/1680 train_time:148132ms step_avg:88.23ms +step:1680/1680 train_time:148220ms step_avg:88.23ms +step:1680/1680 val_loss:3.2823 train_time:148311ms step_avg:88.28ms +peak memory allocated: 30760 MiB reserved: 45914 MiB diff --git a/records/092725_BF16CE/f7c90ea9-95b0-4652-b933-a73edab09583.txt b/records/092725_BF16CE/f7c90ea9-95b0-4652-b933-a73edab09583.txt new file mode 100644 index 000000000..944011537 --- /dev/null +++ b/records/092725_BF16CE/f7c90ea9-95b0-4652-b933-a73edab09583.txt @@ -0,0 +1,3206 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + Though empirically small 1D params perform efficiently here: + NS approximately performs a magnitude normalization of the grad + This hyper-optimized class has faster execution time than the current impl of Adam for small params + + Custom distributed sizing: + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The scheduling is: + 1. reduce scatter smear_gate (1 param 7 padding params) + 2. reduce scatter attn_gate (10 params 6 padding params) + 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) + 4. reduce scatter attn/mlp round 2 (16 mlp params) + 5. wait on step 1, then compute NS of 1 and schedule all gather + 6. wait on step 2, then compute NS of 2 and schedule all gather + 7. wait on step 3, then compute NS of 3 and schedule all gather + GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] + GPUs that receive params of type attn reshape before NS + 8. wait on 4, then compute NS of 4 and schedule all gather + 9. wait for each all gather to complete and update params + Empirically, leading with small params provides an additional 0.2s improvement. + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, custom_sizing=True): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + # custom sizing requires 8 GPUs + if custom_sizing and dist.get_world_size()==8: + param_groups = self.generate_custom_param_groups(params) + else: + param_groups = self.generate_standard_param_groups(params) + super().__init__(param_groups, defaults) + + def generate_standard_param_groups(self, params): + """ + Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. + Creates one param group per size, while giving attn its own param group for resize op. + """ + params = list(params) + param_groups = [] + attn_subset = [p for p in params if p.module == 'attn'] + non_attn_subset = [p for p in params if p.module != 'attn'] + param_groups.append(dict(params=attn_subset)) + + sizes = {p.shape for p in non_attn_subset} + for size in sizes: + group_params = [p for p in non_attn_subset if p.shape == size] + param_groups.append(dict(params=group_params)) + return param_groups + + def generate_custom_param_groups(self, params): + """ + Implementation requires that a single GPU does not receive both attn + and mlp params when a param group is split across GPUs. + """ + module_ranks = { + 'smear_gate': 1, # 1 param + 'attn_gate': 2, # 10 params + 'attn': 3, # 10 params + 'mlp': 4, # 22 params + } + params = list(params) + params.sort(key=lambda x: module_ranks.get(x.module)) + idx = 0 + group_sizes = [1,10,16,16] + assert len(params)==sum(group_sizes) + param_groups = [] + for size in group_sizes: + group_params = params[idx:idx+size] + param_groups.append(dict(params=group_params)) + idx += size + return param_groups + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *p_example.shape), + dtype=p_example.dtype, + device=p_example.device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(p_example.grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + original_shape = batched_update_grads.shape + # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O + module_idx = start_idx if start_idx= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (self.dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # make matrices the same shape as MLP to enable batched call in optimizer + self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) + # label module to enable custom optimizer sizing + self.qkvo_w.module='attn' + with torch.no_grad(): + self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + # label module to enable custom optimizer sizing + self.attn_gate.weight.module = 'attn_gate' + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make matrices the same shape to enable batched call in optimizer + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + # label modules to enable custom optimizer sizing + self.c_fc.module='mlp' + self.c_proj.module='mlp' + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # label modules to enable custom optimizer sizing + self.smear_gate.weight.module = 'smear_gate' + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + # skip layer zero + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +gate_params = [p for n, p in model.named_parameters() if "gate" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws_long = args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws_long = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws_long > ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = torch.zeros((), device=device, dtype=torch.float32) + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Sep 27 13:36:42 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 27C P0 120W / 700W | 5856MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 24C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 22C P0 116W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 26C P0 119W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 24C P0 120W / 700W | 1518MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 177392 C /usr/bin/python 0MiB | +| 0 N/A N/A 177393 C /usr/bin/python 0MiB | +| 0 N/A N/A 177394 C /usr/bin/python 0MiB | +| 0 N/A N/A 177395 C /usr/bin/python 0MiB | +| 0 N/A N/A 177396 C /usr/bin/python 0MiB | +| 0 N/A N/A 177397 C /usr/bin/python 0MiB | +| 0 N/A N/A 177398 C /usr/bin/python 0MiB | +| 0 N/A N/A 177399 C /usr/bin/python 0MiB | +| 1 N/A N/A 177393 C /usr/bin/python 0MiB | +| 2 N/A N/A 177394 C /usr/bin/python 0MiB | +| 3 N/A N/A 177395 C /usr/bin/python 0MiB | +| 4 N/A N/A 177396 C /usr/bin/python 0MiB | +| 5 N/A N/A 177397 C /usr/bin/python 0MiB | +| 6 N/A N/A 177398 C /usr/bin/python 0MiB | +| 7 N/A N/A 177399 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:139ms step_avg:139.07ms +step:2/1680 train_time:161ms step_avg:80.25ms +step:3/1680 train_time:224ms step_avg:74.73ms +step:4/1680 train_time:309ms step_avg:77.27ms +step:5/1680 train_time:395ms step_avg:78.96ms +step:6/1680 train_time:480ms step_avg:80.07ms +step:7/1680 train_time:566ms step_avg:80.92ms +step:8/1680 train_time:653ms step_avg:81.59ms +step:9/1680 train_time:739ms step_avg:82.09ms +step:10/1680 train_time:825ms step_avg:82.45ms +step:11/1680 train_time:911ms step_avg:82.78ms +step:12/1680 train_time:999ms step_avg:83.22ms +step:13/1680 train_time:1091ms step_avg:83.89ms +step:14/1680 train_time:1180ms step_avg:84.29ms +step:15/1680 train_time:1270ms step_avg:84.67ms +step:16/1680 train_time:1357ms step_avg:84.81ms +step:17/1680 train_time:1444ms step_avg:84.91ms +step:18/1680 train_time:1530ms step_avg:85.00ms +step:19/1680 train_time:1616ms step_avg:85.06ms +step:20/1680 train_time:1702ms step_avg:85.11ms +step:21/1680 train_time:1789ms step_avg:85.20ms +step:22/1680 train_time:1876ms step_avg:85.26ms +step:23/1680 train_time:1963ms step_avg:85.34ms +step:24/1680 train_time:2051ms step_avg:85.45ms +step:25/1680 train_time:2140ms step_avg:85.61ms +step:26/1680 train_time:2230ms step_avg:85.76ms +step:27/1680 train_time:2319ms step_avg:85.88ms +step:28/1680 train_time:2406ms step_avg:85.93ms +step:29/1680 train_time:2493ms step_avg:85.96ms +step:30/1680 train_time:2579ms step_avg:85.98ms +step:31/1680 train_time:2666ms step_avg:85.99ms +step:32/1680 train_time:2753ms step_avg:86.04ms +step:33/1680 train_time:2840ms step_avg:86.06ms +step:34/1680 train_time:2927ms step_avg:86.09ms +step:35/1680 train_time:3015ms step_avg:86.14ms +step:36/1680 train_time:3105ms step_avg:86.24ms +step:37/1680 train_time:3193ms step_avg:86.30ms +step:38/1680 train_time:3282ms step_avg:86.36ms +step:39/1680 train_time:3369ms step_avg:86.39ms +step:40/1680 train_time:3456ms step_avg:86.41ms +step:41/1680 train_time:3543ms step_avg:86.42ms +step:42/1680 train_time:3630ms step_avg:86.42ms +step:43/1680 train_time:3718ms step_avg:86.45ms +step:44/1680 train_time:3804ms step_avg:86.46ms +step:45/1680 train_time:3891ms step_avg:86.48ms +step:46/1680 train_time:3979ms step_avg:86.49ms +step:47/1680 train_time:4066ms step_avg:86.50ms +step:48/1680 train_time:4154ms step_avg:86.54ms +step:49/1680 train_time:4241ms step_avg:86.55ms +step:50/1680 train_time:4329ms step_avg:86.58ms +step:51/1680 train_time:4416ms step_avg:86.59ms +step:52/1680 train_time:4504ms step_avg:86.61ms +step:53/1680 train_time:4591ms step_avg:86.62ms +step:54/1680 train_time:4678ms step_avg:86.63ms +step:55/1680 train_time:4765ms step_avg:86.65ms +step:56/1680 train_time:4853ms step_avg:86.66ms +step:57/1680 train_time:4940ms step_avg:86.67ms +step:58/1680 train_time:5026ms step_avg:86.66ms +step:59/1680 train_time:5114ms step_avg:86.68ms +step:60/1680 train_time:5201ms step_avg:86.69ms +step:61/1680 train_time:5290ms step_avg:86.72ms +step:62/1680 train_time:5377ms step_avg:86.72ms +step:63/1680 train_time:5464ms step_avg:86.74ms +step:64/1680 train_time:5551ms step_avg:86.73ms +step:65/1680 train_time:5638ms step_avg:86.74ms +step:66/1680 train_time:5725ms step_avg:86.75ms +step:67/1680 train_time:5812ms step_avg:86.75ms +step:68/1680 train_time:5899ms step_avg:86.75ms +step:69/1680 train_time:5987ms step_avg:86.76ms +step:70/1680 train_time:6074ms step_avg:86.77ms +step:71/1680 train_time:6161ms step_avg:86.78ms +step:72/1680 train_time:6248ms step_avg:86.78ms +step:73/1680 train_time:6335ms step_avg:86.78ms +step:74/1680 train_time:6422ms step_avg:86.79ms +step:75/1680 train_time:6510ms step_avg:86.80ms +step:76/1680 train_time:6597ms step_avg:86.80ms +step:77/1680 train_time:6684ms step_avg:86.81ms +step:78/1680 train_time:6771ms step_avg:86.81ms +step:79/1680 train_time:6859ms step_avg:86.82ms +step:80/1680 train_time:6946ms step_avg:86.82ms +step:81/1680 train_time:7032ms step_avg:86.82ms +step:82/1680 train_time:7120ms step_avg:86.83ms +step:83/1680 train_time:7207ms step_avg:86.83ms +step:84/1680 train_time:7295ms step_avg:86.84ms +step:85/1680 train_time:7382ms step_avg:86.85ms +step:86/1680 train_time:7469ms step_avg:86.85ms +step:87/1680 train_time:7556ms step_avg:86.85ms +step:88/1680 train_time:7643ms step_avg:86.85ms +step:89/1680 train_time:7729ms step_avg:86.85ms +step:90/1680 train_time:7817ms step_avg:86.85ms +step:91/1680 train_time:7904ms step_avg:86.86ms +step:92/1680 train_time:7992ms step_avg:86.87ms +step:93/1680 train_time:8079ms step_avg:86.87ms +step:94/1680 train_time:8166ms step_avg:86.87ms +step:95/1680 train_time:8253ms step_avg:86.88ms +step:96/1680 train_time:8340ms step_avg:86.88ms +step:97/1680 train_time:8427ms step_avg:86.88ms +step:98/1680 train_time:8514ms step_avg:86.88ms +step:99/1680 train_time:8601ms step_avg:86.88ms +step:100/1680 train_time:8689ms step_avg:86.89ms +step:101/1680 train_time:8776ms step_avg:86.89ms +step:102/1680 train_time:8863ms step_avg:86.89ms +step:103/1680 train_time:8950ms step_avg:86.89ms +step:104/1680 train_time:9037ms step_avg:86.89ms +step:105/1680 train_time:9124ms step_avg:86.90ms +step:106/1680 train_time:9212ms step_avg:86.91ms +step:107/1680 train_time:9300ms step_avg:86.91ms +step:108/1680 train_time:9387ms step_avg:86.92ms +step:109/1680 train_time:9474ms step_avg:86.92ms +step:110/1680 train_time:9562ms step_avg:86.92ms +step:111/1680 train_time:9649ms step_avg:86.93ms +step:112/1680 train_time:9736ms step_avg:86.93ms +step:113/1680 train_time:9823ms step_avg:86.93ms +step:114/1680 train_time:9909ms step_avg:86.92ms +step:115/1680 train_time:9997ms step_avg:86.93ms +step:116/1680 train_time:10084ms step_avg:86.93ms +step:117/1680 train_time:10171ms step_avg:86.94ms +step:118/1680 train_time:10258ms step_avg:86.93ms +step:119/1680 train_time:10346ms step_avg:86.94ms +step:120/1680 train_time:10433ms step_avg:86.94ms +step:121/1680 train_time:10519ms step_avg:86.94ms +step:122/1680 train_time:10606ms step_avg:86.94ms +step:123/1680 train_time:10694ms step_avg:86.94ms +step:124/1680 train_time:10781ms step_avg:86.94ms +step:125/1680 train_time:10868ms step_avg:86.94ms +step:125/1680 val_loss:4.3077 train_time:10956ms step_avg:87.65ms +step:126/1680 train_time:10975ms step_avg:87.11ms +step:127/1680 train_time:11047ms step_avg:86.98ms +step:128/1680 train_time:11142ms step_avg:87.05ms +step:129/1680 train_time:11231ms step_avg:87.06ms +step:130/1680 train_time:11318ms step_avg:87.07ms +step:131/1680 train_time:11405ms step_avg:87.06ms +step:132/1680 train_time:11492ms step_avg:87.06ms +step:133/1680 train_time:11577ms step_avg:87.05ms +step:134/1680 train_time:11663ms step_avg:87.04ms +step:135/1680 train_time:11749ms step_avg:87.03ms +step:136/1680 train_time:11835ms step_avg:87.02ms +step:137/1680 train_time:11922ms step_avg:87.02ms +step:138/1680 train_time:12009ms step_avg:87.02ms +step:139/1680 train_time:12098ms step_avg:87.03ms +step:140/1680 train_time:12186ms step_avg:87.04ms +step:141/1680 train_time:12274ms step_avg:87.05ms +step:142/1680 train_time:12361ms step_avg:87.05ms +step:143/1680 train_time:12449ms step_avg:87.05ms +step:144/1680 train_time:12535ms step_avg:87.05ms +step:145/1680 train_time:12621ms step_avg:87.04ms +step:146/1680 train_time:12708ms step_avg:87.04ms +step:147/1680 train_time:12794ms step_avg:87.04ms +step:148/1680 train_time:12880ms step_avg:87.03ms +step:149/1680 train_time:12968ms step_avg:87.03ms +step:150/1680 train_time:13055ms step_avg:87.04ms +step:151/1680 train_time:13143ms step_avg:87.04ms +step:152/1680 train_time:13230ms step_avg:87.04ms +step:153/1680 train_time:13318ms step_avg:87.04ms +step:154/1680 train_time:13405ms step_avg:87.04ms +step:155/1680 train_time:13492ms step_avg:87.05ms +step:156/1680 train_time:13579ms step_avg:87.04ms +step:157/1680 train_time:13665ms step_avg:87.04ms +step:158/1680 train_time:13752ms step_avg:87.04ms +step:159/1680 train_time:13838ms step_avg:87.03ms +step:160/1680 train_time:13925ms step_avg:87.03ms +step:161/1680 train_time:14012ms step_avg:87.03ms +step:162/1680 train_time:14099ms step_avg:87.03ms +step:163/1680 train_time:14187ms step_avg:87.04ms +step:164/1680 train_time:14275ms step_avg:87.04ms +step:165/1680 train_time:14362ms step_avg:87.05ms +step:166/1680 train_time:14449ms step_avg:87.04ms +step:167/1680 train_time:14536ms step_avg:87.04ms +step:168/1680 train_time:14623ms step_avg:87.04ms +step:169/1680 train_time:14709ms step_avg:87.04ms +step:170/1680 train_time:14797ms step_avg:87.04ms +step:171/1680 train_time:14884ms step_avg:87.04ms +step:172/1680 train_time:14971ms step_avg:87.04ms +step:173/1680 train_time:15058ms step_avg:87.04ms +step:174/1680 train_time:15145ms step_avg:87.04ms +step:175/1680 train_time:15233ms step_avg:87.04ms +step:176/1680 train_time:15321ms step_avg:87.05ms +step:177/1680 train_time:15408ms step_avg:87.05ms +step:178/1680 train_time:15495ms step_avg:87.05ms +step:179/1680 train_time:15581ms step_avg:87.05ms +step:180/1680 train_time:15668ms step_avg:87.04ms +step:181/1680 train_time:15754ms step_avg:87.04ms +step:182/1680 train_time:15841ms step_avg:87.04ms +step:183/1680 train_time:15928ms step_avg:87.04ms +step:184/1680 train_time:16016ms step_avg:87.04ms +step:185/1680 train_time:16102ms step_avg:87.04ms +step:186/1680 train_time:16190ms step_avg:87.05ms +step:187/1680 train_time:16278ms step_avg:87.05ms +step:188/1680 train_time:16365ms step_avg:87.05ms +step:189/1680 train_time:16452ms step_avg:87.05ms +step:190/1680 train_time:16539ms step_avg:87.05ms +step:191/1680 train_time:16627ms step_avg:87.05ms +step:192/1680 train_time:16714ms step_avg:87.05ms +step:193/1680 train_time:16801ms step_avg:87.05ms +step:194/1680 train_time:16888ms step_avg:87.05ms +step:195/1680 train_time:16975ms step_avg:87.05ms +step:196/1680 train_time:17062ms step_avg:87.05ms +step:197/1680 train_time:17149ms step_avg:87.05ms +step:198/1680 train_time:17236ms step_avg:87.05ms +step:199/1680 train_time:17323ms step_avg:87.05ms +step:200/1680 train_time:17410ms step_avg:87.05ms +step:201/1680 train_time:17497ms step_avg:87.05ms +step:202/1680 train_time:17584ms step_avg:87.05ms +step:203/1680 train_time:17672ms step_avg:87.05ms +step:204/1680 train_time:17759ms step_avg:87.05ms +step:205/1680 train_time:17846ms step_avg:87.05ms +step:206/1680 train_time:17933ms step_avg:87.05ms +step:207/1680 train_time:18020ms step_avg:87.05ms +step:208/1680 train_time:18107ms step_avg:87.05ms +step:209/1680 train_time:18194ms step_avg:87.05ms +step:210/1680 train_time:18280ms step_avg:87.05ms +step:211/1680 train_time:18367ms step_avg:87.05ms +step:212/1680 train_time:18455ms step_avg:87.05ms +step:213/1680 train_time:18542ms step_avg:87.05ms +step:214/1680 train_time:18628ms step_avg:87.05ms +step:215/1680 train_time:18715ms step_avg:87.05ms +step:216/1680 train_time:18803ms step_avg:87.05ms +step:217/1680 train_time:18890ms step_avg:87.05ms +step:218/1680 train_time:18977ms step_avg:87.05ms +step:219/1680 train_time:19064ms step_avg:87.05ms +step:220/1680 train_time:19152ms step_avg:87.05ms +step:221/1680 train_time:19238ms step_avg:87.05ms +step:222/1680 train_time:19325ms step_avg:87.05ms +step:223/1680 train_time:19412ms step_avg:87.05ms +step:224/1680 train_time:19499ms step_avg:87.05ms +step:225/1680 train_time:19586ms step_avg:87.05ms +step:226/1680 train_time:19673ms step_avg:87.05ms +step:227/1680 train_time:19761ms step_avg:87.05ms +step:228/1680 train_time:19848ms step_avg:87.05ms +step:229/1680 train_time:19935ms step_avg:87.05ms +step:230/1680 train_time:20022ms step_avg:87.05ms +step:231/1680 train_time:20109ms step_avg:87.05ms +step:232/1680 train_time:20196ms step_avg:87.05ms +step:233/1680 train_time:20283ms step_avg:87.05ms +step:234/1680 train_time:20370ms step_avg:87.05ms +step:235/1680 train_time:20457ms step_avg:87.05ms +step:236/1680 train_time:20544ms step_avg:87.05ms +step:237/1680 train_time:20631ms step_avg:87.05ms +step:238/1680 train_time:20717ms step_avg:87.05ms +step:239/1680 train_time:20805ms step_avg:87.05ms +step:240/1680 train_time:20893ms step_avg:87.05ms +step:241/1680 train_time:20979ms step_avg:87.05ms +step:242/1680 train_time:21066ms step_avg:87.05ms +step:243/1680 train_time:21153ms step_avg:87.05ms +step:244/1680 train_time:21240ms step_avg:87.05ms +step:245/1680 train_time:21327ms step_avg:87.05ms +step:246/1680 train_time:21414ms step_avg:87.05ms +step:247/1680 train_time:21501ms step_avg:87.05ms +step:248/1680 train_time:21588ms step_avg:87.05ms +step:249/1680 train_time:21675ms step_avg:87.05ms +step:250/1680 train_time:21762ms step_avg:87.05ms +step:250/1680 val_loss:3.9735 train_time:21851ms step_avg:87.40ms +step:251/1680 train_time:21870ms step_avg:87.13ms +step:252/1680 train_time:21939ms step_avg:87.06ms +step:253/1680 train_time:22034ms step_avg:87.09ms +step:254/1680 train_time:22122ms step_avg:87.09ms +step:255/1680 train_time:22209ms step_avg:87.09ms +step:256/1680 train_time:22294ms step_avg:87.09ms +step:257/1680 train_time:22381ms step_avg:87.08ms +step:258/1680 train_time:22466ms step_avg:87.08ms +step:259/1680 train_time:22552ms step_avg:87.07ms +step:260/1680 train_time:22639ms step_avg:87.07ms +step:261/1680 train_time:22725ms step_avg:87.07ms +step:262/1680 train_time:22813ms step_avg:87.07ms +step:263/1680 train_time:22900ms step_avg:87.07ms +step:264/1680 train_time:22989ms step_avg:87.08ms +step:265/1680 train_time:23078ms step_avg:87.09ms +step:266/1680 train_time:23165ms step_avg:87.09ms +step:267/1680 train_time:23252ms step_avg:87.09ms +step:268/1680 train_time:23339ms step_avg:87.08ms +step:269/1680 train_time:23425ms step_avg:87.08ms +step:270/1680 train_time:23511ms step_avg:87.08ms +step:271/1680 train_time:23597ms step_avg:87.07ms +step:272/1680 train_time:23684ms step_avg:87.07ms +step:273/1680 train_time:23770ms step_avg:87.07ms +step:274/1680 train_time:23857ms step_avg:87.07ms +step:275/1680 train_time:23945ms step_avg:87.07ms +step:276/1680 train_time:24033ms step_avg:87.08ms +step:277/1680 train_time:24120ms step_avg:87.08ms +step:278/1680 train_time:24207ms step_avg:87.08ms +step:279/1680 train_time:24295ms step_avg:87.08ms +step:280/1680 train_time:24381ms step_avg:87.08ms +step:281/1680 train_time:24469ms step_avg:87.08ms +step:282/1680 train_time:24555ms step_avg:87.07ms +step:283/1680 train_time:24641ms step_avg:87.07ms +step:284/1680 train_time:24728ms step_avg:87.07ms +step:285/1680 train_time:24815ms step_avg:87.07ms +step:286/1680 train_time:24902ms step_avg:87.07ms +step:287/1680 train_time:24989ms step_avg:87.07ms +step:288/1680 train_time:25078ms step_avg:87.08ms +step:289/1680 train_time:25165ms step_avg:87.08ms +step:290/1680 train_time:25252ms step_avg:87.08ms +step:291/1680 train_time:25339ms step_avg:87.08ms +step:292/1680 train_time:25425ms step_avg:87.07ms +step:293/1680 train_time:25512ms step_avg:87.07ms +step:294/1680 train_time:25599ms step_avg:87.07ms +step:295/1680 train_time:25685ms step_avg:87.07ms +step:296/1680 train_time:25772ms step_avg:87.07ms +step:297/1680 train_time:25859ms step_avg:87.07ms +step:298/1680 train_time:25945ms step_avg:87.06ms +step:299/1680 train_time:26032ms step_avg:87.06ms +step:300/1680 train_time:26120ms step_avg:87.07ms +step:301/1680 train_time:26207ms step_avg:87.07ms +step:302/1680 train_time:26294ms step_avg:87.07ms +step:303/1680 train_time:26381ms step_avg:87.07ms +step:304/1680 train_time:26468ms step_avg:87.07ms +step:305/1680 train_time:26555ms step_avg:87.07ms +step:306/1680 train_time:26642ms step_avg:87.06ms +step:307/1680 train_time:26729ms step_avg:87.06ms +step:308/1680 train_time:26816ms step_avg:87.06ms +step:309/1680 train_time:26903ms step_avg:87.06ms +step:310/1680 train_time:26990ms step_avg:87.07ms +step:311/1680 train_time:27077ms step_avg:87.07ms +step:312/1680 train_time:27165ms step_avg:87.07ms +step:313/1680 train_time:27251ms step_avg:87.07ms +step:314/1680 train_time:27339ms step_avg:87.07ms +step:315/1680 train_time:27426ms step_avg:87.07ms +step:316/1680 train_time:27513ms step_avg:87.07ms +step:317/1680 train_time:27599ms step_avg:87.06ms +step:318/1680 train_time:27686ms step_avg:87.06ms +step:319/1680 train_time:27773ms step_avg:87.06ms +step:320/1680 train_time:27859ms step_avg:87.06ms +step:321/1680 train_time:27946ms step_avg:87.06ms +step:322/1680 train_time:28033ms step_avg:87.06ms +step:323/1680 train_time:28120ms step_avg:87.06ms +step:324/1680 train_time:28207ms step_avg:87.06ms +step:325/1680 train_time:28294ms step_avg:87.06ms +step:326/1680 train_time:28382ms step_avg:87.06ms +step:327/1680 train_time:28469ms step_avg:87.06ms +step:328/1680 train_time:28556ms step_avg:87.06ms +step:329/1680 train_time:28642ms step_avg:87.06ms +step:330/1680 train_time:28729ms step_avg:87.06ms +step:331/1680 train_time:28817ms step_avg:87.06ms +step:332/1680 train_time:28903ms step_avg:87.06ms +step:333/1680 train_time:28990ms step_avg:87.06ms +step:334/1680 train_time:29077ms step_avg:87.06ms +step:335/1680 train_time:29164ms step_avg:87.06ms +step:336/1680 train_time:29252ms step_avg:87.06ms +step:337/1680 train_time:29339ms step_avg:87.06ms +step:338/1680 train_time:29426ms step_avg:87.06ms +step:339/1680 train_time:29513ms step_avg:87.06ms +step:340/1680 train_time:29600ms step_avg:87.06ms +step:341/1680 train_time:29687ms step_avg:87.06ms +step:342/1680 train_time:29773ms step_avg:87.06ms +step:343/1680 train_time:29860ms step_avg:87.06ms +step:344/1680 train_time:29947ms step_avg:87.06ms +step:345/1680 train_time:30034ms step_avg:87.06ms +step:346/1680 train_time:30121ms step_avg:87.06ms +step:347/1680 train_time:30209ms step_avg:87.06ms +step:348/1680 train_time:30295ms step_avg:87.06ms +step:349/1680 train_time:30383ms step_avg:87.06ms +step:350/1680 train_time:30470ms step_avg:87.06ms +step:351/1680 train_time:30557ms step_avg:87.06ms +step:352/1680 train_time:30644ms step_avg:87.06ms +step:353/1680 train_time:30731ms step_avg:87.06ms +step:354/1680 train_time:30817ms step_avg:87.05ms +step:355/1680 train_time:30904ms step_avg:87.05ms +step:356/1680 train_time:30990ms step_avg:87.05ms +step:357/1680 train_time:31078ms step_avg:87.05ms +step:358/1680 train_time:31165ms step_avg:87.05ms +step:359/1680 train_time:31252ms step_avg:87.05ms +step:360/1680 train_time:31339ms step_avg:87.05ms +step:361/1680 train_time:31426ms step_avg:87.05ms +step:362/1680 train_time:31512ms step_avg:87.05ms +step:363/1680 train_time:31599ms step_avg:87.05ms +step:364/1680 train_time:31686ms step_avg:87.05ms +step:365/1680 train_time:31774ms step_avg:87.05ms +step:366/1680 train_time:31862ms step_avg:87.05ms +step:367/1680 train_time:31949ms step_avg:87.05ms +step:368/1680 train_time:32036ms step_avg:87.05ms +step:369/1680 train_time:32123ms step_avg:87.05ms +step:370/1680 train_time:32209ms step_avg:87.05ms +step:371/1680 train_time:32297ms step_avg:87.05ms +step:372/1680 train_time:32383ms step_avg:87.05ms +step:373/1680 train_time:32470ms step_avg:87.05ms +step:374/1680 train_time:32558ms step_avg:87.05ms +step:375/1680 train_time:32645ms step_avg:87.05ms +step:375/1680 val_loss:3.8179 train_time:32733ms step_avg:87.29ms +step:376/1680 train_time:32753ms step_avg:87.11ms +step:377/1680 train_time:32824ms step_avg:87.07ms +step:378/1680 train_time:32915ms step_avg:87.08ms +step:379/1680 train_time:33004ms step_avg:87.08ms +step:380/1680 train_time:33090ms step_avg:87.08ms +step:381/1680 train_time:33176ms step_avg:87.08ms +step:382/1680 train_time:33262ms step_avg:87.07ms +step:383/1680 train_time:33348ms step_avg:87.07ms +step:384/1680 train_time:33434ms step_avg:87.07ms +step:385/1680 train_time:33520ms step_avg:87.07ms +step:386/1680 train_time:33606ms step_avg:87.06ms +step:387/1680 train_time:33694ms step_avg:87.06ms +step:388/1680 train_time:33781ms step_avg:87.07ms +step:389/1680 train_time:33870ms step_avg:87.07ms +step:390/1680 train_time:33960ms step_avg:87.08ms +step:391/1680 train_time:34047ms step_avg:87.08ms +step:392/1680 train_time:34133ms step_avg:87.08ms +step:393/1680 train_time:34220ms step_avg:87.07ms +step:394/1680 train_time:34306ms step_avg:87.07ms +step:395/1680 train_time:34392ms step_avg:87.07ms +step:396/1680 train_time:34479ms step_avg:87.07ms +step:397/1680 train_time:34565ms step_avg:87.07ms +step:398/1680 train_time:34652ms step_avg:87.06ms +step:399/1680 train_time:34739ms step_avg:87.06ms +step:400/1680 train_time:34827ms step_avg:87.07ms +step:401/1680 train_time:34914ms step_avg:87.07ms +step:402/1680 train_time:35002ms step_avg:87.07ms +step:403/1680 train_time:35089ms step_avg:87.07ms +step:404/1680 train_time:35176ms step_avg:87.07ms +step:405/1680 train_time:35262ms step_avg:87.07ms +step:406/1680 train_time:35349ms step_avg:87.07ms +step:407/1680 train_time:35436ms step_avg:87.07ms +step:408/1680 train_time:35523ms step_avg:87.07ms +step:409/1680 train_time:35609ms step_avg:87.06ms +step:410/1680 train_time:35696ms step_avg:87.06ms +step:411/1680 train_time:35783ms step_avg:87.06ms +step:412/1680 train_time:35870ms step_avg:87.06ms +step:413/1680 train_time:35958ms step_avg:87.07ms +step:414/1680 train_time:36045ms step_avg:87.07ms +step:415/1680 train_time:36133ms step_avg:87.07ms +step:416/1680 train_time:36220ms step_avg:87.07ms +step:417/1680 train_time:36306ms step_avg:87.07ms +step:418/1680 train_time:36393ms step_avg:87.06ms +step:419/1680 train_time:36479ms step_avg:87.06ms +step:420/1680 train_time:36567ms step_avg:87.06ms +step:421/1680 train_time:36653ms step_avg:87.06ms +step:422/1680 train_time:36740ms step_avg:87.06ms +step:423/1680 train_time:36828ms step_avg:87.06ms +step:424/1680 train_time:36915ms step_avg:87.06ms +step:425/1680 train_time:37002ms step_avg:87.06ms +step:426/1680 train_time:37089ms step_avg:87.06ms +step:427/1680 train_time:37176ms step_avg:87.06ms +step:428/1680 train_time:37264ms step_avg:87.07ms +step:429/1680 train_time:37351ms step_avg:87.06ms +step:430/1680 train_time:37438ms step_avg:87.07ms +step:431/1680 train_time:37525ms step_avg:87.07ms +step:432/1680 train_time:37612ms step_avg:87.06ms +step:433/1680 train_time:37698ms step_avg:87.06ms +step:434/1680 train_time:37785ms step_avg:87.06ms +step:435/1680 train_time:37872ms step_avg:87.06ms +step:436/1680 train_time:37960ms step_avg:87.07ms +step:437/1680 train_time:38048ms step_avg:87.07ms +step:438/1680 train_time:38134ms step_avg:87.06ms +step:439/1680 train_time:38222ms step_avg:87.07ms +step:440/1680 train_time:38309ms step_avg:87.07ms +step:441/1680 train_time:38396ms step_avg:87.07ms +step:442/1680 train_time:38483ms step_avg:87.07ms +step:443/1680 train_time:38569ms step_avg:87.06ms +step:444/1680 train_time:38656ms step_avg:87.06ms +step:445/1680 train_time:38743ms step_avg:87.06ms +step:446/1680 train_time:38830ms step_avg:87.06ms +step:447/1680 train_time:38917ms step_avg:87.06ms +step:448/1680 train_time:39004ms step_avg:87.06ms +step:449/1680 train_time:39091ms step_avg:87.06ms +step:450/1680 train_time:39178ms step_avg:87.06ms +step:451/1680 train_time:39265ms step_avg:87.06ms +step:452/1680 train_time:39353ms step_avg:87.06ms +step:453/1680 train_time:39440ms step_avg:87.06ms +step:454/1680 train_time:39527ms step_avg:87.06ms +step:455/1680 train_time:39613ms step_avg:87.06ms +step:456/1680 train_time:39700ms step_avg:87.06ms +step:457/1680 train_time:39786ms step_avg:87.06ms +step:458/1680 train_time:39873ms step_avg:87.06ms +step:459/1680 train_time:39961ms step_avg:87.06ms +step:460/1680 train_time:40049ms step_avg:87.06ms +step:461/1680 train_time:40135ms step_avg:87.06ms +step:462/1680 train_time:40223ms step_avg:87.06ms +step:463/1680 train_time:40310ms step_avg:87.06ms +step:464/1680 train_time:40397ms step_avg:87.06ms +step:465/1680 train_time:40484ms step_avg:87.06ms +step:466/1680 train_time:40571ms step_avg:87.06ms +step:467/1680 train_time:40658ms step_avg:87.06ms +step:468/1680 train_time:40744ms step_avg:87.06ms +step:469/1680 train_time:40832ms step_avg:87.06ms +step:470/1680 train_time:40919ms step_avg:87.06ms +step:471/1680 train_time:41005ms step_avg:87.06ms +step:472/1680 train_time:41093ms step_avg:87.06ms +step:473/1680 train_time:41180ms step_avg:87.06ms +step:474/1680 train_time:41267ms step_avg:87.06ms +step:475/1680 train_time:41354ms step_avg:87.06ms +step:476/1680 train_time:41442ms step_avg:87.06ms +step:477/1680 train_time:41528ms step_avg:87.06ms +step:478/1680 train_time:41615ms step_avg:87.06ms +step:479/1680 train_time:41702ms step_avg:87.06ms +step:480/1680 train_time:41789ms step_avg:87.06ms +step:481/1680 train_time:41877ms step_avg:87.06ms +step:482/1680 train_time:41964ms step_avg:87.06ms +step:483/1680 train_time:42052ms step_avg:87.06ms +step:484/1680 train_time:42139ms step_avg:87.06ms +step:485/1680 train_time:42225ms step_avg:87.06ms +step:486/1680 train_time:42313ms step_avg:87.06ms +step:487/1680 train_time:42400ms step_avg:87.06ms +step:488/1680 train_time:42487ms step_avg:87.06ms +step:489/1680 train_time:42573ms step_avg:87.06ms +step:490/1680 train_time:42660ms step_avg:87.06ms +step:491/1680 train_time:42747ms step_avg:87.06ms +step:492/1680 train_time:42834ms step_avg:87.06ms +step:493/1680 train_time:42923ms step_avg:87.07ms +step:494/1680 train_time:43009ms step_avg:87.06ms +step:495/1680 train_time:43097ms step_avg:87.06ms +step:496/1680 train_time:43184ms step_avg:87.06ms +step:497/1680 train_time:43271ms step_avg:87.06ms +step:498/1680 train_time:43358ms step_avg:87.06ms +step:499/1680 train_time:43446ms step_avg:87.07ms +step:500/1680 train_time:43533ms step_avg:87.07ms +step:500/1680 val_loss:3.7175 train_time:43621ms step_avg:87.24ms +step:501/1680 train_time:43640ms step_avg:87.11ms +step:502/1680 train_time:43710ms step_avg:87.07ms +step:503/1680 train_time:43800ms step_avg:87.08ms +step:504/1680 train_time:43888ms step_avg:87.08ms +step:505/1680 train_time:43975ms step_avg:87.08ms +step:506/1680 train_time:44061ms step_avg:87.08ms +step:507/1680 train_time:44146ms step_avg:87.07ms +step:508/1680 train_time:44233ms step_avg:87.07ms +step:509/1680 train_time:44319ms step_avg:87.07ms +step:510/1680 train_time:44406ms step_avg:87.07ms +step:511/1680 train_time:44493ms step_avg:87.07ms +step:512/1680 train_time:44580ms step_avg:87.07ms +step:513/1680 train_time:44668ms step_avg:87.07ms +step:514/1680 train_time:44757ms step_avg:87.08ms +step:515/1680 train_time:44845ms step_avg:87.08ms +step:516/1680 train_time:44932ms step_avg:87.08ms +step:517/1680 train_time:45019ms step_avg:87.08ms +step:518/1680 train_time:45106ms step_avg:87.08ms +step:519/1680 train_time:45192ms step_avg:87.07ms +step:520/1680 train_time:45278ms step_avg:87.07ms +step:521/1680 train_time:45365ms step_avg:87.07ms +step:522/1680 train_time:45451ms step_avg:87.07ms +step:523/1680 train_time:45539ms step_avg:87.07ms +step:524/1680 train_time:45626ms step_avg:87.07ms +step:525/1680 train_time:45715ms step_avg:87.08ms +step:526/1680 train_time:45802ms step_avg:87.08ms +step:527/1680 train_time:45889ms step_avg:87.08ms +step:528/1680 train_time:45977ms step_avg:87.08ms +step:529/1680 train_time:46063ms step_avg:87.08ms +step:530/1680 train_time:46150ms step_avg:87.08ms +step:531/1680 train_time:46237ms step_avg:87.08ms +step:532/1680 train_time:46324ms step_avg:87.07ms +step:533/1680 train_time:46410ms step_avg:87.07ms +step:534/1680 train_time:46497ms step_avg:87.07ms +step:535/1680 train_time:46585ms step_avg:87.07ms +step:536/1680 train_time:46673ms step_avg:87.08ms +step:537/1680 train_time:46760ms step_avg:87.08ms +step:538/1680 train_time:46847ms step_avg:87.08ms +step:539/1680 train_time:46935ms step_avg:87.08ms +step:540/1680 train_time:47023ms step_avg:87.08ms +step:541/1680 train_time:47109ms step_avg:87.08ms +step:542/1680 train_time:47196ms step_avg:87.08ms +step:543/1680 train_time:47283ms step_avg:87.08ms +step:544/1680 train_time:47370ms step_avg:87.08ms +step:545/1680 train_time:47456ms step_avg:87.08ms +step:546/1680 train_time:47543ms step_avg:87.08ms +step:547/1680 train_time:47630ms step_avg:87.08ms +step:548/1680 train_time:47718ms step_avg:87.08ms +step:549/1680 train_time:47806ms step_avg:87.08ms +step:550/1680 train_time:47894ms step_avg:87.08ms +step:551/1680 train_time:47983ms step_avg:87.08ms +step:552/1680 train_time:48071ms step_avg:87.09ms +step:553/1680 train_time:48159ms step_avg:87.09ms +step:554/1680 train_time:48246ms step_avg:87.09ms +step:555/1680 train_time:48334ms step_avg:87.09ms +step:556/1680 train_time:48423ms step_avg:87.09ms +step:557/1680 train_time:48510ms step_avg:87.09ms +step:558/1680 train_time:48598ms step_avg:87.09ms +step:559/1680 train_time:48686ms step_avg:87.10ms +step:560/1680 train_time:48774ms step_avg:87.10ms +step:561/1680 train_time:48862ms step_avg:87.10ms +step:562/1680 train_time:48950ms step_avg:87.10ms +step:563/1680 train_time:49038ms step_avg:87.10ms +step:564/1680 train_time:49127ms step_avg:87.10ms +step:565/1680 train_time:49215ms step_avg:87.11ms +step:566/1680 train_time:49303ms step_avg:87.11ms +step:567/1680 train_time:49390ms step_avg:87.11ms +step:568/1680 train_time:49478ms step_avg:87.11ms +step:569/1680 train_time:49566ms step_avg:87.11ms +step:570/1680 train_time:49654ms step_avg:87.11ms +step:571/1680 train_time:49742ms step_avg:87.11ms +step:572/1680 train_time:49830ms step_avg:87.12ms +step:573/1680 train_time:49918ms step_avg:87.12ms +step:574/1680 train_time:50006ms step_avg:87.12ms +step:575/1680 train_time:50094ms step_avg:87.12ms +step:576/1680 train_time:50182ms step_avg:87.12ms +step:577/1680 train_time:50270ms step_avg:87.12ms +step:578/1680 train_time:50358ms step_avg:87.13ms +step:579/1680 train_time:50446ms step_avg:87.13ms +step:580/1680 train_time:50534ms step_avg:87.13ms +step:581/1680 train_time:50622ms step_avg:87.13ms +step:582/1680 train_time:50710ms step_avg:87.13ms +step:583/1680 train_time:50799ms step_avg:87.13ms +step:584/1680 train_time:50886ms step_avg:87.13ms +step:585/1680 train_time:50974ms step_avg:87.14ms +step:586/1680 train_time:51062ms step_avg:87.14ms +step:587/1680 train_time:51150ms step_avg:87.14ms +step:588/1680 train_time:51238ms step_avg:87.14ms +step:589/1680 train_time:51326ms step_avg:87.14ms +step:590/1680 train_time:51414ms step_avg:87.14ms +step:591/1680 train_time:51502ms step_avg:87.14ms +step:592/1680 train_time:51591ms step_avg:87.15ms +step:593/1680 train_time:51678ms step_avg:87.15ms +step:594/1680 train_time:51766ms step_avg:87.15ms +step:595/1680 train_time:51854ms step_avg:87.15ms +step:596/1680 train_time:51942ms step_avg:87.15ms +step:597/1680 train_time:52029ms step_avg:87.15ms +step:598/1680 train_time:52118ms step_avg:87.15ms +step:599/1680 train_time:52205ms step_avg:87.15ms +step:600/1680 train_time:52294ms step_avg:87.16ms +step:601/1680 train_time:52382ms step_avg:87.16ms +step:602/1680 train_time:52469ms step_avg:87.16ms +step:603/1680 train_time:52557ms step_avg:87.16ms +step:604/1680 train_time:52645ms step_avg:87.16ms +step:605/1680 train_time:52733ms step_avg:87.16ms +step:606/1680 train_time:52821ms step_avg:87.16ms +step:607/1680 train_time:52909ms step_avg:87.16ms +step:608/1680 train_time:52997ms step_avg:87.17ms +step:609/1680 train_time:53086ms step_avg:87.17ms +step:610/1680 train_time:53173ms step_avg:87.17ms +step:611/1680 train_time:53262ms step_avg:87.17ms +step:612/1680 train_time:53349ms step_avg:87.17ms +step:613/1680 train_time:53437ms step_avg:87.17ms +step:614/1680 train_time:53526ms step_avg:87.18ms +step:615/1680 train_time:53614ms step_avg:87.18ms +step:616/1680 train_time:53702ms step_avg:87.18ms +step:617/1680 train_time:53790ms step_avg:87.18ms +step:618/1680 train_time:53878ms step_avg:87.18ms +step:619/1680 train_time:53966ms step_avg:87.18ms +step:620/1680 train_time:54054ms step_avg:87.18ms +step:621/1680 train_time:54142ms step_avg:87.19ms +step:622/1680 train_time:54230ms step_avg:87.19ms +step:623/1680 train_time:54318ms step_avg:87.19ms +step:624/1680 train_time:54406ms step_avg:87.19ms +step:625/1680 train_time:54494ms step_avg:87.19ms +step:625/1680 val_loss:3.6157 train_time:54584ms step_avg:87.33ms +step:626/1680 train_time:54604ms step_avg:87.23ms +step:627/1680 train_time:54674ms step_avg:87.20ms +step:628/1680 train_time:54763ms step_avg:87.20ms +step:629/1680 train_time:54854ms step_avg:87.21ms +step:630/1680 train_time:54943ms step_avg:87.21ms +step:631/1680 train_time:55029ms step_avg:87.21ms +step:632/1680 train_time:55116ms step_avg:87.21ms +step:633/1680 train_time:55203ms step_avg:87.21ms +step:634/1680 train_time:55290ms step_avg:87.21ms +step:635/1680 train_time:55377ms step_avg:87.21ms +step:636/1680 train_time:55464ms step_avg:87.21ms +step:637/1680 train_time:55557ms step_avg:87.22ms +step:638/1680 train_time:55647ms step_avg:87.22ms +step:639/1680 train_time:55736ms step_avg:87.22ms +step:640/1680 train_time:55824ms step_avg:87.23ms +step:641/1680 train_time:55913ms step_avg:87.23ms +step:642/1680 train_time:56001ms step_avg:87.23ms +step:643/1680 train_time:56089ms step_avg:87.23ms +step:644/1680 train_time:56177ms step_avg:87.23ms +step:645/1680 train_time:56264ms step_avg:87.23ms +step:646/1680 train_time:56352ms step_avg:87.23ms +step:647/1680 train_time:56439ms step_avg:87.23ms +step:648/1680 train_time:56528ms step_avg:87.24ms +step:649/1680 train_time:56617ms step_avg:87.24ms +step:650/1680 train_time:56706ms step_avg:87.24ms +step:651/1680 train_time:56795ms step_avg:87.24ms +step:652/1680 train_time:56883ms step_avg:87.24ms +step:653/1680 train_time:56972ms step_avg:87.25ms +step:654/1680 train_time:57059ms step_avg:87.25ms +step:655/1680 train_time:57147ms step_avg:87.25ms +step:656/1680 train_time:57234ms step_avg:87.25ms +step:657/1680 train_time:57322ms step_avg:87.25ms +step:658/1680 train_time:57409ms step_avg:87.25ms +step:659/1680 train_time:57498ms step_avg:87.25ms +step:660/1680 train_time:57586ms step_avg:87.25ms +step:661/1680 train_time:57674ms step_avg:87.25ms +step:662/1680 train_time:57763ms step_avg:87.25ms +step:663/1680 train_time:57851ms step_avg:87.26ms +step:664/1680 train_time:57939ms step_avg:87.26ms +step:665/1680 train_time:58027ms step_avg:87.26ms +step:666/1680 train_time:58115ms step_avg:87.26ms +step:667/1680 train_time:58202ms step_avg:87.26ms +step:668/1680 train_time:58290ms step_avg:87.26ms +step:669/1680 train_time:58378ms step_avg:87.26ms +step:670/1680 train_time:58466ms step_avg:87.26ms +step:671/1680 train_time:58553ms step_avg:87.26ms +step:672/1680 train_time:58642ms step_avg:87.26ms +step:673/1680 train_time:58730ms step_avg:87.27ms +step:674/1680 train_time:58819ms step_avg:87.27ms +step:675/1680 train_time:58907ms step_avg:87.27ms +step:676/1680 train_time:58994ms step_avg:87.27ms +step:677/1680 train_time:59082ms step_avg:87.27ms +step:678/1680 train_time:59170ms step_avg:87.27ms +step:679/1680 train_time:59259ms step_avg:87.27ms +step:680/1680 train_time:59346ms step_avg:87.27ms +step:681/1680 train_time:59434ms step_avg:87.27ms +step:682/1680 train_time:59521ms step_avg:87.27ms +step:683/1680 train_time:59610ms step_avg:87.28ms +step:684/1680 train_time:59699ms step_avg:87.28ms +step:685/1680 train_time:59787ms step_avg:87.28ms +step:686/1680 train_time:59874ms step_avg:87.28ms +step:687/1680 train_time:59962ms step_avg:87.28ms +step:688/1680 train_time:60050ms step_avg:87.28ms +step:689/1680 train_time:60138ms step_avg:87.28ms +step:690/1680 train_time:60226ms step_avg:87.28ms +step:691/1680 train_time:60314ms step_avg:87.28ms +step:692/1680 train_time:60401ms step_avg:87.29ms +step:693/1680 train_time:60490ms step_avg:87.29ms +step:694/1680 train_time:60578ms step_avg:87.29ms +step:695/1680 train_time:60665ms step_avg:87.29ms +step:696/1680 train_time:60754ms step_avg:87.29ms +step:697/1680 train_time:60843ms step_avg:87.29ms +step:698/1680 train_time:60931ms step_avg:87.29ms +step:699/1680 train_time:61019ms step_avg:87.29ms +step:700/1680 train_time:61107ms step_avg:87.30ms +step:701/1680 train_time:61195ms step_avg:87.30ms +step:702/1680 train_time:61283ms step_avg:87.30ms +step:703/1680 train_time:61370ms step_avg:87.30ms +step:704/1680 train_time:61458ms step_avg:87.30ms +step:705/1680 train_time:61546ms step_avg:87.30ms +step:706/1680 train_time:61634ms step_avg:87.30ms +step:707/1680 train_time:61722ms step_avg:87.30ms +step:708/1680 train_time:61810ms step_avg:87.30ms +step:709/1680 train_time:61898ms step_avg:87.30ms +step:710/1680 train_time:61985ms step_avg:87.30ms +step:711/1680 train_time:62073ms step_avg:87.30ms +step:712/1680 train_time:62161ms step_avg:87.30ms +step:713/1680 train_time:62249ms step_avg:87.31ms +step:714/1680 train_time:62338ms step_avg:87.31ms +step:715/1680 train_time:62425ms step_avg:87.31ms +step:716/1680 train_time:62513ms step_avg:87.31ms +step:717/1680 train_time:62601ms step_avg:87.31ms +step:718/1680 train_time:62689ms step_avg:87.31ms +step:719/1680 train_time:62777ms step_avg:87.31ms +step:720/1680 train_time:62864ms step_avg:87.31ms +step:721/1680 train_time:62952ms step_avg:87.31ms +step:722/1680 train_time:63040ms step_avg:87.31ms +step:723/1680 train_time:63128ms step_avg:87.31ms +step:724/1680 train_time:63216ms step_avg:87.31ms +step:725/1680 train_time:63304ms step_avg:87.32ms +step:726/1680 train_time:63392ms step_avg:87.32ms +step:727/1680 train_time:63479ms step_avg:87.32ms +step:728/1680 train_time:63567ms step_avg:87.32ms +step:729/1680 train_time:63655ms step_avg:87.32ms +step:730/1680 train_time:63743ms step_avg:87.32ms +step:731/1680 train_time:63831ms step_avg:87.32ms +step:732/1680 train_time:63919ms step_avg:87.32ms +step:733/1680 train_time:64007ms step_avg:87.32ms +step:734/1680 train_time:64095ms step_avg:87.32ms +step:735/1680 train_time:64183ms step_avg:87.32ms +step:736/1680 train_time:64272ms step_avg:87.33ms +step:737/1680 train_time:64360ms step_avg:87.33ms +step:738/1680 train_time:64449ms step_avg:87.33ms +step:739/1680 train_time:64537ms step_avg:87.33ms +step:740/1680 train_time:64625ms step_avg:87.33ms +step:741/1680 train_time:64713ms step_avg:87.33ms +step:742/1680 train_time:64801ms step_avg:87.33ms +step:743/1680 train_time:64889ms step_avg:87.33ms +step:744/1680 train_time:64976ms step_avg:87.33ms +step:745/1680 train_time:65065ms step_avg:87.34ms +step:746/1680 train_time:65152ms step_avg:87.34ms +step:747/1680 train_time:65240ms step_avg:87.34ms +step:748/1680 train_time:65329ms step_avg:87.34ms +step:749/1680 train_time:65417ms step_avg:87.34ms +step:750/1680 train_time:65504ms step_avg:87.34ms +step:750/1680 val_loss:3.5643 train_time:65594ms step_avg:87.46ms +step:751/1680 train_time:65613ms step_avg:87.37ms +step:752/1680 train_time:65684ms step_avg:87.35ms +step:753/1680 train_time:65775ms step_avg:87.35ms +step:754/1680 train_time:65865ms step_avg:87.35ms +step:755/1680 train_time:65953ms step_avg:87.36ms +step:756/1680 train_time:66040ms step_avg:87.35ms +step:757/1680 train_time:66127ms step_avg:87.35ms +step:758/1680 train_time:66214ms step_avg:87.35ms +step:759/1680 train_time:66301ms step_avg:87.35ms +step:760/1680 train_time:66388ms step_avg:87.35ms +step:761/1680 train_time:66475ms step_avg:87.35ms +step:762/1680 train_time:66564ms step_avg:87.35ms +step:763/1680 train_time:66653ms step_avg:87.36ms +step:764/1680 train_time:66743ms step_avg:87.36ms +step:765/1680 train_time:66832ms step_avg:87.36ms +step:766/1680 train_time:66921ms step_avg:87.36ms +step:767/1680 train_time:67009ms step_avg:87.36ms +step:768/1680 train_time:67096ms step_avg:87.36ms +step:769/1680 train_time:67183ms step_avg:87.36ms +step:770/1680 train_time:67270ms step_avg:87.36ms +step:771/1680 train_time:67357ms step_avg:87.36ms +step:772/1680 train_time:67445ms step_avg:87.36ms +step:773/1680 train_time:67533ms step_avg:87.36ms +step:774/1680 train_time:67621ms step_avg:87.37ms +step:775/1680 train_time:67710ms step_avg:87.37ms +step:776/1680 train_time:67799ms step_avg:87.37ms +step:777/1680 train_time:67888ms step_avg:87.37ms +step:778/1680 train_time:67976ms step_avg:87.37ms +step:779/1680 train_time:68065ms step_avg:87.37ms +step:780/1680 train_time:68153ms step_avg:87.38ms +step:781/1680 train_time:68240ms step_avg:87.38ms +step:782/1680 train_time:68327ms step_avg:87.38ms +step:783/1680 train_time:68415ms step_avg:87.38ms +step:784/1680 train_time:68503ms step_avg:87.38ms +step:785/1680 train_time:68591ms step_avg:87.38ms +step:786/1680 train_time:68680ms step_avg:87.38ms +step:787/1680 train_time:68769ms step_avg:87.38ms +step:788/1680 train_time:68857ms step_avg:87.38ms +step:789/1680 train_time:68945ms step_avg:87.38ms +step:790/1680 train_time:69034ms step_avg:87.38ms +step:791/1680 train_time:69121ms step_avg:87.38ms +step:792/1680 train_time:69208ms step_avg:87.38ms +step:793/1680 train_time:69296ms step_avg:87.38ms +step:794/1680 train_time:69383ms step_avg:87.38ms +step:795/1680 train_time:69471ms step_avg:87.39ms +step:796/1680 train_time:69559ms step_avg:87.39ms +step:797/1680 train_time:69648ms step_avg:87.39ms +step:798/1680 train_time:69736ms step_avg:87.39ms +step:799/1680 train_time:69824ms step_avg:87.39ms +step:800/1680 train_time:69912ms step_avg:87.39ms +step:801/1680 train_time:70000ms step_avg:87.39ms +step:802/1680 train_time:70087ms step_avg:87.39ms +step:803/1680 train_time:70175ms step_avg:87.39ms +step:804/1680 train_time:70263ms step_avg:87.39ms +step:805/1680 train_time:70351ms step_avg:87.39ms +step:806/1680 train_time:70438ms step_avg:87.39ms +step:807/1680 train_time:70526ms step_avg:87.39ms +step:808/1680 train_time:70614ms step_avg:87.39ms +step:809/1680 train_time:70702ms step_avg:87.39ms +step:810/1680 train_time:70791ms step_avg:87.40ms +step:811/1680 train_time:70879ms step_avg:87.40ms +step:812/1680 train_time:70967ms step_avg:87.40ms +step:813/1680 train_time:71056ms step_avg:87.40ms +step:814/1680 train_time:71144ms step_avg:87.40ms +step:815/1680 train_time:71231ms step_avg:87.40ms +step:816/1680 train_time:71319ms step_avg:87.40ms +step:817/1680 train_time:71406ms step_avg:87.40ms +step:818/1680 train_time:71494ms step_avg:87.40ms +step:819/1680 train_time:71582ms step_avg:87.40ms +step:820/1680 train_time:71671ms step_avg:87.40ms +step:821/1680 train_time:71759ms step_avg:87.40ms +step:822/1680 train_time:71848ms step_avg:87.41ms +step:823/1680 train_time:71936ms step_avg:87.41ms +step:824/1680 train_time:72024ms step_avg:87.41ms +step:825/1680 train_time:72112ms step_avg:87.41ms +step:826/1680 train_time:72200ms step_avg:87.41ms +step:827/1680 train_time:72288ms step_avg:87.41ms +step:828/1680 train_time:72376ms step_avg:87.41ms +step:829/1680 train_time:72464ms step_avg:87.41ms +step:830/1680 train_time:72552ms step_avg:87.41ms +step:831/1680 train_time:72639ms step_avg:87.41ms +step:832/1680 train_time:72727ms step_avg:87.41ms +step:833/1680 train_time:72815ms step_avg:87.41ms +step:834/1680 train_time:72903ms step_avg:87.41ms +step:835/1680 train_time:72991ms step_avg:87.41ms +step:836/1680 train_time:73079ms step_avg:87.42ms +step:837/1680 train_time:73167ms step_avg:87.42ms +step:838/1680 train_time:73255ms step_avg:87.42ms +step:839/1680 train_time:73343ms step_avg:87.42ms +step:840/1680 train_time:73432ms step_avg:87.42ms +step:841/1680 train_time:73520ms step_avg:87.42ms +step:842/1680 train_time:73608ms step_avg:87.42ms +step:843/1680 train_time:73696ms step_avg:87.42ms +step:844/1680 train_time:73784ms step_avg:87.42ms +step:845/1680 train_time:73873ms step_avg:87.42ms +step:846/1680 train_time:73960ms step_avg:87.42ms +step:847/1680 train_time:74049ms step_avg:87.43ms +step:848/1680 train_time:74137ms step_avg:87.43ms +step:849/1680 train_time:74225ms step_avg:87.43ms +step:850/1680 train_time:74312ms step_avg:87.43ms +step:851/1680 train_time:74401ms step_avg:87.43ms +step:852/1680 train_time:74489ms step_avg:87.43ms +step:853/1680 train_time:74577ms step_avg:87.43ms +step:854/1680 train_time:74665ms step_avg:87.43ms +step:855/1680 train_time:74753ms step_avg:87.43ms +step:856/1680 train_time:74841ms step_avg:87.43ms +step:857/1680 train_time:74928ms step_avg:87.43ms +step:858/1680 train_time:75017ms step_avg:87.43ms +step:859/1680 train_time:75105ms step_avg:87.43ms +step:860/1680 train_time:75194ms step_avg:87.43ms +step:861/1680 train_time:75282ms step_avg:87.44ms +step:862/1680 train_time:75371ms step_avg:87.44ms +step:863/1680 train_time:75459ms step_avg:87.44ms +step:864/1680 train_time:75546ms step_avg:87.44ms +step:865/1680 train_time:75634ms step_avg:87.44ms +step:866/1680 train_time:75722ms step_avg:87.44ms +step:867/1680 train_time:75810ms step_avg:87.44ms +step:868/1680 train_time:75898ms step_avg:87.44ms +step:869/1680 train_time:75987ms step_avg:87.44ms +step:870/1680 train_time:76075ms step_avg:87.44ms +step:871/1680 train_time:76163ms step_avg:87.44ms +step:872/1680 train_time:76252ms step_avg:87.44ms +step:873/1680 train_time:76339ms step_avg:87.44ms +step:874/1680 train_time:76427ms step_avg:87.45ms +step:875/1680 train_time:76515ms step_avg:87.45ms +step:875/1680 val_loss:3.5204 train_time:76605ms step_avg:87.55ms +step:876/1680 train_time:76623ms step_avg:87.47ms +step:877/1680 train_time:76696ms step_avg:87.45ms +step:878/1680 train_time:76789ms step_avg:87.46ms +step:879/1680 train_time:76878ms step_avg:87.46ms +step:880/1680 train_time:76965ms step_avg:87.46ms +step:881/1680 train_time:77052ms step_avg:87.46ms +step:882/1680 train_time:77139ms step_avg:87.46ms +step:883/1680 train_time:77226ms step_avg:87.46ms +step:884/1680 train_time:77313ms step_avg:87.46ms +step:885/1680 train_time:77401ms step_avg:87.46ms +step:886/1680 train_time:77488ms step_avg:87.46ms +step:887/1680 train_time:77576ms step_avg:87.46ms +step:888/1680 train_time:77667ms step_avg:87.46ms +step:889/1680 train_time:77758ms step_avg:87.47ms +step:890/1680 train_time:77847ms step_avg:87.47ms +step:891/1680 train_time:77935ms step_avg:87.47ms +step:892/1680 train_time:78023ms step_avg:87.47ms +step:893/1680 train_time:78111ms step_avg:87.47ms +step:894/1680 train_time:78198ms step_avg:87.47ms +step:895/1680 train_time:78286ms step_avg:87.47ms +step:896/1680 train_time:78373ms step_avg:87.47ms +step:897/1680 train_time:78460ms step_avg:87.47ms +step:898/1680 train_time:78548ms step_avg:87.47ms +step:899/1680 train_time:78637ms step_avg:87.47ms +step:900/1680 train_time:78727ms step_avg:87.47ms +step:901/1680 train_time:78816ms step_avg:87.48ms +step:902/1680 train_time:78906ms step_avg:87.48ms +step:903/1680 train_time:78995ms step_avg:87.48ms +step:904/1680 train_time:79083ms step_avg:87.48ms +step:905/1680 train_time:79170ms step_avg:87.48ms +step:906/1680 train_time:79258ms step_avg:87.48ms +step:907/1680 train_time:79346ms step_avg:87.48ms +step:908/1680 train_time:79433ms step_avg:87.48ms +step:909/1680 train_time:79521ms step_avg:87.48ms +step:910/1680 train_time:79609ms step_avg:87.48ms +step:911/1680 train_time:79699ms step_avg:87.48ms +step:912/1680 train_time:79788ms step_avg:87.49ms +step:913/1680 train_time:79876ms step_avg:87.49ms +step:914/1680 train_time:79964ms step_avg:87.49ms +step:915/1680 train_time:80052ms step_avg:87.49ms +step:916/1680 train_time:80140ms step_avg:87.49ms +step:917/1680 train_time:80228ms step_avg:87.49ms +step:918/1680 train_time:80315ms step_avg:87.49ms +step:919/1680 train_time:80402ms step_avg:87.49ms +step:920/1680 train_time:80490ms step_avg:87.49ms +step:921/1680 train_time:80578ms step_avg:87.49ms +step:922/1680 train_time:80667ms step_avg:87.49ms +step:923/1680 train_time:80755ms step_avg:87.49ms +step:924/1680 train_time:80844ms step_avg:87.49ms +step:925/1680 train_time:80932ms step_avg:87.49ms +step:926/1680 train_time:81020ms step_avg:87.49ms +step:927/1680 train_time:81108ms step_avg:87.50ms +step:928/1680 train_time:81197ms step_avg:87.50ms +step:929/1680 train_time:81285ms step_avg:87.50ms +step:930/1680 train_time:81372ms step_avg:87.50ms +step:931/1680 train_time:81460ms step_avg:87.50ms +step:932/1680 train_time:81548ms step_avg:87.50ms +step:933/1680 train_time:81636ms step_avg:87.50ms +step:934/1680 train_time:81725ms step_avg:87.50ms +step:935/1680 train_time:81813ms step_avg:87.50ms +step:936/1680 train_time:81902ms step_avg:87.50ms +step:937/1680 train_time:81990ms step_avg:87.50ms +step:938/1680 train_time:82078ms step_avg:87.50ms +step:939/1680 train_time:82166ms step_avg:87.50ms +step:940/1680 train_time:82255ms step_avg:87.51ms +step:941/1680 train_time:82343ms step_avg:87.51ms +step:942/1680 train_time:82430ms step_avg:87.51ms +step:943/1680 train_time:82519ms step_avg:87.51ms +step:944/1680 train_time:82607ms step_avg:87.51ms +step:945/1680 train_time:82695ms step_avg:87.51ms +step:946/1680 train_time:82783ms step_avg:87.51ms +step:947/1680 train_time:82872ms step_avg:87.51ms +step:948/1680 train_time:82960ms step_avg:87.51ms +step:949/1680 train_time:83048ms step_avg:87.51ms +step:950/1680 train_time:83135ms step_avg:87.51ms +step:951/1680 train_time:83223ms step_avg:87.51ms +step:952/1680 train_time:83311ms step_avg:87.51ms +step:953/1680 train_time:83400ms step_avg:87.51ms +step:954/1680 train_time:83488ms step_avg:87.51ms +step:955/1680 train_time:83575ms step_avg:87.51ms +step:956/1680 train_time:83663ms step_avg:87.51ms +step:957/1680 train_time:83751ms step_avg:87.51ms +step:958/1680 train_time:83839ms step_avg:87.51ms +step:959/1680 train_time:83927ms step_avg:87.52ms +step:960/1680 train_time:84016ms step_avg:87.52ms +step:961/1680 train_time:84104ms step_avg:87.52ms +step:962/1680 train_time:84192ms step_avg:87.52ms +step:963/1680 train_time:84280ms step_avg:87.52ms +step:964/1680 train_time:84368ms step_avg:87.52ms +step:965/1680 train_time:84456ms step_avg:87.52ms +step:966/1680 train_time:84544ms step_avg:87.52ms +step:967/1680 train_time:84632ms step_avg:87.52ms +step:968/1680 train_time:84720ms step_avg:87.52ms +step:969/1680 train_time:84808ms step_avg:87.52ms +step:970/1680 train_time:84896ms step_avg:87.52ms +step:971/1680 train_time:84984ms step_avg:87.52ms +step:972/1680 train_time:85072ms step_avg:87.52ms +step:973/1680 train_time:85160ms step_avg:87.52ms +step:974/1680 train_time:85248ms step_avg:87.52ms +step:975/1680 train_time:85336ms step_avg:87.52ms +step:976/1680 train_time:85424ms step_avg:87.52ms +step:977/1680 train_time:85512ms step_avg:87.53ms +step:978/1680 train_time:85601ms step_avg:87.53ms +step:979/1680 train_time:85689ms step_avg:87.53ms +step:980/1680 train_time:85777ms step_avg:87.53ms +step:981/1680 train_time:85865ms step_avg:87.53ms +step:982/1680 train_time:85953ms step_avg:87.53ms +step:983/1680 train_time:86042ms step_avg:87.53ms +step:984/1680 train_time:86129ms step_avg:87.53ms +step:985/1680 train_time:86218ms step_avg:87.53ms +step:986/1680 train_time:86306ms step_avg:87.53ms +step:987/1680 train_time:86394ms step_avg:87.53ms +step:988/1680 train_time:86482ms step_avg:87.53ms +step:989/1680 train_time:86570ms step_avg:87.53ms +step:990/1680 train_time:86657ms step_avg:87.53ms +step:991/1680 train_time:86747ms step_avg:87.53ms +step:992/1680 train_time:86835ms step_avg:87.54ms +step:993/1680 train_time:86923ms step_avg:87.54ms +step:994/1680 train_time:87011ms step_avg:87.54ms +step:995/1680 train_time:87100ms step_avg:87.54ms +step:996/1680 train_time:87187ms step_avg:87.54ms +step:997/1680 train_time:87275ms step_avg:87.54ms +step:998/1680 train_time:87363ms step_avg:87.54ms +step:999/1680 train_time:87451ms step_avg:87.54ms +step:1000/1680 train_time:87539ms step_avg:87.54ms +step:1000/1680 val_loss:3.4694 train_time:87628ms step_avg:87.63ms +step:1001/1680 train_time:87647ms step_avg:87.56ms +step:1002/1680 train_time:87720ms step_avg:87.54ms +step:1003/1680 train_time:87810ms step_avg:87.55ms +step:1004/1680 train_time:87899ms step_avg:87.55ms +step:1005/1680 train_time:87987ms step_avg:87.55ms +step:1006/1680 train_time:88074ms step_avg:87.55ms +step:1007/1680 train_time:88161ms step_avg:87.55ms +step:1008/1680 train_time:88248ms step_avg:87.55ms +step:1009/1680 train_time:88335ms step_avg:87.55ms +step:1010/1680 train_time:88422ms step_avg:87.55ms +step:1011/1680 train_time:88510ms step_avg:87.55ms +step:1012/1680 train_time:88598ms step_avg:87.55ms +step:1013/1680 train_time:88689ms step_avg:87.55ms +step:1014/1680 train_time:88779ms step_avg:87.55ms +step:1015/1680 train_time:88870ms step_avg:87.56ms +step:1016/1680 train_time:88958ms step_avg:87.56ms +step:1017/1680 train_time:89047ms step_avg:87.56ms +step:1018/1680 train_time:89134ms step_avg:87.56ms +step:1019/1680 train_time:89221ms step_avg:87.56ms +step:1020/1680 train_time:89308ms step_avg:87.56ms +step:1021/1680 train_time:89395ms step_avg:87.56ms +step:1022/1680 train_time:89482ms step_avg:87.56ms +step:1023/1680 train_time:89570ms step_avg:87.56ms +step:1024/1680 train_time:89659ms step_avg:87.56ms +step:1025/1680 train_time:89748ms step_avg:87.56ms +step:1026/1680 train_time:89837ms step_avg:87.56ms +step:1027/1680 train_time:89926ms step_avg:87.56ms +step:1028/1680 train_time:90014ms step_avg:87.56ms +step:1029/1680 train_time:90102ms step_avg:87.56ms +step:1030/1680 train_time:90189ms step_avg:87.56ms +step:1031/1680 train_time:90277ms step_avg:87.56ms +step:1032/1680 train_time:90364ms step_avg:87.56ms +step:1033/1680 train_time:90452ms step_avg:87.56ms +step:1034/1680 train_time:90540ms step_avg:87.56ms +step:1035/1680 train_time:90629ms step_avg:87.56ms +step:1036/1680 train_time:90717ms step_avg:87.56ms +step:1037/1680 train_time:90806ms step_avg:87.57ms +step:1038/1680 train_time:90895ms step_avg:87.57ms +step:1039/1680 train_time:90984ms step_avg:87.57ms +step:1040/1680 train_time:91072ms step_avg:87.57ms +step:1041/1680 train_time:91160ms step_avg:87.57ms +step:1042/1680 train_time:91247ms step_avg:87.57ms +step:1043/1680 train_time:91334ms step_avg:87.57ms +step:1044/1680 train_time:91422ms step_avg:87.57ms +step:1045/1680 train_time:91510ms step_avg:87.57ms +step:1046/1680 train_time:91598ms step_avg:87.57ms +step:1047/1680 train_time:91687ms step_avg:87.57ms +step:1048/1680 train_time:91775ms step_avg:87.57ms +step:1049/1680 train_time:91864ms step_avg:87.57ms +step:1050/1680 train_time:91953ms step_avg:87.57ms +step:1051/1680 train_time:92041ms step_avg:87.57ms +step:1052/1680 train_time:92130ms step_avg:87.58ms +step:1053/1680 train_time:92217ms step_avg:87.58ms +step:1054/1680 train_time:92305ms step_avg:87.58ms +step:1055/1680 train_time:92393ms step_avg:87.58ms +step:1056/1680 train_time:92480ms step_avg:87.58ms +step:1057/1680 train_time:92568ms step_avg:87.58ms +step:1058/1680 train_time:92656ms step_avg:87.58ms +step:1059/1680 train_time:92744ms step_avg:87.58ms +step:1060/1680 train_time:92832ms step_avg:87.58ms +step:1061/1680 train_time:92920ms step_avg:87.58ms +step:1062/1680 train_time:93008ms step_avg:87.58ms +step:1063/1680 train_time:93097ms step_avg:87.58ms +step:1064/1680 train_time:93185ms step_avg:87.58ms +step:1065/1680 train_time:93273ms step_avg:87.58ms +step:1066/1680 train_time:93361ms step_avg:87.58ms +step:1067/1680 train_time:93449ms step_avg:87.58ms +step:1068/1680 train_time:93536ms step_avg:87.58ms +step:1069/1680 train_time:93625ms step_avg:87.58ms +step:1070/1680 train_time:93712ms step_avg:87.58ms +step:1071/1680 train_time:93800ms step_avg:87.58ms +step:1072/1680 train_time:93889ms step_avg:87.58ms +step:1073/1680 train_time:93978ms step_avg:87.58ms +step:1074/1680 train_time:94067ms step_avg:87.59ms +step:1075/1680 train_time:94155ms step_avg:87.59ms +step:1076/1680 train_time:94243ms step_avg:87.59ms +step:1077/1680 train_time:94331ms step_avg:87.59ms +step:1078/1680 train_time:94418ms step_avg:87.59ms +step:1079/1680 train_time:94506ms step_avg:87.59ms +step:1080/1680 train_time:94594ms step_avg:87.59ms +step:1081/1680 train_time:94682ms step_avg:87.59ms +step:1082/1680 train_time:94770ms step_avg:87.59ms +step:1083/1680 train_time:94859ms step_avg:87.59ms +step:1084/1680 train_time:94947ms step_avg:87.59ms +step:1085/1680 train_time:95037ms step_avg:87.59ms +step:1086/1680 train_time:95125ms step_avg:87.59ms +step:1087/1680 train_time:95213ms step_avg:87.59ms +step:1088/1680 train_time:95300ms step_avg:87.59ms +step:1089/1680 train_time:95388ms step_avg:87.59ms +step:1090/1680 train_time:95476ms step_avg:87.59ms +step:1091/1680 train_time:95564ms step_avg:87.59ms +step:1092/1680 train_time:95653ms step_avg:87.59ms +step:1093/1680 train_time:95740ms step_avg:87.59ms +step:1094/1680 train_time:95828ms step_avg:87.59ms +step:1095/1680 train_time:95916ms step_avg:87.59ms +step:1096/1680 train_time:96005ms step_avg:87.60ms +step:1097/1680 train_time:96094ms step_avg:87.60ms +step:1098/1680 train_time:96183ms step_avg:87.60ms +step:1099/1680 train_time:96272ms step_avg:87.60ms +step:1100/1680 train_time:96361ms step_avg:87.60ms +step:1101/1680 train_time:96449ms step_avg:87.60ms +step:1102/1680 train_time:96537ms step_avg:87.60ms +step:1103/1680 train_time:96625ms step_avg:87.60ms +step:1104/1680 train_time:96714ms step_avg:87.60ms +step:1105/1680 train_time:96803ms step_avg:87.60ms +step:1106/1680 train_time:96893ms step_avg:87.61ms +step:1107/1680 train_time:96982ms step_avg:87.61ms +step:1108/1680 train_time:97072ms step_avg:87.61ms +step:1109/1680 train_time:97160ms step_avg:87.61ms +step:1110/1680 train_time:97250ms step_avg:87.61ms +step:1111/1680 train_time:97340ms step_avg:87.61ms +step:1112/1680 train_time:97428ms step_avg:87.62ms +step:1113/1680 train_time:97517ms step_avg:87.62ms +step:1114/1680 train_time:97606ms step_avg:87.62ms +step:1115/1680 train_time:97694ms step_avg:87.62ms +step:1116/1680 train_time:97783ms step_avg:87.62ms +step:1117/1680 train_time:97873ms step_avg:87.62ms +step:1118/1680 train_time:97962ms step_avg:87.62ms +step:1119/1680 train_time:98051ms step_avg:87.62ms +step:1120/1680 train_time:98139ms step_avg:87.62ms +step:1121/1680 train_time:98228ms step_avg:87.63ms +step:1122/1680 train_time:98317ms step_avg:87.63ms +step:1123/1680 train_time:98406ms step_avg:87.63ms +step:1124/1680 train_time:98495ms step_avg:87.63ms +step:1125/1680 train_time:98585ms step_avg:87.63ms +step:1125/1680 val_loss:3.4151 train_time:98675ms step_avg:87.71ms +step:1126/1680 train_time:98695ms step_avg:87.65ms +step:1127/1680 train_time:98765ms step_avg:87.64ms +step:1128/1680 train_time:98857ms step_avg:87.64ms +step:1129/1680 train_time:98950ms step_avg:87.64ms +step:1130/1680 train_time:99040ms step_avg:87.65ms +step:1131/1680 train_time:99128ms step_avg:87.65ms +step:1132/1680 train_time:99216ms step_avg:87.65ms +step:1133/1680 train_time:99304ms step_avg:87.65ms +step:1134/1680 train_time:99391ms step_avg:87.65ms +step:1135/1680 train_time:99479ms step_avg:87.65ms +step:1136/1680 train_time:99567ms step_avg:87.65ms +step:1137/1680 train_time:99657ms step_avg:87.65ms +step:1138/1680 train_time:99748ms step_avg:87.65ms +step:1139/1680 train_time:99839ms step_avg:87.65ms +step:1140/1680 train_time:99930ms step_avg:87.66ms +step:1141/1680 train_time:100018ms step_avg:87.66ms +step:1142/1680 train_time:100107ms step_avg:87.66ms +step:1143/1680 train_time:100195ms step_avg:87.66ms +step:1144/1680 train_time:100284ms step_avg:87.66ms +step:1145/1680 train_time:100372ms step_avg:87.66ms +step:1146/1680 train_time:100460ms step_avg:87.66ms +step:1147/1680 train_time:100548ms step_avg:87.66ms +step:1148/1680 train_time:100637ms step_avg:87.66ms +step:1149/1680 train_time:100727ms step_avg:87.66ms +step:1150/1680 train_time:100817ms step_avg:87.67ms +step:1151/1680 train_time:100908ms step_avg:87.67ms +step:1152/1680 train_time:100997ms step_avg:87.67ms +step:1153/1680 train_time:101086ms step_avg:87.67ms +step:1154/1680 train_time:101174ms step_avg:87.67ms +step:1155/1680 train_time:101264ms step_avg:87.67ms +step:1156/1680 train_time:101352ms step_avg:87.68ms +step:1157/1680 train_time:101441ms step_avg:87.68ms +step:1158/1680 train_time:101529ms step_avg:87.68ms +step:1159/1680 train_time:101618ms step_avg:87.68ms +step:1160/1680 train_time:101707ms step_avg:87.68ms +step:1161/1680 train_time:101797ms step_avg:87.68ms +step:1162/1680 train_time:101887ms step_avg:87.68ms +step:1163/1680 train_time:101976ms step_avg:87.68ms +step:1164/1680 train_time:102066ms step_avg:87.69ms +step:1165/1680 train_time:102155ms step_avg:87.69ms +step:1166/1680 train_time:102244ms step_avg:87.69ms +step:1167/1680 train_time:102333ms step_avg:87.69ms +step:1168/1680 train_time:102422ms step_avg:87.69ms +step:1169/1680 train_time:102511ms step_avg:87.69ms +step:1170/1680 train_time:102599ms step_avg:87.69ms +step:1171/1680 train_time:102688ms step_avg:87.69ms +step:1172/1680 train_time:102777ms step_avg:87.69ms +step:1173/1680 train_time:102866ms step_avg:87.69ms +step:1174/1680 train_time:102955ms step_avg:87.70ms +step:1175/1680 train_time:103045ms step_avg:87.70ms +step:1176/1680 train_time:103133ms step_avg:87.70ms +step:1177/1680 train_time:103222ms step_avg:87.70ms +step:1178/1680 train_time:103311ms step_avg:87.70ms +step:1179/1680 train_time:103399ms step_avg:87.70ms +step:1180/1680 train_time:103488ms step_avg:87.70ms +step:1181/1680 train_time:103577ms step_avg:87.70ms +step:1182/1680 train_time:103666ms step_avg:87.70ms +step:1183/1680 train_time:103755ms step_avg:87.71ms +step:1184/1680 train_time:103844ms step_avg:87.71ms +step:1185/1680 train_time:103933ms step_avg:87.71ms +step:1186/1680 train_time:104022ms step_avg:87.71ms +step:1187/1680 train_time:104112ms step_avg:87.71ms +step:1188/1680 train_time:104200ms step_avg:87.71ms +step:1189/1680 train_time:104290ms step_avg:87.71ms +step:1190/1680 train_time:104379ms step_avg:87.71ms +step:1191/1680 train_time:104467ms step_avg:87.71ms +step:1192/1680 train_time:104556ms step_avg:87.71ms +step:1193/1680 train_time:104645ms step_avg:87.72ms +step:1194/1680 train_time:104734ms step_avg:87.72ms +step:1195/1680 train_time:104823ms step_avg:87.72ms +step:1196/1680 train_time:104912ms step_avg:87.72ms +step:1197/1680 train_time:105001ms step_avg:87.72ms +step:1198/1680 train_time:105090ms step_avg:87.72ms +step:1199/1680 train_time:105178ms step_avg:87.72ms +step:1200/1680 train_time:105267ms step_avg:87.72ms +step:1201/1680 train_time:105356ms step_avg:87.72ms +step:1202/1680 train_time:105445ms step_avg:87.72ms +step:1203/1680 train_time:105534ms step_avg:87.73ms +step:1204/1680 train_time:105623ms step_avg:87.73ms +step:1205/1680 train_time:105713ms step_avg:87.73ms +step:1206/1680 train_time:105801ms step_avg:87.73ms +step:1207/1680 train_time:105891ms step_avg:87.73ms +step:1208/1680 train_time:105980ms step_avg:87.73ms +step:1209/1680 train_time:106068ms step_avg:87.73ms +step:1210/1680 train_time:106156ms step_avg:87.73ms +step:1211/1680 train_time:106245ms step_avg:87.73ms +step:1212/1680 train_time:106335ms step_avg:87.73ms +step:1213/1680 train_time:106424ms step_avg:87.74ms +step:1214/1680 train_time:106512ms step_avg:87.74ms +step:1215/1680 train_time:106601ms step_avg:87.74ms +step:1216/1680 train_time:106690ms step_avg:87.74ms +step:1217/1680 train_time:106779ms step_avg:87.74ms +step:1218/1680 train_time:106868ms step_avg:87.74ms +step:1219/1680 train_time:106957ms step_avg:87.74ms +step:1220/1680 train_time:107047ms step_avg:87.74ms +step:1221/1680 train_time:107136ms step_avg:87.74ms +step:1222/1680 train_time:107225ms step_avg:87.75ms +step:1223/1680 train_time:107315ms step_avg:87.75ms +step:1224/1680 train_time:107404ms step_avg:87.75ms +step:1225/1680 train_time:107492ms step_avg:87.75ms +step:1226/1680 train_time:107581ms step_avg:87.75ms +step:1227/1680 train_time:107670ms step_avg:87.75ms +step:1228/1680 train_time:107759ms step_avg:87.75ms +step:1229/1680 train_time:107848ms step_avg:87.75ms +step:1230/1680 train_time:107937ms step_avg:87.75ms +step:1231/1680 train_time:108026ms step_avg:87.75ms +step:1232/1680 train_time:108115ms step_avg:87.76ms +step:1233/1680 train_time:108204ms step_avg:87.76ms +step:1234/1680 train_time:108292ms step_avg:87.76ms +step:1235/1680 train_time:108381ms step_avg:87.76ms +step:1236/1680 train_time:108469ms step_avg:87.76ms +step:1237/1680 train_time:108558ms step_avg:87.76ms +step:1238/1680 train_time:108647ms step_avg:87.76ms +step:1239/1680 train_time:108736ms step_avg:87.76ms +step:1240/1680 train_time:108825ms step_avg:87.76ms +step:1241/1680 train_time:108915ms step_avg:87.76ms +step:1242/1680 train_time:109004ms step_avg:87.77ms +step:1243/1680 train_time:109094ms step_avg:87.77ms +step:1244/1680 train_time:109184ms step_avg:87.77ms +step:1245/1680 train_time:109273ms step_avg:87.77ms +step:1246/1680 train_time:109362ms step_avg:87.77ms +step:1247/1680 train_time:109451ms step_avg:87.77ms +step:1248/1680 train_time:109539ms step_avg:87.77ms +step:1249/1680 train_time:109628ms step_avg:87.77ms +step:1250/1680 train_time:109717ms step_avg:87.77ms +step:1250/1680 val_loss:3.3778 train_time:109807ms step_avg:87.85ms +step:1251/1680 train_time:109826ms step_avg:87.79ms +step:1252/1680 train_time:109899ms step_avg:87.78ms +step:1253/1680 train_time:109992ms step_avg:87.78ms +step:1254/1680 train_time:110082ms step_avg:87.78ms +step:1255/1680 train_time:110170ms step_avg:87.79ms +step:1256/1680 train_time:110259ms step_avg:87.79ms +step:1257/1680 train_time:110347ms step_avg:87.79ms +step:1258/1680 train_time:110435ms step_avg:87.79ms +step:1259/1680 train_time:110523ms step_avg:87.79ms +step:1260/1680 train_time:110610ms step_avg:87.79ms +step:1261/1680 train_time:110697ms step_avg:87.79ms +step:1262/1680 train_time:110788ms step_avg:87.79ms +step:1263/1680 train_time:110879ms step_avg:87.79ms +step:1264/1680 train_time:110969ms step_avg:87.79ms +step:1265/1680 train_time:111058ms step_avg:87.79ms +step:1266/1680 train_time:111147ms step_avg:87.79ms +step:1267/1680 train_time:111235ms step_avg:87.79ms +step:1268/1680 train_time:111324ms step_avg:87.79ms +step:1269/1680 train_time:111412ms step_avg:87.79ms +step:1270/1680 train_time:111500ms step_avg:87.80ms +step:1271/1680 train_time:111588ms step_avg:87.80ms +step:1272/1680 train_time:111676ms step_avg:87.80ms +step:1273/1680 train_time:111766ms step_avg:87.80ms +step:1274/1680 train_time:111857ms step_avg:87.80ms +step:1275/1680 train_time:111947ms step_avg:87.80ms +step:1276/1680 train_time:112037ms step_avg:87.80ms +step:1277/1680 train_time:112126ms step_avg:87.80ms +step:1278/1680 train_time:112215ms step_avg:87.81ms +step:1279/1680 train_time:112303ms step_avg:87.81ms +step:1280/1680 train_time:112392ms step_avg:87.81ms +step:1281/1680 train_time:112480ms step_avg:87.81ms +step:1282/1680 train_time:112568ms step_avg:87.81ms +step:1283/1680 train_time:112657ms step_avg:87.81ms +step:1284/1680 train_time:112746ms step_avg:87.81ms +step:1285/1680 train_time:112836ms step_avg:87.81ms +step:1286/1680 train_time:112925ms step_avg:87.81ms +step:1287/1680 train_time:113015ms step_avg:87.81ms +step:1288/1680 train_time:113105ms step_avg:87.81ms +step:1289/1680 train_time:113193ms step_avg:87.81ms +step:1290/1680 train_time:113282ms step_avg:87.82ms +step:1291/1680 train_time:113372ms step_avg:87.82ms +step:1292/1680 train_time:113460ms step_avg:87.82ms +step:1293/1680 train_time:113549ms step_avg:87.82ms +step:1294/1680 train_time:113638ms step_avg:87.82ms +step:1295/1680 train_time:113727ms step_avg:87.82ms +step:1296/1680 train_time:113816ms step_avg:87.82ms +step:1297/1680 train_time:113906ms step_avg:87.82ms +step:1298/1680 train_time:113996ms step_avg:87.82ms +step:1299/1680 train_time:114085ms step_avg:87.83ms +step:1300/1680 train_time:114174ms step_avg:87.83ms +step:1301/1680 train_time:114263ms step_avg:87.83ms +step:1302/1680 train_time:114352ms step_avg:87.83ms +step:1303/1680 train_time:114441ms step_avg:87.83ms +step:1304/1680 train_time:114530ms step_avg:87.83ms +step:1305/1680 train_time:114618ms step_avg:87.83ms +step:1306/1680 train_time:114708ms step_avg:87.83ms +step:1307/1680 train_time:114796ms step_avg:87.83ms +step:1308/1680 train_time:114885ms step_avg:87.83ms +step:1309/1680 train_time:114976ms step_avg:87.83ms +step:1310/1680 train_time:115065ms step_avg:87.84ms +step:1311/1680 train_time:115154ms step_avg:87.84ms +step:1312/1680 train_time:115244ms step_avg:87.84ms +step:1313/1680 train_time:115333ms step_avg:87.84ms +step:1314/1680 train_time:115421ms step_avg:87.84ms +step:1315/1680 train_time:115510ms step_avg:87.84ms +step:1316/1680 train_time:115598ms step_avg:87.84ms +step:1317/1680 train_time:115687ms step_avg:87.84ms +step:1318/1680 train_time:115776ms step_avg:87.84ms +step:1319/1680 train_time:115866ms step_avg:87.84ms +step:1320/1680 train_time:115956ms step_avg:87.85ms +step:1321/1680 train_time:116045ms step_avg:87.85ms +step:1322/1680 train_time:116134ms step_avg:87.85ms +step:1323/1680 train_time:116224ms step_avg:87.85ms +step:1324/1680 train_time:116312ms step_avg:87.85ms +step:1325/1680 train_time:116401ms step_avg:87.85ms +step:1326/1680 train_time:116490ms step_avg:87.85ms +step:1327/1680 train_time:116578ms step_avg:87.85ms +step:1328/1680 train_time:116667ms step_avg:87.85ms +step:1329/1680 train_time:116757ms step_avg:87.85ms +step:1330/1680 train_time:116846ms step_avg:87.85ms +step:1331/1680 train_time:116935ms step_avg:87.85ms +step:1332/1680 train_time:117024ms step_avg:87.86ms +step:1333/1680 train_time:117112ms step_avg:87.86ms +step:1334/1680 train_time:117201ms step_avg:87.86ms +step:1335/1680 train_time:117290ms step_avg:87.86ms +step:1336/1680 train_time:117378ms step_avg:87.86ms +step:1337/1680 train_time:117467ms step_avg:87.86ms +step:1338/1680 train_time:117556ms step_avg:87.86ms +step:1339/1680 train_time:117645ms step_avg:87.86ms +step:1340/1680 train_time:117734ms step_avg:87.86ms +step:1341/1680 train_time:117822ms step_avg:87.86ms +step:1342/1680 train_time:117912ms step_avg:87.86ms +step:1343/1680 train_time:118000ms step_avg:87.86ms +step:1344/1680 train_time:118090ms step_avg:87.86ms +step:1345/1680 train_time:118179ms step_avg:87.87ms +step:1346/1680 train_time:118268ms step_avg:87.87ms +step:1347/1680 train_time:118357ms step_avg:87.87ms +step:1348/1680 train_time:118446ms step_avg:87.87ms +step:1349/1680 train_time:118536ms step_avg:87.87ms +step:1350/1680 train_time:118625ms step_avg:87.87ms +step:1351/1680 train_time:118713ms step_avg:87.87ms +step:1352/1680 train_time:118802ms step_avg:87.87ms +step:1353/1680 train_time:118891ms step_avg:87.87ms +step:1354/1680 train_time:118981ms step_avg:87.87ms +step:1355/1680 train_time:119070ms step_avg:87.87ms +step:1356/1680 train_time:119160ms step_avg:87.88ms +step:1357/1680 train_time:119249ms step_avg:87.88ms +step:1358/1680 train_time:119337ms step_avg:87.88ms +step:1359/1680 train_time:119426ms step_avg:87.88ms +step:1360/1680 train_time:119515ms step_avg:87.88ms +step:1361/1680 train_time:119604ms step_avg:87.88ms +step:1362/1680 train_time:119693ms step_avg:87.88ms +step:1363/1680 train_time:119781ms step_avg:87.88ms +step:1364/1680 train_time:119870ms step_avg:87.88ms +step:1365/1680 train_time:119959ms step_avg:87.88ms +step:1366/1680 train_time:120049ms step_avg:87.88ms +step:1367/1680 train_time:120138ms step_avg:87.88ms +step:1368/1680 train_time:120228ms step_avg:87.89ms +step:1369/1680 train_time:120316ms step_avg:87.89ms +step:1370/1680 train_time:120405ms step_avg:87.89ms +step:1371/1680 train_time:120495ms step_avg:87.89ms +step:1372/1680 train_time:120584ms step_avg:87.89ms +step:1373/1680 train_time:120674ms step_avg:87.89ms +step:1374/1680 train_time:120763ms step_avg:87.89ms +step:1375/1680 train_time:120853ms step_avg:87.89ms +step:1375/1680 val_loss:3.3427 train_time:120943ms step_avg:87.96ms +step:1376/1680 train_time:120961ms step_avg:87.91ms +step:1377/1680 train_time:121035ms step_avg:87.90ms +step:1378/1680 train_time:121127ms step_avg:87.90ms +step:1379/1680 train_time:121216ms step_avg:87.90ms +step:1380/1680 train_time:121304ms step_avg:87.90ms +step:1381/1680 train_time:121393ms step_avg:87.90ms +step:1382/1680 train_time:121480ms step_avg:87.90ms +step:1383/1680 train_time:121569ms step_avg:87.90ms +step:1384/1680 train_time:121656ms step_avg:87.90ms +step:1385/1680 train_time:121745ms step_avg:87.90ms +step:1386/1680 train_time:121833ms step_avg:87.90ms +step:1387/1680 train_time:121923ms step_avg:87.90ms +step:1388/1680 train_time:122016ms step_avg:87.91ms +step:1389/1680 train_time:122106ms step_avg:87.91ms +step:1390/1680 train_time:122196ms step_avg:87.91ms +step:1391/1680 train_time:122286ms step_avg:87.91ms +step:1392/1680 train_time:122375ms step_avg:87.91ms +step:1393/1680 train_time:122462ms step_avg:87.91ms +step:1394/1680 train_time:122551ms step_avg:87.91ms +step:1395/1680 train_time:122639ms step_avg:87.91ms +step:1396/1680 train_time:122727ms step_avg:87.91ms +step:1397/1680 train_time:122816ms step_avg:87.91ms +step:1398/1680 train_time:122906ms step_avg:87.92ms +step:1399/1680 train_time:122997ms step_avg:87.92ms +step:1400/1680 train_time:123087ms step_avg:87.92ms +step:1401/1680 train_time:123177ms step_avg:87.92ms +step:1402/1680 train_time:123267ms step_avg:87.92ms +step:1403/1680 train_time:123356ms step_avg:87.92ms +step:1404/1680 train_time:123444ms step_avg:87.92ms +step:1405/1680 train_time:123533ms step_avg:87.92ms +step:1406/1680 train_time:123621ms step_avg:87.92ms +step:1407/1680 train_time:123709ms step_avg:87.92ms +step:1408/1680 train_time:123798ms step_avg:87.92ms +step:1409/1680 train_time:123887ms step_avg:87.93ms +step:1410/1680 train_time:123978ms step_avg:87.93ms +step:1411/1680 train_time:124067ms step_avg:87.93ms +step:1412/1680 train_time:124157ms step_avg:87.93ms +step:1413/1680 train_time:124246ms step_avg:87.93ms +step:1414/1680 train_time:124335ms step_avg:87.93ms +step:1415/1680 train_time:124423ms step_avg:87.93ms +step:1416/1680 train_time:124512ms step_avg:87.93ms +step:1417/1680 train_time:124601ms step_avg:87.93ms +step:1418/1680 train_time:124690ms step_avg:87.93ms +step:1419/1680 train_time:124778ms step_avg:87.93ms +step:1420/1680 train_time:124867ms step_avg:87.93ms +step:1421/1680 train_time:124956ms step_avg:87.94ms +step:1422/1680 train_time:125045ms step_avg:87.94ms +step:1423/1680 train_time:125136ms step_avg:87.94ms +step:1424/1680 train_time:125225ms step_avg:87.94ms +step:1425/1680 train_time:125314ms step_avg:87.94ms +step:1426/1680 train_time:125403ms step_avg:87.94ms +step:1427/1680 train_time:125493ms step_avg:87.94ms +step:1428/1680 train_time:125581ms step_avg:87.94ms +step:1429/1680 train_time:125670ms step_avg:87.94ms +step:1430/1680 train_time:125758ms step_avg:87.94ms +step:1431/1680 train_time:125848ms step_avg:87.94ms +step:1432/1680 train_time:125938ms step_avg:87.95ms +step:1433/1680 train_time:126027ms step_avg:87.95ms +step:1434/1680 train_time:126116ms step_avg:87.95ms +step:1435/1680 train_time:126205ms step_avg:87.95ms +step:1436/1680 train_time:126296ms step_avg:87.95ms +step:1437/1680 train_time:126385ms step_avg:87.95ms +step:1438/1680 train_time:126474ms step_avg:87.95ms +step:1439/1680 train_time:126562ms step_avg:87.95ms +step:1440/1680 train_time:126650ms step_avg:87.95ms +step:1441/1680 train_time:126739ms step_avg:87.95ms +step:1442/1680 train_time:126828ms step_avg:87.95ms +step:1443/1680 train_time:126917ms step_avg:87.95ms +step:1444/1680 train_time:127006ms step_avg:87.95ms +step:1445/1680 train_time:127096ms step_avg:87.96ms +step:1446/1680 train_time:127185ms step_avg:87.96ms +step:1447/1680 train_time:127275ms step_avg:87.96ms +step:1448/1680 train_time:127363ms step_avg:87.96ms +step:1449/1680 train_time:127453ms step_avg:87.96ms +step:1450/1680 train_time:127541ms step_avg:87.96ms +step:1451/1680 train_time:127630ms step_avg:87.96ms +step:1452/1680 train_time:127718ms step_avg:87.96ms +step:1453/1680 train_time:127806ms step_avg:87.96ms +step:1454/1680 train_time:127896ms step_avg:87.96ms +step:1455/1680 train_time:127985ms step_avg:87.96ms +step:1456/1680 train_time:128074ms step_avg:87.96ms +step:1457/1680 train_time:128163ms step_avg:87.96ms +step:1458/1680 train_time:128251ms step_avg:87.96ms +step:1459/1680 train_time:128341ms step_avg:87.96ms +step:1460/1680 train_time:128430ms step_avg:87.97ms +step:1461/1680 train_time:128519ms step_avg:87.97ms +step:1462/1680 train_time:128609ms step_avg:87.97ms +step:1463/1680 train_time:128697ms step_avg:87.97ms +step:1464/1680 train_time:128786ms step_avg:87.97ms +step:1465/1680 train_time:128874ms step_avg:87.97ms +step:1466/1680 train_time:128962ms step_avg:87.97ms +step:1467/1680 train_time:129052ms step_avg:87.97ms +step:1468/1680 train_time:129141ms step_avg:87.97ms +step:1469/1680 train_time:129231ms step_avg:87.97ms +step:1470/1680 train_time:129319ms step_avg:87.97ms +step:1471/1680 train_time:129408ms step_avg:87.97ms +step:1472/1680 train_time:129498ms step_avg:87.97ms +step:1473/1680 train_time:129588ms step_avg:87.98ms +step:1474/1680 train_time:129676ms step_avg:87.98ms +step:1475/1680 train_time:129764ms step_avg:87.98ms +step:1476/1680 train_time:129853ms step_avg:87.98ms +step:1477/1680 train_time:129942ms step_avg:87.98ms +step:1478/1680 train_time:130030ms step_avg:87.98ms +step:1479/1680 train_time:130119ms step_avg:87.98ms +step:1480/1680 train_time:130210ms step_avg:87.98ms +step:1481/1680 train_time:130300ms step_avg:87.98ms +step:1482/1680 train_time:130389ms step_avg:87.98ms +step:1483/1680 train_time:130478ms step_avg:87.98ms +step:1484/1680 train_time:130568ms step_avg:87.98ms +step:1485/1680 train_time:130656ms step_avg:87.98ms +step:1486/1680 train_time:130745ms step_avg:87.98ms +step:1487/1680 train_time:130834ms step_avg:87.99ms +step:1488/1680 train_time:130923ms step_avg:87.99ms +step:1489/1680 train_time:131011ms step_avg:87.99ms +step:1490/1680 train_time:131100ms step_avg:87.99ms +step:1491/1680 train_time:131189ms step_avg:87.99ms +step:1492/1680 train_time:131278ms step_avg:87.99ms +step:1493/1680 train_time:131367ms step_avg:87.99ms +step:1494/1680 train_time:131456ms step_avg:87.99ms +step:1495/1680 train_time:131545ms step_avg:87.99ms +step:1496/1680 train_time:131635ms step_avg:87.99ms +step:1497/1680 train_time:131723ms step_avg:87.99ms +step:1498/1680 train_time:131812ms step_avg:87.99ms +step:1499/1680 train_time:131901ms step_avg:87.99ms +step:1500/1680 train_time:131989ms step_avg:87.99ms +step:1500/1680 val_loss:3.3134 train_time:132079ms step_avg:88.05ms +step:1501/1680 train_time:132097ms step_avg:88.01ms +step:1502/1680 train_time:132171ms step_avg:88.00ms +step:1503/1680 train_time:132263ms step_avg:88.00ms +step:1504/1680 train_time:132352ms step_avg:88.00ms +step:1505/1680 train_time:132440ms step_avg:88.00ms +step:1506/1680 train_time:132528ms step_avg:88.00ms +step:1507/1680 train_time:132616ms step_avg:88.00ms +step:1508/1680 train_time:132704ms step_avg:88.00ms +step:1509/1680 train_time:132791ms step_avg:88.00ms +step:1510/1680 train_time:132880ms step_avg:88.00ms +step:1511/1680 train_time:132968ms step_avg:88.00ms +step:1512/1680 train_time:133060ms step_avg:88.00ms +step:1513/1680 train_time:133150ms step_avg:88.00ms +step:1514/1680 train_time:133241ms step_avg:88.01ms +step:1515/1680 train_time:133330ms step_avg:88.01ms +step:1516/1680 train_time:133419ms step_avg:88.01ms +step:1517/1680 train_time:133508ms step_avg:88.01ms +step:1518/1680 train_time:133597ms step_avg:88.01ms +step:1519/1680 train_time:133685ms step_avg:88.01ms +step:1520/1680 train_time:133774ms step_avg:88.01ms +step:1521/1680 train_time:133862ms step_avg:88.01ms +step:1522/1680 train_time:133951ms step_avg:88.01ms +step:1523/1680 train_time:134041ms step_avg:88.01ms +step:1524/1680 train_time:134130ms step_avg:88.01ms +step:1525/1680 train_time:134220ms step_avg:88.01ms +step:1526/1680 train_time:134310ms step_avg:88.01ms +step:1527/1680 train_time:134399ms step_avg:88.02ms +step:1528/1680 train_time:134488ms step_avg:88.02ms +step:1529/1680 train_time:134577ms step_avg:88.02ms +step:1530/1680 train_time:134665ms step_avg:88.02ms +step:1531/1680 train_time:134754ms step_avg:88.02ms +step:1532/1680 train_time:134843ms step_avg:88.02ms +step:1533/1680 train_time:134932ms step_avg:88.02ms +step:1534/1680 train_time:135021ms step_avg:88.02ms +step:1535/1680 train_time:135110ms step_avg:88.02ms +step:1536/1680 train_time:135199ms step_avg:88.02ms +step:1537/1680 train_time:135289ms step_avg:88.02ms +step:1538/1680 train_time:135378ms step_avg:88.02ms +step:1539/1680 train_time:135467ms step_avg:88.02ms +step:1540/1680 train_time:135556ms step_avg:88.02ms +step:1541/1680 train_time:135644ms step_avg:88.02ms +step:1542/1680 train_time:135734ms step_avg:88.02ms +step:1543/1680 train_time:135822ms step_avg:88.02ms +step:1544/1680 train_time:135910ms step_avg:88.02ms +step:1545/1680 train_time:135999ms step_avg:88.03ms +step:1546/1680 train_time:136088ms step_avg:88.03ms +step:1547/1680 train_time:136178ms step_avg:88.03ms +step:1548/1680 train_time:136268ms step_avg:88.03ms +step:1549/1680 train_time:136357ms step_avg:88.03ms +step:1550/1680 train_time:136446ms step_avg:88.03ms +step:1551/1680 train_time:136535ms step_avg:88.03ms +step:1552/1680 train_time:136623ms step_avg:88.03ms +step:1553/1680 train_time:136711ms step_avg:88.03ms +step:1554/1680 train_time:136800ms step_avg:88.03ms +step:1555/1680 train_time:136888ms step_avg:88.03ms +step:1556/1680 train_time:136977ms step_avg:88.03ms +step:1557/1680 train_time:137066ms step_avg:88.03ms +step:1558/1680 train_time:137157ms step_avg:88.03ms +step:1559/1680 train_time:137246ms step_avg:88.03ms +step:1560/1680 train_time:137337ms step_avg:88.04ms +step:1561/1680 train_time:137426ms step_avg:88.04ms +step:1562/1680 train_time:137516ms step_avg:88.04ms +step:1563/1680 train_time:137604ms step_avg:88.04ms +step:1564/1680 train_time:137693ms step_avg:88.04ms +step:1565/1680 train_time:137781ms step_avg:88.04ms +step:1566/1680 train_time:137870ms step_avg:88.04ms +step:1567/1680 train_time:137959ms step_avg:88.04ms +step:1568/1680 train_time:138048ms step_avg:88.04ms +step:1569/1680 train_time:138137ms step_avg:88.04ms +step:1570/1680 train_time:138226ms step_avg:88.04ms +step:1571/1680 train_time:138315ms step_avg:88.04ms +step:1572/1680 train_time:138404ms step_avg:88.04ms +step:1573/1680 train_time:138494ms step_avg:88.04ms +step:1574/1680 train_time:138583ms step_avg:88.05ms +step:1575/1680 train_time:138673ms step_avg:88.05ms +step:1576/1680 train_time:138761ms step_avg:88.05ms +step:1577/1680 train_time:138849ms step_avg:88.05ms +step:1578/1680 train_time:138939ms step_avg:88.05ms +step:1579/1680 train_time:139028ms step_avg:88.05ms +step:1580/1680 train_time:139117ms step_avg:88.05ms +step:1581/1680 train_time:139206ms step_avg:88.05ms +step:1582/1680 train_time:139295ms step_avg:88.05ms +step:1583/1680 train_time:139383ms step_avg:88.05ms +step:1584/1680 train_time:139473ms step_avg:88.05ms +step:1585/1680 train_time:139562ms step_avg:88.05ms +step:1586/1680 train_time:139652ms step_avg:88.05ms +step:1587/1680 train_time:139740ms step_avg:88.05ms +step:1588/1680 train_time:139829ms step_avg:88.05ms +step:1589/1680 train_time:139918ms step_avg:88.05ms +step:1590/1680 train_time:140007ms step_avg:88.05ms +step:1591/1680 train_time:140096ms step_avg:88.06ms +step:1592/1680 train_time:140185ms step_avg:88.06ms +step:1593/1680 train_time:140273ms step_avg:88.06ms +step:1594/1680 train_time:140362ms step_avg:88.06ms +step:1595/1680 train_time:140451ms step_avg:88.06ms +step:1596/1680 train_time:140540ms step_avg:88.06ms +step:1597/1680 train_time:140631ms step_avg:88.06ms +step:1598/1680 train_time:140719ms step_avg:88.06ms +step:1599/1680 train_time:140808ms step_avg:88.06ms +step:1600/1680 train_time:140896ms step_avg:88.06ms +step:1601/1680 train_time:140985ms step_avg:88.06ms +step:1602/1680 train_time:141075ms step_avg:88.06ms +step:1603/1680 train_time:141163ms step_avg:88.06ms +step:1604/1680 train_time:141252ms step_avg:88.06ms +step:1605/1680 train_time:141341ms step_avg:88.06ms +step:1606/1680 train_time:141430ms step_avg:88.06ms +step:1607/1680 train_time:141518ms step_avg:88.06ms +step:1608/1680 train_time:141608ms step_avg:88.06ms +step:1609/1680 train_time:141697ms step_avg:88.07ms +step:1610/1680 train_time:141786ms step_avg:88.07ms +step:1611/1680 train_time:141875ms step_avg:88.07ms +step:1612/1680 train_time:141964ms step_avg:88.07ms +step:1613/1680 train_time:142054ms step_avg:88.07ms +step:1614/1680 train_time:142142ms step_avg:88.07ms +step:1615/1680 train_time:142231ms step_avg:88.07ms +step:1616/1680 train_time:142320ms step_avg:88.07ms +step:1617/1680 train_time:142408ms step_avg:88.07ms +step:1618/1680 train_time:142498ms step_avg:88.07ms +step:1619/1680 train_time:142587ms step_avg:88.07ms +step:1620/1680 train_time:142676ms step_avg:88.07ms +step:1621/1680 train_time:142766ms step_avg:88.07ms +step:1622/1680 train_time:142855ms step_avg:88.07ms +step:1623/1680 train_time:142944ms step_avg:88.07ms +step:1624/1680 train_time:143032ms step_avg:88.07ms +step:1625/1680 train_time:143121ms step_avg:88.07ms +step:1625/1680 val_loss:3.2897 train_time:143212ms step_avg:88.13ms +step:1626/1680 train_time:143230ms step_avg:88.09ms +step:1627/1680 train_time:143305ms step_avg:88.08ms +step:1628/1680 train_time:143400ms step_avg:88.08ms +step:1629/1680 train_time:143489ms step_avg:88.08ms +step:1630/1680 train_time:143578ms step_avg:88.08ms +step:1631/1680 train_time:143665ms step_avg:88.08ms +step:1632/1680 train_time:143753ms step_avg:88.08ms +step:1633/1680 train_time:143841ms step_avg:88.08ms +step:1634/1680 train_time:143930ms step_avg:88.08ms +step:1635/1680 train_time:144018ms step_avg:88.08ms +step:1636/1680 train_time:144106ms step_avg:88.08ms +step:1637/1680 train_time:144196ms step_avg:88.09ms +step:1638/1680 train_time:144287ms step_avg:88.09ms +step:1639/1680 train_time:144379ms step_avg:88.09ms +step:1640/1680 train_time:144469ms step_avg:88.09ms +step:1641/1680 train_time:144559ms step_avg:88.09ms +step:1642/1680 train_time:144647ms step_avg:88.09ms +step:1643/1680 train_time:144735ms step_avg:88.09ms +step:1644/1680 train_time:144823ms step_avg:88.09ms +step:1645/1680 train_time:144912ms step_avg:88.09ms +step:1646/1680 train_time:145000ms step_avg:88.09ms +step:1647/1680 train_time:145088ms step_avg:88.09ms +step:1648/1680 train_time:145179ms step_avg:88.09ms +step:1649/1680 train_time:145268ms step_avg:88.09ms +step:1650/1680 train_time:145359ms step_avg:88.10ms +step:1651/1680 train_time:145448ms step_avg:88.10ms +step:1652/1680 train_time:145538ms step_avg:88.10ms +step:1653/1680 train_time:145626ms step_avg:88.10ms +step:1654/1680 train_time:145715ms step_avg:88.10ms +step:1655/1680 train_time:145804ms step_avg:88.10ms +step:1656/1680 train_time:145892ms step_avg:88.10ms +step:1657/1680 train_time:145981ms step_avg:88.10ms +step:1658/1680 train_time:146069ms step_avg:88.10ms +step:1659/1680 train_time:146159ms step_avg:88.10ms +step:1660/1680 train_time:146247ms step_avg:88.10ms +step:1661/1680 train_time:146337ms step_avg:88.10ms +step:1662/1680 train_time:146427ms step_avg:88.10ms +step:1663/1680 train_time:146517ms step_avg:88.10ms +step:1664/1680 train_time:146606ms step_avg:88.10ms +step:1665/1680 train_time:146695ms step_avg:88.11ms +step:1666/1680 train_time:146783ms step_avg:88.11ms +step:1667/1680 train_time:146873ms step_avg:88.11ms +step:1668/1680 train_time:146961ms step_avg:88.11ms +step:1669/1680 train_time:147050ms step_avg:88.11ms +step:1670/1680 train_time:147138ms step_avg:88.11ms +step:1671/1680 train_time:147227ms step_avg:88.11ms +step:1672/1680 train_time:147317ms step_avg:88.11ms +step:1673/1680 train_time:147406ms step_avg:88.11ms +step:1674/1680 train_time:147496ms step_avg:88.11ms +step:1675/1680 train_time:147585ms step_avg:88.11ms +step:1676/1680 train_time:147674ms step_avg:88.11ms +step:1677/1680 train_time:147762ms step_avg:88.11ms +step:1678/1680 train_time:147851ms step_avg:88.11ms +step:1679/1680 train_time:147940ms step_avg:88.11ms +step:1680/1680 train_time:148028ms step_avg:88.11ms +step:1680/1680 val_loss:3.2790 train_time:148119ms step_avg:88.17ms +peak memory allocated: 30760 MiB reserved: 46054 MiB diff --git a/train_gpt.py b/train_gpt.py index f345750dc..95d474470 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -413,10 +413,10 @@ class Muon(torch.optim.Optimizer): This hyper-optimized class has faster execution time than the current impl of Adam for small params Custom distributed sizing: - The model stores all attn and mlp weights in the same shape, and then updates the view as - needed on the forward pass. This enables attn and mlp weights to be contained within the same - dist.reduce_scatter_tensor() call. The model architecture has been customized to enable - (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. + The model stores all attn and mlp weights in the same shape, and then updates the view as + needed on the forward pass. This enables attn and mlp weights to be contained within the same + dist.reduce_scatter_tensor() call. The model architecture has been customized to enable + (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding on mlp and attn. The scheduling is: 1. reduce scatter smear_gate (1 param 7 padding params) 2. reduce scatter attn_gate (10 params 6 padding params) @@ -456,10 +456,10 @@ def generate_standard_param_groups(self, params): group_params = [p for p in non_attn_subset if p.shape == size] param_groups.append(dict(params=group_params)) return param_groups - + def generate_custom_param_groups(self, params): """ - Implementation requires that a single GPU does not receive both attn + Implementation requires that a single GPU does not receive both attn and mlp params when a param group is split across GPUs. """ module_ranks = { @@ -614,7 +614,7 @@ def step(self): for p in params[module_idx:module_idx+chunk_size]: assert getattr(params[module_idx],'module','none')=='attn' batch = 4 * original_shape[0] - d1 = original_shape[1] + d1 = original_shape[1] d2 = original_shape[2] // 4 batched = batched_update_grads.view(batch, d1, d2) v_chunk = newton_schulz_triton(batched) @@ -777,7 +777,7 @@ def __init__(self, head_dim, max_seq_len): self.head_dim = head_dim self.max_seq_len = max_seq_len self.reset() - + def reset(self): angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) @@ -1020,10 +1020,15 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_sho skip_connections.append(x) x = norm(x) - logits = self.lm_head(x).float() + logits = self.lm_head(x) # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0) + logits_for_loss = logits.float() if not self.training else logits + loss = F.cross_entropy( + logits_for_loss.view(-1, logits_for_loss.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) return loss # ----------------------------------------------------------------------------- @@ -1065,12 +1070,12 @@ def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False) def _load(self): self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() self.ready.set() - + def start(self): self.ready.clear() self.thread = threading.Thread(target=self._load) self.thread.start() - + def get(self): if self.thread: self.ready.wait() @@ -1113,17 +1118,17 @@ def __init__(self, file_iter, world_size: int = 1): self.thread = None self.data = None self.ready = threading.Event() - + def _load(self): tokens = _load_data_shard(next(self.file_iter)) self.data = (tokens, BOSFinder(tokens, self.world_size)) self.ready.set() - + def start(self): self.ready.clear() self.thread = threading.Thread(target=self._load) self.thread.start() - + def get(self): if self.thread: self.ready.wait() @@ -1390,7 +1395,7 @@ def get_ws(step: int): assert args.val_tokens % args.val_batch_size == 0 val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) - val_loss = 0 + val_loss = torch.zeros((), device=device, dtype=torch.float32) with torch.no_grad(): for _ in range(val_steps): inputs, targets, cum_seqlens = next(val_loader)